欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 维修 > 【NLP 26、实践 ⑥ 引入bert,判断文本中是否有特定字符出现】

【NLP 26、实践 ⑥ 引入bert,判断文本中是否有特定字符出现】

2025/2/26 0:09:33 来源:https://blog.csdn.net/m0_73983707/article/details/145702544  浏览:    关键词:【NLP 26、实践 ⑥ 引入bert,判断文本中是否有特定字符出现】

目录

引入bert,判断文本中特定字符出现 

1.设计模型

2.前馈运算

3.建立词表

4.生成样本

5.建立数据集

6.建立模型

7.测试模型结果

8.模型训练

9.用训练好的模型预测

10.完整代码 


我欲挑灯见你,可是梦怕火

我泪眼婆娑,坐实你来过

                                —— 25.1.21

引入bert,判断文本中特定字符出现 

在实践 ② 中,我们使用了神经网络模型做这个任务

模型结构:嵌入层 ——> 循环神经网络层 ——>  线性层 ——> 交叉熵损失函数

【NLP 13、实践 ② 判断文本中是否有特定字符出现】_使用大模型判断一个词中是否包含另一个词-CSDN博客

1.设计模型

BertModel.from_pretrained():用于从预训练模型目录或 Hugging Face 模型库加载 BERT 模型的权重及配置。

参数名称类型是否必填说明
pretrained_model_name_or_path字符串模型名称(如 bert-base-uncased)或本地路径。
configBertConfig对象自定义配置类,用于覆盖默认配置。
state_dict字典预训练权重字典,用于部分加载模型。
cache_dir字符串缓存目录,用于存储下载的模型文件。
from_tf布尔值是否从 TensorFlow 模型加载权重,默认为 False
ignore_mismatched_sizes布尔值是否忽略权重大小不匹配的错误,默认为 False
local_files_only布尔值是否仅从本地文件加载模型,默认为 False

返回值:

  • last_hidden_state: 形状为 (batch_size, sequence_length, hidden_size) 的张量,表示输入序列经过 BERT 编码器后的最后一个隐藏层的输出。
  • pooler_output: 形状为 (batch_size, hidden_size) 的张量,表示经过池化层(通常是 [CLS] 标记的输出)处理后的结果,通常用于分类任务。
  • hidden_states(可选): 如果设置了 output_hidden_states=True,则返回一个包含所有层隐藏状态的元组,形状为 (batch_size, sequence_length, hidden_size)
  • attentions(可选): 如果设置了 output_attentions=True,则返回一个包含所有注意力权重的元组。

return_dict参数:

  • 当 return_dict 设置为 True 时,forward() 方法返回一个 BaseModelOutput 对象,该对象包含了模型的各种输出,如最后一层的隐藏状态、[CLS] 标记的输出等。
  • 当 return_dict 设置为 False 时,forward() 方法返回一个元组,包含与 BaseModelOutput 对象相同的元素,但不包含对象结构。

nn.Linear():PyTorch 中用于创建全连接层(线性层)的类。它将输入张量转换为输出张量,通过矩阵乘法和偏置加法实现。

参数名称类型是否必填说明
in_featuresint输入特征的数量
out_featuresint输出特征的数量
biasbool是否使用偏置,默认为 True

torch.sigmoid:PyTorch 中的一个激活函数,用于将输入张量的每个元素映射到区间 (0, 1) 之间。它在神经网络中常用于二分类问题的输出层,将模型的输出转换为概率值。

参数名称类型是否必填说明
inputTensor输入的张量,可以是任意形状和类型。
outTensor, 可选指定输出结果的张量。如果提供,结果将存储在此张量中。

nn.Dropout():PyTorch 中用于防止神经网络过拟合的一种正则化技术。它通过在训练过程中随机丢弃一部分神经元的输出,来减少神经元之间的依赖性,从而提高模型的泛化能力。

训练过程中随机地将输入张量的一部分元素置零,而在评估(或测试)过程中则不会进行任何操作。这样做可以防止模型过拟合,提高模型的泛化能力

参数名称类型是否必填说明
pfloat被舍弃的概率,默认为 0.5
inplacebool是否在原地操作,默认为 False

