先搞清楚一个问题:为什么需要 Transformer?
在 Transformer 问世之前,NLP 领域基本由 RNN 和 LSTM 主导。然而,它们存在一个致命的缺陷——只能按顺序逐个词地处理文本。

设想这样一个场景:你正在阅读一句话。当 RNN 处理到"烤鸭"这个词时,它实际上已经"经过"了好几个词,对前面提到的"北京"印象早已变得模糊。虽然 LSTM 通过门控机制在一定程度上缓解了这一问题,但本质上仍然是逐个词按顺序处理。
Transformer 的解决方案则截然不同:它让每个词能够同时看到句子中的所有词,想关注哪些词就关注哪些词,一步到位。这正是 Self-Attention(自注意力机制)的核心思想。
一、Self-Attention:让每个词都具备全局视野
1.1 核心直觉
自注意力要解决的核心问题是:在一个序列中,每个元素究竟应该"关注"哪些其他元素?
通俗来说,当模型处理到"它"这个代词时,必须明确这个"它"指的是"猫"还是"垫子"。人类一眼就能看出 "it" 指代的是 "cat",而自注意力机制的目标,就是让模型习得这种语言理解能力。
1.2 Q、K、V 三剑客
自注意力机制的核心是三个矩阵:Query(查询)、Key(键)、Value(值)。
这个概念其实非常容易理解,拿搜索引擎来打个比方:
- Query:就像你在搜索框输入的关键词——"我想找什么?"
- Key:好比每篇文章的标题或标签——"我是什么?"
- Value:就是文章的实质内容——"这是我的信息"
整个匹配过程描述起来也很直观:先用 Query 与所有 Key 计算相似度,相似度高的 Value 就获得更多关注。
输入序列: [词1, 词2, 词3, ...]
↓ (乘以三个权重矩阵)
Q矩阵 K矩阵 V矩阵
↓ ↓ ↓
Q1,Q2.. K1,K2.. V1,V2..
↓
Q·K^T → 注意力分数 → softmax → 加权求和 Value → 输出
1.3 数学公式(别担心,其实很直观)
Scaled Dot-Product Attention 的公式非常经典:
Attention(Q, K, V) = softmax(QK^T / √d_k) · V
import torch
import torch.nn.functional as F
import math
def self_attention(Q, K, V):
"""
Q: (seq_len, d_k) 查询矩阵
K: (seq_len, d_k) 键矩阵
V: (seq_len, d_v) 值矩阵
"""
d_k = Q.size(-1)
# 第一步:Q 和 K 做点积,得到注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1))
# 第二步:除以 √d_k(缩放因子),防止分数过大导致 softmax 梯度消失
scores = scores / math.sqrt(d_k)
# 第三步:softmax 归一化,得到注意力权重(和为1)
attention_weights = F.softmax(scores, dim=-1)
# 第四步:用注意力权重对 V 加权求和
output = torch.matmul(attention_weights, V)
return output, attention_weights
1.4 一个直观的例子
假设我们要处理一句话:["我", "爱", "AI"]
经过 self-attention 后,可能会形成这样的注意力权重:
注意力权重矩阵(softmax后):
我 爱 AI
我 [0.7, 0.2, 0.1] ← "我"主要关注自己
爱 [0.3, 0.4, 0.3] ← "爱"同时关注"我"和"AI"
AI [0.1, 0.3, 0.6] ← "AI"主要关注自己和"爱"
看到没有?模型自动学会了"爱"这个动词连接了"我"和"AI"这个语义关系。
二、Multi-Head Attention:多视角观察,信息更丰富
2.1 为什么需要多头?
单头注意力只有一组 Q、K、V,这意味着每个词只能从一个维度去关注其他词。然而语言的复杂性远超想象——一个词可能同时与多个词存在关联,而且关联的类型也各不相同。多头注意力的思路就是:使用多组 Q、K、V,让模型从不同的"角度"同时捕捉信息。
2.2 结构图解
输入 X
│
├──→ Linear → Q₁, K₁, V₁ → Scaled Attention → Head₁ ──┐
├──→ Linear → Q₂, K₂, V₂ → Scaled Attention → Head₂ ──┤
├──→ Linear → Q₃, K₃, V₃ → Scaled Attention → Head₃ ──┼→ Concat → Linear → 输出
├──→ ... │
└──→ Linear → Q₈, K₈, V₈ → Scaled Attention → Head₈ ──┘
原论文采用了 8 个头(h=8),每个头的维度为 d_model/h。例如 d_model=512,则每个头的维度是 64。
2.3 代码实现
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 每个头的维度 = 64
# Q、K、V 的线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 输出的线性变换层
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 1. 线性变换得到 Q、K、V
Q = self.W_q(x) # (batch, seq_len, d_model)
K = self.W_k(x)
V = self.W_v(x)
# 2. 拆分成多个头
Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# 3. 计算缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 4. 拼接所有头
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
# 5. 最终线性变换
output = self.W_o(attn_output)
return output
2.4 多头究竟学到了什么?
不同的头会聚焦不同类型的信息:
| 头编号 | 可能关注的模式 | 例子 |
|---|---|---|
| Head 1 | 相邻词的关系 | "The cat sat" 中的相邻词 |
| Head 2 | 语法依赖 | "The cat that ate the fish" 中的从句关系 |
| Head 3 | 指代关系 | "The cat because it was tired" 中的"it"指代"cat" |
| Head 4 | 位置关系 | 句首词与句尾词的远距离关联 |
这就是多头的精妙之处——不事先指定关注什么,而是让模型自主学习和发现。
三、Positional Encoding:为词语赋予"位置坐标"
3.1 一个关键问题
前面提到,Transformer 是并行处理所有词的。这带来了一个严峻挑战:缺乏顺序信息!
没错,你没有听错。如果将 "狗咬人" 和 "人咬狗" 分别输入 Transformer,如果不添加位置信息,模型会认为这两句话是完全相同的。RNN/LSTM 天然具有顺序性(逐个词处理),因此不存在这个问题。但 Transformer 通过并行计算换来的效率提升,需要以丢失位置信息为代价。
解决方案说起来也很直接:手动为每个词添加一个"位置编码"。
3.2 位置编码如何实现?
原论文采用了 正弦-余弦位置编码(Sinusoidal Positional Encoding):
位置编码的公式如下:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
对应代码实现:
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model=512, max_len=5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1)
# 计算分母项: 10000^(2i/d_model)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
# 偶数位置用 sin,奇数位置用 cos
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
# (1, max_len, d_model) 方便广播
pe = pe.unsqueeze(0)
# 注册为 buffer(不参与梯度更新,但会保存到模型中)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
# 把位置编码加到输入上
x = x + self.pe[:, :x.size(1), :]
return x
3.3 为什么选用 sin/cos?
你可能会好奇:为什么不用简单的 0, 1, 2, 3... 来表示位置?
原因有三点:
① 能够处理比训练序列更长的文本
sin/cos 是周期函数,天然支持外推。相比之下,学习得到的位置编码(例如 BERT 所使用的)在处理超长文本时效果不太理想。
② 每个位置拥有独一无二的编码
不同位置的正弦波组合是唯一的,就像指纹一样,不会产生混淆。
③ 模型可以学到相对位置关系
这一点是 sin/cos 编码最精妙的地方。通过三角函数公式:
sin(α + β) = sin(α)cos(β) + cos(α)sin(β)
模型可以通过线性变换,从绝对位置编码中推导出相对位置关系。
3.4 位置编码可视化
位置编码热力图(每个位置 × 每个维度):
维度 →
0 1 2 3 4 5 6 7
位置 ↓
0 [ 0.01 1.00 0.01 1.00 0.01 1.00 0.01 1.00 ]
1 [ 0.84 0.54 0.01 1.00 0.01 1.00 0.01 1.00 ]
2 [ 0.91 -0.42 0.02 1.00 0.01 1.00 0.01 1.00 ]
3 [ 0.14 -0.99 0.03 1.00 0.01 1.00 0.01 1.00 ]
...
50 [... 变化越来越慢的低频部分 ...]
可以观察到清晰的规律:低维度变化较快(高频信号),高维度变化较慢(低频信号)。这种多尺度的设计让模型能够同时捕捉局部和全局的位置关系。
四、三大组件如何协同工作?
把上面三个部分串联起来,Transformer 编码器(Encoder)一层的完整流程如下:
输入词向量 (Word Embedding)
│
▼
┌─────────────────────┐
│ Positional Encoding │ ← 添加位置信息
└────────┬────────────┘
▼
┌──────────────────────┐
│ Multi-Head Attention │ ← 自注意力:词与词之间交互
└────────┬─────────────┘
▼
┌───────────────────────────┐
│ Add & Norm (残差+层归一化) │
└────────┬──────────────────┘
▼
┌──────────────────────┐
│ Feed Forward Network │ ← 全连接层:逐位置变换
└────────┬─────────────┘
▼
┌───────────────────────────┐
│ Add & Norm (残差+层归一化) │
└────────┬──────────────────┘
▼
输出到下一层
用代码串联起来就是:
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 1. Multi-Head Self-Attention + 残差连接 + LayerNorm
attn_out = self.self_attn(x)
x = self.norm1(x + self.dropout(attn_out))
# 2. Feed Forward + 残差连接 + LayerNorm
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
五、速查对照表
| 组件 | 解决什么问题 | 核心思想 | 一句话记忆 |
|---|---|---|---|
| Self-Attention | 词与词之间如何交互 | Q·K^T 计算相似度,加权获取 V | "你跟谁最相关?" |
| Multi-Head Attention | 单一视角不够全面 | 多组 QKV 并行,多角度观察 | "多看几眼,看得更全" |
| Positional Encoding | 并行计算丢失位置信息 | sin/cos 注入位置信息 | "告诉模型你在哪" |
