PyTorch torchvision 是 PyTorch 生态中专门用于计算机视觉任务的库,提供了以下核心功能:
1. 核心功能概览
功能类别 | 主要内容 |
---|---|
数据集 | 内置经典数据集(MNIST/CIFAR/ImageNet/COCO等) |
预训练模型 | 主流CV模型(ResNet/VGG/ViT/Mask R-CNN等) |
数据增强 | 图像变换工具(裁剪/翻转/归一化等) |
工具函数 | 图像读写、视频处理、检测评估工具等 |
2. 关键组件详解
(1) 数据集加载(Datasets)
from torchvision import datasets# 示例1:加载MNIST数据集
mnist = datasets.MNIST(root='./data', # 存储路径train=True, # 训练集download=True, # 自动下载transform=transforms.ToTensor() # 转为张量
)# 示例2:加载自定义图像数据集
custom_data = datasets.ImageFolder(root='path/to/images', # 按文件夹分类transform=transforms.Compose([transforms.Resize(256),transforms.ToTensor()])
)
(2) 预训练模型(Models)
from torchvision import models# 加载模型(带预训练权重)
resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')# 修改分类层(迁移学习)
resnet50.fc = torch.nn.Linear(2048, 10) # 改为10分类# 查看所有可用模型
print(dir(models))
(3) 数据增强(Transforms)
from torchvision import transforms# 典型图像预处理流程
transform = transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪缩放transforms.RandomHorizontalFlip(), # 水平翻转transforms.ColorJitter(0.4, 0.4, 0.4), # 颜色扰动transforms.ToTensor(), # 转为张量transforms.Normalize( # 标准化mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
3. 典型应用示例
(1) 图像分类流程
import torch
from torchvision import datasets, models, transforms# 数据准备
train_transform = transforms.Compose([...])
train_data = datasets.ImageFolder('./train', transform=train_transform)# 模型加载
model = models.efficientnet_b0(pretrained=True)# 训练循环
for images, labels in dataloader:outputs = model(images)loss = criterion(outputs, labels)...
(2) 目标检测(Faster R-CNN)
from torchvision.models.detection import fasterrcnn_resnet50_fpnmodel = fasterrcnn_resnet50_fpn(pretrained=True)
# 输入格式:List[Tensor(C, H, W)], 每张图需归一化到0-1
images = [torch.rand(3, 300, 400)]
predictions = model(images)
4. 扩展功能
-
视频处理:
torchvision.io
(视频读写/抽帧) -
模型量化:支持INT8量化(
torchvision.quantization
) -
ONNX导出:
torch.onnx.export(model, ...)