欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 艺术 > PyTorch图像预处理--Compose

PyTorch图像预处理--Compose

2025/4/2 1:03:21 来源:https://blog.csdn.net/byxdaz/article/details/146540720  浏览:    关键词:PyTorch图像预处理--Compose

   torchvision.transforms.Compose 是 PyTorch 中用于图像预处理的核心工具,可将多个图像变换操作组合成一个顺序执行的流水线。

1. 定义与作用

  • 功能‌:将多个图像处理步骤(如缩放、裁剪、归一化等)串联为一个整体,简化代码并确保操作顺序正确‌。
  • 适用场景‌:数据预处理(训练/测试)、数据增强(如随机裁剪、翻转)‌。

2. 基本用法

通过 transforms.Compose() 按顺序传入变换列表:

from torchvision import transformstransform = transforms.Compose([transforms.Resize(256),          # 缩放图像短边至256像素transforms.CenterCrop(224),       # 中心裁剪224x224区域transforms.ToTensor(),            # 转换为张量(范围[0,1])transforms.Normalize(             # 标准化至[-1,1]mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

3. 常用变换操作

操作说明
transforms.Resize()调整图像尺寸(支持固定值或比例缩放)‌
transforms.RandomCrop()随机裁剪(常用于数据增强)‌
transforms.ToTensor()将 PIL 图像或 NumPy 数组转为张量,并归一化至 [0.0, 1.0]
transforms.Normalize()标准化处理(需先执行 ToTensor())‌

4. 标准化处理详解

假设输入为范围 [0,1] 的张量,Normalize 按以下公式处理:
image = (image - mean) / std

  • 示例‌:若 mean=0.5std=0.5,则数据范围被映射到 [-1, 1]‌。

5. 完整示例

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 定义变换
transform = transforms.Compose([transforms.RandomHorizontalFlip(),  # 随机水平翻转(数据增强)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集并应用变换
train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)# 训练循环
for images, labels in train_loader:# 输入模型训练...
  • 数据流‌:原始图像 → 随机翻转 → 张量转换 → 标准化 → 批处理输入模型‌。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词