欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 建筑 > 小白训练日记——2025/4/22

小白训练日记——2025/4/22

2025/4/29 16:43:34 来源:https://blog.csdn.net/benbenbai/article/details/147428814  浏览:    关键词:小白训练日记——2025/4/22

实验描述

将GobalM模块加入到changerEx的stage2中。
下面展示一些内联片段:

model = dict(backbone=dict(interaction_cfg=(None,dict(type='GlobalM', embed_dim=128,num_heads=32,axial_strategy='row'),dict(type='ChannelExchange', p=1/2),dict(type='ChannelExchange', p=1/2))),decode_head=dict(num_classes=2,sampler=dict(type='mmseg.OHEMPixelSampler', thresh=0.7, min_kept=100000)),# test_cfg=dict(mode='slide', crop_size=crop_size, stride=(crop_size[0]//2, crop_size[1]//2)),)

GlobalM的定义如下:

@ITERACTION_LAYERS.register_module()
class GlobalM(nn.Module):"""全局空间多头交互自注意力模块(Global-M)功能:挖掘高维空间光谱特征中的全局空间相关性参数说明:embed_dim: 嵌入维度(特征通道数)num_heads: 注意力头数axial_strategy: 轴向分割策略('row'行分割/'column'列分割)"""def __init__(self, embed_dim, num_heads, axial_strategy):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsself.axial_strategy = axial_strategy# 1. QKV投影层(使用1x1卷积实现)# 输入:B×C×H×W → 输出:B×3C×H×W(分别对应Q,K,V)self.qkv_proj = nn.Conv2d(embed_dim, embed_dim * 3, kernel_size=1)# 2. 多头交互卷积(论文中的ω^{3×3}操作)# 用于融合不同注意力头的特征self.mh_interaction = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1)# 3. 前馈网络(FFN)# 结构:1x1卷积扩展→GELU激活→1x1卷积压缩self.ffn = nn.Sequential(nn.Conv2d(embed_dim, embed_dim * 4, kernel_size=1),nn.GELU(),nn.Conv2d(embed_dim * 4, embed_dim, kernel_size=1))# 4. 层归一化self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)# 注意力缩放因子self.scale = self.head_dim ** -0.5def forward(self, x1, x2):"""前向传播过程输入:x (B, C, H, W)输出:增强后的特征图 (B, C, H, W)修改为兼容 IA_ResNet 的双输入结构,并返回两个输出参数:x1: 主输入特征 (B, C, H, W)x2: 次输入 (本模块未使用,仅为保持接口兼容)返回:(out, x2): 返回处理后的特征和原始 x2(保持双输出结构)"""B, C, H, W = x1.shaperesidual = x1  # 残差连接# === 第一阶段:多头自注意力 ===# 1. 层归一化x_norm = self.norm1(x1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)# 2. 生成Q,K,Vqkv = self.qkv_proj(x_norm)  # B×3C×H×Wq, k, v = qkv.chunk(3, dim=1)  # 各为B×C×H×W# 3. 全局轴向分割(GAS策略)if self.axial_strategy == 'row':# 行分割:将特征图按行分成H个token,每个token尺寸为W×Cq = q.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2)  # B×H×Nh×W×Dhk = k.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2)v = v.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2)else:  # column# 列分割:将特征图按列分成W个token,每个token尺寸为H×Cq = q.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2)  # B×W×Nh×H×Dhk = k.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2)v = v.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2)# 4. 计算注意力权重(QK^T/sqrt(d_k))attn = (q @ k.transpose(-2, -1)) * self.scale  # B×L×Nh×L×L (L=H/W)attn = attn.softmax(dim=-1)# 5. 注意力加权求和out = attn @ v  # B×L×Nh×L×Dh# 6. 恢复原始形状if self.axial_strategy == 'row':out = out.permute(0, 2, 4, 1, 3)  # B×Nh×Dh×H×Welse:out = out.permute(0, 2, 4, 3, 1)  # B×Nh×Dh×H×Wout = out.reshape(B, C, H, W)# 7. 多头交互(3x3卷积融合多头特征)out = self.mh_interaction(out)out += residual  # 残差连接# === 第二阶段:前馈网络 ===residual = outout = self.norm2(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)out = self.ffn(out)out += residual# 返回 (out, x2),即使 x2 未被修改return out, x2 if x2 is not None else x1  # 如果 x2 是 None,返回 x1 作为占位

实验结果

在s2looking上的训练完成后,在验证集(每8k次进行一次验证)上的结果为:

2025/04/16 01:10:45 - mmengine - INFO - Iter(val) [ 50/500]    eta: 0:00:17  time: 0.0331  data_time: 0.0058  memory: 578  
2025/04/16 01:10:47 - mmengine - INFO - Iter(val) [100/500]    eta: 0:00:14  time: 0.0402  data_time: 0.0130  memory: 578  
2025/04/16 01:10:49 - mmengine - INFO - Iter(val) [150/500]    eta: 0:00:12  time: 0.0317  data_time: 0.0045  memory: 578  
2025/04/16 01:10:50 - mmengine - INFO - Iter(val) [200/500]    eta: 0:00:10  time: 0.0318  data_time: 0.0046  memory: 578  
2025/04/16 01:10:52 - mmengine - INFO - Iter(val) [250/500]    eta: 0:00:08  time: 0.0319  data_time: 0.0046  memory: 578  
2025/04/16 01:10:54 - mmengine - INFO - Iter(val) [300/500]    eta: 0:00:06  time: 0.0321  data_time: 0.0047  memory: 578  
2025/04/16 01:10:55 - mmengine - INFO - Iter(val) [350/500]    eta: 0:00:05  time: 0.0318  data_time: 0.0045  memory: 578  
2025/04/16 01:10:57 - mmengine - INFO - Iter(val) [400/500]    eta: 0:00:03  time: 0.0361  data_time: 0.0087  memory: 578  
2025/04/16 01:10:59 - mmengine - INFO - Iter(val) [450/500]    eta: 0:00:01  time: 0.0320  data_time: 0.0047  memory: 578  
2025/04/16 01:11:00 - mmengine - INFO - Iter(val) [500/500]    eta: 0:00:00  time: 0.0310  data_time: 0.0041  memory: 578  
2025/04/16 01:11:00 - mmengine - INFO - per class results:
2025/04/16 01:11:00 - mmengine - INFO - 
+-----------+--------+-----------+--------+-------+-------+
|   Class   | Fscore | Precision | Recall |  IoU  |  Acc  |
+-----------+--------+-----------+--------+-------+-------+
| unchanged | 98.96  |   99.07   | 98.85  | 97.94 | 98.85 |
|  changed  | 30.29  |    28.2   | 32.71  | 17.85 | 32.71 |
+-----------+--------+-----------+--------+-------+-------+
2025/04/16 01:11:00 - mmengine - INFO - Iter(val) [500/500]    aAcc: 97.9500  mFscore: 64.6200  mPrecision: 63.6300  mRecall: 65.7800  mIoU: 57.8900  mAcc: 65.7800  data_time: 0.0062  time: 0.0340

在测试集上的结果为:

2025/04/22 19:40:01 - mmengine - WARNING - The prefix is not set in metric class IoUMetric.
2025/04/22 19:40:01 - mmengine - INFO - Load checkpoint from changer_r18_globalM_stage2/iter_80000.pth
2025/04/22 19:40:17 - mmengine - INFO - Iter(test) [  50/1000]    eta: 0:04:56  time: 0.1928  data_time: 0.1650  memory: 17217  
2025/04/22 19:40:26 - mmengine - INFO - Iter(test) [ 100/1000]    eta: 0:03:40  time: 0.1764  data_time: 0.1476  memory: 479  
2025/04/22 19:40:35 - mmengine - INFO - Iter(test) [ 150/1000]    eta: 0:03:10  time: 0.1709  data_time: 0.1438  memory: 479  
2025/04/22 19:40:43 - mmengine - INFO - Iter(test) [ 200/1000]    eta: 0:02:46  time: 0.1655  data_time: 0.1373  memory: 479  
2025/04/22 19:40:51 - mmengine - INFO - Iter(test) [ 250/1000]    eta: 0:02:27  time: 0.1496  data_time: 0.1225  memory: 479  
2025/04/22 19:40:59 - mmengine - INFO - Iter(test) [ 300/1000]    eta: 0:02:14  time: 0.1564  data_time: 0.1296  memory: 479  
2025/04/22 19:41:07 - mmengine - INFO - Iter(test) [ 350/1000]    eta: 0:02:02  time: 0.1574  data_time: 0.1296  memory: 479  
2025/04/22 19:41:16 - mmengine - INFO - Iter(test) [ 400/1000]    eta: 0:01:51  time: 0.1477  data_time: 0.1202  memory: 479  
2025/04/22 19:41:24 - mmengine - INFO - Iter(test) [ 450/1000]    eta: 0:01:41  time: 0.1874  data_time: 0.1595  memory: 479  
2025/04/22 19:41:33 - mmengine - INFO - Iter(test) [ 500/1000]    eta: 0:01:32  time: 0.1641  data_time: 0.1348  memory: 479  
2025/04/22 19:41:41 - mmengine - INFO - Iter(test) [ 550/1000]    eta: 0:01:21  time: 0.1415  data_time: 0.1143  memory: 479  
2025/04/22 19:41:49 - mmengine - INFO - Iter(test) [ 600/1000]    eta: 0:01:12  time: 0.1890  data_time: 0.1607  memory: 479  
2025/04/22 19:41:58 - mmengine - INFO - Iter(test) [ 650/1000]    eta: 0:01:02  time: 0.1698  data_time: 0.1416  memory: 479  
2025/04/22 19:42:06 - mmengine - INFO - Iter(test) [ 700/1000]    eta: 0:00:53  time: 0.1320  data_time: 0.1043  memory: 479  
2025/04/22 19:42:13 - mmengine - INFO - Iter(test) [ 750/1000]    eta: 0:00:43  time: 0.1528  data_time: 0.1256  memory: 479  
2025/04/22 19:42:22 - mmengine - INFO - Iter(test) [ 800/1000]    eta: 0:00:35  time: 0.1697  data_time: 0.1419  memory: 479  
2025/04/22 19:42:31 - mmengine - INFO - Iter(test) [ 850/1000]    eta: 0:00:26  time: 0.1735  data_time: 0.1456  memory: 479  
2025/04/22 19:42:39 - mmengine - INFO - Iter(test) [ 900/1000]    eta: 0:00:17  time: 0.1672  data_time: 0.1386  memory: 479  
2025/04/22 19:42:49 - mmengine - INFO - Iter(test) [ 950/1000]    eta: 0:00:08  time: 0.1929  data_time: 0.1649  memory: 479  
2025/04/22 19:42:58 - mmengine - INFO - Iter(test) [1000/1000]    eta: 0:00:00  time: 0.1733  data_time: 0.1454  memory: 479  
2025/04/22 19:42:58 - mmengine - INFO - per class results:
2025/04/22 19:42:58 - mmengine - INFO - 
+-----------+--------+-----------+--------+-------+-------+
|   Class   | Fscore | Precision | Recall |  IoU  |  Acc  |
+-----------+--------+-----------+--------+-------+-------+
| unchanged | 99.39  |   98.79   | 100.0  | 98.79 | 100.0 |
|  changed  |  0.08  |   90.77   |  0.04  |  0.04 |  0.04 |
+-----------+--------+-----------+--------+-------+-------+
2025/04/22 19:42:58 - mmengine - INFO - Iter(test) [1000/1000]    aAcc: 98.7900  mFscore: 49.7400  mPrecision: 94.7800  mRecall: 50.0200  mIoU: 49.4100  mAcc: 50.0200  data_time: 0.1437  time: 0.1761

实验结果分析

关键观察:
1.类别间性能差异显著:
验证集:unchanged类Fscore=98.96 vs changed类Fscore=30.29
测试集:unchanged类Fscore=99.39 vs changed类Fscore=0.08
2.测试集性能崩塌:
changed类的Recall从验证集32.71骤降到0.04,说明模型完全无法检测变化区域
3.训练-测试泛化差距:
验证集mIoU=57.89 → 测试集mIoU=49.41,显示过拟合风险

可能原因分析:

    1. 类别极端不平衡问题
      从验证集结果推测数据分布中unchanged样本占比极高(可能超过99%)
      模型学习到"always predict unchanged"的简单策略即可获得高整体准确率
      测试集changed类样本可能更少或分布差异更大
    1. GlobalM模块适配性问题
      轴向注意力(row/column分割)可能破坏局部空间关系,对变化检测需要的精细定位不利
      32个注意力头过多(通常建议head_dim≥32),可能导致注意力过于分散
      多头交互的3x3卷积可能引入不必要的位置偏置
    1. 训练策略缺陷
      OHEM设置不当(thresh=0.7过高),难例挖掘未能有效捕捉变化样本
      没有使用类别平衡损失函数(如Focal Loss)
      可能缺乏有效的数据增强(特别是对变化区域的增强)
    1. 特征交互设计问题
      Stage2中GlobalM与ChannelExchange交替使用可能导致特征混淆
      双流结构的信息融合方式不够合理(x2特征未被有效利用)

改进建议

1.数据层面:

  • 重平衡数据集:
    使用加权采样(oversampling changed类)
    引入copy-paste augmentation人工增加变化区域

  • 增强策略:
    针对变化检测设计空间变换增强(如非对称形变)
    使用MixUp增强策略平衡类别

2. 模型层面:

  • 注意力机制改进:
# 修改GlobalM参数配置
dict(type='GlobalM', embed_dim=128,num_heads=8,  # 减少注意力头数axial_strategy='hybrid',  # 增加行列交替策略use_local_attn=True)  # 新增局部注意力分支
  • 特征交互优化:
    在GlobalM后增加空间注意力门控机制
    将ChannelExchange替换为可学习的ChannelAttention
  • 解码头改进:
decode_head=dict(num_classes=2,loss_decode=dict(type='FocalLoss', loss_weight=[1.0, 5.0]),  # 类别加权sampler=dict(type='mmseg.OHEMPixelSampler', thresh=0.5, min_kept=150000)  # 调整阈值
)

3.训练策略:

  • 渐进式训练:
    第一阶段冻结GlobalM,先训练基础特征提取器
    第二阶段分层解冻注意力模块
  • 损失函数改进:
    Dice Loss + Focal Loss组合
    引入边界感知损失(Boundary-aware Loss)
  • 后处理优化:
test_cfg=dict(mode='slide',crop_size=(512, 512),stride=(256, 256),post_process=dict(type='CRFPostProcess',  # 增加条件随机场后处理win_size=5,iter_max=10)
)

验证方向:

  • 可视化分析:
    使用Grad-CAM可视化GlobalM的注意力区域
    对比有无GlobalM的特征响应图
  • 消融实验:
    逐步移除各组件验证有效性
    测试不同注意力头数的影响(8/16/32)
  • 数据诊断:
    统计测试集changed样本数量
    检查标注质量(可能存在标注噪声)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词