本次查询:Activation Recomputation
中文解释:激活重计算
常见场景:大模型训练 / 深度学习框架优化
一句话解释
Activation Recomputation(激活重计算)是一种显存优化技术,在前向传播时只保留部分关键激活值,丢弃大部分中间结果,然后在反向传播时重新计算被丢弃的激活值,从而大幅度降低显存占用。
为什么会被关注
随着大模型参数规模达到千亿甚至万亿级别,显存瓶颈成为训练时的主要限制。传统做法需要保存所有前向计算的中间激活值用于反向传播,这会使显存需求随模型深度线性增长。激活重计算通过用计算换内存,使得单卡可以训练更大的模型或使用更大的批量,成为各大AI框架(如PyTorch、Megatron-LM)的标准功能。
核心逻辑
在前向传播过程中,每个网络层都会产生激活值(中间特征)。为了反向传播计算梯度,通常需要这些激活值。激活重计算策略会标记某些层或计算区域,在前向完成后立即释放其激活值。反向传播时,从最近的检查点(Checkpoint)重新执行正向计算,恢复需要的激活值。这样做增加了约30%至50%的计算开销,但可能减少60%至80%的显存占用。
常见场景
大规模Transformer模型训练(如GPT、LLaMA、BERT)中,激活重计算常与张量并行、流水线并行结合使用。在PyTorch中可通过`torch.utils.checkpoint`接口轻松启用。训练超长序列(如8192以上)时,激活重计算几乎必不可少。此外,在显存有限的消费级显卡上微调大模型时,此技术也被广泛采用。
容易混淆的点
激活重计算(Activation Recomputation)和梯度检查点(Gradient Checkpointing)常被混用,实际上两者是同一概念的不同叫法。容易与“显存交换”(将数据换到CPU内存)混淆,但后者依靠PCIe传输,延迟更高。另外,它并非减少计算量,反而是增加计算量,只是让原本不能训练的大模型变得可训练。
