游乐游手机版
首页/AI教程/文章详情

从零复现Attention Is All You Need模型

时间:2026-06-02 18:46
2017年,Google那篇《Attention Is All You Need》横空出世,直接把Transformer推到了前台。这篇文章做了一个在当时相当激进的决定——彻底抛弃RNN和LSTM的递归结构,全靠一种叫Attention的机制来捕捉语义关系。结果我们都看到了:GPT、BERT、LLa

2017年,Google那篇《Attention Is All You Need》横空出世,直接把Transformer推到了前台。这篇文章做了一个在当时相当激进的决定——彻底抛弃RNN和LSTM的递归结构,全靠一种叫Attention的机制来捕捉语义关系。结果我们都看到了:GPT、BERT、LLaMA……几乎整个大模型时代都是从这里开始的。

这篇文章会带着你,用一份干净的、大约350行的纯PyTorch代码,把Transformer的每一个组件拆开来看。不是泛泛地讲概念,是切切实实从零搭到完整模型,代码可以直接跑,非常适合拿来理解原理、调试或者自己二次修改。

1. 组件一览

在深入代码之前,我们先快速扫一眼整个架构的核心模块,知道每个部分干的活和大概的计算量,心里有个底。

组件功能复杂度
Scaled Dot-Product AttentionQ/K/V 相似度计算与聚合O(n·dₖ)
Multi-Head Attention多个表示空间并行注意力O(n·d_model)
Position-wise FFN每个位置非线性变换O(d_model·d_ff)
Positional Encoding引入位置信息O(max_len·d_model)
Layer Norm维度规范化O(d_model)
Encoder Layer自注意力 + FFNO(n²·d_model)
Decoder Layer带掩码 + 交叉注意力O(n²·d_model)

2. Scaled Dot-Product Attention

公式与直观

Scaled Dot-Product Attention是Transformer里最核心的计算单元。它的本质说白了就一句话:用“查询”(Query)和“键”(Key)的相似度,来决定怎么去加权聚合“值”(Value)。

Attention(Q,K,V)=softmax(QKdk)V

为什么要除以根号dk?这个问题其实很重要。当dk比较大的时候,点积的结果会随着维度增加而变大,这会把softmax推到非常极端的区域,梯度基本就消失了。缩放一下,让方差保持稳定,训练才能稳得住。

代码解析

class ScaledDotProductAttention(nn.Module):
    """Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V"""
    def __init__(self, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

这里有几个关键点值得留意:

  • masked_fill会把padding位置的值变成负无穷,softmax之后权重就变成0了,不会影响输出结果。
  • 计算完attention权重之后才做dropout,这是一个很重要的正则化手段。
  • 函数里同时返回了attn_weights,主要就是为了可视化调试用的。

这个模块的计算瓶颈在QK转置的矩阵乘法上,复杂度是O(n²·dk),这里的n就是序列长度。这也是整个Transformer最吃性能的地方。

3. Multi-Head Attention

从单头到多头

单头注意力有一个天然的局限:它只能在一种表示空间里看问题。Multi-Head Attention的做法是把Q、K、V分别投影到h个不同的表示空间里,各自独立算一遍注意力,最后把结果拼起来,再投影回到原来的维度。

MultiHead(Q,K,V)=Concat(head1,...,headh)WO

代码实现里有一个很聪明的设计决策:先一次性投影,再拆成多个头。数学上定义h个独立的线性层是等价的,但代码里只需要4个线性层而不是3h个,效率高得多。

代码解析

class MultiHeadAttention(nn.Module):
    """MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_O
    其中 head_i = Attention(Q @ W_Q_i, K @ W_K_i, V @ W_V_i)"""
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model 必须能被 n_heads 整除"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        self.attention = ScaledDotProductAttention(dropout)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        # 1) 线性投影 → (batch, seq_len, d_model)
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V)
        # 2) 拆成多头 → (batch, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        # 3) Scaled Dot-Product Attention
        attn_output, attn_weights = self.attention(Q, K, V, mask)
        # 4) 拼接多头 → (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        # 5) 最终线性投影
        output = self.W_O(attn_output)
        return output

几个值得注意的点:

  • d_model % n_heads == 0这个断言保证了每个head能分到整数维度。
  • view + transpose这个操作就像"重新排片"一样,先把维度拆开再转置,让head维度排到前面。
  • .contiguous()这一步不能漏,transpose只是改了视图的内存布局,不调contiguous的话,后续的view会报错。
  • 最后的W_O是拼接后的最终投影,它负责把多头信息融合回到d_model维度。

mask处理这里有个小技巧:mask用(batch, 1, 1, seq_len)的格式,可以直接和scores的(batch, n_heads, seq_len, seq_len)做广播,省掉了额外的unsqueeze操作。

4. Position-wise Feed-Forward Network

非线性变换与容量

每个位置的表示在经过注意力层之后,还要过一个两层的全连接网络。这个FFN是position-wise的,意思是它对序列里的每个位置独立应用同样的参数,效果相当于卷积核大小为1的卷积。

