欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 能源 > BERT 多分类实战:从训练到评估的完整指南

BERT 多分类实战:从训练到评估的完整指南

2025/2/23 17:05:02 来源:https://blog.csdn.net/sjxgghg/article/details/142658366  浏览:    关键词:BERT 多分类实战:从训练到评估的完整指南

BERT 多分类实战

项目介绍

在自然语言处理(NLP)领域,BERT 已成为强大的预训练模型,适用于各类文本分类任务。在这篇技术文章中,我们将展示如何使用 Hugging Face 的 transformers.Trainer,构建并训练一个基于 BERT 的多分类模型,并对其进行评估。

准备工作

在开始动手前,我们需要准备以下内容:

  1. 必要的库transformers(Hugging Face提供的BERT库)、torch(PyTorch框架)、sklearn(常用的机器学习工具库)

  2. 数据集:为了演示,我们使用一个包含新闻文本的多分类数据集。每条新闻需要被分类为['World', 'Sports', 'Business', 'Sci/Tech']等类别。

dataset = load_dataset("fancyzhx/ag_news")

数据预处理

该数据集为英文数据集,故使用bert-base-uncased 模型。

model_name = "bert-base-uncased"tokenizer = AutoTokenizer.from_pretrained(model_name)def tokenize_func(tokenizer):def _func(item):return tokenizer(item["text"],max_length=512,truncation=True,)return _funcnew_dataset = dataset.map(tokenize_func(tokenizer=tokenizer))

训练

构建基于BERT的多分类模型。Hugging Face的transformers库提供了预训练的BERT模型,并且我们可以很方便地在这个基础上微调。

bert = AutoModelForSequenceClassification.from_pretrained(model_name,trust_remote_code=True,num_labels=4,
)

在这里,我们使用了BertForSequenceClassification,它是一个预训练的BERT模型,同时包含一个序列分类头,用于文本分类任务。

args = TrainingArguments(output_dir=output_dir,evaluation_strategy="epoch",save_strategy="epoch",save_total_limit=3,learning_rate=2e-5,num_train_epochs=epoch,weight_decay=0.01,per_device_train_batch_size=32,per_device_eval_batch_size=16,# logging_steps=16,save_safetensors=True,overwrite_output_dir=True,load_best_model_at_end=True,
)trainer = Trainer(model=self.model,args=args,train_dataset=self.train_dataset,eval_dataset=self.eval_dataset,data_collator=self.data_collator,compute_metrics=self.compute_metrics,tokenizer=tokenizer,
)

使用了Trainer来简化训练流程,定义了训练的相关参数,比如训练轮次、batch size等。

评估

在测试集上的评估如下所示:

from sklearn.metrics import classification_report, confusion_matrix

使用 classification_report 评估模型的预测情况。

在这里插入图片描述

[注]: support 代表数据样本的数量。

使用 confusion_matrix 计算混淆矩阵。最后在测试集上,预测的混淆矩阵如下所示,利用 confusion_matrix计算混淆矩阵。

在这里插入图片描述

在多分类问题中,模型评估是非常重要的环节,混淆矩阵(Confusion Matrix)作为一种直观的评估工具,被广泛应用于分类问题。

混淆矩阵是一个 n x n 的矩阵,其中 n 是分类任务中的类别数量。它展示了模型在每个分类上的真实标签预测标签的分布情况。矩阵的每一行代表实际的类,列代表模型预测的类。

对角线元素表示模型正确分类的数量,它们表示真实类别和预测类别完全一致的样本数。

因此它直观展示出了每个类别的分类情况,帮助我们识别模型在不同类别上的表现差异。

开源

项目代码:https://github.com/JieShenAI/wechat/blob/main/24/09/多分类实战/bert_多分类实战.ipynb

总结

在这篇文章中,展示了如何基于BERT进行多分类任务的完整流程。从数据预处理到模型训练再到评估。

接下来,你可以尝试用其他数据集来训练你的模型,甚至调整BERT的预训练模型或使用不同的超参数进行微调。无论是文本分类、情感分析,还是其他NLP任务。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词