欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 产业 > [center loss] demo

[center loss] demo

2024/10/24 20:12:12 来源:https://blog.csdn.net/weixin_51552032/article/details/141832768  浏览:    关键词:[center loss] demo

看到 center loss 可以像聚类一样,使用不同的核函数(可以看成是计算距离的函数)计算loss,记录一下这个demo。


import os
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torch import nnclass CenterLoss(nn.Module):"""Center loss.Reference:Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.Args:num_classes (int): number of classes.feat_dim (int): feature dimension."""def __init__(self, num_classes, feat_dim):super(CenterLoss, self).__init__()self.num_classes = num_classesself.feat_dim = feat_dimself.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))def forward(self, x, labels):"""Args:x: feature matrix with shape (batch_size, feat_dim).labels: ground truth labels with shape (batch_size)."""device = x.devicebatch_size = x.size(0)if device != self.centers.device:self.centers = self.centers.to(device)distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()distmat.addmm_(1, -2, x, self.centers.t())classes = torch.arange(self.num_classes).long().to(device)labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)mask = labels.eq(classes.expand(batch_size, self.num_classes))dist = distmat * mask.float()loss = dist.clamp(min=0, max=1e12).sum() / batch_sizereturn lossos.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.chdir(os.path.dirname(__file__) + "/../")
cLoss  = CenterLoss(10, 128)
mnist_train = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor(), download=True)net = nn.Sequential(nn.Linear(28*28, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 128)
)net = net.cuda()
cLoss = cLoss.cuda()optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
optimizer_center = torch.optim.SGD(cLoss.parameters(), lr=0.5)train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)for epoch in range(50):net.train()cLoss.train()for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28).cuda()target = target.cuda()feature = net(data)loss1 = nn.CrossEntropyLoss()(feature, target)loss2 = cLoss(feature, target)loss = loss1 + 0.1 * loss2optimizer.zero_grad()optimizer_center.zero_grad()loss.backward()optimizer.step()optimizer_center.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))net.eval()cLoss.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data = data.view(-1, 28*28).cuda()target = target.cuda()feature = net(data)test_loss += nn.CrossEntropyLoss()(feature, target).item()pred = feature.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com