FFN(x)=ReLU(xW1+b1)W2+b2

代码解析

class PositionWiseFeedForward(nn.Module):
    """FFN(x) = ReLU(x @ W_1 + b_1) @ W_2 + b_2
    内部维度从 d_model → d_ff → d_model"""
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

这里的设计思路也很清晰:

  • 内部的d_ff通常比d_model大得多,论文里用的是512到2048的扩展,这一下子就给了非线性变换充足的容量。
  • dropout放在ReLU之后、第二次线性投影之前,这是目前比较主流的做法。
  • 原始论文用的是ReLU,后来像GPT这些工作更多用GELU,这是一个值得留意的演进细节。

为什么偏偏是两层?论文里的实验表明,一层表达能力确实不够,但三层以上的收益又微乎其微。两层就是性能和资源之间最优的那个平衡点。

5. Positional Encoding

为序列引入位置信息

Self-Attention有一个特性一定会让你意外——它是对位置完全不敏感的。不管你把序列里的元素怎么打乱,输出结果都一样。为了引入位置信息,原始论文用的是正余弦编码。

PE(pos,2i)=sin(pos100002i/dmodel)

PE(pos,2i+1)=cos(pos100002i/dmodel)

为什么用正余弦而不用可学习的位置嵌入?这里有三个很实在的理由:

  1. 它可以处理比训练时更长的序列,也就是有外推能力。
  2. 不需要额外参数,省内存。
  3. 相对位置信息通过线性变换就能表达出来——因为sin(α+Δ) = sinα·cosΔ + cosα·sinΔ,这里天然存在线性关系。

代码解析

class PositionalEncoding(nn.Module):
    """PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))"""
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

代码里有几个细节值得拿出来说说:

  • div_term用了指数形式而不是直接算10000的2i/d_model次方,这是为了数值稳定性。
  • register_buffer这一步让pe可以跟着模型一起移动到CPU或GPU上,但不会作为参数被优化器更新。
  • forward里直接把pe和输入相加,通过broadcast机制对齐维度,这是最经典的做法了。

6. Layer Normalization

维度规范化

Layer Normalization做的事情是:对每个样本的所有维度做一次变换——减去均值,除以标准差,然后再做一个可学习的线性变换。跟Batch Norm不同,LN不依赖batch的大小,处理变长序列的时候要稳定得多。

LayerNorm(x)=γxμσ2+ϵ+β

代码解析

class LayerNorm(nn.Module):
    """LayerNorm(x) = gamma * (x - mean) / sqrt(var + eps) + beta
    手写版,方便理解;实际可直接用 nn.LayerNorm"""
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True, unbiased=False)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

这段代码虽然手写了一遍,但实际用的时候直接用nn.LayerNorm就行。不过这个手写版能让你看得更清楚:unbiased=False用的是样本标准差而不是无偏估计,这和原始论文的做法一致。eps取1e-6是为了防止除零。

7. Encoder Layer

网络中的网络单元

Encoder层是Transformer的基本构建块。每层包含两个子层:多头自注意力和FFN,每个子层后面紧跟着一个残差连接和层规范化。

x → MultiHead Self-Attention → Add & Norm → FFN → Add & Norm

代码解析

class EncoderLayer(nn.Module):
    """一个 Encoder 层:x → MultiHead Self-Attention → Add & Norm → FFN → Add & Norm"""
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-Attention + Add & Norm
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        # FFN + Add & Norm
        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output)
        x = self.norm2(x)
        return x

几个要点:

  • 残差连接(x + sublayer(x))是解决深层网络梯度消失的关键设计——梯度可以直接通过这一条shortcut回传。
  • 这里用的是Post-LN模式,先做残差再做规范化,和原始论文的做法保持一致。
  • self_attn的三个输入参数全都是x,这正好对应了"自注意力"的概念——Q、K、V都来自同一个序列。

8. Decoder Layer

带掩码的自注意力与交叉注意力

Decoder层比Encoder多了一个子层——Cross-Attention。在这个子层里,Encoder的输出作为K和V,Decoder的输入作为Q。同时,自注意力层需要用下三角mask来遮住后面的位置,防止信息泄露。

x → Masked Self-Attention → Add & Norm → Cross-Attention → Add & Norm → FFN → Add & Norm

代码解析

class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self-Attention(带 look-ahead mask)
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        # Cross-Attention: Q 来自 Decoder, K/V 来自 Encoder
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = x + self.dropout2(attn_output)
        x = self.norm2(x)
        # FFN
        ffn_output = self.ffn(x)
        x = x + self.dropout3(ffn_output)
        x = self.norm3(x)
        return x

这里有几个明显的区别:

  • Self-Attention用了tgt_mask,也就是下三角mask来遮住未来的位置;Cross-Attention则用src_mask来过滤掉Encoder那边padding的位置。
  • Cross-Attention的K和V来自Encoder,Q来自Decoder——这是一种"引导"机制,Decoder每走一步都能看到输入序列的全部信息。
  • Decoder比Encoder多了整整一套残差连接和LayerNorm,总共3个。

