欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 资讯 > 【计算机视觉基础CV-图像分类】03-深度学习图像分类实战:鲜花数据集加载与预处理详解

【计算机视觉基础CV-图像分类】03-深度学习图像分类实战:鲜花数据集加载与预处理详解

2024/12/24 10:37:40 来源:https://blog.csdn.net/weixin_41645791/article/details/144583368  浏览:    关键词:【计算机视觉基础CV-图像分类】03-深度学习图像分类实战:鲜花数据集加载与预处理详解

本文将深入介绍鲜花分类数据集的加载与处理方式,同时详细解释代码的每一步骤并给出更丰富的实践建议和拓展思路。以实用为导向,为读者提供从数据组织、预处理、加载到可视化展示的完整过程,并为后续模型训练打下基础。


前言

在计算机视觉的深度学习实践中,数据加载和预处理是至关重要的一步。无论你是初学者,还是有一定经验的从业者,都需要深刻理解如何将原始数据转化为神经网络可接受的输入。PyTorch中的torchvision.datasetstorchvision.transforms为我们提供了极大的便利,使图像数据的加载和处理更加高效与简洁。

本文将以“鲜花分类数据集”(一个包含5种不同花卉类别的图像数据集)为例,详细讲述如何使用ImageFolder类进行数据加载,并通过transforms对图像进行预处理和数据增强。我们还会深入讨论数据集结构、训练/验证集划分、代码注释和实践建议,并给出详细说明。


数据集简介与结构

本例使用的鲜花分类数据集共包含5种花:雏菊(daisy)、蒲公英(dandelion)、玫瑰(roses)、向日葵(sunflowers)和郁金香(tulips)。数据量约为:

  • 训练集(train):3306张图像

  • 验证集(val):364张图像

数据已按类别分好目录,每个类别对应一个文件夹,文件夹中存放若干图片文件。结构示意如下:

dataset/flower_datas/├─ train/│   ├─ daisy/       # 雏菊类图像若干张│   ├─ dandelion/   # 蒲公英类图像若干张│   ├─ roses/       # 玫瑰类图像若干张│   ├─ sunflowers/   # 向日葵类图像若干张│   └─ tulips/       # 郁金香类图像若干张└─ val/├─ daisy/├─ dandelion/├─ roses/├─ sunflowers/└─ tulips/

这种目录结构非常适合ImageFolder数据集类,它会根据子文件夹的名称自动分配类别标签,从0开始编号。例如:

  • daisy -> 0

  • dandelion -> 1

  • roses -> 2

  • sunflowers -> 3

  • tulips -> 4

这样无需手动编码类别映射,简化了流程。


ImageFolder和transform

ImageFolder简介

ImageFoldertorchvision.datasets中的一个实用数据类,它假设数据按如下规则组织:

  • root/class_x/xxx.png

  • root/class_x/xxy.png

  • root/class_y/xxz.png

  • ...

其中class_xclass_y是类名(字符串),ImageFolder会根据这些类名自动生成类别索引。加载后,每个样本是一个(image, label)二元组,image通常会通过transform转换为Tensorlabel为整数索引。


transforms的数据预处理功能

torchvision.transforms提供多种图像处理方法,用来改变图像格式、尺寸、颜色空间和进行数据增强。例如:

  • ToTensor():将PIL图像或Numpy数组转换为(C,H,W)格式的张量,并将像素值归一化到[0,1]之间。

  • Resize((224,224)):将图像缩放到224x224大小,这通常是预训练模型如ResNet、VGG的标准输入尺寸。

  • RandomHorizontalFlip():随机水平翻转图像,用于数据增强,提高模型对翻转不敏感。

  • Normalize(mean, std):对图像的每个通道进行归一化,使训练更稳定。

你可以根据需求灵活组合多个变换操作,使用transforms.Compose将其串联成流水线。


加载鲜花分类数据集的示例代码

下面的代码示例中,我将详细注释每个步骤,为读者提供清晰的思路。该示例以最基本的ToTensor和Resize为主,读者可按需添加更多transform。

