一、分类头查找
在 PyTorch 的预训练模型(如 ResNet
、VGG
、DenseNet
等)中,model.fc
通常表示模型的最后一层全连接层(Fully Connected Layer),也就是分类头(Classifier Head)。它的作用是将提取的特征映射到最终的分类结果上。不同模型的分类头名称可能不同(如 classifier
),需通过 print(model)
确认。
1. model.fc
的具体含义
-
fc
是 "Fully Connected" 的缩写,即全连接层。 -
在分类任务中,这一层的作用是将网络提取的全局特征转换为类别概率分布。
-
例如:
-
ResNet
的model.fc
默认输出 1000 维(对应 ImageNet 的 1000 类)。 -
如果改成自己的任务(比如 10 分类),需要替换这一层:
model.fc = nn.Linear(in_features, 10) # 改为输出10类
-
2. 为什么是 model.fc
?
PyTorch 的预训练模型通常将全连接层命名为 fc
,例如:
-
ResNet:
model.fc
-
DenseNet:
model.classifier
(注意命名可能不同) -
VGG:
model.classifier[6]
(VGG 的全连接层是一个序列,最后一层是分类头)
可以通过打印模型结构确认:
from torchvision.models import resnet50
model = resnet50()
print(model) # 查看最后一层的名称
3. 如何修改 model.fc
?
场景1:直接替换分类头(适应新类别数)
import torch.nn as nn
from torchvision.models import resnet50model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features # 获取输入特征维度(如2048