nn.functional.cross_entropy:PyTorch 中用于计算交叉熵损失的函数。交叉熵损失常用于多分类任务,衡量模型预测的概率分布与真实标签之间的差异。它是衡量分类模型性能的重要指标之一。

参数名称类型是否必填说明
inputTensor输入张量,通常是模型的输出 logits
targetTensor目标标签,整数形式,表示每个样本的类别
weightTensor样本权重,默认为 None
size_averagebool是否对损失进行平均,默认为 None
ignore_indexint忽略的标签索引,默认为 -100
reducebool是否对损失进行缩减,默认为 None
reductionstr损失缩减方式,默认为 'mean'
class TorchModel(nn.Module):def __init__(self, input_dim, sentence_length, vocab):super(TorchModel, self).__init__()# 原始代码# self.embedding = nn.Embedding(len(vocab) + 1, input_dim)# self.layer = nn.Linear(input_dim, input_dim)# self.pool = nn.MaxPool1d(sentence_length)self.bert = BertModel.from_pretrained(r"F:\人工智能NLP\NLP资料\week6 语言模型\bert-base-chinese", return_dict=False)self.classify = nn.Linear(input_dim, 3)self.activation = torch.sigmoid     #sigmoid做激活函数self.dropout = nn.Dropout(0.5)self.loss = nn.functional.cross_entropy

2.前馈运算

    #当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):# 原始代码# x = self.embedding(x)  #input shape:(batch_size, sen_len) (10,6)# x = self.layer(x)      #input shape:(batch_size, sen_len, input_dim) (10,6,20)# x = self.dropout(x)    #input shape:(batch_size, sen_len, input_dim)# x = self.activation(x) #input shape:(batch_size, sen_len, input_dim)# x = self.pool(x.transpose(1,2)).squeeze() #input shape:(batch_size, sen_len, input_dim)sequence_output, pooler_output = self.bert(x)x = self.classify(pooler_output)y_pred = self.activation(x)if y is not None:return self.loss(y_pred, y.squeeze())else:return y_pred

3.建立词表

enumerate():Python 的内置函数,用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

参数名称类型是否必填说明
iterableobject一个可迭代对象,如列表、元组或字符串
startint下标起始位置,默认为 0

len(): Python 的内置函数,用于返回对象(字符、列表、元组等)的长度或项目数量。

参数名称类型是否必填说明
objobject要计算长度的对象,如字符串、列表、元组等
#字符集随便挑了一些汉字,实际上还可以扩充
#为每个字生成一个标号
#{"a":1, "b":2, "c":3...}
#abc -> [1,2,3]
def build_vocab():chars = "abcdefghijklmnopqrstuvwxyz"  #字符集vocab = {}for index, char in enumerate(chars):vocab[char] = index + 1   #每个字对应一个序号vocab['unk'] = len(vocab)+1return vocab

4.生成样本

list():将可迭代对象转换为列表,可以将字符串、元组、集合等可迭代对象转换为列表。

参数名称类型是否必填说明
iterable可迭代对象一个可迭代对象,如字符串、元组或列表

vocab,keys():返回字典中所有的键。

参数名称类型是否必填说明
vocab字典对象要获取键的字典对象

random.choice():从非空序列中随机选择一个元素。适用于从列表、元组、字符串等序列中随机选取一个元素。

参数名称类型是否必填说明
sequence序列类型一个非空序列,如列表、元组或字符串

range():生成一个包含指定范围内整数的序列。常用于循环中控制循环次数,可以指定起始值、结束值和步长。

参数名称类型是否必填说明
start整数序列的起始值,默认为 0
stop整数序列的结束值,但不包括这个值
step整数序列中相邻两个数字之间的差值,默认为 1

vocab.get():从字典中获取指定键的值,如果键不存在,返回默认值。提供了一种安全的方式来访问字典中的值,避免因键不存在而引发异常。

参数名称类型是否必填说明
key要检索的键
default默认值如果键不存在时返回的默认值,默认为 None

set():创建一个无序且不重复的元素集合,可以将列表、元组、字符串等可迭代对象转换为集合,集合中的元素唯一且无序。