import torch
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt# 数据集存放路径,根据实际情况修改
flowers_train_path = '../01.图像分类/dataset/flower_datas/train/'
flowers_val_path = '../01.图像分类/dataset/flower_datas/val/'# 定义数据预处理
# 这里的transforms主要包括:
# 1. ToTensor():将PIL图片或numpy数组转为Tensor,并将像素值归一化到[0,1]区间。
# 2. Resize((224,224)):将所有图片大小统一为224x224,以匹配后续卷积神经网络的输入要求。
# 对于实际训练,更建议加入数据增强手段(如随机裁剪、翻转、归一化等),
# 但本例先展示基本流程。
dataset_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((224,224))
])# 使用ImageFolder加载训练集和验证集
# ImageFolder会扫描指定目录下的子文件夹,并以子文件夹名称作为类别。
flowers_train = ImageFolder(root=flowers_train_path, transform=dataset_transform)
flowers_val = ImageFolder(root=flowers_val_path, transform=dataset_transform)# 打印样本数量
print("训练集样本数:", len(flowers_train))
print("验证集样本数:", len(flowers_val))# flowers_train.classes属性包含类别名称列表,如['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
print("类别名称列表:", flowers_train.classes)# 获取单个样本进行查看
# __getitem__(index)返回(img, label),img是Tensor,label是int
sample_index = 3000
sample_img, sample_label = flowers_train[sample_index]print("样本索引:", sample_index)
print("类别标签索引:", sample_label, "类别名称:", flowers_train.classes[sample_label])
print("图像Tensor尺寸:", sample_img.shape)  # 期望为[3,224,224]# 可视化图像
# Matplotlib的imshow要求图像为(H,W,C),而Tensor是(C,H,W),需要permute调整维度顺序。
plt.imshow(sample_img.permute(1,2,0))
plt.title(flowers_train.classes[sample_label])
plt.show()

代码输出: 


关于训练集、验证集和测试集的说明

本数据集中已提前将数据分为trainval两个目录:

  • train/:训练集,用于模型训练过程中反向传播和参数更新。

  • val/:验证集,用于在训练中间进行性能评估,不参与参数更新,仅用于选择超参数或判断训练是否过拟合。

有些数据集还会提供test/测试集,用于最终评估模型在未知数据上的表现,但本例中未提供,如有需要可自行分割数据或从其他来源获取。


DataLoader的引入

仅有ImageFolder还不够,为了在训练时批量读取数据并进行迭代,我们通常会将数据集对象传入DataLoader中。

DataLoader的作用是:

  • 按指定的batch_size从Dataset中抽取样本构成mini-batch。

  • 可设置shuffle=True来随机打乱样本顺序,防止模型记住样本顺序。

  • 使用num_workers参数并行加速数据加载。

示例(可选代码):

