欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 文化 > PyTorch深度学习模型训练流程:(一、分类)

PyTorch深度学习模型训练流程:(一、分类)

2024/10/24 2:00:18 来源:https://blog.csdn.net/moyao_miao/article/details/141466047  浏览:    关键词:PyTorch深度学习模型训练流程:(一、分类)

自己写了个封装PyTorch深度学习训练流程的函数,实现了根据输入参数训练模型并可视化训练过程的功能,可以方便快捷地检验一个模型的效果,有助于提高选择模型架构、优化超参数等工作的效率。发出来供大家参考,如有不足之处,欢迎批评讨论。

分类是人工智能的一个非常重要的应用,这篇文章分享的函数适用于实现分类的深度学习模型,包括以下功能:

  1. 根据输入的数据集、模型、优化器、损失函数等参数训练一个分类模型;
  2. 使用visdom可视化训练过程,实时输出精确度曲线、损失曲线、混淆矩阵和ROC曲线;
  3. 支持二分类和多分类;
  4. 输入数据集支持形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型,可以方便灵活地调用torch内置数据集或自定义数据集;
  5. 支持使用GPU加速深度学习模型的训练。

废话不多说,先来看下输出效果:

二分类
多分类

 深度学习的完整流程通常包括以下几个步骤:

  1. 收集数据
  2. 数据预处理
  3. 选择模型
  4. 训练模型
  5. 评估模型
  6. 超参数调优
  7. 测试模型

本函数封装了训练模型和评估模型的步骤,包括:

  1. 若数据集为(X,y)形式则分离训练集和测试集(测试集占20%),数据标准化,封装训练集和测试集;
  2. 将训练集和测试集设置为加载器;
  3. 遍历训练集加载器,计算每一批次的输出和损失,并反向传播更新神经网络参数;
  4. 每迭代100次评估一下模型,用测试集数据计算并画出精确度曲线、损失曲线、混淆矩阵和ROC曲线。

代码如下:

from functools import partial
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, r2_scoreimport torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from visdom import Visdomfrom typing import Union, Optional
from sklearn.base import TransformerMixin
from torch.optim.optimizer import Optimizerdef classify(data: tuple[Union[np.ndarray, Dataset], Union[np.ndarray, Dataset]],model: nn.Module,optimizer: Optimizer,criterion: nn.Module,scaler: Optional[TransformerMixin] = None,batch_size: int = 64,epochs: int = 10,device: Optional[torch.device] = None
) -> nn.Module:"""分类任务的训练函数。:param data: 形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型:param model: 分类模型:param optimizer: 优化器:param criterion: 损失函数:param scaler: 数据标准化器:param batch_size: 批大小:param epochs: 训练轮数:param device: 训练设备:return: 训练好的分类模型"""if isinstance(data[0], np.ndarray):X, y = data# 处理类别classes = np.unique(y)classes_str = [str(i) for i in classes]num_classes = len(classes)# 分离训练集和测试集,指定随机种子以便复现X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化if scaler is not None:X_train = scaler.fit_transform(X_train)X_test = scaler.transform(X_test)# 转换为tensorX_train = torch.from_numpy(X_train.astype(np.float32))X_test = torch.from_numpy(X_test.astype(np.float32))y_train = torch.from_numpy(y_train.astype(np.int64))y_test = torch.from_numpy(y_test.astype(np.int64))# 将X和y封装成TensorDatasettrain_dataset = TensorDataset(X_train, y_train)test_dataset = TensorDataset(X_test, y_test)elif isinstance(data[0], Dataset):train_dataset, test_dataset = dataclasses = list(train_dataset.class_to_idx.values())classes_str = train_dataset.classesnum_classes = len(classes)else:raise ValueError('Unsupported data type')train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)model.to(device)vis = Visdom()# 训练模型for epoch in range(epochs):for step, (batch_x_train, batch_y_train) in enumerate(train_loader):batch_x_train = batch_x_train.to(device)batch_y_train = batch_y_train.to(device)# 前向传播output = model(batch_x_train)loss = criterion(output, batch_y_train)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()niter = epoch * len(train_loader) + step + 1  # 计算迭代次数if niter % 100 == 0:# 评估模型model.eval()with torch.no_grad():eval_dict = {'test_loss': [],'test_acc': [],'test_cm': [],'test_roc': [],}for batch_x_test, batch_y_test in test_loader:batch_x_test = batch_x_test.to(device)batch_y_test = batch_y_test.to(device)test_output = model(batch_x_test)predicted = torch.argmax(test_output, 1)test_predicted_tuple = (batch_y_test.numpy(), predicted.numpy())# 计算并记录损失、精确度、混淆矩阵、ROC曲线eval_dict['test_loss'].append(criterion(test_output, batch_y_test))eval_dict['test_acc'].append(accuracy_score(*test_predicted_tuple))eval_dict['test_cm'].append(confusion_matrix(*test_predicted_tuple, labels=classes))if num_classes == 2:# eval_dict['test_roc']形状为(len,(fpr,tpr),3)eval_dict['test_roc'].append(roc_curve(*test_predicted_tuple)[:2])else:# 多分类ROC曲线需要one-hot编码y_test_one_hot, predicted_one_hot = map(partial(label_binarize, classes=classes), test_predicted_tuple)fpr_list = []tpr_list = []for i in range(num_classes):fpr, tpr, _ = roc_curve(y_test_one_hot[:, i], predicted_one_hot[:, i])# 无(fpr,tpr)数据点时,插值填充(0,0)数据点if len(fpr) != 3:fpr = np.insert(fpr, 0, 0)tpr = np.insert(tpr, 0, 0)fpr_list.append(fpr)tpr_list.append(tpr)# eval_dict['test_roc']形状为(len,(fpr,tpr),num_classes,3)eval_dict['test_roc'].append((fpr_list, tpr_list))# 画出损失曲线vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.stack((loss, torch.mean(torch.tensor(eval_dict['test_loss'])))).unsqueeze(0),win='loss',update='append',opts=dict(title='Loss', legend=['train_loss', 'test_loss']),)# 画出精确度曲线train_acc = accuracy_score(batch_y_train.numpy(), torch.argmax(output, 1).numpy())vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.tensor((train_acc, np.mean(eval_dict['test_acc']))).unsqueeze(0),win='accuracy',update='append',opts=dict(title='Accuracy', legend=['train_acc', 'test_acc'], ytickmin=0, ytickmax=1),)# 画出混淆矩阵vis.heatmap(X=np.add.reduce(eval_dict['test_cm']),win='confusion_matrix',opts=dict(title='Confusion Matrix', columnnames=classes_str, rownames=classes_str),)# 画出ROC曲线test_roc_arr = np.array(eval_dict['test_roc'])zeros_df = pd.DataFrame({'fpr': [0], 'tpr': [0]})  # 用于填充的(0,0)数据点ones_df = pd.DataFrame({'fpr': [1], 'tpr': [1]})  # 用于填充的(1,1)数据点if num_classes == 2:plot_arr = test_roc_arr[:, :, 1]  # 提取(fpr,tpr)数据点,形状为(len,(fpr,tpr))cats = pd.qcut(plot_arr[:, 0], q=10, labels=False, duplicates='drop')  # 按fpr大小分成10个数据一样多的区间groups = pd.DataFrame(plot_arr, columns=['fpr', 'tpr']).groupby(cats).mean()  # 计算每个区间的平均值,形状为(10,(fpr,tpr))plot_df = pd.concat([zeros_df, groups, ones_df])  # 头添加(0,0),尾添加(1,1)数据点,形状为(12,(fpr,tpr))x = plot_df['fpr']Y = plot_df['tpr']else:plot_df_list = []plot_arr = test_roc_arr[:, :, :, 1].swapaxes(1, 2)  # 提取(fpr,tpr)数据点并换轴,形状为(len,num_classes,(fpr,tpr))for i in range(num_classes):cats = pd.qcut(plot_arr[:, i, 0], q=10, labels=False, duplicates='drop')groups = pd.DataFrame(plot_arr[:, i, :], columns=['fpr', 'tpr']).groupby(cats).mean()  # 形状为(10,(fpr,tpr))plot_df = pd.concat([zeros_df, groups, ones_df])  # 形状为(12,(fpr,tpr))add_num = 12 - len(plot_df)# 长度不足12时,插值填充(0,0)数据点if add_num > 0:plot_df = pd.concat([zeros_df] * add_num + [plot_df])plot_df_list.append(plot_df)  # 形状为(num_classes,12,(fpr,tpr))plot_arr_sum = np.stack(plot_df_list, axis=1)  # 形状为(12,num_classes,(fpr,tpr))x = plot_arr_sum[:, :, 0]Y = plot_arr_sum[:, :, 1]vis.line(X=x,Y=Y,win='ROC',opts=dict(title='ROC', legend=classes_str),)return model

注意:代码运行前要先在命令行输入python -m visdom.server,在浏览器中打开提供的链接:

 成功运行的效果如下:

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com