参数名称类型是否必填说明
iterable可迭代对象一个可迭代对象,如列表、元组或字符串,默认为空集合
#随机生成一个样本
#从所有字中选取sentence_length个字
#反之为负样本
def build_sample(vocab, sentence_length):#随机从字表选取sentence_length个字,可能重复x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]#A类样本if set("abc") & set(x) and not set("xyz") & set(x):y = 0#B类样本elif not set("abc") & set(x) and set("xyz") & set(x):y = 1#C类样本else:y = 2x = [vocab.get(word, vocab['unk']) for word in x]   #将字转换成序号,为了做embeddingreturn x, y

5.建立数据集

range(): Python 的内置函数,用于生成一个不可变的数字序列。通常用于 for 循环中控制循环次数。

参数名称类型是否必填说明
start整数序列的起始值,默认为 0
stop整数序列的结束值(不包含该值)。
step整数序列中相邻两个数之间的差值,默认为 1

append():列表(list)对象的方法,用于在列表末尾添加一个元素。

参数名称类型是否必填说明
object任意类型要添加到列表末尾的对象。

torch.LongTensor(): PyTorch 库中的函数,用于创建一个包含长整型(64位整数)数据的新张量。

参数名称类型是否必填说明
data可迭代对象或标量用于初始化张量的数据。可以是列表、元组、NumPy 数组或其他张量。
dtypetorch.dtype指定张量的数据类型,默认为 torch.long
devicestr 或 torch.device指定张量所在的设备(如 'cpu' 或 'cuda:0'),默认为当前默认设备。
requires_gradbool是否需要计算梯度,默认为 False
sizetorch.Size 或 int...指定张量的形状。如果提供 data,则此参数会被忽略。
其他参数-其他与张量创建相关的参数,如 layoutpin_memory 等。
#建立数据集
#输入需要的样本数量。需要多少生成多少
def build_dataset(sample_length, vocab, sentence_length):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(vocab, sentence_length)dataset_x.append(x)dataset_y.append([y])return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)

6.建立模型

#建立模型
def build_model(vocab, char_dim, sentence_length):model = TorchModel(char_dim, sentence_length, vocab)return model

7.测试模型结果

model.eval():将模型设置为评估模式,禁用 Dropout 和 BatchNorm 的训练行为,确保模型在推理时的稳定性。

squeeze():从数组或张量中删除所有大小为1的维度,简化数据结构。

参数名称类型是否必填说明
inputTensor输入的张量
dimint 或 None指定要移除的维度,如果不指定,则移除所有大小为1的维度

torch.no_grad():用于临时禁用梯度计算,通常用于推理阶段

zip():将多个可迭代对象中的元素一一配对,返回元组列表,便于同时处理多个序列。

参数名称类型是否必填说明
*iterables可迭代对象将多个可迭代对象中的元素一一配对,返回元组列表

int():将输入的数值或字符串转换为整数类型,舍去小数部分。

参数名称类型是否必填说明
x数值或字符串将输入转换为整数类型
#测试代码
#用来测试每轮模型的准确率
def evaluate(model, vocab, sample_length):model.eval()total = 200 #测试样本数量x, y = build_dataset(total, vocab, sample_length)   #建立200个用于测试的样本y = y.squeeze()print("A类样本数量:%d, B类样本数量:%d, C类样本数量:%d"%(y.tolist().count(0), y.tolist().count(1), y.tolist().count(2)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)      #模型预测for y_p, y_t in zip(y_pred, y):  #与真实标签进行对比if int(torch.argmax(y_p)) == int(y_t):correct += 1   #正样本判断正确else:wrong += 1print("正确预测个数:%d / %d, 正确率:%f"%(correct, total, correct/(correct+wrong)))return correct/(correct+wrong)

8.模型训练

range():生成一个包含指定范围内整数的序列。

参数名称类型是否必填说明
start整数序列的起始值,默认为 0
stop整数序列的结束值(不包含该值)。
step整数序列中相邻两个数之间的差值,默认为 1

optim.zero_grad():将模型参数的梯度清零,通常在每个训练步骤开始时调用。

backward():计算张量的梯度,通常在反向传播过程中使用。

参数名称类型是否必填说明
gradientTensor 或 None指定梯度的权重,默认为 None
retain_graphbool是否保留计算图,默认为 False
create_graphbool是否创建计算图以计算高阶导数,默认为 False

