2017年,Google那篇《Attention Is All You Need》横空出世,直接把Transformer推到了前台。这篇文章做了一个在当时相当激进的决定——彻底抛弃RNN和LSTM的递归结构,全靠一种叫Attention的机制来捕捉语义关系。结果我们都看到了:GPT、BERT、LLaMA……几乎整个大模型时代都是从这里开始的。
这篇文章会带着你,用一份干净的、大约350行的纯PyTorch代码,把Transformer的每一个组件拆开来看。不是泛泛地讲概念,是切切实实从零搭到完整模型,代码可以直接跑,非常适合拿来理解原理、调试或者自己二次修改。
1. 组件一览
在深入代码之前,我们先快速扫一眼整个架构的核心模块,知道每个部分干的活和大概的计算量,心里有个底。
| 组件 | 功能 | 复杂度 |
|---|---|---|
| Scaled Dot-Product Attention | Q/K/V 相似度计算与聚合 | O(n·dₖ) |
| Multi-Head Attention | 多个表示空间并行注意力 | O(n·d_model) |
| Position-wise FFN | 每个位置非线性变换 | O(d_model·d_ff) |
| Positional Encoding | 引入位置信息 | O(max_len·d_model) |
| Layer Norm | 维度规范化 | O(d_model) |
| Encoder Layer | 自注意力 + FFN | O(n²·d_model) |
| Decoder Layer | 带掩码 + 交叉注意力 | O(n²·d_model) |
2. Scaled Dot-Product Attention
公式与直观
Scaled Dot-Product Attention是Transformer里最核心的计算单元。它的本质说白了就一句话:用“查询”(Query)和“键”(Key)的相似度,来决定怎么去加权聚合“值”(Value)。
为什么要除以根号dk?这个问题其实很重要。当dk比较大的时候,点积的结果会随着维度增加而变大,这会把softmax推到非常极端的区域,梯度基本就消失了。缩放一下,让方差保持稳定,训练才能稳得住。
代码解析
class ScaledDotProductAttention(nn.Module):
"""Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V"""
def __init__(self, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, V)
return output, attn_weights
这里有几个关键点值得留意:
masked_fill会把padding位置的值变成负无穷,softmax之后权重就变成0了,不会影响输出结果。- 计算完attention权重之后才做dropout,这是一个很重要的正则化手段。
- 函数里同时返回了
attn_weights,主要就是为了可视化调试用的。
这个模块的计算瓶颈在QK转置的矩阵乘法上,复杂度是O(n²·dk),这里的n就是序列长度。这也是整个Transformer最吃性能的地方。
3. Multi-Head Attention
从单头到多头
单头注意力有一个天然的局限:它只能在一种表示空间里看问题。Multi-Head Attention的做法是把Q、K、V分别投影到h个不同的表示空间里,各自独立算一遍注意力,最后把结果拼起来,再投影回到原来的维度。
代码实现里有一个很聪明的设计决策:先一次性投影,再拆成多个头。数学上定义h个独立的线性层是等价的,但代码里只需要4个线性层而不是3h个,效率高得多。
代码解析
class MultiHeadAttention(nn.Module):
"""MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_O
其中 head_i = Attention(Q @ W_Q_i, K @ W_K_i, V @ W_V_i)"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0, "d_model 必须能被 n_heads 整除"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.attention = ScaledDotProductAttention(dropout)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 1) 线性投影 → (batch, seq_len, d_model)
Q = self.W_Q(Q)
K = self.W_K(K)
V = self.W_V(V)
# 2) 拆成多头 → (batch, n_heads, seq_len, d_k)
Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 3) Scaled Dot-Product Attention
attn_output, attn_weights = self.attention(Q, K, V, mask)
# 4) 拼接多头 → (batch, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 5) 最终线性投影
output = self.W_O(attn_output)
return output
几个值得注意的点:
d_model % n_heads == 0这个断言保证了每个head能分到整数维度。view + transpose这个操作就像"重新排片"一样,先把维度拆开再转置,让head维度排到前面。.contiguous()这一步不能漏,transpose只是改了视图的内存布局,不调contiguous的话,后续的view会报错。- 最后的
W_O是拼接后的最终投影,它负责把多头信息融合回到d_model维度。
mask处理这里有个小技巧:mask用(batch, 1, 1, seq_len)的格式,可以直接和scores的(batch, n_heads, seq_len, seq_len)做广播,省掉了额外的unsqueeze操作。
4. Position-wise Feed-Forward Network
非线性变换与容量
每个位置的表示在经过注意力层之后,还要过一个两层的全连接网络。这个FFN是position-wise的,意思是它对序列里的每个位置独立应用同样的参数,效果相当于卷积核大小为1的卷积。
代码解析
class PositionWiseFeedForward(nn.Module):
"""FFN(x) = ReLU(x @ W_1 + b_1) @ W_2 + b_2
内部维度从 d_model → d_ff → d_model"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
这里的设计思路也很清晰:
- 内部的d_ff通常比d_model大得多,论文里用的是512到2048的扩展,这一下子就给了非线性变换充足的容量。
- dropout放在ReLU之后、第二次线性投影之前,这是目前比较主流的做法。
- 原始论文用的是ReLU,后来像GPT这些工作更多用GELU,这是一个值得留意的演进细节。
为什么偏偏是两层?论文里的实验表明,一层表达能力确实不够,但三层以上的收益又微乎其微。两层就是性能和资源之间最优的那个平衡点。
5. Positional Encoding
为序列引入位置信息
Self-Attention有一个特性一定会让你意外——它是对位置完全不敏感的。不管你把序列里的元素怎么打乱,输出结果都一样。为了引入位置信息,原始论文用的是正余弦编码。
为什么用正余弦而不用可学习的位置嵌入?这里有三个很实在的理由:
- 它可以处理比训练时更长的序列,也就是有外推能力。
- 不需要额外参数,省内存。
- 相对位置信息通过线性变换就能表达出来——因为sin(α+Δ) = sinα·cosΔ + cosα·sinΔ,这里天然存在线性关系。
代码解析
class PositionalEncoding(nn.Module):
"""PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
代码里有几个细节值得拿出来说说:
div_term用了指数形式而不是直接算10000的2i/d_model次方,这是为了数值稳定性。register_buffer这一步让pe可以跟着模型一起移动到CPU或GPU上,但不会作为参数被优化器更新。- forward里直接把pe和输入相加,通过broadcast机制对齐维度,这是最经典的做法了。
6. Layer Normalization
维度规范化
Layer Normalization做的事情是:对每个样本的所有维度做一次变换——减去均值,除以标准差,然后再做一个可学习的线性变换。跟Batch Norm不同,LN不依赖batch的大小,处理变长序列的时候要稳定得多。
代码解析
class LayerNorm(nn.Module):
"""LayerNorm(x) = gamma * (x - mean) / sqrt(var + eps) + beta
手写版,方便理解;实际可直接用 nn.LayerNorm"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True, unbiased=False)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
这段代码虽然手写了一遍,但实际用的时候直接用nn.LayerNorm就行。不过这个手写版能让你看得更清楚:unbiased=False用的是样本标准差而不是无偏估计,这和原始论文的做法一致。eps取1e-6是为了防止除零。
7. Encoder Layer
网络中的网络单元
Encoder层是Transformer的基本构建块。每层包含两个子层:多头自注意力和FFN,每个子层后面紧跟着一个残差连接和层规范化。
x → MultiHead Self-Attention → Add & Norm → FFN → Add & Norm
代码解析
class EncoderLayer(nn.Module):
"""一个 Encoder 层:x → MultiHead Self-Attention → Add & Norm → FFN → Add & Norm"""
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-Attention + Add & Norm
attn_output = self.self_attn(x, x, x, mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# FFN + Add & Norm
ffn_output = self.ffn(x)
x = x + self.dropout2(ffn_output)
x = self.norm2(x)
return x
几个要点:
- 残差连接(x + sublayer(x))是解决深层网络梯度消失的关键设计——梯度可以直接通过这一条shortcut回传。
- 这里用的是Post-LN模式,先做残差再做规范化,和原始论文的做法保持一致。
self_attn的三个输入参数全都是x,这正好对应了"自注意力"的概念——Q、K、V都来自同一个序列。
8. Decoder Layer
带掩码的自注意力与交叉注意力
Decoder层比Encoder多了一个子层——Cross-Attention。在这个子层里,Encoder的输出作为K和V,Decoder的输入作为Q。同时,自注意力层需要用下三角mask来遮住后面的位置,防止信息泄露。
x → Masked Self-Attention → Add & Norm → Cross-Attention → Add & Norm → FFN → Add & Norm
代码解析
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# Self-Attention(带 look-ahead mask)
attn_output = self.self_attn(x, x, x, tgt_mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# Cross-Attention: Q 来自 Decoder, K/V 来自 Encoder
attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
x = x + self.dropout2(attn_output)
x = self.norm2(x)
# FFN
ffn_output = self.ffn(x)
x = x + self.dropout3(ffn_output)
x = self.norm3(x)
return x
这里有几个明显的区别:
- Self-Attention用了
tgt_mask,也就是下三角mask来遮住未来的位置;Cross-Attention则用src_mask来过滤掉Encoder那边padding的位置。 - Cross-Attention的K和V来自Encoder,Q来自Decoder——这是一种"引导"机制,Decoder每走一步都能看到输入序列的全部信息。
- Decoder比Encoder多了整整一套残差连接和LayerNorm,总共3个。
9. 完整 Transformer
拼装成网络
最后一步就是把N层Encoder和N层Decoder叠起来,再加上嵌入层、位置编码和最终的分类头,一个完整的Transformer就拼装完成了。
src → Embedding → Positional Encoding → N × EncoderLayer
↓
tgt → Embedding → Positional Encoding → N × DecoderLayer → Linear → output
代码解析
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model=512, n_heads=8, d_ff=2048, n_layers=6, dropout=0.1, max_len=5000):
super().__init__()
self.encoder_embed = nn.Embedding(src_vocab, d_model)
self.decoder_embed = nn.Embedding(tgt_vocab, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
self.fc_out = nn.Linear(d_model, tgt_vocab)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# Encoder
src_emb = self.pos_encoding(self.encoder_embed(src))
for layer in self.encoder_layers:
src_emb = layer(src_emb, src_mask)
# Decoder
tgt_emb = self.pos_encoding(self.decoder_embed(tgt))
for layer in self.decoder_layers:
tgt_emb = layer(tgt_emb, src_emb, src_mask, tgt_mask)
return self.fc_out(tgt_emb)
最后这几个设计点值得记住:
nn.ModuleList保证了每一层的参数都能被正确注册。- Encoder和Decoder各自有独立的嵌入层和位置编码,互不干扰。
fc_out把d_model投影到词表的大小,负责输出下一个token的概率分布。
10. 总结
这份从零开始的实现,覆盖了Transformer从Scaled Dot-Product Attention到完整Encoder-Decoder架构的所有核心组件。每一行代码背后都有明确的动机和设计思考。
把这些基础组件吃透之后,再去看GPT系列那种只用Decoder的做法,或者BERT系列只用Encoder的方案,以及LLaMA这些现代变体,就能很快抓住它们各自的设计决策——知道它们在哪一步做了什么取舍,为什么那么做。
