自回归模型的局限性与DeepSeek-V3中的多令牌预测
为什么单令牌预测限制了模型能力
传统语言模型的训练目标说起来很直观:给定前t个令牌,预测第t+1个令牌。这种自回归分解方式固然优雅且高效,但一个根本性的问题始终存在——模型只接收即时下一个令牌的预测训练信号,从未明确学习过提前规划多个步骤的能力。

举个例子,生成句子“猫坐在垫子上,因为它很舒服”。当模型预测到“因为”这个词时,其实已经在潜意识里考虑句子将如何完成——包括从句结构、代词指代以及最终结论。但问题在于,仅靠单令牌预测,并没有明确的梯度信号去鼓励这种前向规划。直观地说,模型缺少一种“远见”的激励机制。
这种局限在需要长期连贯性的任务中尤为突出,比如故事生成、多段落推理或者代码生成。模型很容易写出局部流畅但全局自相矛盾的内容——局部看着很顺,整体却成了逻辑上的“拆东墙补西墙”。
DeepSeek-V3中的多令牌预测:提前预测多个令牌
多令牌预测(Multi-Token Prediction,简称MTP)的思路非常清晰:通过添加辅助预测头,让模型能够同时预测未来的多个令牌。除了标准的位置t+1预测,还要同时预测位置t+2、t+3等后续令牌。
完整的训练目标函数如下:
L = Lmain + Σ{k=1}^{n} λ_k * L_k
其中n是预测的未来令牌数,λ_k是加权系数(通常随距离增大而递减)。
DeepSeek-V3架构:多令牌预测头详解
要实现多令牌预测,架构上需要做出相应补充。不能直接复用主语言建模头去预测未来,关键是要依赖中间令牌的信息。
预测头的结构思路:对于预测k个令牌后的位置,需要组合两个信息源:
- Transformer在位置i的隐藏表示h_i
- 位置i+k-1处令牌的嵌入表示e_{i+k-1}
组合方式很简单:combined = Wcombine([norm(h_i), norm(e{i+k-1})]) + b
然后通过一个轻量级的Transformer(注意力层和前馈层)处理,再投影到词汇表,生成预测logits。
梯度视角下的多令牌预测
从优化角度看,MTP能提供更丰富的梯度信号。标准训练中只有隐藏表示hi接收来自预测x{i+1}的梯度。而使用MTP后,hi还会接收来自预测x{i+k}的梯度。这些额外梯度会鼓励h_i编码的信息不仅跟下一个令牌相关,还跟多个未来令牌相关。
这相当于给训练过程增加了一个隐式正则化器,约束学习到的表示更加结构化、更具前瞻性,全局连贯性自然也就随之提升。
训练与推理阶段的差异
训练阶段:所有预测并行计算,使用真实令牌信息,误差不会累积。
推理阶段:MTP头通常不用于自回归生成——它的核心作用是在训练阶段改善学习到的表示。推理时仍然使用标准单令牌预测方式,这样能保证部署时的计算效率。
损失权重设置
对于预测深度k,权重通常采用指数衰减:λ_k = γ^(k-1),其中γ∈(0,1)。举个例子,γ=0.5时,深度1权重1.0,深度2权重0.5,深度3权重0.25。离当前越远,权重越小。
多令牌预测头的代码实现
class MultiTokenPredictionHead(nn.Module):
"""多令牌预测头
每个头预测特定未来位置的令牌
组合前一个隐藏状态与未来令牌嵌入
"""
def __init__(self, config: DeepSeekConfig, depth: int):
super().__init__()
self.depth = depth
self.n_embd = config.n_embd
# 组合前一个隐藏状态与未来令牌嵌入
self.combine_proj = nn.Linear(2 * config.n_embd, config.n_embd, bias=config.bias)
# 归一化层
self.norm1 = RMSNorm(config.n_embd)
self.norm2 = RMSNorm(config.n_embd)
# Transformer组件(每个头的轻量级Transformer)
self.attn = MultiheadLatentAttention(config)
self.mlp = MixtureOfExperts(config)
self.attn_norm = RMSNorm(config.n_embd)
self.mlp_norm = RMSNorm(config.n_embd)
def forward(self, prev_hidden, future_token_embed):
"""
参数:
prev_hidden: [B, T, D] - 前一层的隐藏状态
future_token_embed: [B, T, D] - 未来令牌的嵌入
返回:
hidden: [B, T, D] - 处理后的隐藏状态
"""
# 归一化输入
prev_norm = self.norm1(prev_hidden)
future_norm = self.norm2(future_token_embed)
# 组合表示
combined = torch.cat([prev_norm, future_norm], dim=-1)
hidden = self.combine_proj(combined)
# 通过轻量级Transformer处理
hidden = hidden + self.attn(self.attn_norm(hidden))
moe_out, _ = self.mlp(self.mlp_norm(hidden))
hidden = hidden + moe_out
return hidden
核心Transformer中集成多令牌预测
训练时,MTP头会被集成到主模型里,操作流程大致如下:
- 主预测:将最终隐藏状态投影到词汇表,预测下一个令牌
- 深度1预测:获取真实令牌嵌入,通过头1处理,再投影预测
- 深度2预测:基于头1输出继续处理
关键洞察是,头之间存在链式依赖关系,形成一种层次化结构。就像搭积木,每一层都依赖前一层的结果。
多令牌预测的优势
研究已经表明,MTP带来了多个实证收益:
- 改进的连贯性:生成更全局连贯的文本
- 更好的规划能力:在故事写作或代码生成等任务中,帮助模型做出前向兼容的选择
- 更快的收敛速度:额外训练信号加速了学习过程
- 正则化效果:防止过拟合,鼓励表示支持多个相关目标
总结
传统自回归模型依赖单令牌预测,这种策略虽然有效,但多少有些“短视”。MTP通过让模型同时预测多个令牌,巧妙地解决了这一局限——不仅加速了训练和推理,还丰富了上下文理解能力。这项创新不光是效率上的提升,更在理论和实证层面为语言模型的能力边界拓展提供了全新思路。
