欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 创投人物 > 数据集类不平衡的处理方法

数据集类不平衡的处理方法

2025/2/25 11:52:02 来源:https://blog.csdn.net/qq_41990294/article/details/145123564  浏览:    关键词:数据集类不平衡的处理方法

最近在处理一个类不平均的数据集,这里记录一些注意事项,方便以后查阅。

数据集类不平衡的处理方法

  • 数据集类不平衡的处理方法
    • 1. 数据处理方法
    • 2. 模型改进方法
    • 3. 测试与评估方法
    • 4. 综合策略
    • 5. 示例代码
    • 6. 注意事项
  • 模型评估指标
    • 1. 混淆矩阵(Confusion Matrix)
    • 2. 准确率(Accuracy)
    • 3. 精确率(Precision)
    • 4. 召回率(Recall)
    • 5. F1-Score
    • 6. Kappa值(Cohen's Kappa)
    • 7. 特异度(Specificity)
    • 8. 假正率(FPR)
    • 9. ROC曲线与AUC
    • 10. PR曲线与AUC

数据集类不平衡的处理方法

对于类别不平衡的数据集,模型训练与测试的效果可能受到数据分布的影响,因此需要采取一些方法来缓解类别不平衡问题,从而提高模型的性能和泛化能力。以下是常见的解决策略:

1. 数据处理方法

(1) 过采样(Oversampling)

  • 定义:在训练集中增加少数类别的样本数量。
  • 实现
    • 随机复制少数类别样本(Random Oversampling)。
    • 使用合成数据生成技术(如SMOTE)。
  • 优点:缓解数据不平衡,增加模型对少数类别的学习能力。
  • 注意:过度过采样可能导致过拟合。

(2) 欠采样(Undersampling)

  • 定义:减少多数类别的样本数量。
  • 实现:随机移除多数类别样本。
  • 优点:加速训练并平衡数据。
  • 注意:可能丢失重要信息,导致模型性能下降。

(3) 数据增强

  • 定义:通过旋转、缩放等操作对少数类别样本进行扩增。
  • 适用场景:图像、音频等数据。

2. 模型改进方法

(1) 使用加权损失函数

  • 定义:为不同类别分配不同的损失权重,让模型更关注少数类别。
  • 实现
    • 权重比例通常与类别分布反比。
    • 常用损失函数:加权交叉熵损失(Weighted Cross-Entropy Loss)、Focal Loss。
  • 优点:无需改变数据分布。

(2) 平衡采样器(Balanced Sampler)

  • 定义:在每个mini-batch中,按类别比例采样数据。
  • 实现:调整DataLoader或批量生成策略。

(3) 使用特殊模型架构

  • 定义:采用适合处理不平衡数据的模型。
  • 例如:集成学习(如随机森林、XGBoost等)可以较好地处理类别不平衡问题。

3. 测试与评估方法

(1) 选择合适的评估指标

  • 问题:传统准确率指标在类别不平衡情况下可能误导结果。
  • 替代指标
    • 精确率(Precision)、召回率(Recall)、F1-Score。
    • ROC曲线与AUC。
    • PR曲线与AUC。

(2) 分层采样(Stratified Sampling)

  • 定义:在训练集和测试集中保持类别分布一致。
  • 实现:划分数据集时,按照类别比例分层。

(3) 混淆矩阵分析

  • 定义:观察模型对不同类别的预测表现。
  • 作用:确定哪些类别需要改进。

4. 综合策略

(1) 混合采样

  • 定义:结合过采样与欠采样。
  • 适用场景:同时提高训练效率和模型性能。

(2) 使用非对称阈值

  • 定义:针对少数类别设置较低的决策阈值,增加召回率。
  • 实现:通过调整predict_proba的概率阈值。

(3) 分阶段训练

  • 定义:先用平衡数据训练模型,再用真实分布数据进行微调。
  • 优点:提高模型的实际适用性。

5. 示例代码

(1) 使用加权损失函数(以PyTorch为例)

import torch
import torch.nn as nn# 假设类别0的样本数为1000,类别1的样本数为100
weights = torch.tensor([1/1000, 1/100], dtype=torch.float32)
criterion = nn.CrossEntropyLoss(weight=weights)

(2) 使用SMOTE进行过采样

from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)# 使用SMOTE
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)

6. 注意事项

  1. 测试集保持真实分布:测试数据应反映真实世界的数据分布,不应进行平衡处理。
  2. 防止数据泄漏:避免数据增强或采样方法导致数据泄漏(如同一条样本的不同增强版本出现在训练集和验证集中)。
  3. 动态调整:根据数据特性选择适合的策略,灵活调整。

模型评估指标

1. 混淆矩阵(Confusion Matrix)

定义
混淆矩阵是一种评估分类模型性能的工具,显示模型在测试数据上的分类结果与实际结果的对比情况。它将分类结果划分为四种类型:TP(真正例)、TN(真负例)、FP(假正例)和 FN(假负例)。混淆矩阵通常用于二分类问题,但也可扩展到多分类问题。

二分类混淆矩阵

实际类别 / 预测类别预测为正例预测为负例
实际为正例TPFN
实际为负例FPTN

元素解释

  1. TP(真正例,True Positive)
    • 实际类别为正例,且模型预测为正例。
    • 示例:病人被正确诊断为患病。
  2. TN(真负例,True Negative)
    • 实际类别为负例,且模型预测为负例。
    • 示例:健康人被正确诊断为健康。
  3. FP(假正例,False Positive)
    • 实际类别为负例,但模型预测为正例(“误报”)。
    • 示例:健康人被错误诊断为患病。
  4. FN(假负例,False Negative)
    • 实际类别为正例,但模型预测为负例(“漏报”)。
    • 示例:病人被错误诊断为健康。

