欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 建筑 > 【深度学习】用Pytorch完成MNIST手写数字数据集的训练和测试

【深度学习】用Pytorch完成MNIST手写数字数据集的训练和测试

2024/11/30 14:46:42 来源:https://blog.csdn.net/qq_57315305/article/details/140997720  浏览:    关键词:【深度学习】用Pytorch完成MNIST手写数字数据集的训练和测试

模型训练相关

思路:

  1. 导入数据集(对数据集转换为张量)
  2. 加载数据集(使数据集成为可以进行迭代)
  3. 搭建卷积模型
  4. 进行模型训练(每训练一轮查看一次在测试集上的准确率)
  5. 使用tensorboard进行可视化
  6. 保存训练后的模型
  7. 加载训练好的模型进行测试.

选择的模型结构
imagepng

训练模型和评估模型

在conda命令行输入
tensorboard --logdir=“tensorboard --logdir=D:\student\ai-study\02框架学习\logs\mnist”
打开可视化面板
imagepng
imagepng
在测试集上的准确率不断上升

训练和评估完整代码

"""
@author:Lunau
@file:022_mnist.py
@time:2024/08/07
@任务:使用pytorch对mnist数据集进行训练和测试
"""
import torch
import torchvision
import time
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter"""
@root:存放数据集的目录
@train:为True表示是作为训练集
@transforms:导入数据集的同时进行预处理
@download:为True表示从网络下载模型
"""
# 导入数据集 这里没有归一化
train_data = torchvision.datasets.MNIST("./dataset/MNIST", train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.MNIST("./dataset/MNIST", train=False, transform=torchvision.transforms.ToTensor(), download=True)# # 加载数据集,方便进行迭代
train_dataloader = DataLoader(dataset=train_data, batch_size=64)
test_dataloader = DataLoader(dataset=test_data, batch_size=64)
# img, target = test_data[0]
# print(img)
# print(target)# 构建卷积层
class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, padding=0, stride=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, padding=0, stride=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=320, out_features=10))def forward(self, x):x = self.model(x)return x# 创建卷积模型
model = Model()# 损失函数 交叉熵
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)# 可视化
writer = SummaryWriter("./logs/mnist")# 训练网络的参数
total_train_step = 0  # 训练次数
total_test_step = 0  # 测试次数# 训练
def train():# 训练步骤开始model.train()global total_train_stepfor data in train_dataloader:imgs, targets = dataoutputs = model(imgs)loss = loss_fn(outputs, targets)  # 计算当前损失# 优化器进行优化optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 ==0:writer.add_scalar("train_loss", loss, total_train_step) # 可视化每轮的损失print(f"训练次数:{total_train_step}, Loss:{loss}")def test():global total_test_stepmodel.eval()total_test_loss = 0total_accuracy = 0test_data_len = len(test_data)with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutput = model(imgs)loss = loss_fn(output, targets)total_test_loss +=lossaccuracy = (output.argmax(1) == targets).sum().item() # 计算出正确的次数total_accuracy+=accuracytotal_accuracy = total_accuracy / test_data_len# 在整体测试集上的损失print(f"整体测试损失Loss:{total_test_loss}")# 整体测试的正确率print(f"整体测试的正确率acc:{total_accuracy}")writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_acc", total_accuracy, total_test_step)total_test_step +=1return total_accuracyif __name__ == '__main__':epoch = 10  # 训练的轮数for i in range(epoch):print(f"第{i + 1}轮训练开始")train()test()torch.save(model, "mnist1.pth")

测试模型

测试的照片
mnist3jpgmnist5jpgmnist9jpg

测试代码

"""
@author:Lunau
@file:023_mnist_test.py
@time:2024/08/07
"""
import cv2
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import nn# 测试模型class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, padding=0, stride=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, padding=0, stride=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=320, out_features=10))def forward(self, x):x = self.model(x)return x
# 测试单张照片
def test_one_image():image_path = "./images/mnist_3.jpg"image = Image.open(image_path)print(image)image = image.convert('L')image.show()# 图片进行转换transform = torchvision.transforms.Compose([torchvision.transforms.Resize((28, 28)),torchvision.transforms.ToTensor()])image = transform(image)print(image.shape)# 加载模型 若模型是在gpu训练出来,需要在cpu上运行需要进行一个映射model = torch.load("./mnist1.pth")image = torch.reshape(image, (1, 1, 28, 28))  # 转换一下尺寸,为输入要求的尺寸# 测试model.eval()with torch.no_grad():output = model(image)print(output)print(f"手写数字是:{output.argmax(1).item()}")test_one_image()

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com