欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 名人名企 > BERT的中文问答系统(羲和1.0)

BERT的中文问答系统(羲和1.0)

2024/10/23 23:30:53 来源:https://blog.csdn.net/weixin_54366286/article/details/142851138  浏览:    关键词:BERT的中文问答系统(羲和1.0)

确保项目目录结构清晰,我们可以通过以下步骤来组织代码和生成项目目录结构。我们将项目分为几个主要部分:数据、模型、日志、图标、源代码等。

项目目录结构
code

project_root/
├── data/
│   └── train_data.jsonl
├── models/
│   └── xihua_model.pth
├── logs/
│   └── <date_time>/
│       └── 羲和.txt
├── icons/
│   ├── xihe.png
│   └── ling.png
├── src/
│   ├── main.py
│   ├── xihua_dataset.py
│   ├── xihua_model.py
│   ├── xihua_gui.py
│   ├── utils.py
│   └── train.py
└── README.md

代码拆分
1.
main.py
主入口文件,负责启动GUI。

python

import os
import tkinter as tk# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.join(PROJECT_ROOT, 'src')# 导入模块
from src.xihua_gui import XihuaChatbotGUIif __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

xihua_dataset.py
数据集类的定义。

python

import os
import json
import jsonlines
from transformers import BertTokenizer
import loggingclass XihuaDataset:def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:if self.validate_item(item):data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = [item for item in json.load(f) if self.validate_item(item)]except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef validate_item(self, item):required_keys = ['question', 'xihe_answers', 'ling_answers']if all(key in item for key in required_keys):return Truelogging.warning(f"跳过无效项: 缺少必要键 {required_keys}")return Falsedef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']# 确保 xihe_answers 和 ling_answers 都有值xihe_answer = item.get('xihe_answers', [])ling_answer = item.get('ling_answers', [])if not xihe_answer and ling_answer:xihe_answer = ling_answerelif not ling_answer and xihe_answer:ling_answer = xihe_answerxihe_answer = xihe_answer[0] if xihe_answer else ""ling_answer = ling_answer[0] if ling_answer else ""try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)xihe_inputs = self.tokenizer(xihe_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)ling_inputs = self.tokenizer(ling_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com