Focal Loss 是一种专门设计用于处理类别不平衡问题的损失函数。与标准的 CrossEntropyLoss
不同,Focal Loss 通过引入一个调节因子,减少了模型在容易区分的样本上的损失,专注于难分类的样本。它尤其适合在正负样本分布严重不均衡的场景中使用。
公式为:
Focal Loss = − α t ( 1 − p t ) γ log ( p t ) \text{Focal Loss} = -\alpha_t (1 - p_t)^\gamma \log(p_t) Focal Loss=−αt(1−pt)γlog(pt)
其中:
- p t p_t pt 是模型对于正确类别的预测概率。
- α t \alpha_t αt 是权重因子,通常用于平衡正负样本。
- γ \gamma γ 是一个调节因子,用于控制容易分类样本对总损失的影响,常取值为 2。
使用 Focal Loss 的步骤
- 计算 Cross Entropy Loss:这是 Focal Loss 的基础。
- 计算调节因子:根据模型预测的概率,计算难易度因子 ( 1 − p t ) γ (1 - p_t)^\gamma (1−pt)γ。
- 组合计算 Focal Loss:将 Cross Entropy 和调节因子结合得到最终的损失。
代码实现
你可以通过自定义 Focal Loss 来替换标准的 CrossEntropyLoss
,具体实现如下:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass FocalLoss(nn.Module):def __init__(self, gamma=2, alpha=None, reduction='mean'):""":param gamma: Focusing parameter. Default is 2.:param alpha: Weighting factor for class imbalance. Default is None.:param reduction: Specifies the reduction to apply to the output: 'none', 'mean' or 'sum'."""super(FocalLoss, self).__init__()self.gamma = gammaself.alpha = alpha # Can be a scalar (for binary) or a tensor (for multi-class)self.reduction = reductiondef forward(self, inputs, targets):# Cross entropy loss (without reduction, so we can apply custom weight)ce_loss = F.cross_entropy(inputs, targets, reduction='none')# Compute the probabilities of the correct classp_t = torch.exp(-ce_loss) # Equivalent to exp(log(p_t)) = p_t# Calculate the focal weight (1 - p_t)^gammafocal_weight = (1 - p_t) ** self.gamma# Apply alpha balancing if providedif self.alpha is not None:alpha_t = self.alpha[targets]focal_weight = alpha_t * focal_weight# Combine focal weight and cross entropy lossfocal_loss = focal_weight * ce_lossif self.reduction == 'mean':return focal_loss.mean()elif self.reduction == 'sum':return focal_loss.sum()else:return focal_loss# Example usage for a multi-class classification problem
if __name__ == "__main__":num_classes = 10batch_size = 5# Random predictions (logits) and targetsinputs = torch.randn(batch_size, num_classes, requires_grad=True) # Model outputstargets = torch.randint(0, num_classes, (batch_size,)) # True labels# Initialize Focal Loss with gamma=2.0 and no alpha balancingfocal_loss_fn = FocalLoss(gamma=2.0)# Calculate lossloss = focal_loss_fn(inputs, targets)print(f"Focal Loss: {loss.item()}")
解释:
inputs
:形状为(batch_size, num_classes)
,代表模型的预测 logits。targets
:形状为(batch_size,)
,代表真实的类别索引。F.cross_entropy
:计算交叉熵损失,但不应用 reduction,因此我们可以在之后手动计算并应用 Focal Loss。p_t
:计算模型对于正确类别的预测概率。focal_weight
:使用调节因子 ( 1 − p t ) γ (1 - p_t)^\gamma (1−pt)γ 调整容易分类的样本的影响。alpha
:用于应对类别不平衡问题。可以是一个标量或者一个与类别数量相同的张量,给每个类别赋予不同的权重。reduction
:你可以选择如何缩减损失值:mean
(取均值)、sum
(取总和)、none
(返回每个样本的损失值)。
优化思路:
gamma
的调整:gamma
越大,对容易分类样本的抑制越强;gamma
越小,损失函数趋向于标准的交叉熵损失。alpha
的使用:如果有类别不平衡问题,可以根据每个类别的样本比例设置alpha
,使得稀少类别的损失权重更高。
总结:
Focal Loss 在多分类问题中,通过对难分类样本赋予更高的损失权重,来减少容易分类样本对模型训练的干扰,常用于类别不平衡的任务,如目标检测等。