Pytorch | 从零构建MobileNet对CIFAR10进行分类
- CIFAR10数据集
- MobileNet
- 设计理念
- 网络结构
- 技术优势
- 应用领域
- MobileNet结构代码详解
- 结构代码
- 代码详解
- DepthwiseSeparableConv 类
- 初始化方法
- 前向传播 forward 方法
- MobileNet 类
- 初始化方法
- 前向传播 forward 方法
- 训练过程和测试结果
- 代码汇总
- mobilenet.py
- train.py
- test.py
前面文章我们构建了AlexNet、Vgg、GoogleNet对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
这篇文章我们来构建MobileNet.
CIFAR10数据集
CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:
- 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
- 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
- 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。
下面是一些示例样本:
MobileNet
MobileNet是由谷歌在2017年提出的一种轻量级卷积神经网络,主要用于移动端和嵌入式设备等资源受限的环境中进行图像识别和分类任务,以下是对其的详细介绍:
设计理念
- 深度可分离卷积:其核心创新是采用了深度可分离卷积(Depthwise Separable Convolution)来替代传统的卷积操作。深度可分离卷积将标准卷积分解为一个深度卷积(Depthwise Convolution)和一个逐点卷积(Pointwise Convolution),大大减少了计算量和模型参数,同时保持了较好的性能。
网络结构
- 标准卷积层:输入层为3通道的彩色图像,首先经过一个普通的卷积层
conv1
,将通道数从3变为32,同时进行了步长为2的下采样操作,以减小图像尺寸。 - 深度可分离卷积层:包含了一系列的深度可分离卷积层
dsconv1
至dsconv13
,这些层按照一定的规律进行排列,通道数逐渐增加,同时通过不同的步长进行下采样,以提取不同层次的特征。 - 池化层和全连接层:在深度可分离卷积层之后,通过一个自适应平均池化层
avgpool
将特征图转换为1x1的大小,然后通过一个全连接层fc
将特征映射到指定的类别数,完成分类任务。
技术优势
- 模型轻量化:通过深度可分离卷积的使用,大大减少了模型的参数量和计算量,使得模型更加轻量化,适合在移动设备和嵌入式设备上运行。
- 计算效率高:由于减少了计算量,MobileNet在推理时具有较高的计算效率,可以快速地对图像进行分类和识别,满足实时性要求较高的应用场景。
- 性能表现较好:尽管模型轻量化,但MobileNet在图像识别任务上仍然具有较好的性能表现,能够在保持较高准确率的同时,大大降低模型的复杂度。
应用领域
- 移动端视觉任务:广泛应用于各种移动端设备,如智能手机、平板电脑等,用于图像分类、目标检测、人脸识别等视觉任务。
- 嵌入式设备视觉:在嵌入式设备,如智能摄像头、自动驾驶汽车等领域,MobileNet可以为这些设备提供高效的视觉处理能力,实现实时的图像分析和决策。
- 物联网视觉应用:在物联网设备中,MobileNet可以帮助实现对图像数据的快速处理和分析,为智能家居、智能安防等应用提供支持。
MobileNet结构代码详解
结构代码
import torch
import torch.nn as nnclass DepthwiseSeparableConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):super(DepthwiseSeparableConv, self).__init__()self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)self.bn1 = nn.BatchNorm2d(in_channels)self.relu1 = nn.ReLU6(inplace=True)self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)self.bn2 = nn.BatchNorm2d(out_channels)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):out = self.depthwise(x)out = self.bn1(out)out = self.relu1(out)out = self.pointwise(out)out = self.bn2(out)out = self.relu2(out)return outclass MobileNet(nn.Module):def __init__(self, num_classes):super(MobileNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)self.relu = nn.ReLU6(inplace=True)self.dsconv1 = DepthwiseSeparableConv(32, 64, stride=1)self.dsconv2 = DepthwiseSeparableConv(64, 128, stride=2)self.dsconv3 = DepthwiseSeparableConv(128, 128, stride=1)self.dsconv4 = DepthwiseSeparableConv(128, 256, stride=2)self.dsconv5 = DepthwiseSeparableConv(256, 256, stride=1)self.dsconv6 = DepthwiseSeparableConv(256, 512, stride=2)self.dsconv7 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv8 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv9 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv10 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv11 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv12 = DepthwiseSeparableConv(512, 1024, stride=2)self.dsconv13 = DepthwiseSeparableConv(1024, 1024, stride=1)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(1024, num_classes)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.dsconv1(out)out = self.dsconv2(out)out = self.dsconv3(out)out = self.dsconv4(out)out = self.dsconv5(out)out = self.dsconv6(out)out = self.dsconv7(out)out = self.dsconv8(out)out = self.dsconv9(out)out = self.dsconv10(out)out = self.dsconv11(out)out = self.dsconv12(out)out = self.dsconv13(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
代码详解
以下是对上述代码的详细解释:
DepthwiseSeparableConv 类
这是一个自定义的深度可分离卷积层类,继承自 nn.Module
。
初始化方法
- 参数说明:
in_channels
:输入通道数,指定输入数据的通道数量。out_channels
:输出通道数,即卷积操作后输出特征图的通道数量。kernel_size
:卷积核大小,默认为3,用于定义卷积操作中卷积核的尺寸。stride
:步长,默认为1,控制卷积核在输入特征图上滑动的步长。padding
:填充大小,默认为1,在输入特征图周围添加的填充像素数量,以保持特征图尺寸在卷积过程中合适变化。bias
:是否使用偏置,默认为False
,决定卷积层是否添加偏置项。
- 构建的层及作用:
self.depthwise
:这是一个深度卷积层(nn.Conv2d
),通过设置groups=in_channels
,实现了深度可分离卷积中的深度卷积部分,它对每个输入通道分别进行卷积操作,有效地减少了计算量。self.bn1
:批归一化层(nn.BatchNorm2d
),用于对深度卷积后的输出进行归一化处理,加速模型收敛并提升模型的泛化能力。self.relu1
:激活函数层(nn.ReLU6
),采用ReLU6
激活函数(输出值限定在0到6之间),并且设置inplace=True
,意味着直接在输入的张量上进行修改,节省内存空间,增加非线性特性。self.pointwise
:逐点卷积层(nn.Conv2d
),卷积核大小为1,用于将深度卷积后的特征图在通道维度上进行融合,改变通道数到指定的out_channels
。self.bn2
:又是一个批归一化层,对逐点卷积后的输出进行归一化处理。self.relu2
:同样是ReLU6
激活函数层,进一步增加非线性,处理逐点卷积归一化后的结果。
前向传播 forward 方法
定义了数据在该层的前向传播过程:
- 首先将输入
x
通过深度卷积层self.depthwise
进行深度卷积操作,得到输出特征图。 - 然后将深度卷积的输出依次经过批归一化层
self.bn1
和激活函数层self.relu1
。 - 接着把经过处理后的特征图通过逐点卷积层
self.pointwise
进行逐点卷积,改变通道数等特征。 - 最后再经过批归一化层
self.bn2
和激活函数层self.relu2
,并返回最终的输出结果。
MobileNet 类
这是定义的 MobileNet 网络模型类,同样继承自 nn.Module
。
初始化方法
- 参数说明:
num_classes
:分类的类别数量,用于最后全连接层输出对应类别数的预测结果。
- 构建的层及作用:
self.conv1
:普通的二维卷积层(nn.Conv2d
),输入通道数为3(通常对应RGB图像的三个通道),输出通道数为32,卷积核大小为3,步长为2,用于对输入图像进行初步的特征提取和下采样,减少特征图尺寸同时增加通道数。self.bn1
:批归一化层,对conv1
卷积后的输出进行归一化。self.relu
:激活函数层,采用ReLU6
激活函数给特征图增加非线性。- 一系列的
self.dsconv
层(从dsconv1
到dsconv13
):都是前面定义的深度可分离卷积层DepthwiseSeparableConv
的实例,它们逐步对特征图进行更精细的特征提取、通道变换以及下采样等操作,不同的dsconv
层有着不同的输入输出通道数以及步长设置,以此构建出 MobileNet 网络的主体结构,不断提取和融合特征,逐步降低特征图尺寸并增加通道数来获取更高级、更抽象的特征表示。 self.avgpool
:自适应平均池化层(nn.AdaptiveAvgPool2d
),将输入特征图转换为指定大小(1, 1)
的输出,起到全局平均池化的作用,进一步压缩特征图信息,同时保持特征图的维度一致性,方便后续全连接层处理。self.fc
:全连接层(nn.Linear
),输入维度为1024(与前面网络结构最终输出的特征维度对应),输出维度为num_classes
,用于将经过前面卷积和池化等操作得到的特征向量映射到对应类别数量的预测分数上,实现分类任务。
前向传播 forward 方法
定义了 MobileNet 模型整体的前向传播流程:
- 首先将输入
x
通过conv1
进行初始卷积、bn1
进行归一化以及relu
激活。 - 然后依次通过各个深度可分离卷积层(
dsconv1
到dsconv13
),逐步提取和变换特征。 - 接着经过自适应平均池化层
self.avgpool
,将特征图压缩为(1, 1)
大小。 - 再通过
out.view(out.size(0), -1)
操作将特征图展平为一维向量(其中out.size(0)
表示批量大小,-1
表示自动计算剩余维度大小使其展平)。 - 最后将展平后的特征向量通过全连接层
self.fc
得到最终的分类预测结果并返回。
训练过程和测试结果
训练过程损失函数变化曲线:
训练过程准确率变化曲线:
测试结果:
代码汇总
项目github地址
项目结构:
|--data
|--models|--__init__.py|-mobilenet.py|--...
|--results
|--weights
|--train.py
|--test.py
mobilenet.py
import torch
import torch.nn as nnclass DepthwiseSeparableConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):super(DepthwiseSeparableConv, self).__init__()self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)self.bn1 = nn.BatchNorm2d(in_channels)self.relu1 = nn.ReLU6(inplace=True)self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)self.bn2 = nn.BatchNorm2d(out_channels)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):out = self.depthwise(x)out = self.bn1(out)out = self.relu1(out)out = self.pointwise(out)out = self.bn2(out)out = self.relu2(out)return outclass MobileNet(nn.Module):def __init__(self, num_classes):super(MobileNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)self.relu = nn.ReLU6(inplace=True)self.dsconv1 = DepthwiseSeparableConv(32, 64, stride=1)self.dsconv2 = DepthwiseSeparableConv(64, 128, stride=2)self.dsconv3 = DepthwiseSeparableConv(128, 128, stride=1)self.dsconv4 = DepthwiseSeparableConv(128, 256, stride=2)self.dsconv5 = DepthwiseSeparableConv(256, 256, stride=1)self.dsconv6 = DepthwiseSeparableConv(256, 512, stride=2)self.dsconv7 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv8 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv9 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv10 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv11 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv12 = DepthwiseSeparableConv(512, 1024, stride=2)self.dsconv13 = DepthwiseSeparableConv(1024, 1024, stride=1)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(1024, num_classes)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.dsconv1(out)out = self.dsconv2(out)out = self.dsconv3(out)out = self.dsconv4(out)out = self.dsconv5(out)out = self.dsconv6(out)out = self.dsconv7(out)out = self.dsconv8(out)out = self.dsconv9(out)out = self.dsconv10(out)out = self.dsconv11(out)out = self.dsconv12(out)out = self.dsconv13(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'MobileNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()
test.py
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'MobileNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')