知识蒸馏的概念
可以参照NeurIPS2015的论文“Distilling the Knowledge in a Neural Network”了解知识蒸馏的概念。
知识蒸馏的狭义概念就是从复杂模型中迁移知识来提升简单模型的性能。复杂模型称之为教师模型,简单模型称之为学生模型。最近,笔者重温了知识蒸馏的概念,并在CIFAR100数据集上对知识蒸馏进行了验证和实验。
logits,硬目标,软目标的概念:logits指的是网络最后一层的输出概率,硬目标指的是真值标签的one-hot编码,软目标指的是对logits进行softmax之后的概率。
加入温度系数的软目标,为了让softmax之后的概率分布更加软化,Hinton提出了使用了温度参数对logits进行softmax的软化处理,
T为温度,T越大,概率分布更加平缓。
数据集 CIFAR100,是一个经典的图像分类模型,有100个图像类别
数据集直接采用Pytorch定义的官方数据集进行加载
import torchvision
from torchvision import transformsCIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])train_dataset = torchvision.datasets.cifar.CIFAR100(root = "./dataset/",train=True,transform=transform_train,download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(root = "./dataset/",train = False,transform=transform_test,download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)
分类模型:采用ResNet50作为教师模型,VGG16作为学生模型。
VGG16网络定义代码
"""vgg in pytorch[1] Karen Simonyan, Andrew ZissermanVery Deep Convolutional Networks for Large-Scale Image Recognition.https://arxiv.org/abs/1409.1556v6
"""
'''VGG11/13/16/19 in Pytorch.'''import torch
import torch.nn as nncfg = {'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}class VGG(nn.Module):def __init__(self, features, num_class=100):super().__init__()self.features = featuresself.classifier = nn.Sequential(nn.Linear(512, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_class))def forward(self, x):output = self.features(x)output = output.view(output.size()[0], -1)output = self.classifier(output)return outputdef make_layers(cfg, batch_norm=False):layers = []input_channel = 3for l in cfg:if l == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]continuelayers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]if batch_norm:layers += [nn.BatchNorm2d(l)]layers += [nn.ReLU(inplace=True)]input_channel = lreturn nn.Sequential(*layers)def vgg16_bn():return VGG(make_layers(cfg['D'], batch_norm=True))
ResNet50网络定义代码
"""resnet in pytorch[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.Deep Residual Learning for Image Recognitionhttps://arxiv.org/abs/1512.03385v1
"""import torch
import torch.nn as nnclass BasicBlock(nn.Module):"""Basic Block for resnet 18 and resnet 34"""#BasicBlock and BottleNeck block#have different output size#we use class attribute expansion#to distinctexpansion = 1def __init__(self, in_channels, out_channels, stride=1):super().__init__()#residual functionself.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))#shortcutself.shortcut = nn.Sequential()#the shortcut output dimension is not the same with residual function#use 1*1 convolution to match the dimensionif stride != 1 or in_channels != BasicBlock.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class BottleNeck(nn.Module):"""Residual block for resnet over 50 layers"""expansion = 4def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels * BottleNeck.expansion),)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BottleNeck.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels * BottleNeck.expansion))def forward(self, x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class ResNet(nn.Module):def __init__(self, block, num_block, num_classes=100):super().__init__()self.in_channels = 64self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True))#we use a different inputsize than the original paper#so conv2_x's stride is 1self.conv2_x = self._make_layer(block, 64, num_block[0], 1)self.conv3_x = self._make_layer(block, 128, num_block[1], 2)self.conv4_x = self._make_layer(block, 256, num_block[2], 2)self.conv5_x = self._make_layer(block, 512, num_block[3], 2)self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride):"""make resnet layers(by layer i didnt mean this 'layer' was thesame as a neuron netowork layer, ex. conv layer), one layer maycontain more than one residual blockArgs:block: block type, basic block or bottle neck blockout_channels: output depth channel number of this layernum_blocks: how many blocks per layerstride: the stride of the first block of this layerReturn:return a resnet layer"""# we have num_block blocks per layer, the first block# could be 1 or 2, other blocks would always be 1strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):output = self.conv1(x)output = self.conv2_x(output)output = self.conv3_x(output)output = self.conv4_x(output)output = self.conv5_x(output)output = self.avg_pool(output)output = output.view(output.size(0), -1)output = self.fc(output)return outputdef resnet50():""" return a ResNet 50 object"""return ResNet(BottleNeck, [3, 4, 6, 3])
先单独训练教师模型和学生模型,分别统计教师模型学生模型的精度
损失函数 nn.CrossEntropyLoss()
优化器 torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
学习率曲线 torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
epochs = 200
教师模型训练代码
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_resnet import resnet50def TeacherModel():""" return a ResNet 50 object"""model = resnet50()return modeldevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])train_dataset = torchvision.datasets.cifar.CIFAR100(root = "./dataset/",train=True,transform=transform_train,download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(root = "./dataset/",train = False,transform=transform_test,download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)if __name__ == "__main__":"""从头训练教师模型"""model = TeacherModel().to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decayiter_per_epoch = len(train_loader)epochs = 200best_acc = 0.0global_step = 0for epoch in range(epochs):model.train()train_scheduler.step(epoch)for data, targets in tqdm(train_loader):data = data.to(device)targets = targets.to(device)optimizer.zero_grad()prediction = model(data)loss = criterion(prediction, targets)loss.backward()optimizer.step()global_step += 1model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)prediction = model(x)prediction = prediction.max(1).indicesnum_correct += (prediction == y).sum()num_samples += prediction.size(0)acc = (num_correct/num_samples).item()if acc > best_acc:torch.save(model.state_dict(), './weights/teacher_cifar100/teacher_{}.pth'.format(acc))best_acc = accprint("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))"""教师模型Epoch 199: 当前模型最佳精度为:0.7840"""
教师模型的分类精度为78.40%
学生模型的训练代码
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_vgg import vgg16_bndef StudentModel():model = vgg16_bn()return modeldevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])train_dataset = torchvision.datasets.cifar.CIFAR100(root = "./dataset/",train=True,transform=transform_train,download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(root = "./dataset/",train = False,transform=transform_test,download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)if __name__ == "__main__":"""从头训练学生模型"""model = StudentModel().to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decayiter_per_epoch = len(train_loader)epochs = 200best_acc = 0.0global_step = 0for epoch in range(epochs):model.train()train_scheduler.step(epoch)for data, targets in tqdm(train_loader):data = data.to(device)targets = targets.to(device)optimizer.zero_grad()prediction = model(data)loss = criterion(prediction, targets)loss.backward()optimizer.step()global_step += 1model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)prediction = model(x)prediction = prediction.max(1).indicesnum_correct += (prediction == y).sum()num_samples += prediction.size(0)acc = (num_correct/num_samples).item()if acc > best_acc:torch.save(model.state_dict(), './weights/student_cifar100_vgg16/student_{}.pth'.format(acc))best_acc = accprint("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))"""学生模型 VGG16Epoch 199: 当前模型最佳精度为:0.7121"""
学生模型的训练精度为71.21%
教师-学生模型蒸馏训练,学生损失为CE交叉熵损失,蒸馏损失为KL散度损失
重点一:蒸馏学生损失loss=(1-alpha) * T * T * soft_loss + alpha * hard_loss,其中alpha为权重参数,T为Temperature温度参数,用于软目标化
具体可参见 bilibili视频
重点二:蒸馏损失的计算方式,student_predictions需要处以温度参数后进行F.log_softmax变成软目标,teacher_predictions需要处以温度参数
distillation_loss = soft_loss(F.log_softmax(student_predictions / Temp, dim=1), F.softmax(teacher_predictions / Temp, dim=1))
重点三:教师模型需要eval(), 得到教师模型输出需要 with torch.no_grad()和.detach()
with torch.no_grad():teacher_predictions = teacher_model(data)teacher_predictions = teacher_predictions.detach()
重点四:损失权重参数alpha和温度系数T的设定,笔者参照bilibili视频的设定,设置alpha为0.3,温度系数T为4
蒸馏训练代码
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
from teacher_cifar100 import TeacherModel
from vgg_student_cifar100 import StudentModeltorch.manual_seed(0)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")torch.backends.cudnn.benchmark = TrueCIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])#load MNIST datasets
train_dataset = torchvision.datasets.cifar.CIFAR100(root = "./dataset/",train=True,transform=transform_train,download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(root = "./dataset/",train = False,transform=transform_test,download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)if __name__ == "__main__":"""从头训练教师模型"""teacher_model = TeacherModel().to(device).eval()teacher_model.load_state_dict(torch.load("./weights/teacher_cifar100/teacher_0.7839999794960022.pth"))student_model = StudentModel().to(device)Temp = 4alpha = 0.3hard_loss = nn.CrossEntropyLoss()soft_loss = nn.KLDivLoss(reduction='batchmean')optimizer = torch.optim.SGD(student_model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decayiter_per_epoch = len(train_loader)epochs = 200best_acc = 0.0global_step = 0for epoch in range(epochs):student_model.train()train_scheduler.step(epoch)for data, targets in tqdm(train_loader):data = data.to(device)targets = targets.to(device)optimizer.zero_grad()#教师预测with torch.no_grad():teacher_predictions = teacher_model(data)teacher_predictions = teacher_predictions.detach() #参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757student_predictions = student_model(data)student_loss = hard_loss(student_predictions, targets)distillation_loss = soft_loss(F.log_softmax(student_predictions / Temp, dim=1), ##参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757F.softmax(teacher_predictions / Temp, dim=1))loss = (1 - alpha) * Temp * Temp * distillation_loss + alpha * student_loss #T2 参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757loss.backward()optimizer.step()global_step += 1student_model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_loader:x = x.to(device)y = y.to(device)prediction = student_model(x)prediction = prediction.max(1).indicesnum_correct += (prediction == y).sum()num_samples += prediction.size(0)acc = (num_correct/num_samples).item()if acc > best_acc:torch.save(student_model.state_dict(), './weights/knowledge_distillation_cifar100_vgg16/student_{}.pth'.format(acc))best_acc = accprint("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))"""蒸馏学生模型 ResNet50 --> VGG16ResNet50 当前模型最佳精度为:0.7840VGG16 当前模型最佳精度为:0.7121Temp = 4 alpha = 0.3 Acc Epoch 199: 当前模型最佳精度为:0.7388"""
知识蒸馏实验对比结果
模型 | 网络结构 | 分类精度 |
---|---|---|
学生模型 | VGG16 | 71.21% |
教师模型 | ResNet50 | 78.40% |
蒸馏学生模型 | VGG16 | 73.88% |
实验总结分析
通过在CIFAR100数据集上的从ResNet50到VGG16的教师-学生模型的蒸馏实验,表明了Hinton等人提出的知识蒸馏的有效性。同时,通过实验的细节设置,笔者注意到了知识蒸馏的几个设置,soft_loss的计算有F.softmax和F.log_softmax的区别,教师模型需要eval和detach消除梯度,温度参数T和损失平衡系数alpha的选择,soft_loss需要乘以T2的系数,都是需要注意的细节问题。
致谢
[1] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean, “Distilling the Knowledge in a Neural Network,” in NeurIPS 2025.
[2] https://github.com/weiaicunzai/pytorch-cifar100
[3] https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757