欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 名人名企 > 基于BERT的序列到序列(Seq2Seq)模型,生成文本摘要或标题

基于BERT的序列到序列(Seq2Seq)模型,生成文本摘要或标题

2025/3/30 4:44:37 来源:https://blog.csdn.net/2301_76444133/article/details/146484364  浏览:    关键词:基于BERT的序列到序列(Seq2Seq)模型,生成文本摘要或标题

  1. 数据预处理

    • 使用DataGenerator类加载并预处理数据,处理变长序列的padding。
    • 输入为内容(content),目标为标题(title)。
  2. 模型构建

    • 基于BERT构建Seq2Seq模型,使用交叉熵损失。
    • 采用Beam Search进行生成,支持Top-K采样。
  3. 训练与评估

    • 使用Adam优化器进行训练。
    • 每个epoch结束时通过Evaluate回调生成示例标题以观察效果。
import numpy as np
import pandas as pd
from tqdm import tqdm
from bert4keras.bert import build_bert_model
from bert4keras.tokenizer import Tokenizer, load_vocab
from keras.layers import *
from keras.models import Model
from keras import backend as K
from bert4keras.snippets import parallel_apply
from keras.optimizers import Adam
import keras
import math
from sklearn.model_selection import train_test_split
from rouge import Rouge  # 需要安装rouge包# 配置参数
config_path = 'bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'bert/chinese_L-12_H-768_A-12/vocab.txt'max_input_len = 256
max_output_len = 32
batch_size = 16
epochs = 10
beam_size = 3
learning_rate = 2e-5
val_split = 0.1# 数据预处理增强
class DataGenerator(keras.utils.Sequence):def __init__(self, data, batch_size=8, mode='train'):self.batch_size = batch_sizeself.mode = modeself.data = dataself.indices = np.arange(len(data))def __len__(self):return math.ceil(len(self.data) / self.batch_size)def __getitem__(self, index):batch_indices = self.indices[index*self.batch_size : (index+1)*self.batch_size]batch = self.data.iloc[batch_indices]return self._process_batch(batch)def on_epoch_end(self):if self.mode == 'train':np.random.shuffle(self.indices)def _process_batch(self, batch):batch_x, batch_y = [], []for _, row in batch.iterrows():content = row['content'][:max_input_len]title = row['title'][:max_output_len-2]  # 保留空间给[CLS]和[SEP]# 编码器输入x, _ = tokenizer.encode(content, max_length=max_input_len)# 解码器输入输出y, _ = tokenizer.encode(title, max_length=max_output_len)decoder_input = [tokenizer._token_start_id] + y[:-1]decoder_output = ybatch_x.append(x)batch_y.append({'decoder_input': decoder_input, 'decoder_output': decoder_output})# 动态paddingpadded_x = self._pad_sequences([x for x in batch_x], maxlen=max_input_len)padded_decoder_input = self._pad_sequences([y['decoder_input'] for y in batch_y], maxlen=max_output_len,padding='post')padded_decoder_output = self._pad_sequences([y['decoder_output'] for y in batch_y],maxlen=max_output_len,padding='post')return [padded_x, padded_decoder_input], padded_decoder_outputdef _pad_sequences(self, sequences, maxlen, padding='pre'):padded = np.zeros((len(sequences), maxlen))for i, seq in enumerate(sequences):if len(seq) > maxlen:seq = seq[:maxlen]if padding == 'pre':padded[i, -len(seq):] = seqelse:padded[i, :len(seq)] = seqreturn padded# 改进的模型架构
def build_seq2seq_model():# 编码器encoder_inputs = Input(shape=(None,), name='Encoder-Input')encoder = build_bert_model(config_path=config_path,checkpoint_path=checkpoint_path,model='encoder',return_keras_model=False,)encoder_outputs = encoder(encoder_inputs)# 解码器decoder_inputs = Input(shape=(None,), name='Decoder-Input')decoder = build_bert_model(config_path=config_path,checkpoint_path=checkpoint_path,model='decoder',application='lm',return_keras_model=False,)decoder_outputs = decoder([decoder_inputs, encoder_outputs])# 连接模型model = Model([encoder_inputs, decoder_inputs], decoder_outputs)# 自定义损失函数(忽略padding)def seq2seq_loss(y_true, y_pred):y_mask = K.cast(K.not_equal(y_true, 0), K.floatx())loss = K.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)return K.sum(loss * y_mask) / K.sum(y_mask)model.compile(Adam(learning_rate), loss=seq2seq_loss)return model# 改进的Beam Search
def beam_search(model, input_seq, beam_size=3):encoder_input = tokenizer.encode(input_seq)[0]encoder_output = model.get_layer('bert').predict(np.array([encoder_input]))sequences = [[[tokenizer._token_start_id], 0.0]]for _ in range(max_output_len):all_candidates = []for seq, score in sequences:if seq[-1] == tokenizer._token_end_id:all_candidates.append((seq, score))continuedecoder_input = np.array([seq])decoder_output = model.get_layer('bert_1').predict([decoder_input, encoder_output])[:, -1, :]top_k = np.argsort(decoder_output[0])[-beam_size:]for token in top_k:new_seq = seq + [token]new_score = score + np.log(decoder_output[0][token])all_candidates.append((new_seq, new_score))# 长度归一化ordered = sorted(all_candidates, key=lambda x: x[1]/(len(x[0])+1e-8), reverse=True)sequences = ordered[:beam_size]best_seq = sequences[0][0]return tokenizer.decode(best_seq[1:-1])  # 去除[CLS]和[SEP]# 增强的评估回调
class AdvancedEvaluate(keras.callbacks.Callback):def __init__(self, val_data, sample_size=5):self.val_data = val_dataself.rouge = Rouge()self.samples = val_data.sample(sample_size)def on_epoch_end(self, epoch, logs=None):# 生成示例print("\n生成示例:")for _, row in self.samples.iterrows():generated = beam_search(self.model, row['content'], beam_size)print(f"真实标题: {row['title']}")print(f"生成标题: {generated}\n")# 计算ROUGE分数references = []hypotheses = []for _, row in self.val_data.iterrows():generated = beam_search(self.model, row['content'], beam_size=1)references.append(row['title'])hypotheses.append(generated)scores = self.rouge.get_scores(hypotheses, references, avg=True)print(f"验证集ROUGE-L: {scores['rouge-l']['f']:.4f}")# 主流程
if __name__ == "__main__":# 加载数据full_data = pd.read_csv('train.tsv', sep='\t', names=['title', 'content'])train_data, val_data = train_test_split(full_data, test_size=val_split)# 初始化tokenizertokenizer = Tokenizer(dict_path, do_lower_case=True)# 构建模型model = build_seq2seq_model()model.summary()# 数据生成器train_gen = DataGenerator(train_data, batch_size, mode='train')val_gen = DataGenerator(val_data, batch_size, mode='val')# 训练配置callbacks = [AdvancedEvaluate(val_data),keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2, verbose=1),keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)]# 开始训练model.fit(train_gen,validation_data=val_gen,epochs=epochs,callbacks=callbacks,workers=4,use_multiprocessing=True)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词