torchvision.transforms.Compose
是 PyTorch 中用于图像预处理的核心工具,可将多个图像变换操作组合成一个顺序执行的流水线。
1. 定义与作用
- 功能:将多个图像处理步骤(如缩放、裁剪、归一化等)串联为一个整体,简化代码并确保操作顺序正确。
- 适用场景:数据预处理(训练/测试)、数据增强(如随机裁剪、翻转)。
2. 基本用法
通过 transforms.Compose()
按顺序传入变换列表:
from torchvision import transformstransform = transforms.Compose([transforms.Resize(256), # 缩放图像短边至256像素transforms.CenterCrop(224), # 中心裁剪224x224区域transforms.ToTensor(), # 转换为张量(范围[0,1])transforms.Normalize( # 标准化至[-1,1]mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
3. 常用变换操作
操作 | 说明 |
---|---|
transforms.Resize() | 调整图像尺寸(支持固定值或比例缩放) |
transforms.RandomCrop() | 随机裁剪(常用于数据增强) |
transforms.ToTensor() | 将 PIL 图像或 NumPy 数组转为张量,并归一化至 [0.0, 1.0] |
transforms.Normalize() | 标准化处理(需先执行 ToTensor() ) |
4. 标准化处理详解
假设输入为范围 [0,1]
的张量,Normalize
按以下公式处理:
image = (image - mean) / std
- 示例:若
mean=0.5
,std=0.5
,则数据范围被映射到[-1, 1]
。
5. 完整示例
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 定义变换
transform = transforms.Compose([transforms.RandomHorizontalFlip(), # 随机水平翻转(数据增强)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集并应用变换
train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)# 训练循环
for images, labels in train_loader:# 输入模型训练...
- 数据流:原始图像 → 随机翻转 → 张量转换 → 标准化 → 批处理输入模型。