Transformer 模型结构详解:从词嵌入到解码器
在逐一梳理了各个核心组件之后,接下来将它们完整组装起来,便构成了 Transformer 模型的整体架构。上图源自论文《Attention is all you need》的经典配图,但需要特别说明的是:图中 LayerNorm 放置在 Attention 之后,即所谓的“Post-Norm”结构;而在论文官方发布的源码中,LayerNorm 实际上位于 Attention 之前,即采用的“Pre-Norm”方案。从工程实践来看,Pre-Norm 能够使 loss 更加稳定,因此当前主流的大语言模型(LLM)普遍采用 Pre-Norm 策略——先完成层归一化,再进入 Attention 层处理,这样输入信号的稳定性更高。相比之下,Post-Norm 容易导致 Attention 层的输出波动较大。
class Transformer(nn.Module): '''整体模型'''def __init__(self, args):super().__init__()# 必须输入词表大小和 block sizeassert args.vocab_size is not Noneassert args.block_size is not Noneself.args = argsself.transformer = nn.ModuleDict(dict(wte = nn.Embedding(args.vocab_size, args.n_embd),wpe = PositionalEncoding(args),drop = nn.Dropout(args.dropout),encoder = Encoder(args),decoder = Decoder(args),))# 最后的线性层,输入是 n_embd,输出是词表大小self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)# 初始化所有的权重self.apply(self._init_weights)# 查看所有参数的数量print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))'''统计所有参数的数量'''def get_num_params(self, non_embedding=False):# non_embedding: 是否统计 embedding 的参数n_params = sum(p.numel() for p in self.parameters())# 如果不统计 embedding 的参数,就减去if non_embedding:n_params -= self.transformer.wte.weight.numel()return n_params'''初始化权重'''def _init_weights(self, module):# 线性层和 Embedding 层初始化为正则分布if isinstance(module, nn.Linear):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)if module.bias is not None:torch.nn.init.zeros_(module.bias)elif isinstance(module, nn.Embedding):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)'''前向计算函数'''def forward(self, idx, targets=None):# 输入为 idx,维度为 (batch size, sequence length, 1);targets 为目标序列,用于计算 lossdevice = idx.deviceb, t = idx.size()assert t <= self.args.block_size, f"不能计算该序列,该序列长度为 {t}, 最大序列长度只有 {self.args.block_size}"# 通过 self.transformer# 首先将输入 idx 通过 Embedding 层,得到维度为 (batch size, sequence length, n_embd)print("idx",idx.size())# 通过 Embedding 层tok_emb = self.transformer.wte(idx)print("tok_emb",tok_emb.size())# 然后通过位置编码pos_emb = self.transformer.wpe(tok_emb) # 再进行 Dropoutx = self.transformer.drop(pos_emb)# 然后通过 Encoderprint("x after wpe:",x.size())enc_out = self.transformer.encoder(x)print("enc_out:",enc_out.size())# 再通过 Decoderx = self.transformer.decoder(x, enc_out)print("x after decoder:",x.size())if targets is not None:# 训练阶段,如果我们给了 targets,就计算 loss# 先通过最后的 Linear 层,得到维度为 (batch size, sequence length, vocab size)logits = self.lm_head(x)# 再跟 targets 计算交叉熵loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)else:# 推理阶段,我们只需要 logits,loss 为 None# 取 -1 是只取序列中的最后一个作为输出logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dimloss = Nonereturn logits, loss
上述代码完整实现了 Transformer 的 __init__、forward 等核心方法,并包含了参数统计与权重初始化功能。下面我们将逐一拆解其关键结构。
=== Transformer整体架构概览 ===1. Transformer的核心组成部分:输入层: - wte: 词嵌入层(Token Embedding) - wpe: 位置编码层(Positional Encoding) - drop: Dropout层编码器: - encoder: 编码器(Encoder)解码器: - decoder: 解码器(Decoder)输出层: - lm_head: 线性层(输出词表概率)2. 数据流动路径:输入 idx (batch_size, seq_len) ↓ wte: 词嵌入 (batch_size, seq_len, n_embd) ↓ wpe: 位置编码 (batch_size, seq_len, n_embd) ↓ drop: Dropout (batch_size, seq_len, n_embd) ↓ encoder: 编码器 (batch_size, seq_len, n_embd) ↓ decoder: 解码器 (batch_size, seq_len, n_embd) ↓ lm_head: 线性层 (batch_size, seq_len, vocab_size) ↓ logits (batch_size, seq_len, vocab_size)
首先观察 __init__ 方法。第一步是参数校验,确保 vocab_size(词表大小)和 block_size(最大序列长度)均已正确配置。随后使用 ModuleDict 按字典形式创建各组件:wte(词嵌入层)、wpe(位置编码层)、drop(Dropout 层)、encoder(编码器)与 decoder(解码器)。紧跟其后的是一个线性层 lm_head,其作用是将 Transformer 的隐状态映射为词表大小的概率分布,且 bias 参数设置为 False。所有可学习参数均采用正态分布初始化(mean=0.0, std=0.02)。最后打印总参数量,通常以百万(M)为单位呈现。
=== __init__方法详细解析 ===1. 参数检查:assert args.vocab_size is not None assert args.block_size is not None作用: - 确保 vocab_size(词汇表大小)已设置 - 确保 block_size(最大序列长度)已设置 - 若未设置,会触发 AssertionError 报错2. 构建组件:self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(args.vocab_size, args.n_embd), wpe = PositionalEncoding(args), drop = nn.Dropout(args.dropout), encoder = Encoder(args), decoder = Decoder(args), ))各组件说明: - wte: 词嵌入层 * 输入: token索引 (batch_size, seq_len) * 输出: 词向量 (batch_size, seq_len, n_embd) * 参数量: vocab_size × n_embd- wpe: 位置编码层 * 输入: 词向量 (batch_size, seq_len, n_embd) * 输出: 添加位置编码后的向量 (batch_size, seq_len, n_embd)- drop: Dropout层 * 输入: 向量 (batch_size, seq_len, n_embd) * 输出: 经 Dropout 后的向量 (batch_size, seq_len, n_embd) * 作用: 防止过拟合- encoder: 编码器 * 输入: 向量 (batch_size, seq_len, n_embd) * 输出: 编码后的特征向量 (batch_size, seq_len, n_embd)- decoder: 解码器 * 输入: 向量 + 编码器输出 * 输出: 解码后的向量 (batch_size, seq_len, n_embd)3. 输出层:self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)说明: - 输入: 解码器输出 (batch_size, seq_len, n_embd) - 输出: 词表概率 (batch_size, seq_len, vocab_size) - bias=False: 不包含偏置项 - 参数量: n_embd × vocab_size4. 权重初始化:self.apply(self._init_weights)说明: - 对所有线性层与 Embedding 层执行初始化 - 采用正态分布,mean=0.0, std=0.025. 参数统计:print('number of parameters: %.2fM' % (self.get_num_params()/1e6,))说明: - 统计全部参数数量 - 除以 1e6 转换为百万单位(M) - 例如: 10M 表示 1000 万个参数
接下来分析 forward 方法。输入 idx 为 token 索引序列,形状为 (batch_size, seq_len)。若提供了 targets,则在训练阶段计算交叉熵损失;推理阶段仅保留最后一个位置的 logits 用于生成。整个数据流动流程清晰:词嵌入 → 位置编码 → Dropout → 编码器 → 解码器 → 输出层。
=== forward方法详细解析 ===1. 输入参数:def forward(self, idx, targets=None):参数: - idx: 输入序列 * 形状: (batch_size, seq_len) * 内容: token索引 - targets: 目标序列(可选) * 形状: (batch_size, seq_len) * 内容: 目标token索引 * 用途: 计算训练损失2. 参数检查:device = idx.device b, t = idx.size() assert t <= self.args.block_size说明: - device: 获取当前设备(CPU或GPU) - b: batch_size(批次大小) - t: seq_len(序列长度) - 断言序列长度不超过最大限制3. 词嵌入:tok_emb = self.transformer.wte(idx)数据变化: - 输入: idx (batch_size, seq_len) - 输出: tok_emb (batch_size, seq_len, n_embd)示例: - idx: [[1, 2, 3], [4, 5, 6]] - tok_emb: [[[0.1, 0.2, ...], [0.3, 0.4, ...], [0.5, 0.6, ...]], ...] - 每个 token 索引被转换为 n_embd 维的稠密向量4. 位置编码:pos_emb = self.transformer.wpe(tok_emb)数据变化: - 输入: tok_emb (batch_size, seq_len, n_embd) - 输出: pos_emb (batch_size, seq_len, n_embd)作用: - 为每个位置注入位置信息 - 实现方式: tok_emb + 位置编码向量 = pos_emb5. Dropout:x = self.transformer.drop(pos_emb)数据变化: - 输入: pos_emb (batch_size, seq_len, n_embd) - 输出: x (batch_size, seq_len, n_embd)作用: - 随机丢弃部分神经元 - 有效防止过拟合6. 编码器:enc_out = self.transformer.encoder(x)数据变化: - 输入: x (batch_size, seq_len, n_embd) - 输出: enc_out (batch_size, seq_len, n_embd)作用: - 对输入序列进行编码 - 提取深层次特征7. 解码器:x = self.transformer.decoder(x, enc_out)数据变化: - 输入1: x (batch_size, seq_len, n_embd) - 输入2: enc_out (batch_size, seq_len, n_embd) - 输出: x (batch_size, seq_len, n_embd)作用: - 解码输入序列 - 融合编码器的输出信息8. 输出层:if targets is not None: # 训练阶段 logits = self.lm_head(x) loss = F.cross_entropy(...) else: # 推理阶段 logits = self.lm_head(x[:, [-1], :]) loss = None训练阶段: - 输入: x (batch_size, seq_len, n_embd) - 输出: logits (batch_size, seq_len, vocab_size) - 计算损失: 交叉熵(Cross-Entropy)推理阶段: - 输入: x[:, [-1], :] (batch_size, 1, n_embd) - 仅取最后一个时间步 - 输出: logits (batch_size, 1, vocab_size) - loss = None9. 返回值:return logits, losslogits: 词表概率分布 loss: 损失值(训练阶段有效,推理阶段为 None)
为了更加直观,下面给出一个完整示例:batch_size=2,seq_len=5,vocab_size=10000,n_embd=512。数据经过每一步的维度变化如下。
=== 完整数据流动示例 ===1. 参数设置: batch_size: 2 seq_len: 5 vocab_size: 10000 n_embd: 5122. 数据流动过程:步骤1: 输入 - idx: (2, 5) - 例如: [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]步骤2: 词嵌入 - tok_emb: (2, 5, 512) - 每个 token 索引映射为 512 维向量步骤3: 位置编码 - pos_emb: (2, 5, 512) - 在词向量上叠加位置编码步骤4: Dropout - x: (2, 5, 512) - 随机丢弃部分神经元步骤5: 编码器 - enc_out: (2, 5, 512) - 编码输入序列步骤6: 解码器 - x: (2, 5, 512) - 解码的同时融合编码器信息步骤7: 输出层 - logits: (2, 5, 10000) - 每个位置输出词表概率3. 训练模式 vs 推理模式:训练阶段: - 输出: logits (2, 5, 10000) - loss: 交叉熵损失值 - 用途: 反向传播更新模型参数推理阶段: - 输出: logits (2, 1, 10000) - loss: None - 用途: 生成下一个词4. 关键组件梳理:- wte: 词嵌入层,将 token 索引转换为连续向量 - wpe: 位置编码层,为序列注入位置信息 - encoder: 编码器,提取特征 - decoder: 解码器,生成输出 - lm_head: 输出层,输出词表概率分布- 训练时: 输出全部位置的 logits 并计算损失 - 推理时: 仅输出最后一个位置的 logits 用于逐词生成
最后将数据流动的流程图再次列出,便于对照查阅。
输入 idx (batch_size, seq_len)↓wte: 词嵌入 (batch_size, seq_len, n_embd)↓wpe: 位置编码 (batch_size, seq_len, n_embd)↓drop: Dropout (batch_size, seq_len, n_embd)↓encoder: 编码器 (batch_size, seq_len, n_embd)↓decoder: 解码器 (batch_size, seq_len, n_embd)↓lm_head: 线性层 (batch_size, seq_len, vocab_size)↓logits (batch_size, seq_len, vocab_size)
Transformer 关键组件解析
训练与推理的差异:训练阶段计算所有位置上的损失,推理阶段则仅取最后一个位置的输出用于生成。
各模块核心作用:
1、词嵌入层:将输入的 token 索引映射为稠密向量表示
2、位置编码:为序列中各 token 提供位置信息
3、编码器:通过自注意力与前馈网络提取深层特征
4、解码器:基于编码器输出与自身注意力生成目标序列
5、输出层:将隐状态映射为词表大小的概率分布
6、权重初始化:采用正态分布(均值0,标准差0.02)
