微调BERT-base模型,构建层次化分类器,Top-3准确率达97.2%,并自动识别出问题的关键类别
1. 具体微调的BERT-base模型是什么模型?
BERT-base模型是一个预训练的Transformer模型,包含12个Transformer块、12个自注意头和隐藏大小为768。该模型在大规模文本数据上进行了预训练,能够捕捉文本的上下文信息和语义特征。
2. 如何微调的,微调步骤?
微调BERT-base模型的步骤如下:
-
加载预训练模型和分词器:
from transformers import BertTokenizer, BertForSequenceClassificationmodel_name = 'bert-base-uncased' tokenizer = BertTokenizer.from_pretrained(model_name) model = BertForSequenceClassification.from_pretrained(model_name, num_labels=10)
-
准备训练数据:
from torch.utils.data import Dataset, DataLoaderclass TextClassificationDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_length=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_length,return_token_type_ids=False,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt',)return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'label': torch.tensor(label, dtype=torch.long)}# 示例数据 texts = ["网络连接失败", "无法登录账户", "软件安装失败"] labels = [0, 1, 2] # 0: 网络故障, 1: 账户权限, 2: 软件安装dataset = TextClassificationDataset(texts, labels, tokenizer) dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
-
定义训练参数:
from torch.optim import AdamW from transformers import get_scheduleroptimizer = AdamW(model.parameters(), lr=2e-5) num_epochs = 3 num_training_steps = num_epochs * len(dataloader) lr_scheduler = get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps )device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device)
-
训练模型:
from tqdm.auto import tqdmprogress_bar = tqdm(range(num_training_steps))model.train() for epoch in range(num_epochs):for batch in dataloader:batch = {k: v.to(<