欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 能源 > PyTorch之torchvision

PyTorch之torchvision

2025/4/2 16:48:04 来源:https://blog.csdn.net/byxdaz/article/details/146899308  浏览:    关键词:PyTorch之torchvision

PyTorch torchvision 是 PyTorch 生态中专门用于计算机视觉任务的库,提供了以下核心功能:

1. 核心功能概览

功能类别主要内容
数据集内置经典数据集(MNIST/CIFAR/ImageNet/COCO等)
预训练模型主流CV模型(ResNet/VGG/ViT/Mask R-CNN等)
数据增强图像变换工具(裁剪/翻转/归一化等)
工具函数图像读写、视频处理、检测评估工具等

2. 关键组件详解

(1) 数据集加载(Datasets)
from torchvision import datasets# 示例1:加载MNIST数据集
mnist = datasets.MNIST(root='./data',          # 存储路径train=True,            # 训练集download=True,         # 自动下载transform=transforms.ToTensor()  # 转为张量
)# 示例2:加载自定义图像数据集
custom_data = datasets.ImageFolder(root='path/to/images',  # 按文件夹分类transform=transforms.Compose([transforms.Resize(256),transforms.ToTensor()])
)
(2) 预训练模型(Models)
from torchvision import models# 加载模型(带预训练权重)
resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')# 修改分类层(迁移学习)
resnet50.fc = torch.nn.Linear(2048, 10)  # 改为10分类# 查看所有可用模型
print(dir(models))
(3) 数据增强(Transforms)
from torchvision import transforms# 典型图像预处理流程
transform = transforms.Compose([transforms.RandomResizedCrop(224),      # 随机裁剪缩放transforms.RandomHorizontalFlip(),      # 水平翻转transforms.ColorJitter(0.4, 0.4, 0.4), # 颜色扰动transforms.ToTensor(),                  # 转为张量transforms.Normalize(                   # 标准化mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

3. 典型应用示例

(1) 图像分类流程
import torch
from torchvision import datasets, models, transforms# 数据准备
train_transform = transforms.Compose([...])
train_data = datasets.ImageFolder('./train', transform=train_transform)# 模型加载
model = models.efficientnet_b0(pretrained=True)# 训练循环
for images, labels in dataloader:outputs = model(images)loss = criterion(outputs, labels)...
(2) 目标检测(Faster R-CNN)
from torchvision.models.detection import fasterrcnn_resnet50_fpnmodel = fasterrcnn_resnet50_fpn(pretrained=True)
# 输入格式:List[Tensor(C, H, W)], 每张图需归一化到0-1
images = [torch.rand(3, 300, 400)]  
predictions = model(images)

4. 扩展功能

  • 视频处理torchvision.io(视频读写/抽帧)

  • 模型量化:支持INT8量化(torchvision.quantization

  • ONNX导出torch.onnx.export(model, ...)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词