看到 center loss 可以像聚类一样,使用不同的核函数(可以看成是计算距离的函数)计算loss,记录一下这个demo。
import os
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torch import nnclass CenterLoss(nn.Module):"""Center loss.Reference:Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.Args:num_classes (int): number of classes.feat_dim (int): feature dimension."""def __init__(self, num_classes, feat_dim):super(CenterLoss, self).__init__()self.num_classes = num_classesself.feat_dim = feat_dimself.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))def forward(self, x, labels):"""Args:x: feature matrix with shape (batch_size, feat_dim).labels: ground truth labels with shape (batch_size)."""device = x.devicebatch_size = x.size(0)if device != self.centers.device:self.centers = self.centers.to(device)distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()distmat.addmm_(1, -2, x, self.centers.t())classes = torch.arange(self.num_classes).long().to(device)labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)mask = labels.eq(classes.expand(batch_size, self.num_classes))dist = distmat * mask.float()loss = dist.clamp(min=0, max=1e12).sum() / batch_sizereturn lossos.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.chdir(os.path.dirname(__file__) + "/../")
cLoss = CenterLoss(10, 128)
mnist_train = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor(), download=True)net = nn.Sequential(nn.Linear(28*28, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 128)
)net = net.cuda()
cLoss = cLoss.cuda()optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
optimizer_center = torch.optim.SGD(cLoss.parameters(), lr=0.5)train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)for epoch in range(50):net.train()cLoss.train()for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28).cuda()target = target.cuda()feature = net(data)loss1 = nn.CrossEntropyLoss()(feature, target)loss2 = cLoss(feature, target)loss = loss1 + 0.1 * loss2optimizer.zero_grad()optimizer_center.zero_grad()loss.backward()optimizer.step()optimizer_center.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))net.eval()cLoss.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data = data.view(-1, 28*28).cuda()target = target.cuda()feature = net(data)test_loss += nn.CrossEntropyLoss()(feature, target).item()pred = feature.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))