Llama 3.1 405B发布之后,圈子里最常听到的一句感慨是:跑不动,根本跑不动。模型效果固然惊艳,但无论是训练还是部署,对资源的消耗都让绝大多数团队望而却步。于是,知识蒸馏——把大模型的能力压缩到小模型里——几乎成了顺理成章的选择。
趁着这股热度做了一番调研,重点聚焦两篇很有代表性的工作:清华的《MiniLLM:Knowledge Distillation of Large Language Models》和Meta的《Distilling System2 into System1》。这两篇恰好分别代表了白盒蒸馏和黑盒蒸馏两大路线,放在一起看,能更清楚地理解当前蒸馏技术的走向和取舍。
MiniLLM蒸馏
Motivation
先聊聊前向KL散度和反向KL散度的区别。一句话概括:前向KL倾向于mean-seeking,反向KL则倾向于mode-seeking。
这个差异来自KL散度本身的非对称性——只有在两个分布完全相等时,前向和后向才等价。具体来看:

对于前向KL:当p(x)较大的时候,qθ(x)也必须比较大,而且要比p(x)相对更大,否则整项损失降不下去;但当p(x)较小的时候,p(x)在log外面趋于0,主导了整体损失,这时候qθ(x)怎么取都影响不大。所以优化FKL时,qθ(x)倾向于覆盖p(x)的所有mode,哪怕代价是高估p(x)中概率极低的部分——这就是图中橙色区域代表的问题。

再看反向KL:当qθ(x)较大时,为了降低损失,p(x)也必须足够大——也就是说,p(x)概率最高的那些mode,必须和qθ(x)概率最高的区域对齐;而p(x)概率很小的部分,qθ(x)必须把概率压到接近0。当qθ(x)本身为0时,p(x)取什么值都不影响优化。于是RKL最终拟合的是p(x)中概率最大的那一部分——这正是图中绿色区域的特点。
MiniLLM的整体方案如下图所示:

RKL与逆强化学习的等价推导
论文中另一个值得关注的视角,是将RKL与逆强化学习做了类比,并给出了严谨的数学推导。这里直接贴出论文中的关键公式(公式编号保持论文原序,方便对照阅读):
这个类比非常有意思:RKL在数学形式上约等于逆强化学习,而FKL则等价于模仿学习。在实践和理论两方面,逆强化学习的效果通常都优于模仿学习——虽然训练难度更大,但泛化性能和理论上限都更高。从这个角度看,MiniLLM选择RKL路线是有理论底气的。
实际训练怎么做
理论归理论,落地才是关键。MiniLLM的实际训练流程和RLHF非常相似:教师模型在训练过程中只做推理,提供分布信号来指导学生模型的优化。作者还提供了一种基于ranking loss的简化方案,对比传统BERT时代的蒸馏方法,都有稳定的提升。
不妨先看看核心的蒸馏损失计算代码(来自官方实现):
def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits):
with torch.no_grad():
teacher_model.eval()
teacher_outputs = teacher_model(**model_batch, use_cache=False)
teacher_logits = teacher_outputs.logits
if args.model_parallel:
distil_losses = mpu.parallel_soft_cross_entropy_loss(logits.float(), teacher_logits.float())
distil_losses = distil_losses.view(-1)
loss_mask = no_model_batch["loss_mask"].view(-1)
distil_loss = (distil_losses * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
inf_mask = torch.isinf(logits)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (no_model_batch["label"] != -100).int()
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
实际训练中通常还会加上SFT数据的loss,防止模型跑偏——这一点和RLHF中reference model的作用类似。
一个关键的实操注意事项:教师模型和学生模型必须是同源的,也就是说它们得共用相同的tokenizer。对国产模型来说,Qwen、DeepSeek、Yi等系列都有相同tokenizer、不同尺寸的模型可供选择,这点倒是不用太担心。
System2到System1蒸馏
人类认知系统很有意思,它包含两种推理模式:System1是无意识的、快速的直觉判断(快思考),System2则是处理复杂问题时需要深思熟虑的慢思考。大模型领域也是一样——CoT(思维链)、RaR(重述并回答)等中间推理步骤可以类比为System2的深思熟虑,效果好,但延迟高,难以直接用于生产。
其实,类似的工作早在2023年GPT-4发布后就有人在做了——大量团队通过黑盒蒸馏将GPT-4的能力迁移到自家小模型上。更早的还有Llama 2中的Ghost Attention,在每一轮对话中都加入System Prompt来提升多轮指令跟随能力。这些本质上都属于System2到System1的蒸馏。
这篇论文的独特之处在于:它显式地提出了将System2的推理能力蒸馏到System1中的范式,并做了大量实验来验证效果。可以把它理解为一种高质量的数据合成方法——通过这些合成数据进行指令微调,从而提升System1本身的推理能力。
论文用以下几个公式对System1和System2做了形式化定义:
公式3虽然能生成大量训练数据,但质量参差不齐。论文主要通过一致性标准来过滤:
- 输出一致性:输入不变,对输出做N次采样,投票选最优,少数服从多数
- 输入扰动下的一致性:输出不变,对输入增加扰动(比如调整选择题的选项顺序),如果答案不一致则过滤掉
当然,实际应用中应该还有更精细的过滤策略,这里只是抛砖引玉。
论文提出了四种具体的蒸馏方式,效果各异。直接看Prompt可能是最直观的理解方式:
RePhrase And Respond Distillation
Prompt示例:
"{question}"
Rephrase and expand the question, and respond.
让模型先把问题改写一遍——改写过程中会引入更丰富的上下文信息,然后再回答。这样模型能更好地用自己的知识内化问题、给出答案。
System2 Attention Distillation
让大模型先过滤掉输入中的无效信息——比如有偏见的描述、不相干的上下文——然后再在改写后的干净版本上回答问题。
Branch-Solve-Merge Distillation

Chain of Thought Distillation

论文对这四种System2方法逐一做蒸馏到System1的实验,结果就不一一贴了——毕竟效果不好也不会发论文。整体结论是:有效,但领域差异明显。比如RaR蒸馏在澄清指令相关的任务上表现出色,S2A蒸馏能有效处理有偏任务,Branch-Solve-Merge蒸馏则适合作为LLM-Judge评估任务。但一个尴尬的现实是:在复杂推理任务上的蒸馏,目前效果还不太理想。这恐怕也是整个行业的共识,值得持续投入研究。
说到底,无论是黑盒蒸馏还是白盒蒸馏,今天这项技术都是在干一件事:把更大模型的知识密度压缩到更小的模型里去。资源有限但需求旺盛的现实下,这可能是让大模型技术真正落地的最务实路径之一。期待这个方向后续会有更多有意思的突破。
