欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 维修 > 深入浅出 Multi-Head Attention:原理 + 例子 + PyTorch 实现

深入浅出 Multi-Head Attention:原理 + 例子 + PyTorch 实现

2025/4/23 11:55:21 来源:https://blog.csdn.net/murphymeng2001/article/details/147320316  浏览:    关键词:深入浅出 Multi-Head Attention:原理 + 例子 + PyTorch 实现

本文带你一步步理解 Transformer 中最核心的模块:多头注意力机制(Multi-Head Attention)。从原理到实现,配图 + 举例 + PyTorch 代码,一次性说清楚!


什么是 Multi-Head Attention?

简单说,多头注意力就是一种让模型在多个角度“看”一个序列的机制。

在自然语言中,一个词的含义往往依赖于上下文,比如:

“我把苹果给了她”

模型在处理“苹果”时,需要关注“我”“她”“给了”等词,多头注意力就是这样一种机制——从多个角度理解上下文关系。


Self-Attention 是什么?为什么还要多头?

在讲“多头”之前,咱们先回顾一下基础的 Self-Attention

Self-Attention(自注意力)机制的目标是:

让每个词都能“关注”整个句子里的其他词,融合上下文。

它的核心步骤是:

  1. 对每个词生成 Query、Key、Value 向量

  2. 用 Query 和所有 Key 做点积,算出每个词对其他词的关注度(打分)

  3. 用 Softmax 得到权重,对 Value 加权平均,生成当前词的新表示

这样做的好处是:词的语义表示不再是孤立的,而是上下文相关的。


Self-Attention vs Multi-Head Attention

但问题是——单头 Self-Attention 视角有限。就像一个老师只能从一种角度讲课。

于是,Multi-Head Attention 应运而生

特性Self-Attention(单头)Multi-Head Attention(多头)
输入映射矩阵一组 Q/K/V 线性变换多组 Q/K/V,每个头一组
学习角度单一视角多角度并行理解
表达能力有限更丰富、强大
结构简单并行多个头 + 合并输出

一句话总结:

Multi-Head Attention = 多个不同“视角”的 Self-Attention 并行处理 + 合并结果


 多头注意力:8个脑袋一起思考!

多头 = 多个“单头注意力”并行处理!

每个头使用不同的线性变换矩阵,所以能从不同视角处理数据:

  • 第1个头可能专注短依赖(like 动词和主语)

  • 第2个头可能专注实体关系(我 vs 她)

  • 第3个头可能关注时间顺序(“给了”前后)

  • ……共用同一个输入,学习到不同特征!

多头的步骤:

  1. 将输入向量(如512维)拆成多个头(比如8个,每个64维)

  2. 每个头独立进行 attention

  3. 所有头的输出拼接

  4. 再过一次线性变换,融合成最终输出


 PyTorch 实现(简洁版)

我们来看下 PyTorch 中的简化实现:

import torch
import torch.nn as nn
import copydef clones(module, N):return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])def attention(query, key, value, mask=None, dropout=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = torch.softmax(scores, dim=-1)if dropout:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attnclass MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):super().__init__()assert d_model % h == 0self.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = query.size(0)query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for lin, x in zip(self.linears, (query, key, value))]x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

举个例子:多头在实际模型中的作用

假设输入是句子:

"The animal didn't cross the street because it was too tired."

多头注意力的不同头可能会:

  • 🧠 头1:关注“animal”和“it”之间的指代关系;

  • 📐 头2:识别“because”和“tired”之间的因果联系;

  • 📚 头3:注意句子的结构层次……

所以说,多头注意力本质上是一个“并行注意力专家系统”!


 总结

项目解释
目的提升模型表达能力,从多个角度理解输入
核心机制将向量分头 → 每头独立 attention → 合并输出
技术关键view, transpose, matmul, softmax, 拼接线性层

推荐学习路径

  • 🔹 理解 Self-Attention 的点积公式

  • 🔹 搞懂 view, transpose 等张量操作

  • 🔹 看 Transformer 整体结构,关注每层作用

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词