optim.step():根据计算出的梯度更新模型参数。

append():在列表末尾添加一个元素。

参数名称类型是否必填说明
object任意类型要添加到列表末尾的对象。

item():将张量转换为 Python 标量。

np.mean():计算数组或矩阵的算术平均值。

参数名称类型是否必填说明
aarray_like需要计算均值的数组。
axisNone 或 int 或 tuple of ints指定计算均值方向的轴,默认为 None
dtypedata-type输出均值的类型,默认为 float64
outndarray存放结果的备选输出数组,默认为 None
keepdimsbool是否保留减少的轴为尺寸为一的维度,默认为 False

torch.save():将对象保存到磁盘文件。

参数名称类型是否必填说明
objobject要保存的对象。
fUnion[str, PathLike, BinaryIO, IO[bytes]]类似文件的对象或包含文件名的字符串。
pickle_moduleAny用于 pickling 元数据和对象的模块,默认为 pickle
pickle_protocolint可以指定 pickling 协议的版本,默认为 DEFAULT_PROTOCOL
_use_new_zipfile_serializationbool是否使用新的 zipfile 序列化,默认为 True

state_dict():返回模型的状态字典,包含模型的参数和缓冲区。

open():打开一个文件并返回文件对象。

参数名称类型是否必填说明
fileUnion[str, PathLike]文件名或文件对象。
modestr文件打开模式,默认为 'r'
bufferingint缓冲策略,默认为 -1
encodingstr文件编码,默认为 None
errorsstr指定如何处理编码错误,默认为 None
newlinestr控制换行符的行为,默认为 None
closefdbool如果为 True,文件描述符将在关闭文件时关闭,默认为 True
openercallable自定义打开器,默认为 None

write():将字符串写入文件。

参数名称类型是否必填说明
textstr要写入文件的字符串。

json.dumps():将 Python 对象序列化为 JSON 格式的字符串。

参数名称类型是否必填说明
objobject要序列化的 Python 对象。
skipkeysbool如果为 True,则跳过不可序列化的键,默认为 False
ensure_asciibool如果为 True,则所有非 ASCII 字符都将转义,默认为 True
check_circularbool如果为 True,则检查循环引用,默认为 True
allow_nanbool如果为 True,则允许 NaNInfinity 和 -Infinity,默认为 False
clsJSONEncoder自定义的 JSON 编码器,默认为 json.JSONEncoder
indentint 或 str指定缩进,默认为 None
separatorstuple指定分隔符,默认为 (', ', ': ')
defaultcallable自定义函数,用于处理不可序列化的对象,默认为 None
sort_keysbool如果为 True,则按键排序,默认为 False

close():关闭文件对象。

def main():epoch_num = 15        #训练轮数batch_size = 20       #每次训练样本个数train_sample = 1000   #每轮训练总共训练的样本总数char_dim = 768         #每个字的维度sentence_length = 6   #样本文本长度vocab = build_vocab()       #建立字表model = build_model(vocab, char_dim, sentence_length)    #建立模型optim = torch.optim.Adam(model.parameters(), lr=1e-5)   #建立优化器log = []for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, vocab, sentence_length) #构建一组训练样本optim.zero_grad()    #梯度归零loss = model(x, y)   #计算lossloss.backward()      #计算梯度optim.step()         #更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model, vocab, sentence_length)   #测试本轮模型结果log.append([acc, np.mean(watch_loss)])# plt.plot(range(len(log)), [l[0] for l in log])  #画acc曲线# plt.plot(range(len(log)), [l[1] for l in log])  #画loss曲线# plt.show()#保存模型torch.save(model.state_dict(), "model.pth")# 保存词表writer = open("vocab.json", "w", encoding="utf8")writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return

9.用训练好的模型预测

json.load():将 JSON 格式的文件内容解析为 Python 对象(如字典、列表等)。

参数名称类型是否必填说明
fpio.TextIOBase 或 str文件对象或文件路径。
clsJSONDecoder 子类自定义的 JSON 解码器,默认为 json.JSONDecoder
object_hookcallable用于自定义反序列化特定类型的对象。
parse_floatcallable用于解析浮点数的函数,默认为 float
parse_intcallable用于解析整数的函数,默认为 int
parse_constantcallable用于解析 JSON 常量(如 nulltruefalse)。
object_pairs_hookcallable用于自定义反序列化键值对的对象。
bufferslist用于增量解析的缓冲区列表。

