欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 健康 > 美食 > Grouped Query Attention (GQA) PyTorch实现

Grouped Query Attention (GQA) PyTorch实现

2025/4/24 15:46:08 来源:https://blog.csdn.net/qq_45812220/article/details/147353323  浏览:    关键词:Grouped Query Attention (GQA) PyTorch实现

个人在网上看到的实现好像都长得奇奇怪怪的,没有简洁的感觉,因此在这里给出一种易读的GQA实现方法:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads, num_groups):super().__init__()assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"self.num_heads = num_headsself.num_groups = num_groupsself.head_dim = embed_dim // num_headsself.group_dim = self.num_groups * self.head_dim  # Correct: num_groups * head_dimself.scale = self.head_dim ** -0.5# Projectionsself.q_proj = nn.Linear(embed_dim, embed_dim)  # Query: full embed_dim for num_headsself.k_proj = nn.Linear(embed_dim, self.group_dim)  # Key: group_dim for num_groupsself.v_proj = nn.Linear(embed_dim, self.group_dim)  # Value: group_dim for num_groupsself.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# Project inputs to q, k, vq = self.q_proj(x)  # Shape: (batch_size, seq_len, embed_dim)k = self.k_proj(x)  # Shape: (batch_size, seq_len, group_dim)v = self.v_proj(x)  # Shape: (batch_size, seq_len, group_dim)# Reshape query for multi-head attentionq = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_heads, seq_len, head_dim)# Reshape key and value for grouped attentionk = k.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_groups, seq_len, head_dim)v = v.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_groups, seq_len, head_dim)# Repeat k and v to match the number of query headsheads_per_group = self.num_heads // self.num_groupsk = k.repeat_interleave(heads_per_group, dim=1)# Shape: (batch_size, num_heads, seq_len, head_dim)v = v.repeat_interleave(heads_per_group, dim=1)# Shape: (batch_size, num_heads, seq_len, head_dim)# Compute attention scoresscores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# Shape: (batch_size, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v)  # Shape: (batch_size, num_heads, seq_len, head_dim)# Reshape and project outputout = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out)  # Shape: (batch_size, seq_len, embed_dim)return out# Test the model
embed_dim = 64
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(2, 10, embed_dim)  # Input shape: (batch_size, seq_len, embed_dim)
output = model(x)
print(output.shape)  # Expected output: torch.Size([2, 10, 64])

为了读懂GQA,建议读者了解一下MQA的实现,这样顺着读下来会更顺手。

一旦读懂了MQA,GQA的实现思路几乎完全一样,只是多用了一个不太常用的函数tensor.repeat_interleave。关于这个函数,直接点击链接看笔者相关文章就行了,挺好懂的。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词