欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 产业 > 图解多头注意力机制:维度变化一镜到底

图解多头注意力机制:维度变化一镜到底

2025/3/17 10:29:51 来源:https://blog.csdn.net/qq_22866291/article/details/146304009  浏览:    关键词:图解多头注意力机制:维度变化一镜到底



一、多头注意力机制概述

多头注意力(Multi-Head Attention)是Transformer模型的核心组件,其核心思想是通过 ‌并行处理多个子空间‌ 来捕捉序列中不同位置间的复杂依赖关系。主要特点:

  • 并行计算:将高维向量拆分为多个低维子空间
  • 多视角学习:每个注意力头关注不同特征模式
  • 高效性:矩阵运算高度可并行化

在这里插入图片描述

二、代码实现

1. pyTorch 实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):"""Args:embed_dim: 词向量维度(如512)num_heads: 注意力头数量(如8)"""super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads  # 每个头的维度(如512//8=64)assert self.head_dim * num_heads == embed_dim, "维度不可整除"# 定义线性变换层self.query = nn.Linear(embed_dim, embed_dim)  # Q矩阵self.key = nn.Linear(embed_dim, embed_dim)    # K矩阵self.value = nn.Linear(embed_dim, embed_dim)  # V矩阵self.out = nn.Linear(embed_dim, embed_dim)    # 输出层def transpose_for_scores(self, x):"""拆分多头并调整维度顺序输入: [batch_size, seq_len, embed_dim]输出: [batch_size, num_heads, seq_len, head_dim]"""new_shape = x.size()[:-1] + (self.num_heads, self.head_dim)x = x.view(*new_shape)  # 新增头维度return x.permute(0, 2, 1, 3)  # [batch, heads, seq_len, head_dim]def forward(self, query, key, value, mask=None):"""前向传播流程输入形状: [batch_size, seq_len, embed_dim]输出形状: [batch_size, seq_len, embed_dim]"""batch_size = query.size(0)# 1. 线性变换Q = self.query(query)  # [N, seq, D]K = self.key(key)      # [N, seq, D]V = self.value(value)  # [N, seq, D]# 2. 拆分多头Q = self.transpose_for_scores(Q)  # [N, h, seq, d]K = self.transpose_for_scores(K)  # [N, h, seq, d] V = self.transpose_for_scores(V)  # [N, h, seq, d]# 3. 计算注意力分数scores = torch.matmul(Q, K.transpose(-2, -1))  # [N, h, seq_q, seq_k]scores /= math.sqrt(self.head_dim)  # 缩放# 4. 应用掩码(可选)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# 5. 计算注意力权重attn_weights = F.softmax(scores, dim=-1)  # [N, h, seq_q, seq_k]# 6. 应用权重到Valueout = torch.matmul(attn_weights, V)  # [N, h, seq_q, d]# 7. 合并多头out = out.permute(0, 2, 1, 3).contiguous()  # [N, seq_q, h, d]out = out.view(batch_size, -1, self.embed_dim)  # [N, seq, D]# 8. 输出层return self.out(out), attn_weights
2. tensorFlow实现
# TensorFlow (兼容TF2.x)import tensorflow as tf
from tensorflow.keras.layers import Layer, Denseclass MultiHeadAttention(Layer):def __init__(self, embed_dim, num_heads):"""Args:embed_dim: 词向量维度(如512)num_heads: 注意力头数量(如8)"""super(MultiHeadAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "维度不可整除"# 定义线性变换层self.query_dense = Dense(embed_dim)self.key_dense = Dense(embed_dim)self.value_dense = Dense(embed_dim)self.output_dense = Dense(embed_dim)def split_heads(self, x, batch_size):"""拆分多头并调整维度顺序输入: [batch_size, seq_len, embed_dim]输出: [batch_size, num_heads, seq_len, head_dim]"""x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))return tf.transpose(x, perm=[0, 2, 1, 3])def call(self, query, key, value, mask=None):batch_size = tf.shape(query)# 1. 线性变换Q = self.query_dense(query)  # [N, seq, D]K = self.key_dense(key)      # [N, seq, D]V = self.value_dense(value)  # [N, seq, D]# 2. 拆分多头Q = self.split_heads(Q, batch_size)  # [N, h, seq, d]K = self.split_heads(K, batch_size)  # [N, h, seq, d]V = self.split_heads(V, batch_size)  # [N, h, seq, d]# 3. 计算注意力分数matmul_qk = tf.matmul(Q, K, transpose_b=True)  # [N, h, seq_q, seq_k]scaled_attention_logits = matmul_qk / tf.math.sqrt(tf.cast(self.head_dim, tf.float32))# 4. 应用掩码(可选)if mask is not None:scaled_attention_logits += (mask * -1e9)  # 添加极大负值# 5. 计算注意力权重attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)# 6. 应用权重到Valueoutput = tf.matmul(attention_weights, V)  # [N, h, seq_q, d]# 7. 合并多头output = tf.transpose(output, perm=[0, 2, 1, 3])  # [N, seq_q, h, d]concat_attention = tf.reshape(output, (batch_size, -1, self.embed_dim))# 8. 输出层return self.output_dense(concat_attention), attention_weights

三、维度变化全流程详解

1. 参数设定
  • batch_size = 2
  • seq_len = 5
  • embed_dim = 512
  • num_heads = 8
  • head_dim = 512 // 8 = 64
2. 维度变化流程图
原始输入: [2, 5, 512]│├─线性变换───────保持形状→ [2, 5, 512]│├─拆分多头──────→ [2, 8, 5, 64]│                (拆分512为8个64维头)│├─计算注意力分数──→ [2, 8, 5, 5]│                (每个头计算5x5的注意力矩阵)│├─Softmax───────→ [2, 8, 5, 5]│                (最后一维归一化)│├─应用权重到Value→ [2, 8, 5, 64]│                (每个头输出新的序列表示)│├─合并多头───────→ [2, 5, 512]│                (拼接8个64维头恢复512维)│└─输出层────────→ [2, 5, 512]
3. 关键步骤维度变化

在这里插入图片描述

四、关键实现细节解析

1. 多头拆分与合并
# 拆分多头(核心代码)
new_shape = x.size()[:-1] + (num_heads, head_dim)
x = x.view(*new_shape).permute(0, 2, 1, 3)# 合并多头(逆过程)
x = x.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embed_dim)
  • 为什么要permute:将num_heads维度提前,便于后续矩阵乘法并行处理多个头