9. 完整 Transformer

拼装成网络

最后一步就是把N层Encoder和N层Decoder叠起来,再加上嵌入层、位置编码和最终的分类头,一个完整的Transformer就拼装完成了。

src → Embedding → Positional Encoding → N × EncoderLayer
                                ↓
tgt → Embedding → Positional Encoding → N × DecoderLayer → Linear → output

代码解析

class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, n_heads=8, d_ff=2048, n_layers=6, dropout=0.1, max_len=5000):
        super().__init__()
        self.encoder_embed = nn.Embedding(src_vocab, d_model)
        self.decoder_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # Encoder
        src_emb = self.pos_encoding(self.encoder_embed(src))
        for layer in self.encoder_layers:
            src_emb = layer(src_emb, src_mask)
        # Decoder
        tgt_emb = self.pos_encoding(self.decoder_embed(tgt))
        for layer in self.decoder_layers:
            tgt_emb = layer(tgt_emb, src_emb, src_mask, tgt_mask)
        return self.fc_out(tgt_emb)

最后这几个设计点值得记住:

  • nn.ModuleList保证了每一层的参数都能被正确注册。
  • Encoder和Decoder各自有独立的嵌入层和位置编码,互不干扰。
  • fc_out把d_model投影到词表的大小,负责输出下一个token的概率分布。

10. 总结

这份从零开始的实现,覆盖了Transformer从Scaled Dot-Product Attention到完整Encoder-Decoder架构的所有核心组件。每一行代码背后都有明确的动机和设计思考。

把这些基础组件吃透之后,再去看GPT系列那种只用Decoder的做法,或者BERT系列只用Encoder的方案,以及LLaMA这些现代变体,就能很快抓住它们各自的设计决策——知道它们在哪一步做了什么取舍,为什么那么做。

来源:https://juejin.cn/post/7646272432072917043
上一篇2026年5月技术圈AI大模型、云原生与Rust生态复盘 下一篇如何用WPS AI编辑PDF提升文档效率
本站内容用于信息整理与展示,如有侵权或内容问题请及时联系处理。

相关推荐

补充同频道和同主题内容,方便继续浏览更多相关内容。

同类最新

继续查看同栏目最近更新的文章。

更多
2026实测解析GPT-5.5模型能力详解与国内合规使用规范
AI教程 · 2026-06-03

2026实测解析GPT-5.5模型能力详解与国内合规使用规范

2026年,AI大模型迎来了又一次迭代升级。GPT-5 5凭借在多模态精细化处理能力上的跨越式突破,正逐步成为职场办公、内容创作、代码开发以及数据优化等领域的核心生产力工具。然而,对国内多数用户而言,当前仍面临不少现实难题:渠道杂乱、合规边界模糊、账号频繁被封、数据泄露风险——各类非正规镜像站、共享

分时操作系统和实时操作系统的主要区别
AI教程 · 2026-06-03

分时操作系统和实时操作系统的主要区别

分时操作系统和实时操作系统区别 ?️ 操作系统家族里,有两类系统经常被放在一起比较:分时操作系统和实时操作系统。它们虽然都叫“操作系统”,但设计哲学、工作机制和应用场景可以说是天差地别。一个追求“公平共享”,一个追求“确定性响应”。这篇文章打算从定义、核心机制、调度策略、实际应用等维度,把这两者的本

企业AI智能体从零搭建实战踩坑经验全记录
AI教程 · 2026-06-03

企业AI智能体从零搭建实战踩坑经验全记录

去年开始用腾讯云智能体开发平台(ADP)跑了几个企业项目,从最基础的客服Bot一路干到多Agent协同系统,中间踩的坑不少,但积累下来的经验价值也相当可观。这篇文章就聊聊实际落地过程里的那些关键节点和教训,给同样在腾讯云上折腾AI Agent的朋友做个参考。为什么选腾讯云ADP而不是从零搭建做第一个

Selenium自动化测试入门:从环境搭建到首个可维护用例
AI教程 · 2026-06-03

Selenium自动化测试入门:从环境搭建到首个可维护用例

Selenium 入门的核心不在于记住多少 API,而在于把三件事想清楚:环境别装错版本、等待机制别用 sleep、用例结构别写成流水账。下面按照“装环境 → 跑通第一个脚本 → 理解等待 → 选对定位器 → 拆成 Page Object”的顺序走一遍,每一步都附上代码,踩过的坑直接标出来。 Sel

专业表格魔法师 QoderWork CN 让脏数据秒变仪表盘神器
AI教程 · 2026-06-03

专业表格魔法师 QoderWork CN 让脏数据秒变仪表盘神器

使用案例 今天聊聊怎么用阿里巴巴的 QoderWork CN 桌面应用智能体,把 Excel 里那堆乱糟糟的原始数据清洗干净,再做成可视化的看板。整个过程基本不需要写代码,全靠自然语言对话就能搞定。下面就用一个实际案例,把操作步骤拆开来讲。 步骤一:安装并注册 QoderWork CN 账号 先到