欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 产业 > 深度学习入门-06

深度学习入门-06

2024/10/27 14:18:30 来源:https://blog.csdn.net/qhqlnannan/article/details/141565545  浏览:    关键词:深度学习入门-06

基于小土堆学习

如何把数据集和Transform结合袭来

https://pytorch.org/
上述网址是pytorch的官网
在这里插入图片描述
这里会有详细的使用介绍
在这里插入图片描述
下述是对图像处理的专门文档
在这里插入图片描述
单击后可查看详细介绍
在这里插入图片描述
选择CIFAR10数据集

在这里插入图片描述
CIFAR10 数据集是一个广泛使用的计算机视觉数据集,包含了60000张32x32的彩色图像,这些图像分为10个类别,每个类别6000张图像。这些数据集被分为50000张训练图像和10000张测试图像。

参数解释如下:

  • -root(str或pathlib.Path):数据集的根目录,其中应存在cifar-10-batches-py目录,或者如果设置download为True,则会在此目录下下载并保存数据集。
  • -train(bool,可选):如果为True,则从训练集创建数据集;否则,从测试集创建数据集。
  • -transform(callable,可选):一个函数/变换,它接受一个PIL图像并返回变换后的版本。例如,transforms.RandomCrop。
  • -target_transform(callable,可选):一个函数/变换,它接受目标(标签)并对其进行变换。
  • -download(bool,可选):如果为True,则从互联网下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。

在这里插入图片描述

import torchvision
train_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=False,download=True)
#下载训练集和测试机print(test_set[0])#获取数据类型
print("test_set.classes",test_set.classes)#获取分类目标img,target = test_set[0]
print("img:",img)
print("target:",target)
#输出结果target: 3,对应类别0,1,2,3;也就是当前类别是猫cat
print("test_set.classesp[target]当前类型为",test_set.classes[target])
img.show()

运行结果为

C:\Anaconda3\envs\pytorch_test\python.exe H:\Python\Test\P10_dataset_transforms.py 
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x21F676692D0>, 3)
test_set.classes ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img: <PIL.Image.Image image mode=RGB size=32x32 at 0x21F6A68E560>
target: 3
test_set.classesp[target]当前类型为 cat进程已结束,退出代码0

数据集全部转换为tensor数据类型

import torchvisiondataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()    ])train_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=False,transform=dataset_transform,download=True)
#transform=dataset_transform,将数据集中的每个数据都转换为Tensor格式
#下载训练集和测试机print(test_set[0])#获取数据类型
print("test_set.classes",test_set.classes)#获取分类目标img,target = test_set[0]
print("img:",img)
print("target:",target)
#输出结果target: 3,对应类别0,1,2,3;也就是当前类别是猫cat
print("test_set.classesp[target]当前类型为",test_set.classes[target])

输出结果为:

