欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 游戏 > 【AI知识】pytorch手写Attention之Self-Attention,Multi-Head-Attention

【AI知识】pytorch手写Attention之Self-Attention,Multi-Head-Attention

2025/3/22 16:14:52 来源:https://blog.csdn.net/qq_45791939/article/details/146394927  浏览:    关键词:【AI知识】pytorch手写Attention之Self-Attention,Multi-Head-Attention

pytorch手写Attention

  • Self-Attention
  • Multi-Head-Attention

Self-Attention

代码:

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self,embed_dim):super(SelfAttention,self).__init__()self.embed_dim=embed_dimself.WQ=nn.Linear(embed_dim,embed_dim)self.WK=nn.Linear(embed_dim,embed_dim)self.WV=nn.Linear(embed_dim,embed_dim)self.dropout=nn.Dropout(0.1)def forward(self,x,mask=None):"""输入序列x(batch_size,seq_len,embed_dim)"""# (batch_size,seq_len,embed_dim)Q=self.WQ(x) # (batch_size,seq_len,embed_dim)K=self.WK(x)# (batch_size,seq_len,embed_dim)V=self.WV(x)# K(batch_size,seq_len,embed_dim) ,K.transpose(-2,-1)交换张量的最后一个维度和倒数第二个维度attention_scores=torch.matmul(Q,K.transpose(-2,-1))/(self.embed_dim**0.5)# 被掩码的位置设为 -infif mask is not None:attention_scores=attention_scores.masked_fill(mask==0,float('-inf'))# 沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度),对 seq_len 维度计算,让每个 Query 的注意力总和为 1attention_weights=torch.softmax(attention_scores,-1)output=torch.matmul(attention_weights,V)return output,attention_weightsdef create_causal_mask(seq_len):"""生成一个 (seq_len, seq_len) 的上三角矩阵"""mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)  # 生成上三角部分为 1return mask == 0  # 0 表示被掩码,1 表示可用

测试:

batch_size=2
seq_len=5
embed_dim=5
x=torch.rand(batch_size,seq_len,embed_dim)mask_true=create_causal_mask(seq_len)
print("mask:")
print(mask_true)
print("-"*50)
mask_false=Noneself_attention=SelfAttention(embed_dim)output_with_mask,weights_with_mask=self_attention(x,mask_true)
print("output_with_mask:")
print(output_with_mask)
print("weights_with_mask:")
print(weights_with_mask)
print("-"*50)output_without_mask,weights_without_mask=self_attention(x,mask_false)
print("output_without_mask:")
print(output_without_mask)
print("weights_without_mask:")
print(weights_without_mask)

结果:

mask:
tensor([[ True, False, False, False, False],[ True,  True, False, False, False],[ True,  True,  True, False, False],[ True,  True,  True,  True, False],[ True,  True,  True,  True,  True]])
--------------------------------------------------
output_with_mask:
tensor([[[-0.0250, -0.0048,  0.1955,  0.1222,  0.3228],[ 0.0082, -0.1107,  0.2676,  0.1467,  0.4512],[ 0.0056, -0.1283,  0.3186,  0.1582,  0.3351],[ 0.0030, -0.0939,  0.2760,  0.1519,  0.3447],[ 0.0192, -0.1143,  0.3045,  0.1725,  0.3280]],[[-0.0036, -0.4164,  0.3813,  0.2492,  0.5639],[-0.1500, -0.3267,  0.2072,  0.0787,  0.4852],[-0.0660, -0.2758,  0.2731,  0.1216,  0.4643],[-0.0864, -0.2271,  0.2297,  0.0849,  0.4300],[-0.0653, -0.1743,  0.2279,  0.0985,  0.3965]]],grad_fn=<UnsafeViewBackward0>)
weights_with_mask:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.4885, 0.5115, 0.0000, 0.0000, 0.0000],[0.3256, 0.3318, 0.3427, 0.0000, 0.0000],[0.2447, 0.2653, 0.2568, 0.2332, 0.0000],[0.1945, 0.2031, 0.2080, 0.1886, 0.2058]],[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.5361, 0.4639, 0.0000, 0.0000, 0.0000],[0.3654, 0.3014, 0.3332, 0.0000, 0.0000],[0.2744, 0.2371, 0.2544, 0.2340, 0.0000],[0.2194, 0.1978, 0.2025, 0.1885, 0.1918]]],grad_fn=<SoftmaxBackward0>)
--------------------------------------------------
output_without_mask:
tensor([[[ 0.0190, -0.1140,  0.3034,  0.1720,  0.3307],[ 0.0191, -0.1138,  0.3035,  0.1722,  0.3295],[ 0.0191, -0.1133,  0.3035,  0.1724,  0.3276],[ 0.0191, -0.1147,  0.3040,  0.1720,  0.3311],[ 0.0192, -0.1143,  0.3045,  0.1725,  0.3280]],[[-0.0623, -0.1730,  0.2303,  0.1011,  0.3960],[-0.0615, -0.1707,  0.2301,  0.1013,  0.3945],[-0.0601, -0.1736,  0.2327,  0.1035,  0.3966],[-0.0621, -0.1722,  0.2301,  0.1011,  0.3955],[-0.0653, -0.1743,  0.2279,  0.0985,  0.3965]]],grad_fn=<UnsafeViewBackward0>)
weights_without_mask:
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],[0.1961, 0.2053, 0.2040, 0.1902, 0.2044],[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],[-0.0653, -0.1743,  0.2279,  0.0985,  0.3965]]],grad_fn=<UnsafeViewBackward0>)
weights_without_mask:
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],[0.1961, 0.2053, 0.2040, 0.1902, 0.2044],[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],[0.1961, 0.2053, 0.2040, 0.1902, 0.2044],[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],[0.1961, 0.2053, 0.2040, 0.1902, 0.2044],[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],[0.1958, 0.2123, 0.2055, 0.1866, 0.1998],[0.1945, 0.2031, 0.2080, 0.1886, 0.2058]],[[0.2205, 0.1906, 0.2044, 0.1864, 0.1981],[0.2171, 0.1878, 0.2034, 0.1879, 0.2038],[0.2250, 0.1856, 0.2051, 0.1834, 0.2009],[0.2195, 0.1897, 0.2035, 0.1872, 0.2002],[0.2194, 0.1978, 0.2025, 0.1885, 0.1918]]],grad_fn=<SoftmaxBackward0>)

