论文介绍
论文链接:https://arxiv.org/pdf/2307.14010
- 研究背景:高光谱图像(HSI)的超分辨率(SR)任务中,传统方法存在光谱信息利用不充分和上采样后产生伪影的问题。
- 研究目的:提出ESSAformer模型,旨在解决上述问题,提高超分辨率任务的性能。
- 主要内容:介绍了ESSAformer模型的整体结构、创新点、具体方法以及实验验证结果。
创新点
- 迭代精炼结构:通过迭代下采样和上采样策略,捕捉不同尺度下的全局和局部信息,编码高光谱图像的详细内容。
- 光谱相关系数:提出使用光谱相关系数替代传统的点积(余弦相似度),使模型对光谱信息更加友好。
- 高效注意力方法:专为HSI设计的ESSA注意力方法,考虑HSI特性,通过kernelizable技术实现乘法交换,提高计算效率。
方法
- 整体结构:ESSAformer模型采用迭代精炼结构,通过不同阶段的编码和解码过程,逐步精炼图像特征。
- SCC自注意力:作为Transformer的核心之一,自注意力通过关注每个位置的特征来扩大依赖距离。ESSAformer引入了考虑HSI特性的SCC自注意力,以提高数据效率和表示能力。
- 高效SCC核基自注意力:基于SCC自注意力,提出高效SCC核基自注意力(ESSA),以减轻注意力计算的负担。通过核函数计算,实现快速且有效的特征提取。
模块作用
- ESSAttn模块(即ESSA):
- 作用:该模块是ESSAformer模型的核心,通过引入光谱友好的注意力机制,提高模型对HSI的处理能力。
- 细节:ESSAttn模块利用光谱相关系数计算特征之间的相似性,并通过kernelizable技术实现高效的注意力计算。这不仅提高了计算效率,还使模型能够更好地捕捉HSI中的光谱信息。
- 效果:实验结果表明,ESSAttn模块能够显著提高超分辨率任务的性能,恢复出更清晰的图像细节,同时保持较低的计算成本。
总结而言,该论文提出的ESSAformer模型通过引入迭代精炼结构、光谱相关系数以及高效的ESSAttn模块,成功解决了高光谱图像超分辨率任务中的关键问题。实验结果表明,ESSAformer模型在多个公开数据集上取得了领先的性能,证明了其有效性和实用性。代码如下:
class ESSAttn(nn.Module):def __init__(self, dim):super().__init__()self.lnqkv = nn.Linear(dim, dim * 3)self.ln = nn.Linear(dim, dim)def forward(self, x):b, N, C = x.shapeqkv = self.lnqkv(x)qkv = torch.split(qkv, C, 2)q, k, v = qkv[0], qkv[1], qkv[2]a = torch.mean(q, dim=2, keepdim=True)q = q - aa = torch.mean(k, dim=2, keepdim=True)k = k - aq2 = torch.pow(q, 2)q2s = torch.sum(q2, dim=2, keepdim=True)k2 = torch.pow(k, 2)k2s = torch.sum(k2, dim=2, keepdim=True)t1 = vk2 = torch.nn.functional.normalize((k2 / (k2s + 1e-7)), dim=-2)q2 = torch.nn.functional.normalize((q2 / (q2s + 1e-7)), dim=-1)t2 = q2 @ (k2.transpose(-2, -1) @ v) / math.sqrt(N)# t2 = self.norm1(t2)*0.3# print(torch.mean(t1),torch.std(t1))# print(torch.mean(t2), torch.std(t2))# t2 = self.norm1(t2)*0.1# t2 = ((q2 / (q2s+1e-7)) @ t2)# q3 = torch.pow(q,4)# q3s = torch.pow(q2s,2)# k3 = torch.pow(k, 4)# k3s = torch.sum(k2,dim=2).unsqueeze(2).repeat(1, 1, C)# t3 = ((k3 / k3s)*16).transpose(-2, -1) @ v# t3 = ((q3 / q3s)*16) @ t3# print(torch.max(t1))# print(torch.max(t2))# t3 = (((torch.pow(q,4))/24) @ (((torch.pow(k,4).transpose(-2,-1))/24)@v)*16/math.sqrt(N))attn = t1 + t2attn = self.ln(attn)return attndef is_same_matrix(self, m1, m2):rows, cols = len(m1), len(m1[0])for i in range(rows):for j in range(cols):if m1[i][j] != m2[i][j]:return Falsereturn Trueif __name__ == '__main__':input = torch.randn(1, 400, 32)essa = ESSAttn(32)output = essa(input)print(essa)print(input.size())print(output.size())
输出结果:
torch.Size([1, 400, 32])
torch.Size([1, 400, 32])
进一步改进,让其符合卷积的输入,代码如下:
import torch
import torch.nn as nn
import mathclass ESSAttn(nn.Module):def __init__(self, dim):super().__init__()self.lnqkv = nn.Linear(dim, dim * 3)self.ln = nn.Linear(dim, dim)def forward(self, x):B, C, H, W = x.shapex = x.reshape(B, C, H * W).permute(0, 2, 1)B, N, C = x.shapeqkv = self.lnqkv(x)qkv = torch.split(qkv, C, 2)q, k, v = qkv[0], qkv[1], qkv[2]a = torch.mean(q, dim=2, keepdim=True)q = q - aa = torch.mean(k, dim=2, keepdim=True)k = k - aq2 = torch.pow(q, 2)q2s = torch.sum(q2, dim=2, keepdim=True)k2 = torch.pow(k, 2)k2s = torch.sum(k2, dim=2, keepdim=True)t1 = vk2 = torch.nn.functional.normalize((k2 / (k2s + 1e-7)), dim=-2)q2 = torch.nn.functional.normalize((q2 / (q2s + 1e-7)), dim=-1)t2 = q2 @ (k2.transpose(-2, -1) @ v) / math.sqrt(N)attn = t1 + t2attn = self.ln(attn)x = attn.reshape(B, H, W, C).permute(0, 3, 1, 2)return xif __name__ == '__main__':input = torch.randn(1, 32, 640, 480)essa = ESSAttn(32)output = essa(input)print(essa)print(input.size())print(output.size())
输出结果:
torch.Size([1, 32, 640, 480])
torch.Size([1, 32, 640, 480])
代码逐行讲解
这段代码定义了一个名为 ESSAttn
的类,它继承自 PyTorch 的 nn.Module
类,用于构建一个自定义的注意力机制模块。下面是对这段代码的逐行解释:
class ESSAttn(nn.Module):
- 定义了一个名为
ESSAttn
的类,它继承自 PyTorch 的nn.Module
类。nn.Module
是所有神经网络模块的基类。
def __init__(self, dim):
- 这是类的初始化方法,
dim
参数指定了输入特征的维度。
super().__init__()
- 调用父类
nn.Module
的初始化方法。
self.lnqkv = nn.Linear(dim, dim * 3)
- 定义了一个线性层
lnqkv
,它将输入特征从dim
维映射到3 * dim
维。这个线性层用于生成查询(query)、键(key)和值(value)的投影。
self.ln = nn.Linear(dim, dim)
- 定义了一个线性层
ln
,用于在注意力机制之后对输出进行变换,保持输入输出维度一致。
def forward(self, x):
- 定义了前向传播函数,
x
是输入的特征图。
B, C, H, W = x.shape
- 获取输入特征图的批次大小(
B
)、通道数(C
)、高度(H
)和宽度(W
)。
x = x.reshape(B, C, H * W).permute(0, 2, 1)
- 将特征图从
(B, C, H, W)
形状重塑为(B, C, H*W)
,然后交换第二和第三维度,得到(B, H*W, C)
的形状,以便后续处理。
B, N, C = x.shape
- 重新获取变换后的特征图的形状,其中
N = H * W
表示展平后的像素总数。
qkv = self.lnqkv(x)
- 通过线性层
lnqkv
投影输入,得到查询、键和值的组合表示。
qkv = torch.split(qkv, C, 2)
- 将
qkv
沿着最后一个维度(特征维度)分割成三部分,分别对应查询(q)、键(k)和值(v)。
q, k, v = qkv[0], qkv[1], qkv[2]
- 提取查询、键和值。
a = torch.mean(q, dim=2, keepdim=True)q = q - aa = torch.mean(k, dim=2, keepdim=True)k = k - a
- 分别计算查询和键在每个样本上的均值,并从每个查询和键中减去对应的均值,实现中心化。
q2 = torch.pow(q, 2)q2s = torch.sum(q2, dim=2, keepdim=True)k2 = torch.pow(k, 2)k2s = torch.sum(k2, dim=2, keepdim=True)
- 计算查询和键的平方,然后沿着特征维度求和,得到每个查询和键的平方和。
t1 = v
- 将值(v)作为
t1
,准备后续与注意力权重相乘。
k2 = torch.nn.functional.normalize((k2 / (k2s + 1e-7)), dim=-2)q2 = torch.nn.functional.normalize((q2 / (q2s + 1e-7)), dim=-1)
- 对键和查询的平方进行 L2 归一化,以避免数值不稳定。
t2 = q2 @ (k2.transpose(-2, -1) @ v) / math.sqrt(N)
- 计算注意力权重,并通过矩阵乘法将其应用于值(v)。除以
sqrt(N)
是为了缩放注意力分数,防止其过大。
attn = t1 + t2
- 将原始的值(v)和通过注意力机制加权后的值(
t2
)相加,得到最终的注意力输出。
attn = self.ln(attn)
- 通过线性层
ln
对注意力输出进行变换。
x = attn.reshape(B, H, W, C).permute(0, 3, 1, 2)
- 将注意力输出重塑回原始的空间维度
(B, C, H, W)
,并交换维度以匹配输入格式。
return x
- 返回处理后的特征图。