目录
引入bert,判断文本中特定字符出现
1.设计模型
2.前馈运算
3.建立词表
4.生成样本
5.建立数据集
6.建立模型
7.测试模型结果
8.模型训练
9.用训练好的模型预测
10.完整代码
我欲挑灯见你,可是梦怕火
我泪眼婆娑,坐实你来过
—— 25.1.21
引入bert,判断文本中特定字符出现
在实践 ② 中,我们使用了神经网络模型做这个任务
模型结构:嵌入层 ——> 循环神经网络层 ——> 线性层 ——> 交叉熵损失函数
【NLP 13、实践 ② 判断文本中是否有特定字符出现】_使用大模型判断一个词中是否包含另一个词-CSDN博客
1.设计模型
BertModel.from_pretrained():用于从预训练模型目录或 Hugging Face 模型库加载 BERT 模型的权重及配置。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
pretrained_model_name_or_path | 字符串 | 是 | 模型名称(如 bert-base-uncased )或本地路径。 |
config | BertConfig 对象 | 否 | 自定义配置类,用于覆盖默认配置。 |
state_dict | 字典 | 否 | 预训练权重字典,用于部分加载模型。 |
cache_dir | 字符串 | 否 | 缓存目录,用于存储下载的模型文件。 |
from_tf | 布尔值 | 否 | 是否从 TensorFlow 模型加载权重,默认为 False 。 |
ignore_mismatched_sizes | 布尔值 | 否 | 是否忽略权重大小不匹配的错误,默认为 False 。 |
local_files_only | 布尔值 | 否 | 是否仅从本地文件加载模型,默认为 False 。 |
返回值:
last_hidden_state
: 形状为(batch_size, sequence_length, hidden_size)
的张量,表示输入序列经过 BERT 编码器后的最后一个隐藏层的输出。pooler_output
: 形状为(batch_size, hidden_size)
的张量,表示经过池化层(通常是[CLS]
标记的输出)处理后的结果,通常用于分类任务。hidden_states
(可选): 如果设置了output_hidden_states=True
,则返回一个包含所有层隐藏状态的元组,形状为(batch_size, sequence_length, hidden_size)
。attentions
(可选): 如果设置了output_attentions=True
,则返回一个包含所有注意力权重的元组。
return_dict参数:
- 当
return_dict
设置为True
时,forward()
方法返回一个BaseModelOutput
对象,该对象包含了模型的各种输出,如最后一层的隐藏状态、[CLS] 标记的输出等。 - 当
return_dict
设置为False
时,forward()
方法返回一个元组,包含与BaseModelOutput
对象相同的元素,但不包含对象结构。
nn.Linear():PyTorch 中用于创建全连接层(线性层)的类。它将输入张量转换为输出张量,通过矩阵乘法和偏置加法实现。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
in_features | int | 是 | 输入特征的数量 |
out_features | int | 是 | 输出特征的数量 |
bias | bool | 否 | 是否使用偏置,默认为 True |
torch.sigmoid:PyTorch 中的一个激活函数,用于将输入张量的每个元素映射到区间 (0, 1) 之间。它在神经网络中常用于二分类问题的输出层,将模型的输出转换为概率值。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
input | Tensor | 是 | 输入的张量,可以是任意形状和类型。 |
out | Tensor, 可选 | 否 | 指定输出结果的张量。如果提供,结果将存储在此张量中。 |
nn.Dropout():PyTorch 中用于防止神经网络过拟合的一种正则化技术。它通过在训练过程中随机丢弃一部分神经元的输出,来减少神经元之间的依赖性,从而提高模型的泛化能力。
训练过程中随机地将输入张量的一部分元素置零,而在评估(或测试)过程中则不会进行任何操作。这样做可以防止模型过拟合,提高模型的泛化能力
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
p | float | 是 | 被舍弃的概率,默认为 0.5 |
inplace | bool | 否 | 是否在原地操作,默认为 False |
nn.functional.cross_entropy:PyTorch 中用于计算交叉熵损失的函数。交叉熵损失常用于多分类任务,衡量模型预测的概率分布与真实标签之间的差异。它是衡量分类模型性能的重要指标之一。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
input | Tensor | 是 | 输入张量,通常是模型的输出 logits |
target | Tensor | 是 | 目标标签,整数形式,表示每个样本的类别 |
weight | Tensor | 否 | 样本权重,默认为 None |
size_average | bool | 否 | 是否对损失进行平均,默认为 None |
ignore_index | int | 否 | 忽略的标签索引,默认为 -100 |
reduce | bool | 否 | 是否对损失进行缩减,默认为 None |
reduction | str | 否 | 损失缩减方式,默认为 'mean' |
class TorchModel(nn.Module):def __init__(self, input_dim, sentence_length, vocab):super(TorchModel, self).__init__()# 原始代码# self.embedding = nn.Embedding(len(vocab) + 1, input_dim)# self.layer = nn.Linear(input_dim, input_dim)# self.pool = nn.MaxPool1d(sentence_length)self.bert = BertModel.from_pretrained(r"F:\人工智能NLP\NLP资料\week6 语言模型\bert-base-chinese", return_dict=False)self.classify = nn.Linear(input_dim, 3)self.activation = torch.sigmoid #sigmoid做激活函数self.dropout = nn.Dropout(0.5)self.loss = nn.functional.cross_entropy
2.前馈运算
#当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):# 原始代码# x = self.embedding(x) #input shape:(batch_size, sen_len) (10,6)# x = self.layer(x) #input shape:(batch_size, sen_len, input_dim) (10,6,20)# x = self.dropout(x) #input shape:(batch_size, sen_len, input_dim)# x = self.activation(x) #input shape:(batch_size, sen_len, input_dim)# x = self.pool(x.transpose(1,2)).squeeze() #input shape:(batch_size, sen_len, input_dim)sequence_output, pooler_output = self.bert(x)x = self.classify(pooler_output)y_pred = self.activation(x)if y is not None:return self.loss(y_pred, y.squeeze())else:return y_pred
3.建立词表
enumerate():Python 的内置函数,用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
iterable | object | 是 | 一个可迭代对象,如列表、元组或字符串 |
start | int | 否 | 下标起始位置,默认为 0 |
len(): Python 的内置函数,用于返回对象(字符、列表、元组等)的长度或项目数量。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
obj | object | 是 | 要计算长度的对象,如字符串、列表、元组等 |
#字符集随便挑了一些汉字,实际上还可以扩充
#为每个字生成一个标号
#{"a":1, "b":2, "c":3...}
#abc -> [1,2,3]
def build_vocab():chars = "abcdefghijklmnopqrstuvwxyz" #字符集vocab = {}for index, char in enumerate(chars):vocab[char] = index + 1 #每个字对应一个序号vocab['unk'] = len(vocab)+1return vocab
4.生成样本
list():将可迭代对象转换为列表,可以将字符串、元组、集合等可迭代对象转换为列表。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
iterable | 可迭代对象 | 是 | 一个可迭代对象,如字符串、元组或列表 |
vocab,keys():返回字典中所有的键。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
vocab | 字典对象 | 是 | 要获取键的字典对象 |
random.choice():从非空序列中随机选择一个元素。适用于从列表、元组、字符串等序列中随机选取一个元素。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
sequence | 序列类型 | 是 | 一个非空序列,如列表、元组或字符串 |
range():生成一个包含指定范围内整数的序列。常用于循环中控制循环次数,可以指定起始值、结束值和步长。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
start | 整数 | 否 | 序列的起始值,默认为 0 |
stop | 整数 | 是 | 序列的结束值,但不包括这个值 |
step | 整数 | 否 | 序列中相邻两个数字之间的差值,默认为 1 |
vocab.get():从字典中获取指定键的值,如果键不存在,返回默认值。提供了一种安全的方式来访问字典中的值,避免因键不存在而引发异常。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
key | 键 | 是 | 要检索的键 |
default | 默认值 | 否 | 如果键不存在时返回的默认值,默认为 None |
set():创建一个无序且不重复的元素集合,可以将列表、元组、字符串等可迭代对象转换为集合,集合中的元素唯一且无序。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
iterable | 可迭代对象 | 否 | 一个可迭代对象,如列表、元组或字符串,默认为空集合 |
#随机生成一个样本
#从所有字中选取sentence_length个字
#反之为负样本
def build_sample(vocab, sentence_length):#随机从字表选取sentence_length个字,可能重复x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]#A类样本if set("abc") & set(x) and not set("xyz") & set(x):y = 0#B类样本elif not set("abc") & set(x) and set("xyz") & set(x):y = 1#C类样本else:y = 2x = [vocab.get(word, vocab['unk']) for word in x] #将字转换成序号,为了做embeddingreturn x, y
5.建立数据集
range(): Python 的内置函数,用于生成一个不可变的数字序列。通常用于 for
循环中控制循环次数。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
start | 整数 | 否 | 序列的起始值,默认为 0 。 |
stop | 整数 | 是 | 序列的结束值(不包含该值)。 |
step | 整数 | 否 | 序列中相邻两个数之间的差值,默认为 1 。 |
append():列表(list
)对象的方法,用于在列表末尾添加一个元素。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
object | 任意类型 | 是 | 要添加到列表末尾的对象。 |
torch.LongTensor(): PyTorch 库中的函数,用于创建一个包含长整型(64位整数)数据的新张量。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
data | 可迭代对象或标量 | 否 | 用于初始化张量的数据。可以是列表、元组、NumPy 数组或其他张量。 |
dtype | torch.dtype | 否 | 指定张量的数据类型,默认为 torch.long 。 |
device | str 或 torch.device | 否 | 指定张量所在的设备(如 'cpu' 或 'cuda:0' ),默认为当前默认设备。 |
requires_grad | bool | 否 | 是否需要计算梯度,默认为 False 。 |
size | torch.Size 或 int... | 否 | 指定张量的形状。如果提供 data ,则此参数会被忽略。 |
其他参数 | - | 否 | 其他与张量创建相关的参数,如 layout 、pin_memory 等。 |
#建立数据集
#输入需要的样本数量。需要多少生成多少
def build_dataset(sample_length, vocab, sentence_length):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(vocab, sentence_length)dataset_x.append(x)dataset_y.append([y])return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)
6.建立模型
#建立模型
def build_model(vocab, char_dim, sentence_length):model = TorchModel(char_dim, sentence_length, vocab)return model
7.测试模型结果
model.eval():将模型设置为评估模式,禁用 Dropout 和 BatchNorm 的训练行为,确保模型在推理时的稳定性。
squeeze():从数组或张量中删除所有大小为1的维度,简化数据结构。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
input | Tensor | 是 | 输入的张量 |
dim | int 或 None | 否 | 指定要移除的维度,如果不指定,则移除所有大小为1的维度 |
torch.no_grad():用于临时禁用梯度计算,通常用于推理阶段
zip():将多个可迭代对象中的元素一一配对,返回元组列表,便于同时处理多个序列。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
*iterables | 可迭代对象 | 是 | 将多个可迭代对象中的元素一一配对,返回元组列表 |
int():将输入的数值或字符串转换为整数类型,舍去小数部分。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
x | 数值或字符串 | 是 | 将输入转换为整数类型 |
#测试代码
#用来测试每轮模型的准确率
def evaluate(model, vocab, sample_length):model.eval()total = 200 #测试样本数量x, y = build_dataset(total, vocab, sample_length) #建立200个用于测试的样本y = y.squeeze()print("A类样本数量:%d, B类样本数量:%d, C类样本数量:%d"%(y.tolist().count(0), y.tolist().count(1), y.tolist().count(2)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x) #模型预测for y_p, y_t in zip(y_pred, y): #与真实标签进行对比if int(torch.argmax(y_p)) == int(y_t):correct += 1 #正样本判断正确else:wrong += 1print("正确预测个数:%d / %d, 正确率:%f"%(correct, total, correct/(correct+wrong)))return correct/(correct+wrong)
8.模型训练
range():生成一个包含指定范围内整数的序列。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
start | 整数 | 否 | 序列的起始值,默认为 0 。 |
stop | 整数 | 是 | 序列的结束值(不包含该值)。 |
step | 整数 | 否 | 序列中相邻两个数之间的差值,默认为 1 |
optim.zero_grad():将模型参数的梯度清零,通常在每个训练步骤开始时调用。
backward():计算张量的梯度,通常在反向传播过程中使用。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
gradient | Tensor 或 None | 否 | 指定梯度的权重,默认为 None 。 |
retain_graph | bool | 否 | 是否保留计算图,默认为 False 。 |
create_graph | bool | 否 | 是否创建计算图以计算高阶导数,默认为 False 。 |
optim.step():根据计算出的梯度更新模型参数。
append():在列表末尾添加一个元素。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
object | 任意类型 | 是 | 要添加到列表末尾的对象。 |
item():将张量转换为 Python 标量。
np.mean():计算数组或矩阵的算术平均值。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
a | array_like | 是 | 需要计算均值的数组。 |
axis | None 或 int 或 tuple of ints | 否 | 指定计算均值方向的轴,默认为 None 。 |
dtype | data-type | 否 | 输出均值的类型,默认为 float64 。 |
out | ndarray | 否 | 存放结果的备选输出数组,默认为 None 。 |
keepdims | bool | 否 | 是否保留减少的轴为尺寸为一的维度,默认为 False 。 |
torch.save():将对象保存到磁盘文件。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
obj | object | 是 | 要保存的对象。 |
f | Union[str, PathLike, BinaryIO, IO[bytes]] | 是 | 类似文件的对象或包含文件名的字符串。 |
pickle_module | Any | 否 | 用于 pickling 元数据和对象的模块,默认为 pickle 。 |
pickle_protocol | int | 否 | 可以指定 pickling 协议的版本,默认为 DEFAULT_PROTOCOL 。 |
_use_new_zipfile_serialization | bool | 否 | 是否使用新的 zipfile 序列化,默认为 True 。 |
state_dict():返回模型的状态字典,包含模型的参数和缓冲区。
open():打开一个文件并返回文件对象。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
file | Union[str, PathLike] | 是 | 文件名或文件对象。 |
mode | str | 否 | 文件打开模式,默认为 'r' 。 |
buffering | int | 否 | 缓冲策略,默认为 -1 。 |
encoding | str | 否 | 文件编码,默认为 None 。 |
errors | str | 否 | 指定如何处理编码错误,默认为 None 。 |
newline | str | 否 | 控制换行符的行为,默认为 None 。 |
closefd | bool | 否 | 如果为 True ,文件描述符将在关闭文件时关闭,默认为 True 。 |
opener | callable | 否 | 自定义打开器,默认为 None 。 |
write():将字符串写入文件。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
text | str | 是 | 要写入文件的字符串。 |
json.dumps():将 Python 对象序列化为 JSON 格式的字符串。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
obj | object | 是 | 要序列化的 Python 对象。 |
skipkeys | bool | 否 | 如果为 True ,则跳过不可序列化的键,默认为 False 。 |
ensure_ascii | bool | 否 | 如果为 True ,则所有非 ASCII 字符都将转义,默认为 True 。 |
check_circular | bool | 否 | 如果为 True ,则检查循环引用,默认为 True 。 |
allow_nan | bool | 否 | 如果为 True ,则允许 NaN 、Infinity 和 -Infinity ,默认为 False 。 |
cls | JSONEncoder | 否 | 自定义的 JSON 编码器,默认为 json.JSONEncoder 。 |
indent | int 或 str | 否 | 指定缩进,默认为 None 。 |
separators | tuple | 否 | 指定分隔符,默认为 (', ', ': ') 。 |
default | callable | 否 | 自定义函数,用于处理不可序列化的对象,默认为 None 。 |
sort_keys | bool | 否 | 如果为 True ,则按键排序,默认为 False 。 |
close():关闭文件对象。
def main():epoch_num = 15 #训练轮数batch_size = 20 #每次训练样本个数train_sample = 1000 #每轮训练总共训练的样本总数char_dim = 768 #每个字的维度sentence_length = 6 #样本文本长度vocab = build_vocab() #建立字表model = build_model(vocab, char_dim, sentence_length) #建立模型optim = torch.optim.Adam(model.parameters(), lr=1e-5) #建立优化器log = []for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, vocab, sentence_length) #构建一组训练样本optim.zero_grad() #梯度归零loss = model(x, y) #计算lossloss.backward() #计算梯度optim.step() #更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model, vocab, sentence_length) #测试本轮模型结果log.append([acc, np.mean(watch_loss)])# plt.plot(range(len(log)), [l[0] for l in log]) #画acc曲线# plt.plot(range(len(log)), [l[1] for l in log]) #画loss曲线# plt.show()#保存模型torch.save(model.state_dict(), "model.pth")# 保存词表writer = open("vocab.json", "w", encoding="utf8")writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return
9.用训练好的模型预测
json.load():将 JSON 格式的文件内容解析为 Python 对象(如字典、列表等)。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
fp | io.TextIOBase 或 str | 是 | 文件对象或文件路径。 |
cls | JSONDecoder 子类 | 否 | 自定义的 JSON 解码器,默认为 json.JSONDecoder 。 |
object_hook | callable | 否 | 用于自定义反序列化特定类型的对象。 |
parse_float | callable | 否 | 用于解析浮点数的函数,默认为 float 。 |
parse_int | callable | 否 | 用于解析整数的函数,默认为 int 。 |
parse_constant | callable | 否 | 用于解析 JSON 常量(如 null 、true 、false )。 |
object_pairs_hook | callable | 否 | 用于自定义反序列化键值对的对象。 |
buffers | list | 否 | 用于增量解析的缓冲区列表。 |
open():打开一个文件,并返回一个文件对象,以便进行读写操作。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
file | str 或 PathLike | 是 | 文件名或文件路径。 |
mode | str | 否 | 文件打开模式,默认为 'r' (读取)。 |
buffering | int | 否 | 缓冲策略,默认为 -1 (系统默认)。 |
encoding | str | 否 | 文件编码,默认为 None (使用系统默认)。 |
errors | str | 否 | 指定如何处理编码错误,默认为 None 。 |
newline | str | 否 | 控制换行符的行为,默认为 None 。 |
closefd | bool | 否 | 如果为 True ,文件描述符将在关闭文件时关闭,默认为 True 。 |
opener | callable | 否 | 自定义的文件打开器,默认为 None 。 |
torch.load():从磁盘加载一个序列化的对象(如张量、模型等),通常用于恢复训练好的模型或张量。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
f | str 或 PathLike 或 file-like object | 是 | 文件路径或文件对象。 |
map_location | str 或 Callable 或 dict | 否 | 指定张量的存储位置,用于在不同设备间加载模型。 |
pickle_module | module | 否 | 用于反序列化的 pickle 模块,默认为 pickle 。 |
**pickle_load_args | 任意关键字参数 | 否 | 传递给 pickle.load 的额外参数。 |
append():将一个元素添加到列表的末尾。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
object | 任意类型 | 是 | 要添加到列表末尾的对象。 |
model.eval():将模型设置为评估模式,禁用诸如 Dropout 和 BatchNorm 等层的训练特定行为,确保模型在推理时的稳定性。
torch.no_grad():临时禁用梯度计算,用于推理阶段以减少内存消耗和提高性能。
enumerate():将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
iterable | 可迭代对象 | 是 | 一个可迭代对象,如字符串、元组或列表。 |
start | int | 否 | 下标起始位置,默认为 0 。 |
torch.argmax():返回输入张量沿指定维度的最大值的索引。
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
input | Tensor | 是 | 输入的张量。 |
dim | int | 否 | 指定沿哪个维度查找最大值,默认为最后一个维度。 |
keepdim | bool | 否 | 是否保持输出张量的维度,默认为 False 。 |
out | Tensor | 否 | 存放结果的备选输出张量,默认为 None 。 |
#最终预测
def predict(model_path, vocab_path, input_strings):char_dim = 20 # 每个字的维度sentence_length = 6 # 样本文本长度vocab = json.load(open(vocab_path, "r", encoding="utf8"))model = build_model(vocab, char_dim, sentence_length) #建立模型model.load_state_dict(torch.load(model_path)) #加载训练好的权重x = []for input_string in input_strings:x.append([vocab[char] for char in input_string]) #将输入序列化model.eval() #测试模式,不使用dropoutwith torch.no_grad(): #不计算梯度result = model.forward(torch.LongTensor(x)) #模型预测for i, input_string in enumerate(input_strings):print(int(torch.argmax(result[i])), input_string, result[i]) #打印结果
10.完整代码
#coding:utf8import torch
import torch.nn as nn
import numpy as np
import random
import json
from transformers import BertModel"""基于pytorch的网络编写
实现一个网络完成一个简单nlp任务
判断文本中是否有某些特定字符出现week2的例子,修改引入bert
"""class TorchModel(nn.Module):def __init__(self, input_dim, sentence_length, vocab):super(TorchModel, self).__init__()# 原始代码# self.embedding = nn.Embedding(len(vocab) + 1, input_dim)# self.layer = nn.Linear(input_dim, input_dim)# self.pool = nn.MaxPool1d(sentence_length)self.bert = BertModel.from_pretrained(r"F:\人工智能NLP\NLP资料\week6 语言模型\bert-base-chinese", return_dict=False)self.classify = nn.Linear(input_dim, 3)self.activation = torch.sigmoid #sigmoid做激活函数self.dropout = nn.Dropout(0.5)self.loss = nn.functional.cross_entropy#当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):# 原始代码# x = self.embedding(x) #input shape:(batch_size, sen_len) (10,6)# x = self.layer(x) #input shape:(batch_size, sen_len, input_dim) (10,6,20)# x = self.dropout(x) #input shape:(batch_size, sen_len, input_dim)# x = self.activation(x) #input shape:(batch_size, sen_len, input_dim)# x = self.pool(x.transpose(1,2)).squeeze() #input shape:(batch_size, sen_len, input_dim)sequence_output, pooler_output = self.bert(x)x = self.classify(pooler_output)y_pred = self.activation(x)if y is not None:return self.loss(y_pred, y.squeeze())else:return y_pred#字符集随便挑了一些汉字,实际上还可以扩充
#为每个字生成一个标号
#{"a":1, "b":2, "c":3...}
#abc -> [1,2,3]
def build_vocab():chars = "abcdefghijklmnopqrstuvwxyz" #字符集vocab = {}for index, char in enumerate(chars):vocab[char] = index + 1 #每个字对应一个序号vocab['unk'] = len(vocab)+1return vocab#随机生成一个样本
#从所有字中选取sentence_length个字
#反之为负样本
def build_sample(vocab, sentence_length):#随机从字表选取sentence_length个字,可能重复x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]#A类样本if set("abc") & set(x) and not set("xyz") & set(x):y = 0#B类样本elif not set("abc") & set(x) and set("xyz") & set(x):y = 1#C类样本else:y = 2x = [vocab.get(word, vocab['unk']) for word in x] #将字转换成序号,为了做embeddingreturn x, y#建立数据集
#输入需要的样本数量。需要多少生成多少
def build_dataset(sample_length, vocab, sentence_length):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(vocab, sentence_length)dataset_x.append(x)dataset_y.append([y])return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)#建立模型
def build_model(vocab, char_dim, sentence_length):model = TorchModel(char_dim, sentence_length, vocab)return model#测试代码
#用来测试每轮模型的准确率
def evaluate(model, vocab, sample_length):model.eval()total = 200 #测试样本数量x, y = build_dataset(total, vocab, sample_length) #建立200个用于测试的样本y = y.squeeze()print("A类样本数量:%d, B类样本数量:%d, C类样本数量:%d"%(y.tolist().count(0), y.tolist().count(1), y.tolist().count(2)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x) #模型预测for y_p, y_t in zip(y_pred, y): #与真实标签进行对比if int(torch.argmax(y_p)) == int(y_t):correct += 1 #正样本判断正确else:wrong += 1print("正确预测个数:%d / %d, 正确率:%f"%(correct, total, correct/(correct+wrong)))return correct/(correct+wrong)def main():epoch_num = 15 #训练轮数batch_size = 20 #每次训练样本个数train_sample = 1000 #每轮训练总共训练的样本总数char_dim = 768 #每个字的维度sentence_length = 6 #样本文本长度vocab = build_vocab() #建立字表model = build_model(vocab, char_dim, sentence_length) #建立模型optim = torch.optim.Adam(model.parameters(), lr=1e-5) #建立优化器log = []for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, vocab, sentence_length) #构建一组训练样本optim.zero_grad() #梯度归零loss = model(x, y) #计算lossloss.backward() #计算梯度optim.step() #更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model, vocab, sentence_length) #测试本轮模型结果log.append([acc, np.mean(watch_loss)])# plt.plot(range(len(log)), [l[0] for l in log]) #画acc曲线# plt.plot(range(len(log)), [l[1] for l in log]) #画loss曲线# plt.show()#保存模型torch.save(model.state_dict(), "model.pth")# 保存词表writer = open("vocab.json", "w", encoding="utf8")writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return#最终预测
def predict(model_path, vocab_path, input_strings):char_dim = 20 # 每个字的维度sentence_length = 6 # 样本文本长度vocab = json.load(open(vocab_path, "r", encoding="utf8"))model = build_model(vocab, char_dim, sentence_length) #建立模型model.load_state_dict(torch.load(model_path)) #加载训练好的权重x = []for input_string in input_strings:x.append([vocab[char] for char in input_string]) #将输入序列化model.eval() #测试模式,不使用dropoutwith torch.no_grad(): #不计算梯度result = model.forward(torch.LongTensor(x)) #模型预测for i, input_string in enumerate(input_strings):print(int(torch.argmax(result[i])), input_string, result[i]) #打印结果if __name__ == "__main__":main()# test_strings = ["juvaee", "yrwfrg", "rtweqg", "nlhdww"]# predict("model.pth", "vocab.json", test_strings)