open():打开一个文件,并返回一个文件对象,以便进行读写操作。

参数名称类型是否必填说明
filestr 或 PathLike文件名或文件路径。
modestr文件打开模式,默认为 'r'(读取)。
bufferingint缓冲策略,默认为 -1(系统默认)。
encodingstr文件编码,默认为 None(使用系统默认)。
errorsstr指定如何处理编码错误,默认为 None
newlinestr控制换行符的行为,默认为 None
closefdbool如果为 True,文件描述符将在关闭文件时关闭,默认为 True
openercallable自定义的文件打开器,默认为 None

torch.load():从磁盘加载一个序列化的对象(如张量、模型等),通常用于恢复训练好的模型或张量。

参数名称类型是否必填说明
fstr 或 PathLike 或 file-like object文件路径或文件对象。
map_locationstr 或 Callable 或 dict指定张量的存储位置,用于在不同设备间加载模型。
pickle_modulemodule用于反序列化的 pickle 模块,默认为 pickle
**pickle_load_args任意关键字参数传递给 pickle.load 的额外参数。

append():将一个元素添加到列表的末尾。

参数名称类型是否必填说明
object任意类型要添加到列表末尾的对象。

model.eval():将模型设置为评估模式,禁用诸如 Dropout 和 BatchNorm 等层的训练特定行为,确保模型在推理时的稳定性。

torch.no_grad():临时禁用梯度计算,用于推理阶段以减少内存消耗和提高性能。

enumerate():将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

参数名称类型是否必填说明
iterable可迭代对象一个可迭代对象,如字符串、元组或列表。
startint下标起始位置,默认为 0

torch.argmax():返回输入张量沿指定维度的最大值的索引。

参数名称类型是否必填说明
inputTensor输入的张量。
dimint指定沿哪个维度查找最大值,默认为最后一个维度。
keepdimbool是否保持输出张量的维度,默认为 False
outTensor存放结果的备选输出张量,默认为 None
#最终预测
def predict(model_path, vocab_path, input_strings):char_dim = 20  # 每个字的维度sentence_length = 6  # 样本文本长度vocab = json.load(open(vocab_path, "r", encoding="utf8"))model = build_model(vocab, char_dim, sentence_length)    #建立模型model.load_state_dict(torch.load(model_path))       #加载训练好的权重x = []for input_string in input_strings:x.append([vocab[char] for char in input_string])  #将输入序列化model.eval()   #测试模式,不使用dropoutwith torch.no_grad():  #不计算梯度result = model.forward(torch.LongTensor(x))  #模型预测for i, input_string in enumerate(input_strings):print(int(torch.argmax(result[i])), input_string, result[i]) #打印结果

10.完整代码 

