欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 金融 > 【深度学习】DataLoader自定义数据集制作

【深度学习】DataLoader自定义数据集制作

2025/2/7 8:17:26 来源:https://blog.csdn.net/Father_Of_Soft/article/details/145477993  浏览:    关键词:【深度学习】DataLoader自定义数据集制作

第一步 导包

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

第二步 自定义数据集

data_dir = "./flower_data/"
train_dir = data_dir + "/train_filelist"
valid_dir = data_dir + "/val_filelist"
from torch.utils.data import Dataset,DataLoader
class FlowerDataset(Dataset):def __init__(self,root_dir,ann_file,transform=None):self.ann_file = ann_fileself.root_dir = root_dirself.img_label = self.load_annotations()self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]self.label = [label for label in list(self.img_label.values())]self.transform = transformdef __len__(self):return len(self.img)def __getitem__(self,idx):image = Image.open(self.img[idx])label = self.label[idx]if self.transform:image = self.transform(image)label = torch.from_numpy(np.array(label))return image,labeldef load_annotations(self):data_infos = {}with open(self.ann_file) as f:samples = [x.strip().split(" ") for x in f.readlines()]for filename,gt_label in samples:data_infos[filename] = np.array(gt_label,dtype=np.int64)return data_infos

注:ann_file内容格式如下

第三步 自定义transform

data_transforms = {"train":transforms.Compose([transforms.Resize(64),transforms.RandomRotation(45),transforms.CenterCrop(64),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),"valid":transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

第四步 根据自定义Dataset实例化DataLoader

①实例化Dataset

train_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/train.txt",transform=data_transforms["train"])
valid_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/val.txt",transform=data_transforms["valid"])

②实例化DataLoader

train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
val_loader = DataLoader(valid_dataset,batch_size=64,shuffle=True)

③验证图片是否加载正确

image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))


第五步 训练

①前置准备

dataloaders = {"train":train_loader,"valid":val_loader}model_name = "resnet"
feature_extract = True# 是否用GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 使用模型
model_ft = models.resnet18()
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102))# 优化器设置
optimizer_ft = optim.Adam(model_ft.parameters(),lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)
criterion = nn.CrossEntropyLoss()

②自定义模型

def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename="best.pth"):since = time.time()best_acc = 0model.to(device)val_acc_history = []train_acc_history = []train_losses = []valid_losses = []LRs = [optimizer.param_groups[0]["lr"]]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print("Epoch {}/{}".format(epoch,num_epochs-1))print("-"*10)# 训练和验证for phase in ["train","valid"]:if phase == "train":model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 遍历所有数据for inputs,labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度with torch.set_grad_enabled(phase == "train"):outputs = model(inputs)loss = criterion(outputs,labels)_,preds = torch.max(outputs,1)if phase == "train":loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds==labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好的那次模型if phase=="valid" and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {"state_dict":model.state_dict(),"best_acc":best_acc,"optimizer":optimizer.state_dict()}torch.save(state,filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)scheduler.step(epoch_loss)#学习率衰减if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果,等着一会测试model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

③训练模型

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, filename='best.pth')

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com