from torch.utils.data import DataLoaderbatch_size = 32
# 定义训练集和验证集的DataLoader
train_loader = DataLoader(flowers_train, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(flowers_val, batch_size=batch_size, shuffle=False, num_workers=2)# 测试一下加载结果
images, labels = next(iter(train_loader))
print("一个batch的图像尺寸:", images.shape)  # [batch_size, 3, 224, 224]
print("对应的标签:", labels)  # 张量形式,如tensor([0, 1, 3, ...])


有了DataLoader,我们在训练模型时,就可以轻松迭代数据:

for epoch in range(1):for batch_images, batch_labels in train_loader:# 在这里将batch_images, batch_labels输入模型进行训练print("一个batch的图像尺寸:", batch_images.shape)  # [batch_size, 3, 224, 224]print("对应的标签:", batch_labels)  # 张量形式,如tensor([0, 1, 3, ...])passbreak


我们可以打印一下第一个batch 和最后一个batch的标签

batch_count = 0
first_batch_images, first_batch_labels = None, None
last_batch_images, last_batch_labels = None, Nonefor epoch in range(1):for batch_images, batch_labels in train_loader:batch_count += 1# 保存第一个batchif batch_count == 1:first_batch_images, first_batch_labels = batch_images, batch_labelsprint("第一个batch的图像尺寸:", batch_images.shape)print("第一个batch的标签:", batch_labels)# 每次循环都会更新last_batchlast_batch_images, last_batch_labels = batch_images, batch_labelsbreak  # 只进行一次epoch的训练,移除这行会进行多个epoch的训练# 打印最后一个batch
print("最后一个batch的图像尺寸:", last_batch_images.shape)
print("最后一个batch的标签:", last_batch_labels)# 打印总共的batch数量
print("总共的batch数量:", batch_count)


数据增强策略的拓展

实际训练中,为提高模型的泛化能力,我们常加入数据增强操作。这些操作对训练集图像进行随机变换,如随机剪裁、翻转、颜色抖动、归一化等。这样模型不会过度记忆特定图像的像素分布,而会学习更有泛化性的特征。

一个常用的transform示例:

# 定义训练集的图像预处理流程
train_transform = transforms.Compose([# 随机裁剪并缩放图像到224x224的尺寸,裁剪的区域大小是随机的transforms.RandomResizedCrop(224),  # 随机进行水平翻转,用于数据增强,提升模型的泛化能力transforms.RandomHorizontalFlip(),# 将图像转换为Tensor类型,PyTorch要求输入为Tensor格式transforms.ToTensor(),# 进行图像的标准化处理。根据ImageNet数据集的均值和标准差进行归一化,# 使得不同的通道(RGB)具有相同的尺度,便于训练。transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])# 定义验证集的图像预处理流程
val_transform = transforms.Compose([# 将图像的最短边缩放到256像素,保持长宽比例不变transforms.Resize(256),  # 从缩放后的图像中进行中心裁剪,裁剪出224x224的区域,这样图像的尺寸就一致了transforms.CenterCrop(224),# 将图像转换为Tensor类型,PyTorch要求输入为Tensor格式transforms.ToTensor(),# 进行图像的标准化处理。根据ImageNet数据集的均值和标准差进行归一化,# 使得不同的通道(RGB)具有相同的尺度,便于训练。transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])# 使用定义的transform对训练集和验证集进行图像预处理
# flowers_train_path和flowers_val_path是训练集和验证集图像所在的路径
flowers_train = ImageFolder(flowers_train_path, transform=train_transform)  # 训练集
flowers_val = ImageFolder(flowers_val_path, transform=val_transform)  # 验证集

在此示例中,Normalize的参数是使用ImageNet数据集的均值和标准差,这在使用ImageNet预训练模型时是常规操作。对于自定义数据集,你也可以先统计本数据集的均值和方差,再进行归一化。


我们可以打印一下变化前后的图像区别

import os
import random
import numpy as np  # 需要导入numpy
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision.datasets import ImageFolder# 定义训练集的图像预处理流程
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 定义图像数据集路径
train_image_folder = '/Users/coyi/PycharmProjects/coyi_pythonProject/01.图像分类/dataset/flower_datas/train/'# 使用ImageFolder加载数据集
dataset = ImageFolder(train_image_folder, transform=None)# 随机选取一张图片
random_idx = random.randint(0, len(dataset) - 1)
image, label = dataset[random_idx]# 显示原始图像
plt.figure(figsize=(5,5))
plt.title("Original Image")
plt.imshow(image)
plt.axis('off')  # 不显示坐标轴
plt.show()# 应用train_transform变换
transformed_image = train_transform(image)# 反标准化(Undo normalization)以恢复图片的原始视觉效果,因为训练的时候需要标准化
inv_normalize = transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1/0.229, 1/0.224, 1/0.225])
unnormalized_image = inv_normalize(transformed_image)# 将Tensor转回PIL图像进行显示
unnormalized_image = unnormalized_image.permute(1, 2, 0).numpy()  # 转换为HWC格式
unnormalized_image = np.clip(unnormalized_image, 0, 1)  # 限制值在[0, 1]之间,以符合视觉输出# 显示变换后的图像
plt.figure(figsize=(5,5))
plt.title("Transformed Image")
plt.imshow(unnormalized_image)
plt.axis('off')  # 不显示坐标轴
plt.show()

输出: 

备注: 为了显示图片,我对处理后的图片进行了反标准化,实际上训练的时候是不需要反标准化的


为什么要反标准化?

标准化是一个常见的预处理步骤,目的是让模型训练时更稳定,通常是将像素值转换到均值为0、标准差为1的范围。这可以帮助模型更好地收敛,并且消除不同通道(例如RGB)的尺度差异。

然而,标准化后的图像不适合直接用于可视化,因为它们的像素值已经不在[0, 1]的范围内,可能会变成负数或大于1。反标准化的目的是恢复图像的原始视觉效果,让它们的像素值回到原始的视觉范围。

不反标准化可以吗?

在可视化时不反标准化是可以的,但你会看到经过标准化后的图像没有直观的可视化效果,因为图像的像素值会偏离 [0, 1] 的可视化范围。这会导致显示的图像看起来可能是“失真”的,例如图像会变得非常暗、非常亮,或者有一些不自然的颜色。

简而言之:

反标准化是为了恢复图像的原始视觉效果,使得图像显示更符合人类的感知。

• **np.clip()**是为了确保图像的像素值在[0, 1]范围内,符合图像显示的要求。

示例:

假设标准化之后,你得到了一个像素值为 -0.5 或 1.5 的图像像素。这时,如果不进行 np.clip(),直接用 matplotlib 显示,可能会看到图像出现异常的颜色或显示不出来。而通过 np.clip(),将这些像素值限制在[0, 1]的范围内,可以确保图像能正确显示。


类别分布与标签可解释性

flowers_train.classesflowers_val.classes可以查看类名列表。例如:

这意味着模型预测结果中的label=0代表daisy,label=1代表dandelion,以此类推。当我们预测模型输出为label=3时,就可以将其解释为sunflowers。这种可读性非常有助于后期分析和调试。

如果想查看具体每类样本数量,可手动统计,例如:

 

通过查看类别分布,我们可了解数据是否偏斜(某些类样本过多或过少),从而采取相应措施(如类均衡采样、权重平衡等)。


实战建议和下一步计划

  1. 数据准备完成后做什么? 通常下一步就是定义和加载模型(如预训练的ResNet18),然后编写训练循环对模型进行微调或从头训练。在训练循环中,train_loader提供批数据,val_loader则用于评估模型在验证集上的表现。

  2. 调试DataLoader是否正确工作: 在正式训练前,尝试可视化几个batch的数据样本,确保图像大小、颜色正确,标签映射无误。如果出现图像显示不正确或标签偏移,及时检查目录结构和transform流程。

  3. 善用数据增强: 当验证集精度停滞不前或出现过拟合时,尝试加入更多数据增强手段(如RandomRotationColorJitterRandomGrayscale等)提升泛化性能。

  4. 硬件加速: 在加载大规模数据时,合理增加num_workers可以提高数据读取速度(依赖操作系统和硬件条件)。同时,如果是分布式训练,也需考虑分布式Sampler和合适的数据划分策略。

  5. 定制Dataset: 如果你的数据不遵循ImageFolder的结构,也可以自行定义Dataset类,通过实现__len____getitem__方法来自定义数据加载流程。但对像本例这样已按类分文件夹的数据集,ImageFolder无疑是最简单高效的方案。


小结

在本文中,我们从零出发,详细介绍了如何使用PyTorch的ImageFoldertransforms加载和预处理鲜花分类数据集。主要点包括:

  • 数据集组织结构:子文件夹命名为类名,便于ImageFolder自动识别类别。

  • 使用transforms对图像进行ToTensor和Resize等变换,以满足神经网络输入要求。

  • 通过可视化样本和打印类别信息确认数据加载的正确性。

  • 引入DataLoader批量采样和迭代数据,为后续训练循环奠定基础。

  • 展望数据增强、Normalize以及预训练模型迁移学习等实战技巧。

数据加载与预处理是深度学习项目不可或缺的步骤。掌握这些技能,能够让你在模型开发和实验中更加得心应手。未来你可以尝试更多高级技巧,如自定义transforms、对数据集进行统计分析、探索更复杂的增强策略和分布式数据加载方法。

达成这些基础后,你就可以开始定义模型(如使用torchvision.models.resnet18(pretrained=True)加载预训练模型)、设置损失函数(如CrossEntropyLoss)、选择优化器(如Adam或SGD),并在训练循环中快速迭代提升模型性能。

希望本文介绍,能为你对CV数据加载与预处理的理解添砖加瓦,帮助你在图像分类任务中迈出稳健的一步。


如果你遇到了什么问题,或者想了解某些方面的知识,欢迎在评论区留言

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com