#coding:utf8import torch
import torch.nn as nn
import numpy as np
import random
import json
from transformers import BertModel"""基于pytorch的网络编写
实现一个网络完成一个简单nlp任务
判断文本中是否有某些特定字符出现week2的例子,修改引入bert
"""class TorchModel(nn.Module):def __init__(self, input_dim, sentence_length, vocab):super(TorchModel, self).__init__()# 原始代码# self.embedding = nn.Embedding(len(vocab) + 1, input_dim)# self.layer = nn.Linear(input_dim, input_dim)# self.pool = nn.MaxPool1d(sentence_length)self.bert = BertModel.from_pretrained(r"F:\人工智能NLP\NLP资料\week6 语言模型\bert-base-chinese", return_dict=False)self.classify = nn.Linear(input_dim, 3)self.activation = torch.sigmoid     #sigmoid做激活函数self.dropout = nn.Dropout(0.5)self.loss = nn.functional.cross_entropy#当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):# 原始代码# x = self.embedding(x)  #input shape:(batch_size, sen_len) (10,6)# x = self.layer(x)      #input shape:(batch_size, sen_len, input_dim) (10,6,20)# x = self.dropout(x)    #input shape:(batch_size, sen_len, input_dim)# x = self.activation(x) #input shape:(batch_size, sen_len, input_dim)# x = self.pool(x.transpose(1,2)).squeeze() #input shape:(batch_size, sen_len, input_dim)sequence_output, pooler_output = self.bert(x)x = self.classify(pooler_output)y_pred = self.activation(x)if y is not None:return self.loss(y_pred, y.squeeze())else:return y_pred#字符集随便挑了一些汉字,实际上还可以扩充
#为每个字生成一个标号
#{"a":1, "b":2, "c":3...}
#abc -> [1,2,3]
def build_vocab():chars = "abcdefghijklmnopqrstuvwxyz"  #字符集vocab = {}for index, char in enumerate(chars):vocab[char] = index + 1   #每个字对应一个序号vocab['unk'] = len(vocab)+1return vocab#随机生成一个样本
#从所有字中选取sentence_length个字
#反之为负样本
def build_sample(vocab, sentence_length):#随机从字表选取sentence_length个字,可能重复x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]#A类样本if set("abc") & set(x) and not set("xyz") & set(x):y = 0#B类样本elif not set("abc") & set(x) and set("xyz") & set(x):y = 1#C类样本else:y = 2x = [vocab.get(word, vocab['unk']) for word in x]   #将字转换成序号,为了做embeddingreturn x, y#建立数据集
#输入需要的样本数量。需要多少生成多少
def build_dataset(sample_length, vocab, sentence_length):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(vocab, sentence_length)dataset_x.append(x)dataset_y.append([y])return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)#建立模型
def build_model(vocab, char_dim, sentence_length):model = TorchModel(char_dim, sentence_length, vocab)return model#测试代码
#用来测试每轮模型的准确率
def evaluate(model, vocab, sample_length):model.eval()total = 200 #测试样本数量x, y = build_dataset(total, vocab, sample_length)   #建立200个用于测试的样本y = y.squeeze()print("A类样本数量:%d, B类样本数量:%d, C类样本数量:%d"%(y.tolist().count(0), y.tolist().count(1), y.tolist().count(2)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)      #模型预测for y_p, y_t in zip(y_pred, y):  #与真实标签进行对比if int(torch.argmax(y_p)) == int(y_t):correct += 1   #正样本判断正确else:wrong += 1print("正确预测个数:%d / %d, 正确率:%f"%(correct, total, correct/(correct+wrong)))return correct/(correct+wrong)def main():epoch_num = 15        #训练轮数batch_size = 20       #每次训练样本个数train_sample = 1000   #每轮训练总共训练的样本总数char_dim = 768         #每个字的维度sentence_length = 6   #样本文本长度vocab = build_vocab()       #建立字表model = build_model(vocab, char_dim, sentence_length)    #建立模型optim = torch.optim.Adam(model.parameters(), lr=1e-5)   #建立优化器log = []for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, vocab, sentence_length) #构建一组训练样本optim.zero_grad()    #梯度归零loss = model(x, y)   #计算lossloss.backward()      #计算梯度optim.step()         #更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model, vocab, sentence_length)   #测试本轮模型结果log.append([acc, np.mean(watch_loss)])# plt.plot(range(len(log)), [l[0] for l in log])  #画acc曲线# plt.plot(range(len(log)), [l[1] for l in log])  #画loss曲线# plt.show()#保存模型torch.save(model.state_dict(), "model.pth")# 保存词表writer = open("vocab.json", "w", encoding="utf8")writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return#最终预测
def predict(model_path, vocab_path, input_strings):char_dim = 20  # 每个字的维度sentence_length = 6  # 样本文本长度vocab = json.load(open(vocab_path, "r", encoding="utf8"))model = build_model(vocab, char_dim, sentence_length)    #建立模型model.load_state_dict(torch.load(model_path))       #加载训练好的权重x = []for input_string in input_strings:x.append([vocab[char] for char in input_string])  #将输入序列化model.eval()   #测试模式,不使用dropoutwith torch.no_grad():  #不计算梯度result = model.forward(torch.LongTensor(x))  #模型预测for i, input_string in enumerate(input_strings):print(int(torch.argmax(result[i])), input_string, result[i]) #打印结果if __name__ == "__main__":main()# test_strings = ["juvaee", "yrwfrg", "rtweqg", "nlhdww"]# predict("model.pth", "vocab.json", test_strings)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词