欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 明星 > 算法面试准备 - 手撕系列第七期 - MLP(利用FashionMNIST数据集)

算法面试准备 - 手撕系列第七期 - MLP(利用FashionMNIST数据集)

2025/2/21 17:43:29 来源:https://blog.csdn.net/m0_62030579/article/details/145177482  浏览:    关键词:算法面试准备 - 手撕系列第七期 - MLP(利用FashionMNIST数据集)

算法面试准备 - 手撕系列第七期 - MLP(利用FashionMNIST数据集)

目录

  • 算法面试准备 - 手撕系列第七期 - MLP(利用FashionMNIST数据集)
  • FashionMINIST 图像分类原理解析
    • 1. 全连接的原理图
    • 2. 背景介绍
    • 3.引入相关库函数
    • 4. 数据预处理
    • 5. 模型设计
    • 6. 初始化网络,损失函数与优化器
    • 7. 训练与测试
      • 7.1 训练过程
      • 7.2 测试过程
    • 8. 结论
  • 参考

FashionMINIST 图像分类原理解析

本文将详细解析基于 PyTorch 实现的 FashionMNIST 图像分类的原理及代码结构,适用于初学者理解深度学习图像分类任务的完整流程。


1. 全连接的原理图

在这里插入图片描述

全连接的原理图

2. 背景介绍

FashionMNIST 数据集是一个用于替代经典 MNIST 数据集的基准数据集。它包含 10 类不同的服装图像,每张图像大小为 28x28 像素,灰度图像。

类别编号类别名称
0T 恤/上衣
1裤子
2套衫
3连衣裙
4外套
5凉鞋
6衬衫
7运动鞋
8
9短靴

3.引入相关库函数

# 该模块主要是为了实现FashionMinist图像分类。图像的大小为(28,28),类别为同样为10类
'''
# Part1引入相关的库函数
'''
import torch
from torch import nn
from torch.utils import dataimport torchvision
from torchvision import transforms

4. 数据预处理

图像分类任务的第一步是数据加载和预处理。在代码中,通过 torchvision.datasets.FashionMNIST 加载数据集,并对图像进行以下处理:

  1. 转换为张量:使用 transforms.ToTensor() 将图像转换为 PyTorch 张量格式,并将像素值归一化到 [0, 1]。
  2. 数据分割:划分为训练集和测试集,使用 DataLoader 封装成可迭代的数据加载器。
'''
# Part2 数据集的加载,和dataloader的初始化
'''transforms_action = [transforms.ToTensor()]
transforms_action = transforms.Compose(transforms_action)Minist_train = torchvision.datasets.FashionMNIST(root='Minist', train=True, transform=transforms_action, download=True)
Minist_test = torchvision.datasets.FashionMNIST(root='Minist', train=False, transform=transforms_action, download=True)train_dataloader = data.DataLoader(dataset=Minist_train, batch_size=15, shuffle=True)
test_dataloader = data.DataLoader(dataset=Minist_test, batch_size=15, shuffle=True)

5. 模型设计

本例中使用的是多层感知机(MLP)模型,它由以下组件构成:

  1. 输入层:将输入的 28x28 图像展开为一维向量(大小为 784)。
  2. 隐藏层:一层全连接层,输出大小为 128,激活函数使用 ReLU。
  3. 输出层:全连接层输出大小为 10,对应 10 个类别。
class MLP(nn.Module):def __init__(self, image_size, num_kind,latent=128):super(MLP, self).__init__()self.Linear1 = nn.Linear(image_size, latent, bias=False)self.relu1 = nn.ReLU()# 因为最后一层常用于一些其他操作,进行信息传递,一般就不添加非线性的激活函数了,一般都是不需要的。self.Linear2 = nn.Linear(latent, num_kind, bias=False)# 计算CrossEntropyLoss时候会自动计算softmax所以不需要。# self.softmax = nn.Softmax(dim=-1)def forward(self, x):  # (batch,1,28,28)x = x.reshape(x.size()[0], -1)x = self.Linear1(x)x = self.relu1(x)x = self.Linear2(x)# x = self.softmax(x)return x  # (batch,10)

6. 初始化网络,损失函数与优化器

分类任务中使用交叉熵损失(CrossEntropyLoss),其原理是衡量预测类别分布与真实类别分布之间的差异。

优化器选择随机梯度下降(SGD),其更新公式为:
θ t + 1 = θ t − η ∇ L ( θ t ) \theta_{t+1} = \theta_t - \eta \nabla L(\theta_t) θt+1=θtηL(θt)
其中:

  • θ t \theta_t θt 为当前参数
  • η \eta η 为学习率
  • ∇ L ( θ t ) \nabla L(\theta_t) L(θt) 为损失函数的梯度
# 初始化网络
net = MLP(784, 10)
# 初始化loss
loss = nn.CrossEntropyLoss()
# 初始化优化器
optimizer = torch.optim.SGD(params=net.parameters(), lr=1e-3)

7. 训练与测试

7.1 训练过程

  1. 遍历每个批次的数据:
    • 将图像输入模型,计算预测结果。
    • 计算损失函数,反向传播计算梯度。
    • 使用优化器更新模型参数。
    • 清零梯度以避免累积。
  2. 每轮训练结束后保存模型状态。
'''
# Part4 循环训练计算损失
'''epochs = 10for epoch in range(epochs):for images, labels in train_dataloader:# 首先前向传播result = net(images)# 计算损失L = loss(result, labels)# 反向传播L.backward()# 参数更新optimizer.step()# 清除梯度optimizer.zero_grad()# 存储模型torch.save(net, 'checkpoint/module_epoch_{}.pth'.format(epoch))

7.2 测试过程

  1. 模型设置为评估模式,禁用梯度计算(torch.no_grad())。
  2. 遍历测试集,计算平均测试损失。
# 每个epoch在测试集跑一遍进行计算平均损失total_loss = 0total_batches = 0with torch.no_grad():for images_test, labels_test in Minist_test:# 形状是Batchsize*hanglabels_hat = net(images_test)L_test = loss(labels_hat, labels_test)total_loss += L_test.item()total_batches += 1# 计算平均测试损失并记录avg_test_loss = total_loss / total_batchesprint(f'第 {epoch + 1} 轮训练完成,平均测试损失为:{avg_test_loss}')

8. 结论

通过上述流程,我们成功实现了基于 FashionMNIST 数据集的分类模型。代码结构清晰,包含了数据加载、模型定义、训练与测试的完整过程,为深度学习图像分类任务提供了良好的实践基础。

参考

自己(好像会了好像又不会,容易忘记各种简单的操作,比如数据集存储的位置啥的):小菜鸟博士-CSDN博客

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词