扩展到多分类问题

对于多分类问题,混淆矩阵是一个 n × n 的方阵,n 是类别数。矩阵中的第 i i i 行表示实际类别为 i i i 的样本,第 j j j 列表示模型预测类别为 j j j 的样本数量。对角线上的值表示正确分类的样本数,非对角线表示错误分类的情况。

2. 准确率(Accuracy)

定义

  • 表示模型预测正确的样本占总样本的比例。

公式
Accuracy = TP + TN TP + TN + FP + FN \text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}} Accuracy=TP+TN+FP+FNTP+TN

应用场景

  • 类别平衡的数据集,如正常分类任务。

适用性

  • 不适用于类不平衡问题,容易被多数类样本主导,无法反映模型对少数类的性能。

优缺点

  • 优点:直观、简单,适合初步评估。
  • 缺点:类别不平衡时误导性强。

物理意义

  • 反映了模型总体上的正确预测比例。

实现代码

from sklearn.metrics import accuracy_score# y_true: 实际标签, y_pred: 预测标签
accuracy = accuracy_score(y_true, y_pred)

3. 精确率(Precision)

定义

  • 模型预测为正例的样本中,实际为正例的比例。

公式
Precision = TP TP + FP \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} Precision=TP+FPTP

应用场景

  • 误报代价高的场景,如垃圾邮件过滤、金融欺诈检测。

适用性

  • 适用于类不平衡问题,更关注正例的预测质量。

优缺点

  • 优点:有效避免过多误报。
  • 缺点:可能忽略对召回率的考虑。

物理意义

  • 衡量模型对正例预测的可信度。

实现代码

from sklearn.metrics import precision_scoreprecision = precision_score(y_true, y_pred)

4. 召回率(Recall)

定义

  • 实际为正例的样本中,模型正确预测为正例的比例。

公式
Recall = TP TP + FN \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} Recall=TP+FNTP

应用场景

  • 漏报代价高的场景,如医疗诊断、安防监控。

适用性

  • 适用于类不平衡问题,更关注正例的覆盖能力。

优缺点

  • 优点:减少漏报。
  • 缺点:可能导致更多误报。

物理意义

  • 衡量模型对正例的捕获能力。

实现代码

from sklearn.metrics import recall_scorerecall = recall_score(y_true, y_pred)

5. F1-Score

定义

  • 精确率和召回率的调和平均值,综合考虑两者的性能。

公式
F1-Score = 2 × Precision × Recall Precision + Recall \text{F1-Score} = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}} F1-Score=2×Precision+RecallPrecision×Recall

应用场景

  • 需要平衡精确率与召回率的场景。

适用性

  • 适用于类不平衡问题,可作为主要指标。

优缺点

  • 优点:综合性强。
  • 缺点:对高精确率或高召回率的情况敏感。

物理意义

  • 衡量模型整体性能的均衡性。

实现代码

from sklearn.metrics import f1_scoref1 = f1_score(y_true, y_pred)

6. Kappa值(Cohen’s Kappa)

定义

  • 评估模型分类与随机分类之间的一致性。

公式
κ = p o − p e 1 − p e \kappa = \frac{p_o - p_e}{1 - p_e} κ=1pepope

  • p o p_o po:观察到的准确率。
  • p e p_e pe:随机一致的概率。

应用场景

  • 强调分类一致性的场景。

适用性

  • 适用于类不平衡问题

优缺点

  • 优点:剔除随机预测的影响。
  • 缺点:对概率分布的敏感性较高。

物理意义

  • 衡量模型分类的一致性。

实现代码

from sklearn.metrics import cohen_kappa_scorekappa = cohen_kappa_score(y_true, y_pred)

7. 特异度(Specificity)

定义

  • 实际为负例的样本中,模型正确预测为负例的比例。

公式
Specificity = TN TN + FP \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} Specificity=TN+FPTN

应用场景

  • 关注负例分类准确性的场景。

适用性

  • 适用于类不平衡问题

优缺点

  • 优点:关注负例性能。
  • 缺点:容易被少数类忽略。

物理意义

  • 衡量负例预测的能力。

实现代码

from sklearn.metrics import confusion_matrixtn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn + fp)

8. 假正率(FPR)

定义

  • 实际为负例的样本中,模型错误预测为正例的比例。

公式
FPR = FP FP + TN \text{FPR} = \frac{\text{FP}}{\text{FP} + \text{TN}} FPR=FP+TNFP

应用场景

  • 分析误报对结果的影响。

适用性

  • 可辅助其他指标分析。

优缺点

  • 优点:有效评估误报。
  • 缺点:无法独立衡量模型性能。

9. ROC曲线与AUC

定义

  • ROC曲线:以假正率(FPR)为横轴,真正率(TPR)为纵轴,反映分类性能。
  • AUC:ROC曲线下面积。

应用场景

  • 评估分类器整体性能。

适用性

  • 类不平衡时稳定

优缺点

  • 优点:评估全面。
  • 缺点:受类别分布影响。

实现代码

from sklearn.metrics import roc_auc_scoreauc = roc_auc_score(y_true, y_pred_proba)

10. PR曲线与AUC

定义

  • PR曲线:以召回率为横轴,精确率为纵轴。
  • PR-AUC:PR曲线下面积。

应用场景

  • 类不平衡问题,如少数类检测。

优缺点

  • 优点:对正例的性能更敏感。
  • 缺点:忽略负例表现。

实现代码

from sklearn.metrics import precision_recall_curve, aucprecision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
pr_auc = auc(recall, precision)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词