相信很多小伙伴平时都使用内网进行工作,这些网络是无法连接huggingface的,使用魔塔加载模型网络断断续续的很容易失败。但是bert_score只接收一个模型名,然后自动在huggingface下载或在本地缓存加载。这个缓存跟huggingface官方缓存是不同的。
解决办法1:修改bert_score源码。bert_score虽说是只接受模型名,但内部还是通过AutoTokenizer.from_pretrained和AutoModel.from_pretrained这两个方法加载模型,相信这两个方法大家都很熟悉了。因此只需要在源码中添加自己的model_path,并且把源码中的model_type这个参数改为model_path
源码:
def get_model(model_type, num_layers, all_layers=None):if model_type.startswith("scibert"):model = AutoModel.from_pretrained(cache_scibert(model_type))elif "t5" in model_type:from transformers import T5EncoderModelmodel = T5EncoderModel.from_pretrained(model_type)else:model = AutoModel.from_pretrained(model_type)model.eval()def get_tokenizer(model_type, use_fast=False):if model_type.startswith("scibert"):model_type = cache_scibert(model_type)if version.parse(trans_version) >= version.parse("4.0.0"):tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=use_fast)else:assert not use_fast, "Fast tokenizer is not available for version < 4.0.0"tokenizer = AutoTokenizer.from_pretrained(model_type)return tokenizer
改为:
def get_model(model_type, num_layers, all_layers=None):model_path = 'xxx'if model_type.startswith("scibert"):model = AutoModel.from_pretrained(cache_scibert(model_type))elif "t5" in model_type:from transformers import T5EncoderModelmodel = T5EncoderModel.from_pretrained(model_path)else:model = AutoModel.from_pretrained(model_path)model.eval()def get_tokenizer(model_type, use_fast=False):if model_type.startswith("scibert"):model_type = cache_scibert(model_type)model_path = 'xxx'if version.parse(trans_version) >= version.parse("4.0.0"):tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast)else:assert not use_fast, "Fast tokenizer is not available for version < 4.0.0"tokenizer = AutoTokenizer.from_pretrained(model_path)return tokenizer
解决办法2:让bert_score找到缓存模型,相信聪明的小伙伴已经在前面的代码中看到bert_score是如何加载缓存模型的。如果要加载缓存模型,model_type字段加载的模型前要加scibert-前缀。并且需要把本地模型放在指定的目录下。可以看出这个函数下载的模型有它自己的命名规则,需要根据它的规则对自己的模型文件做出相应修改。
下面给出部分源码
def cache_scibert(model_type, cache_folder="~/.cache/torch/transformers"):if not model_type.startswith("scibert"):return model_typeunderscore_model_type = model_type.replace("-", "_")cache_folder = os.path.abspath(os.path.expanduser(cache_folder))filename = os.path.join(cache_folder, underscore_model_type)