本次查询:Gradient Checkpointing
中文解释:梯度检查点
常见场景:大模型训练与推理 / 长序列处理 / 硬件资源受限环境
一句话解释
Gradient Checkpointing是一种深度学习训练策略,训练时只保存部分中间层的激活值(检查点),其余未保存的激活值在反向传播时根据已保存值和原始输入重新计算,从而用额外计算量换取显存占用的大幅降低。
为什么会被关注
随着Transformer、大语言模型等模型规模急速增长,GPU显存成为训练瓶颈。传统做法需存储所有中间激活值,显存极易耗尽。Gradient Checkpointing让工程师能在不增加硬件成本的前提下,训练更深的模型或处理更长的序列,因此成为大模型训练标配技术。
核心逻辑
在正向传播中,神经网络会产生活化值,用于反向传播计算梯度。Gradient Checkpointing将这些激活值分段保存,只保留关键“检查点”的完整状态。反向传播时,检查点之间的片段需要从最近检查点重新执行正向计算来恢复被丢弃的激活值。这本质是“用时间换空间”:每个检查点之间的计算量翻倍,但显存占用从O(L)降低到O(sqrt(L))或O(log L)。
常见场景
训练超长序列的Transformer模型(如GPT、BERT)时,大模型的层数深或序列长度超过4096,显存会快速占满。此外,在多卡并行训练中,每张卡的显存限制也通过Gradient Checkpointing缓解。推理时若需大batch,也可用类似策略。
容易混淆的点
Gradient Checkpointing不等于梯度累积(gradient accumulation),后者是多次小batch更新,缓解显存但增加通信;前者是单次前向中节约激活值存储。它也不等于模型并行或流水线并行,虽然常配合使用。另外,检查点策略会小幅延长训练时间(约20-30%),并非所有场景都适用,需权衡显存与速度。
