实验描述
将GobalM模块加入到changerEx的stage2中。
下面展示一些内联片段:
model = dict(backbone=dict(interaction_cfg=(None,dict(type='GlobalM', embed_dim=128,num_heads=32,axial_strategy='row'),dict(type='ChannelExchange', p=1/2),dict(type='ChannelExchange', p=1/2))),decode_head=dict(num_classes=2,sampler=dict(type='mmseg.OHEMPixelSampler', thresh=0.7, min_kept=100000)),# test_cfg=dict(mode='slide', crop_size=crop_size, stride=(crop_size[0]//2, crop_size[1]//2)),)
GlobalM的定义如下:
@ITERACTION_LAYERS.register_module()
class GlobalM(nn.Module):"""全局空间多头交互自注意力模块(Global-M)功能:挖掘高维空间光谱特征中的全局空间相关性参数说明:embed_dim: 嵌入维度(特征通道数)num_heads: 注意力头数axial_strategy: 轴向分割策略('row'行分割/'column'列分割)"""def __init__(self, embed_dim, num_heads, axial_strategy):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsself.axial_strategy = axial_strategy# 1. QKV投影层(使用1x1卷积实现)# 输入:B×C×H×W → 输出:B×3C×H×W(分别对应Q,K,V)self.qkv_proj = nn.Conv2d(embed_dim, embed_dim * 3, kernel_size=1)# 2. 多头交互卷积(论文中的ω^{3×3}操作)# 用于融合不同注意力头的特征self.mh_interaction = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1)# 3. 前馈网络(FFN)# 结构:1x1卷积扩展→GELU激活→1x1卷积压缩self.ffn = nn.Sequential(nn.Conv2d(embed_dim, embed_dim * 4, kernel_size=1),nn.GELU(),nn.Conv2d(embed_dim * 4, embed_dim, kernel_size=1))# 4. 层归一化self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)# 注意力缩放因子self.scale = self.head_dim ** -0.5def forward(self, x1, x2):"""前向传播过程输入:x (B, C, H, W)输出:增强后的特征图 (B, C, H, W)修改为兼容 IA_ResNet 的双输入结构,并返回两个输出参数:x1: 主输入特征 (B, C, H, W)x2: 次输入 (本模块未使用,仅为保持接口兼容)返回:(out, x2): 返回处理后的特征和原始 x2(保持双输出结构)"""B, C, H, W = x1.shaperesidual = x1 # 残差连接# === 第一阶段:多头自注意力 ===# 1. 层归一化x_norm = self.norm1(x1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)# 2. 生成Q,K,Vqkv = self.qkv_proj(x_norm) # B×3C×H×Wq, k, v = qkv.chunk(3, dim=1) # 各为B×C×H×W# 3. 全局轴向分割(GAS策略)if self.axial_strategy == 'row':# 行分割:将特征图按行分成H个token,每个token尺寸为W×Cq = q.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2) # B×H×Nh×W×Dhk = k.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2)v = v.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2)else: # column# 列分割:将特征图按列分成W个token,每个token尺寸为H×Cq = q.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2) # B×W×Nh×H×Dhk = k.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2)v = v.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2)# 4. 计算注意力权重(QK^T/sqrt(d_k))attn = (q @ k.transpose(-2, -1)) * self.scale # B×L×Nh×L×L (L=H/W)attn = attn.softmax(dim=-1)# 5. 注意力加权求和out = attn @ v # B×L×Nh×L×Dh# 6. 恢复原始形状if self.axial_strategy == 'row':out = out.permute(0, 2, 4, 1, 3) # B×Nh×Dh×H×Welse:out = out.permute(0, 2, 4, 3, 1) # B×Nh×Dh×H×Wout = out.reshape(B, C, H, W)# 7. 多头交互(3x3卷积融合多头特征)out = self.mh_interaction(out)out += residual # 残差连接# === 第二阶段:前馈网络 ===residual = outout = self.norm2(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)out = self.ffn(out)out += residual# 返回 (out, x2),即使 x2 未被修改return out, x2 if x2 is not None else x1 # 如果 x2 是 None,返回 x1 作为占位
实验结果
在s2looking上的训练完成后,在验证集(每8k次进行一次验证)上的结果为:
2025/04/16 01:10:45 - mmengine - INFO - Iter(val) [ 50/500] eta: 0:00:17 time: 0.0331 data_time: 0.0058 memory: 578
2025/04/16 01:10:47 - mmengine - INFO - Iter(val) [100/500] eta: 0:00:14 time: 0.0402 data_time: 0.0130 memory: 578
2025/04/16 01:10:49 - mmengine - INFO - Iter(val) [150/500] eta: 0:00:12 time: 0.0317 data_time: 0.0045 memory: 578
2025/04/16 01:10:50 - mmengine - INFO - Iter(val) [200/500] eta: 0:00:10 time: 0.0318 data_time: 0.0046 memory: 578
2025/04/16 01:10:52 - mmengine - INFO - Iter(val) [250/500] eta: 0:00:08 time: 0.0319 data_time: 0.0046 memory: 578
2025/04/16 01:10:54 - mmengine - INFO - Iter(val) [300/500] eta: 0:00:06 time: 0.0321 data_time: 0.0047 memory: 578
2025/04/16 01:10:55 - mmengine - INFO - Iter(val) [350/500] eta: 0:00:05 time: 0.0318 data_time: 0.0045 memory: 578
2025/04/16 01:10:57 - mmengine - INFO - Iter(val) [400/500] eta: 0:00:03 time: 0.0361 data_time: 0.0087 memory: 578
2025/04/16 01:10:59 - mmengine - INFO - Iter(val) [450/500] eta: 0:00:01 time: 0.0320 data_time: 0.0047 memory: 578
2025/04/16 01:11:00 - mmengine - INFO - Iter(val) [500/500] eta: 0:00:00 time: 0.0310 data_time: 0.0041 memory: 578
2025/04/16 01:11:00 - mmengine - INFO - per class results:
2025/04/16 01:11:00 - mmengine - INFO -
+-----------+--------+-----------+--------+-------+-------+
| Class | Fscore | Precision | Recall | IoU | Acc |
+-----------+--------+-----------+--------+-------+-------+
| unchanged | 98.96 | 99.07 | 98.85 | 97.94 | 98.85 |
| changed | 30.29 | 28.2 | 32.71 | 17.85 | 32.71 |
+-----------+--------+-----------+--------+-------+-------+
2025/04/16 01:11:00 - mmengine - INFO - Iter(val) [500/500] aAcc: 97.9500 mFscore: 64.6200 mPrecision: 63.6300 mRecall: 65.7800 mIoU: 57.8900 mAcc: 65.7800 data_time: 0.0062 time: 0.0340
在测试集上的结果为:
2025/04/22 19:40:01 - mmengine - WARNING - The prefix is not set in metric class IoUMetric.
2025/04/22 19:40:01 - mmengine - INFO - Load checkpoint from changer_r18_globalM_stage2/iter_80000.pth
2025/04/22 19:40:17 - mmengine - INFO - Iter(test) [ 50/1000] eta: 0:04:56 time: 0.1928 data_time: 0.1650 memory: 17217
2025/04/22 19:40:26 - mmengine - INFO - Iter(test) [ 100/1000] eta: 0:03:40 time: 0.1764 data_time: 0.1476 memory: 479
2025/04/22 19:40:35 - mmengine - INFO - Iter(test) [ 150/1000] eta: 0:03:10 time: 0.1709 data_time: 0.1438 memory: 479
2025/04/22 19:40:43 - mmengine - INFO - Iter(test) [ 200/1000] eta: 0:02:46 time: 0.1655 data_time: 0.1373 memory: 479
2025/04/22 19:40:51 - mmengine - INFO - Iter(test) [ 250/1000] eta: 0:02:27 time: 0.1496 data_time: 0.1225 memory: 479
2025/04/22 19:40:59 - mmengine - INFO - Iter(test) [ 300/1000] eta: 0:02:14 time: 0.1564 data_time: 0.1296 memory: 479
2025/04/22 19:41:07 - mmengine - INFO - Iter(test) [ 350/1000] eta: 0:02:02 time: 0.1574 data_time: 0.1296 memory: 479
2025/04/22 19:41:16 - mmengine - INFO - Iter(test) [ 400/1000] eta: 0:01:51 time: 0.1477 data_time: 0.1202 memory: 479
2025/04/22 19:41:24 - mmengine - INFO - Iter(test) [ 450/1000] eta: 0:01:41 time: 0.1874 data_time: 0.1595 memory: 479
2025/04/22 19:41:33 - mmengine - INFO - Iter(test) [ 500/1000] eta: 0:01:32 time: 0.1641 data_time: 0.1348 memory: 479
2025/04/22 19:41:41 - mmengine - INFO - Iter(test) [ 550/1000] eta: 0:01:21 time: 0.1415 data_time: 0.1143 memory: 479
2025/04/22 19:41:49 - mmengine - INFO - Iter(test) [ 600/1000] eta: 0:01:12 time: 0.1890 data_time: 0.1607 memory: 479
2025/04/22 19:41:58 - mmengine - INFO - Iter(test) [ 650/1000] eta: 0:01:02 time: 0.1698 data_time: 0.1416 memory: 479
2025/04/22 19:42:06 - mmengine - INFO - Iter(test) [ 700/1000] eta: 0:00:53 time: 0.1320 data_time: 0.1043 memory: 479
2025/04/22 19:42:13 - mmengine - INFO - Iter(test) [ 750/1000] eta: 0:00:43 time: 0.1528 data_time: 0.1256 memory: 479
2025/04/22 19:42:22 - mmengine - INFO - Iter(test) [ 800/1000] eta: 0:00:35 time: 0.1697 data_time: 0.1419 memory: 479
2025/04/22 19:42:31 - mmengine - INFO - Iter(test) [ 850/1000] eta: 0:00:26 time: 0.1735 data_time: 0.1456 memory: 479
2025/04/22 19:42:39 - mmengine - INFO - Iter(test) [ 900/1000] eta: 0:00:17 time: 0.1672 data_time: 0.1386 memory: 479
2025/04/22 19:42:49 - mmengine - INFO - Iter(test) [ 950/1000] eta: 0:00:08 time: 0.1929 data_time: 0.1649 memory: 479
2025/04/22 19:42:58 - mmengine - INFO - Iter(test) [1000/1000] eta: 0:00:00 time: 0.1733 data_time: 0.1454 memory: 479
2025/04/22 19:42:58 - mmengine - INFO - per class results:
2025/04/22 19:42:58 - mmengine - INFO -
+-----------+--------+-----------+--------+-------+-------+
| Class | Fscore | Precision | Recall | IoU | Acc |
+-----------+--------+-----------+--------+-------+-------+
| unchanged | 99.39 | 98.79 | 100.0 | 98.79 | 100.0 |
| changed | 0.08 | 90.77 | 0.04 | 0.04 | 0.04 |
+-----------+--------+-----------+--------+-------+-------+
2025/04/22 19:42:58 - mmengine - INFO - Iter(test) [1000/1000] aAcc: 98.7900 mFscore: 49.7400 mPrecision: 94.7800 mRecall: 50.0200 mIoU: 49.4100 mAcc: 50.0200 data_time: 0.1437 time: 0.1761
实验结果分析
关键观察:
1.类别间性能差异显著:
验证集:unchanged类Fscore=98.96 vs changed类Fscore=30.29
测试集:unchanged类Fscore=99.39 vs changed类Fscore=0.08
2.测试集性能崩塌:
changed类的Recall从验证集32.71骤降到0.04,说明模型完全无法检测变化区域
3.训练-测试泛化差距:
验证集mIoU=57.89 → 测试集mIoU=49.41,显示过拟合风险
可能原因分析:
-
- 类别极端不平衡问题
从验证集结果推测数据分布中unchanged样本占比极高(可能超过99%)
模型学习到"always predict unchanged"的简单策略即可获得高整体准确率
测试集changed类样本可能更少或分布差异更大
- 类别极端不平衡问题
-
- GlobalM模块适配性问题
轴向注意力(row/column分割)可能破坏局部空间关系,对变化检测需要的精细定位不利
32个注意力头过多(通常建议head_dim≥32),可能导致注意力过于分散
多头交互的3x3卷积可能引入不必要的位置偏置
- GlobalM模块适配性问题
-
- 训练策略缺陷
OHEM设置不当(thresh=0.7过高),难例挖掘未能有效捕捉变化样本
没有使用类别平衡损失函数(如Focal Loss)
可能缺乏有效的数据增强(特别是对变化区域的增强)
- 训练策略缺陷
-
- 特征交互设计问题
Stage2中GlobalM与ChannelExchange交替使用可能导致特征混淆
双流结构的信息融合方式不够合理(x2特征未被有效利用)
- 特征交互设计问题
改进建议
1.数据层面:
-
重平衡数据集:
使用加权采样(oversampling changed类)
引入copy-paste augmentation人工增加变化区域 -
增强策略:
针对变化检测设计空间变换增强(如非对称形变)
使用MixUp增强策略平衡类别
2. 模型层面:
- 注意力机制改进:
# 修改GlobalM参数配置
dict(type='GlobalM', embed_dim=128,num_heads=8, # 减少注意力头数axial_strategy='hybrid', # 增加行列交替策略use_local_attn=True) # 新增局部注意力分支
- 特征交互优化:
在GlobalM后增加空间注意力门控机制
将ChannelExchange替换为可学习的ChannelAttention - 解码头改进:
decode_head=dict(num_classes=2,loss_decode=dict(type='FocalLoss', loss_weight=[1.0, 5.0]), # 类别加权sampler=dict(type='mmseg.OHEMPixelSampler', thresh=0.5, min_kept=150000) # 调整阈值
)
3.训练策略:
- 渐进式训练:
第一阶段冻结GlobalM,先训练基础特征提取器
第二阶段分层解冻注意力模块 - 损失函数改进:
Dice Loss + Focal Loss组合
引入边界感知损失(Boundary-aware Loss) - 后处理优化:
test_cfg=dict(mode='slide',crop_size=(512, 512),stride=(256, 256),post_process=dict(type='CRFPostProcess', # 增加条件随机场后处理win_size=5,iter_max=10)
)
验证方向:
- 可视化分析:
使用Grad-CAM可视化GlobalM的注意力区域
对比有无GlobalM的特征响应图 - 消融实验:
逐步移除各组件验证有效性
测试不同注意力头数的影响(8/16/32) - 数据诊断:
统计测试集changed样本数量
检查标注质量(可能存在标注噪声)