一个拥有 2 亿参数规模的深度学习模型,按照 fp32 精度计算,理论上仅需 800 MB 显存。然而,为什么你手头那块 24 GB 的 GPU 转眼间就被占满?原因其实并不复杂:模型参数仅仅是训练期间消耗 GPU 显存的七种关键因素之一。只有搞清楚这七个要素,你才能从“凭感觉猜测”转变为“依据工程原理”进行精准判断。
GPU 显存的七大消耗来源
当你执行 loss.backward() 与 optimizer.step() 时,GPU 内部究竟存储了哪些数据?
- 模型参数——即网络权重本身
- 梯度——与参数数量一致,每个参数对应一个梯度值
- 优化器状态——例如 Adam 优化器会为每个参数额外存储 2 个张量(m 和 v)
- 激活值——每一层的输出结果,反向传播时需保留输入数据
- 输入批次——加载到 GPU 上的训练数据
- CUDA 工作区——内核临时空间与 cuDNN 选择的缓存区域
- 显存碎片——已分配但因块间间隙而无法有效利用的显存空间
以使用 Adam 优化器训练的 2 亿参数 fp32 模型为例,我们来算一笔明细账:
- 参数:800 MB
- 梯度:800 MB(与参数大小相同)
- Adam 状态(m 和 v):1600 MB(参数量的 2 倍)
- 激活值:差异较大,通常为参数量的 2–10 倍
- 输入批次:取决于批量大小设置
- CUDA 工作区:500 MB–1 GB
- 显存碎片:占总量的 5%–20%
因此,保守估计下,一个“理论上”仅需 800 MB 显存的模型,实际占用往往达到 5–8 GB。这就是理论值与实际值之间巨大差距的根源所在。
如何准确测量显存使用情况
PyTorch 提供了相当精确的显存可见性机制,关键在于知道从何处查看。
import torch
# PyTorch 为张量实际分配的 GPU 显存量
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
# PyTorch 从 CUDA 预留的显存量(包含未使用部分)
reserved = torch.cuda.memory_reserved() / 1024**3 # GB
# 上次重置以来的峰值分配量
peak = torch.cuda.max_memory_allocated() / 1024**3 # GB
# 重置峰值计数器
torch.cuda.reset_peak_memory_stats()
allocated 与 reserved 之间的差值即为显存碎片量。如果 allocated 为 5 GB、reserved 为 8 GB,就意味着有 3 GB 显存是 PyTorch 已申请但无法高效利用的。
print(torch.cuda.memory_summary())
这条命令能够按照分配器内存池输出完整的显存分类统计信息——大小分配对比、当前值与峰值,各项明细一目了然。在完成一步训练后调用,可以清晰看出显存究竟流向了哪里。
大多数人不知道的杀手级调试功能
PyTorch 还支持记录每次显存分配,并以时间线形式进行可视化呈现:
torch.cuda.memory._record_memory_history(max_entries=100_000)
# 执行一步训练
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
# 保存快照
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
将生成的 pickle 文件上传至 https://pytorch.org/memory_viz,你会看到一个交互式可视化界面,清晰展示每次分配、每次释放以及触发它们的完整调用栈。借助这一工具,只需几分钟就能定位到用 print 语句排查需要耗费数天的 OOM 错误。
三种行之有效的显存优化方法
能够测量,才能进行优化。以下按影响程度从大到小排列:
1. 梯度检查点(Gradient Checkpointing)——以计算时间换取显存空间
激活值通常是显存消耗的最大来源。梯度检查点技术在反向传播时重新计算激活值,而非将其全部存储下来。
from torch.utils.checkpoint import checkpoint
class MyBlock(nn.Module):
def forward(self, x):
return checkpoint(self._forward, x, use_reentrant=False)
def _forward(self, x):
# 此处为耗时操作
return x
典型节省幅度:激活值显存减少 40%–60%。代价是反向传播速度降低 20%–30%。
2. 混合精度训练(Mixed Precision Training)——显存减半,精度几乎无损
from torch.amp import autocast, GradScaler
scaler = GradScaler('cuda')
with autocast('cuda', dtype=torch.float16):
output = model(x)
loss = criterion(output, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
激活值、梯度以及大部分运算使用 fp16(每个值占用 2 字节,而非 4 字节),参数和优化器状态则保持 fp32 以保证数值稳定性。典型节省幅度:总显存减少 30%–50%。fp16 运算在现代 GPU 上速度更快,训练过程通常也会随之加速。
3. 优化器的合理选择
Adam 优化器为每个参数额外存储 2 个张量。对于 fp32 精度的 10 亿参数模型,仅优化器状态就需要占用 8 GB 显存。以下是一些替代方案:
- SGD with momentum:每个参数额外存储 1 个张量(Adam 开销的一半)
- AdamW with bnb.optim.AdamW8bit:以 8 位精度存储优化器状态,显存占用减少 4 倍,精度损失极小
- Lion:显存占用与 SGD 相当,收敛效果通常接近 Adam
对于超过 10 亿参数规模的大模型,优化器的选择可能直接决定训练能否在现有硬件上顺利跑起来。
分布式系统领域有句经典名言:无法测量的东西,就无法优化。然而,大多数 PyTorch 团队往往完全跳过了测量步骤:遇到 OOM 就简单粗暴地缩小批量大小,然后继续训练。但 GPU 显存资源十分昂贵,如果你认真分析过实际的显存使用情况,就能将显存占用减半,同时把批量大小翻倍——这通常意味着更快的训练速度与更优的梯度估计质量。
