model.load_state_dict(torch.load(weight_path), strict=False)
当权重中的key和网络中匹配就加载,不匹配就跳过,如果strict是True,那必须完全匹配,不然就报错,默认是True
只加载部分参数权重,可以将state中不需要的参数删除,然后加载其他项
x = torch.load(self.weight)
del x['char_recognizer.classifier.bias']
del x['char_recognizer.classifier.weight']
self.load_state_dict(x, strict=False)
或者将
path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)
https://blog.csdn.net/hxxjxw/article/details/119491163
https://blog.csdn.net/qq_34914551/article/details/87871134