欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > IT业 > 每日Attention学习10——Scale-Aware Modulation

每日Attention学习10——Scale-Aware Modulation

2024/10/24 23:28:22 来源:https://blog.csdn.net/qq_40714949/article/details/140380765  浏览:    关键词:每日Attention学习10——Scale-Aware Modulation
模块出处

[ICCV 23] [link] [code] Scale-Aware Modulation Meet Transformer


模块名称

Scale-Aware Modulation (SAM)


模块作用

改进的自注意力


模块结构

在这里插入图片描述


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SAM(nn.Module):def __init__(self, dim, ca_num_heads=4, sa_num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0., proj_drop=0., expand_ratio=2):super().__init__()self.ca_attention = 1self.dim = dimself.ca_num_heads = ca_num_headsself.sa_num_heads = sa_num_headsassert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}."assert dim % sa_num_heads == 0, f"dim {dim} should be divided by num_heads {sa_num_heads}."self.act = nn.GELU()self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.split_groups=self.dim//ca_num_headsself.v = nn.Linear(dim, dim, bias=qkv_bias)self.s = nn.Linear(dim, dim, bias=qkv_bias)for i in range(self.ca_num_heads):local_conv = nn.Conv2d(dim//self.ca_num_heads, dim//self.ca_num_heads, kernel_size=(3+i*2), padding=(1+i), stride=1, groups=dim//self.ca_num_heads)setattr(self, f"local_conv_{i + 1}", local_conv)self.proj0 = nn.Conv2d(dim, dim*expand_ratio, kernel_size=1, padding=0, stride=1, groups=self.split_groups)self.bn = nn.BatchNorm2d(dim*expand_ratio)self.proj1 = nn.Conv2d(dim*expand_ratio, dim, kernel_size=1, padding=0, stride=1)def forward(self, x, H, W):# InB, N, C = x.shapev = self.v(x)s = self.s(x).reshape(B, H, W, self.ca_num_heads, C//self.ca_num_heads).permute(3, 0, 4, 1, 2)# Multi-Head Mixed Convolutionfor i in range(self.ca_num_heads):local_conv = getattr(self, f"local_conv_{i + 1}")s_i= s[i]s_i = local_conv(s_i).reshape(B, self.split_groups, -1, H, W)if i == 0:s_out = s_ielse:s_out = torch.cat([s_out,s_i],2)s_out = s_out.reshape(B, C, H, W)# Scale-Aware Aggregation (SAA)s_out = self.proj1(self.act(self.bn(self.proj0(s_out))))self.modulator = s_outs_out = s_out.reshape(B, C, N).permute(0, 2, 1)x = s_out * v# Outx = self.proj(x)x = self.proj_drop(x)return xif __name__ == '__main__':x = torch.randn([3, 1024, 256])  # B, N, Csam = SAM(dim=256)out = sam(x, H=32, W=32)  # H=N*Wprint(out.shape)  # 3, 1024, 256

原文表述

我们提出了一种新颖的卷积调制,称为尺度感知调制 (SAM),它包含两个新模块:多头混合卷积 (MHMC) 和尺度感知聚合 (SAA)。MHMC 模块旨在增强感受野并同时捕获多尺度特征。SAA 模块旨在有效地聚合不同头部之间的特征,同时保持轻量级架构。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com