本文带你一步步理解 Transformer 中最核心的模块:多头注意力机制(Multi-Head Attention)。从原理到实现,配图 + 举例 + PyTorch 代码,一次性说清楚!
什么是 Multi-Head Attention?
简单说,多头注意力就是一种让模型在多个角度“看”一个序列的机制。
在自然语言中,一个词的含义往往依赖于上下文,比如:
“我把苹果给了她”
模型在处理“苹果”时,需要关注“我”“她”“给了”等词,多头注意力就是这样一种机制——从多个角度理解上下文关系。
Self-Attention 是什么?为什么还要多头?
在讲“多头”之前,咱们先回顾一下基础的 Self-Attention。
Self-Attention(自注意力)机制的目标是:
让每个词都能“关注”整个句子里的其他词,融合上下文。
它的核心步骤是:
-
对每个词生成 Query、Key、Value 向量
-
用 Query 和所有 Key 做点积,算出每个词对其他词的关注度(打分)
-
用 Softmax 得到权重,对 Value 加权平均,生成当前词的新表示
这样做的好处是:词的语义表示不再是孤立的,而是上下文相关的。
Self-Attention vs Multi-Head Attention
但问题是——单头 Self-Attention 视角有限。就像一个老师只能从一种角度讲课。
于是,Multi-Head Attention 应运而生!
特性 | Self-Attention(单头) | Multi-Head Attention(多头) |
---|---|---|
输入映射矩阵 | 一组 Q/K/V 线性变换 | 多组 Q/K/V,每个头一组 |
学习角度 | 单一视角 | 多角度并行理解 |
表达能力 | 有限 | 更丰富、强大 |
结构 | 简单 | 并行多个头 + 合并输出 |
一句话总结:
Multi-Head Attention = 多个不同“视角”的 Self-Attention 并行处理 + 合并结果
多头注意力:8个脑袋一起思考!
多头 = 多个“单头注意力”并行处理!
每个头使用不同的线性变换矩阵,所以能从不同视角处理数据:
-
第1个头可能专注短依赖(like 动词和主语)
-
第2个头可能专注实体关系(我 vs 她)
-
第3个头可能关注时间顺序(“给了”前后)
-
……共用同一个输入,学习到不同特征!
多头的步骤:
-
将输入向量(如512维)拆成多个头(比如8个,每个64维)
-
每个头独立进行 attention
-
所有头的输出拼接
-
再过一次线性变换,融合成最终输出
PyTorch 实现(简洁版)
我们来看下 PyTorch 中的简化实现:
import torch
import torch.nn as nn
import copydef clones(module, N):return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])def attention(query, key, value, mask=None, dropout=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = torch.softmax(scores, dim=-1)if dropout:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attnclass MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):super().__init__()assert d_model % h == 0self.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = query.size(0)query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for lin, x in zip(self.linears, (query, key, value))]x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)
举个例子:多头在实际模型中的作用
假设输入是句子:
"The animal didn't cross the street because it was too tired."
多头注意力的不同头可能会:
-
🧠 头1:关注“animal”和“it”之间的指代关系;
-
📐 头2:识别“because”和“tired”之间的因果联系;
-
📚 头3:注意句子的结构层次……
所以说,多头注意力本质上是一个“并行注意力专家系统”!
总结
项目 | 解释 |
---|---|
目的 | 提升模型表达能力,从多个角度理解输入 |
核心机制 | 将向量分头 → 每头独立 attention → 合并输出 |
技术关键 | view , transpose , matmul , softmax , 拼接线性层 |
推荐学习路径
-
🔹 理解 Self-Attention 的点积公式
-
🔹 搞懂
view
,transpose
等张量操作 -
🔹 看 Transformer 整体结构,关注每层作用