torchvision.models.vgg 介绍
torchvision.models.vgg
是 PyTorch 中提供的 VGG 网络模型的模块,VGG 是一种经典的卷积神经网络架构。
通过 torchvision.models.vgg
,我们可以方便地加载各种 VGG 模型(如 VGG11、VGG13、VGG16 等),这些模型可以用于图像分类、特征提取等任务。
它有两个重要参数:
- weights 参数:指定是否加载预训练的权重。可以设置为
None
(不加载预训练权重)或者指定为某种权重(如VGG16_Weights.IMAGENET1K_V1
),这样可以使用在 ImageNet 数据集上预训练的权重,以提高模型的性能。 - progress 参数:在下载预训练权重时,是否显示下载进度。
使用 VGG 模型并不需要提前安装 ImageNet 数据集,但如果想使用预训练权重,这些权重是基于 ImageNet 数据集训练的,因此预训练模型的输出会有 1000 个类别。
vgg16_False与vgg16_True的解释
视频中的解释是准确的:
pretrained=False
:仅加载 VGG16 的网络结构,所有参数(权重和偏置)会使用默认的随机初始化。这不需要下载任何文件。pretrained=True
:加载在 ImageNet 数据集上训练好的参数,这会下载并加载每一层的权重参数。
更生动的解释: 可以把 VGG 模型比作一架新飞机:
- 加载网络模型(
pretrained=False
):就像是组装好飞机的框架,但里面的设备和油料都是空的,你需要自己加油和安装设备。 - 下载参数(
pretrained=True
):就像是加满油、安装好设备的飞机,已经可以飞行(即直接用于任务),因为它在 ImageNet 数据集上学到了许多特征。
print(vgg16_True) 和 print(vgg16_False) 的结果
无论 pretrained=True
还是 False
,调用 vgg16
都会得到同样的网络架构,因为 pretrained
只是决定是否加载预训练的权重,不会改变网络的结构。所以 print(vgg16_False)
和 print(vgg16_True)
输出的都是 VGG16 的结构。
VGG16 的输出类别修改
在图1的代码中,VGG16 的最后一层是 Linear(in_features=4096, out_features=1000, bias=True)
,表示输出 1000 个类别。CIFAR-10 数据集只有 10 个类别,因此我们需要修改模型的输出层。
使用 .add_module:
.add_module
是用于向模型的某个模块中添加新层的方法。
使用 .classifier.add_module:
- 在
classifier
部分添加新层
使用 .classifier[6]:
- 直接修改
classifier
的第7层