欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 健康 > 养生 > 解决:bert_score无法加载本地模型

解决:bert_score无法加载本地模型

2025/2/22 16:28:37 来源:https://blog.csdn.net/nlhkfcdxb/article/details/140063671  浏览:    关键词:解决:bert_score无法加载本地模型

相信很多小伙伴平时都使用内网进行工作,这些网络是无法连接huggingface的,使用魔塔加载模型网络断断续续的很容易失败。但是bert_score只接收一个模型名,然后自动在huggingface下载或在本地缓存加载。这个缓存跟huggingface官方缓存是不同的。

解决办法1:修改bert_score源码。bert_score虽说是只接受模型名,但内部还是通过AutoTokenizer.from_pretrained和AutoModel.from_pretrained这两个方法加载模型,相信这两个方法大家都很熟悉了。因此只需要在源码中添加自己的model_path,并且把源码中的model_type这个参数改为model_path

源码:

def get_model(model_type, num_layers, all_layers=None):if model_type.startswith("scibert"):model = AutoModel.from_pretrained(cache_scibert(model_type))elif "t5" in model_type:from transformers import T5EncoderModelmodel = T5EncoderModel.from_pretrained(model_type)else:model = AutoModel.from_pretrained(model_type)model.eval()def get_tokenizer(model_type, use_fast=False):if model_type.startswith("scibert"):model_type = cache_scibert(model_type)if version.parse(trans_version) >= version.parse("4.0.0"):tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=use_fast)else:assert not use_fast, "Fast tokenizer is not available for version < 4.0.0"tokenizer = AutoTokenizer.from_pretrained(model_type)return tokenizer

改为:

def get_model(model_type, num_layers, all_layers=None):model_path = 'xxx'if model_type.startswith("scibert"):model = AutoModel.from_pretrained(cache_scibert(model_type))elif "t5" in model_type:from transformers import T5EncoderModelmodel = T5EncoderModel.from_pretrained(model_path)else:model = AutoModel.from_pretrained(model_path)model.eval()def get_tokenizer(model_type, use_fast=False):if model_type.startswith("scibert"):model_type = cache_scibert(model_type)model_path = 'xxx'if version.parse(trans_version) >= version.parse("4.0.0"):tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast)else:assert not use_fast, "Fast tokenizer is not available for version < 4.0.0"tokenizer = AutoTokenizer.from_pretrained(model_path)return tokenizer

解决办法2:让bert_score找到缓存模型,相信聪明的小伙伴已经在前面的代码中看到bert_score是如何加载缓存模型的。如果要加载缓存模型,model_type字段加载的模型前要加scibert-前缀。并且需要把本地模型放在指定的目录下。可以看出这个函数下载的模型有它自己的命名规则,需要根据它的规则对自己的模型文件做出相应修改。

下面给出部分源码

def cache_scibert(model_type, cache_folder="~/.cache/torch/transformers"):if not model_type.startswith("scibert"):return model_typeunderscore_model_type = model_type.replace("-", "_")cache_folder = os.path.abspath(os.path.expanduser(cache_folder))filename = os.path.join(cache_folder, underscore_model_type)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词