C:\Anaconda3\envs\pytorch_test\python.exe H:\Python\Test\P10_dataset_transforms.py 
Files already downloaded and verified
Files already downloaded and verified
(tensor([[[0.6196, 0.6235, 0.6471,  ..., 0.5373, 0.4941, 0.4549],[0.5961, 0.5922, 0.6235,  ..., 0.5333, 0.4902, 0.4667],[0.5922, 0.5922, 0.6196,  ..., 0.5451, 0.5098, 0.4706],...,[0.2667, 0.1647, 0.1216,  ..., 0.1490, 0.0510, 0.1569],[0.2392, 0.1922, 0.1373,  ..., 0.1020, 0.1137, 0.0784],[0.2118, 0.2196, 0.1765,  ..., 0.0941, 0.1333, 0.0824]],[[0.4392, 0.4353, 0.4549,  ..., 0.3725, 0.3569, 0.3333],[0.4392, 0.4314, 0.4471,  ..., 0.3725, 0.3569, 0.3451],[0.4314, 0.4275, 0.4353,  ..., 0.3843, 0.3725, 0.3490],...,[0.4863, 0.3922, 0.3451,  ..., 0.3804, 0.2510, 0.3333],[0.4549, 0.4000, 0.3333,  ..., 0.3216, 0.3216, 0.2510],[0.4196, 0.4118, 0.3490,  ..., 0.3020, 0.3294, 0.2627]],[[0.1922, 0.1843, 0.2000,  ..., 0.1412, 0.1412, 0.1294],[0.2000, 0.1569, 0.1765,  ..., 0.1216, 0.1255, 0.1333],[0.1843, 0.1294, 0.1412,  ..., 0.1333, 0.1333, 0.1294],...,[0.6941, 0.5804, 0.5373,  ..., 0.5725, 0.4235, 0.4980],[0.6588, 0.5804, 0.5176,  ..., 0.5098, 0.4941, 0.4196],[0.6275, 0.5843, 0.5176,  ..., 0.4863, 0.5059, 0.4314]]]), 3)
test_set.classes ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img: tensor([[[0.6196, 0.6235, 0.6471,  ..., 0.5373, 0.4941, 0.4549],[0.5961, 0.5922, 0.6235,  ..., 0.5333, 0.4902, 0.4667],[0.5922, 0.5922, 0.6196,  ..., 0.5451, 0.5098, 0.4706],...,[0.2667, 0.1647, 0.1216,  ..., 0.1490, 0.0510, 0.1569],[0.2392, 0.1922, 0.1373,  ..., 0.1020, 0.1137, 0.0784],[0.2118, 0.2196, 0.1765,  ..., 0.0941, 0.1333, 0.0824]],[[0.4392, 0.4353, 0.4549,  ..., 0.3725, 0.3569, 0.3333],[0.4392, 0.4314, 0.4471,  ..., 0.3725, 0.3569, 0.3451],[0.4314, 0.4275, 0.4353,  ..., 0.3843, 0.3725, 0.3490],...,[0.4863, 0.3922, 0.3451,  ..., 0.3804, 0.2510, 0.3333],[0.4549, 0.4000, 0.3333,  ..., 0.3216, 0.3216, 0.2510],[0.4196, 0.4118, 0.3490,  ..., 0.3020, 0.3294, 0.2627]],[[0.1922, 0.1843, 0.2000,  ..., 0.1412, 0.1412, 0.1294],[0.2000, 0.1569, 0.1765,  ..., 0.1216, 0.1255, 0.1333],[0.1843, 0.1294, 0.1412,  ..., 0.1333, 0.1333, 0.1294],...,[0.6941, 0.5804, 0.5373,  ..., 0.5725, 0.4235, 0.4980],[0.6588, 0.5804, 0.5176,  ..., 0.5098, 0.4941, 0.4196],[0.6275, 0.5843, 0.5176,  ..., 0.4863, 0.5059, 0.4314]]])
target: 3
test_set.classesp[target]当前类型为 cat进程已结束,退出代码0

继续用Tensorboard进行图片的显示:显示前20张图片

import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()    ])train_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR",train=False,transform=dataset_transform,download=True)
#transform=dataset_transform,将数据集中的每个数据都转换为Tensor格式
#下载训练集和测试机# print(test_set[0])#获取数据类型
# print("test_set.classes",test_set.classes)#获取分类目标
#
# img,target = test_set[0]
# print("img:",img)
# print("target:",target)
# #输出结果target: 3,对应类别0,1,2,3;也就是当前类别是猫cat
# print("test_set.classesp[target]当前类型为",test_set.classes[target])
write = SummaryWriter("logs")
for i in range(20):img, target = test_set[i]write.add_image("img", img, i)
write.close()

结果为:

C:\Anaconda3\envs\pytorch_test\python.exe H:\Python\Test\P10_dataset_transforms.py 
Files already downloaded and verified
Files already downloaded and verified进程已结束,退出代码0

local的结果

**(pytorch_test) PS H:\Python\Test> tensorboard --logdir logs --port=6007 
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.17.1 at http://localhost:6007/ (Press CTRL+C to quit)
**

在这里插入图片描述
拖动可以查看20张图片

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com