>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
📌本周任务:📌
– 1.请根据本文TensorFlow代码,编写出相应的Pytorch代码(建议使用上周的数据测试一下模型是否构建正确)
– 2.了解ResNetV2与ResNetV的区别
– 3.改进思路是否可以迁移到其他地方呢(自由探索)
🏡 我的环境:
- 语言环境:Python3.8
- 编译器:Jupyter Notebook
- 深度学习环境:Pytorch
-
- torch==2.3.1+cu118
-
- torchvision==0.18.1+cu118
本文完全根据 第J2周:ResNet50V2算法实战与解析(TensorFlow版)中的内容转换为pytorch版本,所以前述性的内容不在一一重复,仅就pytorch版本中的内容进行叙述。
一、 前期准备
1. 设置GPU
如果设备上支持GPU就使用GPU,否则使用CPU
import warnings
warnings.filterwarnings("ignore") #忽略警告信息import torch
device=torch.device("cuda" if torch.cuda.is_available() else "CPU")
device
运行结果:
device(type='cuda')
2. 导入数据
import pathlibdata_dir=r"D:\THE MNIST DATABASE\J-series\J1\bird_photos"
data_dir=pathlib.Path(data_dir)
查看数据集中图片的数量
image_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
运行结果:
图片总数为: 565
3. 查看数据集分类
data_paths=list(data_dir.glob('*'))
classNames=[str(path).split("\\")[5] for path in data_paths]
classNames
运行结果:
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
4. 随机查看图片
随机抽取数据集中的20张图片进行查看
import PIL,random
import matplotlib.pyplot as plt
from PIL import Imagedata_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,4))
for i in range(20):plt.subplot(2,10,i+1)plt.axis("off")image=random.choice(data_paths2) #随机选择一个图片plt.title(image.parts[-2]) #通过glob对象取出他的文件夹名称,即分类名plt.imshow(Image.open(str(image))) #显示图片
运行结果:
5. 图片预处理
import torchvision.transforms as transforms
from torchvision import transforms,datasetstrain_transforms=transforms.Compose([transforms.Resize([224,224]), #将图片统一尺寸transforms.RandomHorizontalFlip(), #将图片随机水平翻转transforms.RandomRotation(0.2), #将图片按照 0.2 的弧度值随机旋转transforms.ToTensor(), #将图片转换为tensortransforms.Normalize( #标准化处理-->转换为正态分布,使模型更容易收敛mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])total_data=datasets.ImageFolder(r"D:\THE MNIST DATABASE\J-series\J1\bird_photos",transform=train_transforms
)
total_data
运行结果:
Dataset ImageFolderNumber of datapoints: 565Root location: D:\THE MNIST DATABASE\J-series\J1\bird_photosStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)RandomHorizontalFlip(p=0.5)RandomRotation(degrees=[-0.2, 0.2], interpolation=nearest, expand=False, fill=0)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
将数据集分类情况进行映射输出:
total_data.class_to_idx
运行结果:
{'Bananaquit': 0,'Black Skimmer': 1,'Black Throated Bushtiti': 2,'Cockatoo': 3}
6. 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_sizetrain_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size]
)
train_dataset,test_dataset
运行结果:
(<torch.utils.data.dataset.Subset at 0x1fc1cb43e90>,<torch.utils.data.dataset.Subset at 0x1fc1cb43f10>)
查看训练集和测试集的数据数量:
train_size,test_size
运行结果:
(452, 113)
7. 加载数据集
batch_size=16
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1
)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1
)
查看测试集的情况:
for x,y in train_dl:print("Shape of x [N,C,H,W]:",x.shape)print("Shape of y:",y.shape,y.dtype)break
运行结果:
Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64
二、手动搭建ResNet50V2模型
1.Residual Block
import torch.nn as nn#Residual Block
class Block2(nn.Module):def __init__(self,in_channel,filters,kernel_size=3,stride=1,conv_shortcut=False):super(Block2,self).__init__()self.preact=nn.Sequential(nn.BatchNorm2d(in_channel),nn.ReLU())self.shortcut=conv_shortcutif self.shortcut:self.short=nn.Conv2d(in_channel,4*filters,1,stride=stride,padding=0,bias=False)elif stride>1:self.short=nn.MaxPool2d(kernel_size=1,stride=stride,padding=0)else:self.short=nn.Identity()self.conv1=nn.Sequential(nn.Conv2d(in_channel,filters,1,stride=1,bias=False),nn.BatchNorm2d(filters),nn.ReLU())self.conv2=nn.Sequential(nn.Conv2d(filters,filters,kernel_size,stride=stride,padding=1,bias=False),nn.BatchNorm2d(filters),nn.ReLU())self.conv3=nn.Conv2d(filters,4*filters,1,stride=1,bias=False)def forward(self,x):x1=self.preact(x)if self.shortcut:x2=self.short(x1)else:x2=self.short(x)x1=self.conv1(x1)x1=self.conv2(x1)x1=self.conv3(x1)x=x1+x2return x
2.堆叠Residual Block
class Stack2(nn.Module):def __init__(self,in_channel,filters,blocks,stride=2):super(Stack2,self).__init__()self.conv=nn.Sequential()self.conv.add_module(str(0),Block2(in_channel,filters,conv_shortcut=True))for i in range(1,blocks-1):self.conv.add_module(str(i),Block2(4*filters,filters))self.conv.add_module(str(blocks-1),Block2(4*filters,filters,stride=stride))def forward(self,x):x=self.conv(x)return x
3.ResNet50V2架构复现
#构建ResNet50V2
class ResNet50V2(nn.Module):def __init__(self,include_top=True, #是否包含位于网络顶部的全连接层preact=True, #是否使用预激活use_bias=True, #是否对卷积层使用偏置input_shape=[224,224,3], #classes=1000,pooling=None #用于分类图像的可选类数):super(ResNet50V2,self).__init__()self.conv1=nn.Sequential()self.conv1.add_module('conv',nn.Conv2d(3,64,7,stride=2,padding=3,bias=use_bias,padding_mode='zeros'))if not preact:self.conv1.add_module('bn',nn.BatchNorm2d(64))self.conv1.add_module('relu',nn.ReLU())self.conv1.add_module('max_pool',nn.MaxPool2d(kernel_size=3,stride=2,padding=1))self.conv2=Stack2(64,64,3)self.conv3=Stack2(256,128,4)self.conv4=Stack2(512,256,6)self.conv5=Stack2(1024,512,3,stride=1)self.post=nn.Sequential()if preact:self.post.add_module('bn',nn.BatchNorm2d(2048))self.post.add_module('relu',nn.ReLU())if include_top:self.post.add_module('avg_pool',nn.AdaptiveAvgPool2d((1,1)))self.post.add_module('flatten',nn.Flatten())self.post.add_module('fc',nn.Linear(2048,4))else:if pooling=='avg':self.post.add_module('avg_pool',nn.AdaptiveAvgPool2d((1,1)))elif pooling=='max':self.post.add_module('max_pool',nn.AdaptiveAvgPool2d((1,1)))def forward(self,x):x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)x=self.conv4(x)x=self.conv5(x)x=self.post(x)return x
4.查看模型结构
model=ResNet50V2().to(device)
model
运行结果:
ResNet50V2((conv1): Sequential((conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))(conv2): Stack2((conv): Sequential((0): Block2((preact): Sequential((0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Sequential((0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))(1): Block2((preact): Sequential((0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))(2): Block2((preact): Sequential((0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): MaxPool2d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)(conv1): Sequential((0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(conv3): Stack2((conv): Sequential((0): Block2((preact): Sequential((0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Sequential((0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))(1): Block2((preact): Sequential((0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))(2): Block2((preact): Sequential((0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))(3): Block2((preact): Sequential((0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): MaxPool2d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)(conv1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))))(conv4): Stack2((conv): Sequential((0): Block2((preact): Sequential((0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Sequential((0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False))(1): Block2((preact): Sequential((0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False))(2): Block2((preact): Sequential((0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False))(3): Block2((preact): Sequential((0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False))(4): Block2((preact): Sequential((0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False))(5): Block2((preact): Sequential((0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): MaxPool2d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)(conv1): Sequential((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False))))(conv5): Stack2((conv): Sequential((0): Block2((preact): Sequential((0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(conv1): Sequential((0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False))(1): Block2((preact): Sequential((0): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False))(2): Block2((preact): Sequential((0): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(1): ReLU())(short): Identity()(conv1): Sequential((0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv2): Sequential((0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False))))(post): Sequential((bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU()(avg_pool): AdaptiveAvgPool2d(output_size=(1, 1))(flatten): Flatten(start_dim=1, end_dim=-1)(fc): Linear(in_features=2048, out_features=4, bias=True))
)
5.网络结构打印
#统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model,(3,224,224))
运行结果:
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 64, 112, 112] 9,472MaxPool2d-2 [-1, 64, 56, 56] 0BatchNorm2d-3 [-1, 64, 56, 56] 128ReLU-4 [-1, 64, 56, 56] 0Conv2d-5 [-1, 256, 56, 56] 16,384Conv2d-6 [-1, 64, 56, 56] 4,096BatchNorm2d-7 [-1, 64, 56, 56] 128ReLU-8 [-1, 64, 56, 56] 0Conv2d-9 [-1, 64, 56, 56] 36,864BatchNorm2d-10 [-1, 64, 56, 56] 128ReLU-11 [-1, 64, 56, 56] 0Conv2d-12 [-1, 256, 56, 56] 16,384Block2-13 [-1, 256, 56, 56] 0BatchNorm2d-14 [-1, 256, 56, 56] 512ReLU-15 [-1, 256, 56, 56] 0Identity-16 [-1, 256, 56, 56] 0Conv2d-17 [-1, 64, 56, 56] 16,384BatchNorm2d-18 [-1, 64, 56, 56] 128ReLU-19 [-1, 64, 56, 56] 0Conv2d-20 [-1, 64, 56, 56] 36,864BatchNorm2d-21 [-1, 64, 56, 56] 128ReLU-22 [-1, 64, 56, 56] 0Conv2d-23 [-1, 256, 56, 56] 16,384Block2-24 [-1, 256, 56, 56] 0BatchNorm2d-25 [-1, 256, 56, 56] 512ReLU-26 [-1, 256, 56, 56] 0MaxPool2d-27 [-1, 256, 28, 28] 0Conv2d-28 [-1, 64, 56, 56] 16,384BatchNorm2d-29 [-1, 64, 56, 56] 128ReLU-30 [-1, 64, 56, 56] 0Conv2d-31 [-1, 64, 28, 28] 36,864BatchNorm2d-32 [-1, 64, 28, 28] 128ReLU-33 [-1, 64, 28, 28] 0Conv2d-34 [-1, 256, 28, 28] 16,384Block2-35 [-1, 256, 28, 28] 0Stack2-36 [-1, 256, 28, 28] 0BatchNorm2d-37 [-1, 256, 28, 28] 512ReLU-38 [-1, 256, 28, 28] 0Conv2d-39 [-1, 512, 28, 28] 131,072Conv2d-40 [-1, 128, 28, 28] 32,768BatchNorm2d-41 [-1, 128, 28, 28] 256ReLU-42 [-1, 128, 28, 28] 0Conv2d-43 [-1, 128, 28, 28] 147,456BatchNorm2d-44 [-1, 128, 28, 28] 256ReLU-45 [-1, 128, 28, 28] 0Conv2d-46 [-1, 512, 28, 28] 65,536Block2-47 [-1, 512, 28, 28] 0BatchNorm2d-48 [-1, 512, 28, 28] 1,024ReLU-49 [-1, 512, 28, 28] 0Identity-50 [-1, 512, 28, 28] 0Conv2d-51 [-1, 128, 28, 28] 65,536BatchNorm2d-52 [-1, 128, 28, 28] 256ReLU-53 [-1, 128, 28, 28] 0Conv2d-54 [-1, 128, 28, 28] 147,456BatchNorm2d-55 [-1, 128, 28, 28] 256ReLU-56 [-1, 128, 28, 28] 0Conv2d-57 [-1, 512, 28, 28] 65,536Block2-58 [-1, 512, 28, 28] 0BatchNorm2d-59 [-1, 512, 28, 28] 1,024ReLU-60 [-1, 512, 28, 28] 0Identity-61 [-1, 512, 28, 28] 0Conv2d-62 [-1, 128, 28, 28] 65,536BatchNorm2d-63 [-1, 128, 28, 28] 256ReLU-64 [-1, 128, 28, 28] 0Conv2d-65 [-1, 128, 28, 28] 147,456BatchNorm2d-66 [-1, 128, 28, 28] 256ReLU-67 [-1, 128, 28, 28] 0Conv2d-68 [-1, 512, 28, 28] 65,536Block2-69 [-1, 512, 28, 28] 0BatchNorm2d-70 [-1, 512, 28, 28] 1,024ReLU-71 [-1, 512, 28, 28] 0MaxPool2d-72 [-1, 512, 14, 14] 0Conv2d-73 [-1, 128, 28, 28] 65,536BatchNorm2d-74 [-1, 128, 28, 28] 256ReLU-75 [-1, 128, 28, 28] 0Conv2d-76 [-1, 128, 14, 14] 147,456BatchNorm2d-77 [-1, 128, 14, 14] 256ReLU-78 [-1, 128, 14, 14] 0Conv2d-79 [-1, 512, 14, 14] 65,536Block2-80 [-1, 512, 14, 14] 0Stack2-81 [-1, 512, 14, 14] 0BatchNorm2d-82 [-1, 512, 14, 14] 1,024ReLU-83 [-1, 512, 14, 14] 0Conv2d-84 [-1, 1024, 14, 14] 524,288Conv2d-85 [-1, 256, 14, 14] 131,072BatchNorm2d-86 [-1, 256, 14, 14] 512ReLU-87 [-1, 256, 14, 14] 0Conv2d-88 [-1, 256, 14, 14] 589,824BatchNorm2d-89 [-1, 256, 14, 14] 512ReLU-90 [-1, 256, 14, 14] 0Conv2d-91 [-1, 1024, 14, 14] 262,144Block2-92 [-1, 1024, 14, 14] 0BatchNorm2d-93 [-1, 1024, 14, 14] 2,048ReLU-94 [-1, 1024, 14, 14] 0Identity-95 [-1, 1024, 14, 14] 0Conv2d-96 [-1, 256, 14, 14] 262,144BatchNorm2d-97 [-1, 256, 14, 14] 512ReLU-98 [-1, 256, 14, 14] 0Conv2d-99 [-1, 256, 14, 14] 589,824BatchNorm2d-100 [-1, 256, 14, 14] 512ReLU-101 [-1, 256, 14, 14] 0Conv2d-102 [-1, 1024, 14, 14] 262,144Block2-103 [-1, 1024, 14, 14] 0BatchNorm2d-104 [-1, 1024, 14, 14] 2,048ReLU-105 [-1, 1024, 14, 14] 0Identity-106 [-1, 1024, 14, 14] 0Conv2d-107 [-1, 256, 14, 14] 262,144BatchNorm2d-108 [-1, 256, 14, 14] 512ReLU-109 [-1, 256, 14, 14] 0Conv2d-110 [-1, 256, 14, 14] 589,824BatchNorm2d-111 [-1, 256, 14, 14] 512ReLU-112 [-1, 256, 14, 14] 0Conv2d-113 [-1, 1024, 14, 14] 262,144Block2-114 [-1, 1024, 14, 14] 0BatchNorm2d-115 [-1, 1024, 14, 14] 2,048ReLU-116 [-1, 1024, 14, 14] 0Identity-117 [-1, 1024, 14, 14] 0Conv2d-118 [-1, 256, 14, 14] 262,144BatchNorm2d-119 [-1, 256, 14, 14] 512ReLU-120 [-1, 256, 14, 14] 0Conv2d-121 [-1, 256, 14, 14] 589,824BatchNorm2d-122 [-1, 256, 14, 14] 512ReLU-123 [-1, 256, 14, 14] 0Conv2d-124 [-1, 1024, 14, 14] 262,144Block2-125 [-1, 1024, 14, 14] 0BatchNorm2d-126 [-1, 1024, 14, 14] 2,048ReLU-127 [-1, 1024, 14, 14] 0Identity-128 [-1, 1024, 14, 14] 0Conv2d-129 [-1, 256, 14, 14] 262,144BatchNorm2d-130 [-1, 256, 14, 14] 512ReLU-131 [-1, 256, 14, 14] 0Conv2d-132 [-1, 256, 14, 14] 589,824BatchNorm2d-133 [-1, 256, 14, 14] 512ReLU-134 [-1, 256, 14, 14] 0Conv2d-135 [-1, 1024, 14, 14] 262,144Block2-136 [-1, 1024, 14, 14] 0BatchNorm2d-137 [-1, 1024, 14, 14] 2,048ReLU-138 [-1, 1024, 14, 14] 0MaxPool2d-139 [-1, 1024, 7, 7] 0Conv2d-140 [-1, 256, 14, 14] 262,144BatchNorm2d-141 [-1, 256, 14, 14] 512ReLU-142 [-1, 256, 14, 14] 0Conv2d-143 [-1, 256, 7, 7] 589,824BatchNorm2d-144 [-1, 256, 7, 7] 512ReLU-145 [-1, 256, 7, 7] 0Conv2d-146 [-1, 1024, 7, 7] 262,144Block2-147 [-1, 1024, 7, 7] 0Stack2-148 [-1, 1024, 7, 7] 0BatchNorm2d-149 [-1, 1024, 7, 7] 2,048ReLU-150 [-1, 1024, 7, 7] 0Conv2d-151 [-1, 2048, 7, 7] 2,097,152Conv2d-152 [-1, 512, 7, 7] 524,288BatchNorm2d-153 [-1, 512, 7, 7] 1,024ReLU-154 [-1, 512, 7, 7] 0Conv2d-155 [-1, 512, 7, 7] 2,359,296BatchNorm2d-156 [-1, 512, 7, 7] 1,024ReLU-157 [-1, 512, 7, 7] 0Conv2d-158 [-1, 2048, 7, 7] 1,048,576Block2-159 [-1, 2048, 7, 7] 0BatchNorm2d-160 [-1, 2048, 7, 7] 4,096ReLU-161 [-1, 2048, 7, 7] 0Identity-162 [-1, 2048, 7, 7] 0Conv2d-163 [-1, 512, 7, 7] 1,048,576BatchNorm2d-164 [-1, 512, 7, 7] 1,024ReLU-165 [-1, 512, 7, 7] 0Conv2d-166 [-1, 512, 7, 7] 2,359,296BatchNorm2d-167 [-1, 512, 7, 7] 1,024ReLU-168 [-1, 512, 7, 7] 0Conv2d-169 [-1, 2048, 7, 7] 1,048,576Block2-170 [-1, 2048, 7, 7] 0BatchNorm2d-171 [-1, 2048, 7, 7] 4,096ReLU-172 [-1, 2048, 7, 7] 0Identity-173 [-1, 2048, 7, 7] 0Conv2d-174 [-1, 512, 7, 7] 1,048,576BatchNorm2d-175 [-1, 512, 7, 7] 1,024ReLU-176 [-1, 512, 7, 7] 0Conv2d-177 [-1, 512, 7, 7] 2,359,296BatchNorm2d-178 [-1, 512, 7, 7] 1,024ReLU-179 [-1, 512, 7, 7] 0Conv2d-180 [-1, 2048, 7, 7] 1,048,576Block2-181 [-1, 2048, 7, 7] 0Stack2-182 [-1, 2048, 7, 7] 0BatchNorm2d-183 [-1, 2048, 7, 7] 4,096ReLU-184 [-1, 2048, 7, 7] 0
AdaptiveAvgPool2d-185 [-1, 2048, 1, 1] 0Flatten-186 [-1, 2048] 0Linear-187 [-1, 4] 8,196
================================================================
Total params: 23,508,612
Trainable params: 23,508,612
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 241.68
Params size (MB): 89.68
Estimated Total Size (MB): 331.93
----------------------------------------------------------------
三、 训练模型
1. 编写训练函数
#训练函数
def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset) #训练集的大小num_batches=len(dataloader) #批次数目train_loss,train_acc=0,0 #初始化训练损失和正确率for x,y in dataloader: #获取图片及其标签x,y=x.to(device),y.to(device)#计算预测误差pred=model(x) #网络输出loss=loss_fn(pred,y) #计算网络输出和真实值之间的差距,二者差值即为损失#反向传播optimizer.zero_grad() #grad属性归零loss.backward() #反向传播optimizer.step() #每一步自动更新#记录acc和losstrain_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_loss+=loss.item()train_acc/=sizetrain_loss/=num_batchesreturn train_acc,train_loss
2. 编写测试函数
测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
#测试函数
def test(dataloader,model,loss_fn):size=len(dataloader.dataset) #测试集的大小num_batches=len(dataloader) #批次数目test_loss,test_acc=0,0#当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)#计算losstarget_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_loss
3. 正式训练
import copy
opt=torch.optim.Adam(model.parameters(),lr=1e-4) #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss() #创建损失函数epochs=10train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]best_acc=0 #设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)#保存最佳模型到J2_modelif epoch_test_acc>best_acc:best_acc=epoch_test_accJ2_model=copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)#获取当前学习率lr=opt.state_dict()['param_groups'][0]['lr']template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))
#保存最佳模型到文件中
PATH=r'D:\THE MNIST DATABASE\J-series\J2_model.pth'
torch.save(model.state_dict(),PATH)
运行结果:
Epoch: 1,Train_acc:75.2%,Train_loss:0.674,Test_acc:70.8%,Test_loss:1.028,Lr:1.00E-04
Epoch: 2,Train_acc:76.8%,Train_loss:0.629,Test_acc:75.2%,Test_loss:0.868,Lr:1.00E-04
Epoch: 3,Train_acc:86.9%,Train_loss:0.398,Test_acc:79.6%,Test_loss:1.163,Lr:1.00E-04
Epoch: 4,Train_acc:91.4%,Train_loss:0.290,Test_acc:74.3%,Test_loss:0.954,Lr:1.00E-04
Epoch: 5,Train_acc:90.0%,Train_loss:0.296,Test_acc:76.1%,Test_loss:0.877,Lr:1.00E-04
Epoch: 6,Train_acc:88.1%,Train_loss:0.323,Test_acc:58.4%,Test_loss:1.141,Lr:1.00E-04
Epoch: 7,Train_acc:92.3%,Train_loss:0.243,Test_acc:74.3%,Test_loss:0.770,Lr:1.00E-04
Epoch: 8,Train_acc:94.5%,Train_loss:0.179,Test_acc:77.9%,Test_loss:1.187,Lr:1.00E-04
Epoch: 9,Train_acc:95.8%,Train_loss:0.139,Test_acc:77.9%,Test_loss:0.919,Lr:1.00E-04
Epoch:10,Train_acc:94.9%,Train_loss:0.159,Test_acc:84.1%,Test_loss:0.508,Lr:1.00E-04
四、 结果可视化
1. Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei'] #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #正常显示负号
plt.rcParams['figure.dpi']=300 #分辨率epochs_range=range(epochs)
plt.figure(figsize=(12,3))plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
运行结果:
2. 指定图片进行预测
from PIL import Imageclasses=list(total_data.class_to_idx)def predict_one_image(image_path,model,transform,classes):test_img=Image.open(image_path).convert('RGB')plt.imshow(test_img) #展示预测的图片test_img=transform(test_img)img=test_img.to(device).unsqueeze(0)model.eval()output=model(img)_,pred=torch.max(output,1)pred_class=classes[pred]print(f'预测结果是:{pred_class}')
预测图片:
#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\J-series\J1\bird_photos\Black Skimmer\001.jpg',model=model,transform=train_transforms,classes=classes)
运行结果:
预测结果是:Black Skimmer
3. 模型评估
J2_model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,J2_model,loss_fn)
epoch_test_acc,epoch_test_loss
运行结果:
(0.8495575221238938, 0.5142213940271176)
五、心得体会
本周项目训练中,在pytorch环境下手动搭建了ResNet50V2模型,与ResNet50模型相比,残差模型将BN和ReLU进行了前置,在一定程度上有效地提升了模型的准确率。但是在本项目中,模型结果不尽人意,虽然也对数据集进行了增强处理,可能是由于数据过小引起的,留待今后验证。