2. 注意力分数计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
  • 转置维度‌:将K的seq_len和head_dim维度交换,使矩阵乘法满足[seq_q, d] x [d, seq_k] → [seq_q, seq_k]
  • 缩放因子‌:防止点积结果过大导致softmax梯度消失
3. 掩码处理技巧

python

scores = scores.masked_fill(mask == 0, -1e9)
  • 作用‌:将填充位置(如)的注意力权重趋近于0
  • 为什么用-1e9‌:经过softmax后,exp(-1e9) ≈ 0

五、完整运行示例

# 测试用例
embed_dim = 512
num_heads = 8
model = MultiHeadAttention(embed_dim, num_heads)# 生成测试数据
batch_size = 2
seq_len = 5
inputs = torch.randn(batch_size, seq_len, embed_dim)# 前向传播
output, attn = model(inputs, inputs, inputs)# 验证输出形状
print(output.shape)  # torch.Size([2, 5, 512])
print(attn.shape)    # torch.Size([2, 8, 5, 5])

六、总结与常见问题

1. 核心优势
  • 并行计算效率‌:通过矩阵运算同时处理所有位置和注意力头
  • 多视角学习‌:不同注意力头可关注语法、语义等不同特征
  • 长距离依赖‌:直接计算任意两个位置间的关联
2. FAQ
  • Q1:为什么需要多个注意力头?‌

  • A:类比CNN中多个卷积核,不同头可以捕捉不同类型的特征依赖

  • Q2:head_dim为什么要设置为embed_dim/num_heads?‌

  • A:保持总参数量不变,确保拆分前后的维度乘积相等(num_heads * head_dim = embed_dim)

  • Q3:permute之后为什么要调用contiguous()?‌

  • A:确保张量在内存中连续存储,避免后续view操作报错

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词