代码中的一些用法解释:

1)torch.nn.Linear ()

torch.nn.Linear() 是 PyTorch 最基础的全连接层(线性变换层),用于执行以下操作:
在这里插入图片描述

nn.Linear(in_features, out_features, bias=True),in_features 是输入特征维度,out_features是输出特征维度,bias表示是否使用偏置项,默认为 True

2)K.transpose(-2,-1)

在 PyTorch 中,torch.transpose() 用于交换张量的两个维度。参数 -2 和 -1 是指张量的倒数第二个维度和最后一个维度。

K.transpose(-2, -1) 和 K.transpose(-1, -2) 都是交换最后两个维度,它们的效果完全相同。

3)torch.triu(torch.ones(seq_len, seq_len), diagonal=1)

def create_causal_mask(seq_len):"""生成一个 (seq_len, seq_len) 的上三角矩阵"""mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)  # 生成上三角部分为 1return mask == 0  # 0 表示被掩码,1 表示可用

此函数解释:

  • torch.ones(seq_len, seq_len) 生成 seq_len × seq_len 矩阵,所有元素都是 1
  • torch.triu(…, diagonal=1) 取 上三角(不包括主对角线),上三角是 1,其余 0
  • mask == 0 把 0 变成 True(可用),把 1 变成 False(被屏蔽)

这样返回一个seq_len=5的mask矩阵:

tensor([[ True, False, False, False, False],[ True,  True, False, False, False],[ True,  True,  True, False, False],[ True,  True,  True,  True, False],[ True,  True,  True,  True,  True]])

4)attention_scores.masked_fill(mask==0,float('-inf'))

作用: 根据 mask矩阵 进行掩码处理,将 mask == 0 的位置填充为 -inf(负无穷),使其在 softmax 计算时权重变为 0

masked_fill 属于 PyTorch 张量(torch.Tensor)的方法,用于根据布尔掩码(mask)填充指定值。masked_fill 的语法: tensor.masked_fill(mask, value)

tensor:要修改的张量
mask:布尔掩码(True/False 或 0/1)
value:要填充的值(如 -inf)

解释:attention_scores.masked_fill(mask==0,float('-inf'))

mask == 0 选取 应该被屏蔽的位置(即 上三角部分)
masked_fill(mask == 0, -inf) 把上三角部分设为-inf
这样,Softmax 后被屏蔽的部分变成 0,不会影响注意力计算。

5) torch.softmax(attention_scores,-1)
作用: 对 attention_scores 进行 Softmax 归一化,确保注意力权重(attn_weights)的总和为 1,控制每个 Token 对序列中其他 Token 的关注程度

torch.softmax(input, dim) 语法:

input:要进行 Softmax 计算的张量
dim:沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度)

Multi-Head-Attention

import torch
import torch.nn as nn
import mathclass Multi_Head_Attention(nn.Module):def __init__(self,embed_dim,nums_heads):super(Multi_Head_Attention,self).__init__()assert embed_dim % nums_heads ==0,"embed_dim 必须能被 num_heads 整除"self.embed_dim=embed_dimself.nums_heads=nums_headsself.head_dim=embed_dim//nums_headsself.WQ=nn.Linear(embed_dim,embed_dim)self.WK=nn.Linear(embed_dim,embed_dim)self.WV=nn.Linear(embed_dim,embed_dim)self.fc=nn.Linear(embed_dim,embed_dim)self.scale=math.sqrt(embed_dim)def forward(self,x,mask=None):batch_size,seq_len,embed_dim=x.shape# Q,K,V: batch_size,seq_len,embed_dimQ=self.WQ(x)K=self.WK(x)V=self.WV(x)# Q,K,V: batch_size,seq_len,embed_dim -> batch_size,seq_len,self.nums_heads,self.head_dim -> batch_size,self.nums_heads,seq_len,self.head_dimQ=Q.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)K=K.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)V=V.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)# batch_size,self.nums_heads,seq_len,seq_lenattn_scores=torch.matmul(Q,K.transpose(-1,-2))/self.scaleif mask is not None:attn_scores=attn_scores.masked_fill(mask==0,float('-inf'))attn_weights=torch.softmax(attn_scores,-1)# batch_size,self.nums_heads,seq_len,head_dimoutput=torch.matmul(attn_weights,V)# batch_size,self.nums_heads,seq_len,head_dim -> batch_size,seq_len,self.nums_heads,head_dim -> batch_size,seq_len,self.embed_dimoutput=output.transpose(1,2).contiguous().view(batch_size,seq_len,self.embed_dim)output=self.fc(output)return output,attn_weightsdef create_mask(seq_len):mask=torch.triu(torch.ones(seq_len,seq_len),diagonal=1)return mask==0

测试:

batch_size=2
seq_len=5
nums_heads=3
embed_dim=6
x=torch.randn(batch_size,seq_len,embed_dim)
mask=create_mask(seq_len)multiheadattention=Multi_Head_Attention(embed_dim,nums_heads)
output,weights=multiheadattention(x,mask)
print(output)
print(weights)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词