具体示例:“数字位置分类任务”
我们设计一个简单的任务来对比 CNN 和 Transformer 对位置变化的处理能力:
任务设定
- 输入:28x28 灰度图像,包含一个手写数字(0~9),但数字位置可能出现在图像任意位置(而非固定居中)。
- 目标:模型需要同时完成两个任务:
- 分类:识别数字类别(0~9)。
- 定位:预测数字的中心坐标(x, y,取值范围 [0, 27])。
- 训练数据:仅包含数字出现在图像左侧半区的样本(x ≤ 13)。
- 测试数据:数字出现在图像右侧半区(x > 13),测试模型对未见过位置的泛化能力。
1. CNN 模型设计
import torch
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super().__init__()# 特征提取(隐含平移不变性)self.features = nn.Sequential(nn.Conv2d(1, 16, 3, padding=1), # 3x3卷积核,滑动检测局部特征nn.ReLU(),nn.MaxPool2d(2), # 14x14nn.Conv2d(16, 32, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2) # 7x7)# 分类头self.classifier = nn.Sequential(nn.Linear(32*7*