欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 会展 > pytorch 的pth格式模型转onnx格式模型 - python 实现

pytorch 的pth格式模型转onnx格式模型 - python 实现

2024/11/30 8:56:21 来源:https://blog.csdn.net/weixin_42140236/article/details/144034626  浏览:    关键词:pytorch 的pth格式模型转onnx格式模型 - python 实现

pytorch 的pth格式模型转onnx格式模型 - python 实现

#-*-coding:utf-8-*-
# date:2021-10-5
# Author: DataBall - XIAN
# function: pytorch model 2 onnximport os
import argparse
import torch
import torch.nn as nn
import numpy as npfrom network.resnet import resnet18,resnet50if __name__ == "__main__":parser = argparse.ArgumentParser(description=' Project handpose x')parser.add_argument('--model_path', type=str, default = r'ckpt\resnet_18_epoch-275-x96.pth',help = 'model_path') # 模型路径parser.add_argument('--model', type=str, default = 'resnet_18',help = '''model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0''') # 模型类型parser.add_argument('--GPUS', type=str, default = '0',help = 'GPUS') # GPU选择parser.add_argument('--test_path', type=str, default = './image/',help = 'test_path') # 测试图片路径parser.add_argument('--img_size', type=tuple , default = (96,96),help = 'img_size') # 输入模型图片尺寸print('\n/******************* {} ******************/\n'.format(parser.description))#--------------------------------------------------------------------------ops = parser.parse_args()# 解析添加参数#--------------------------------------------------------------------------print('----------------------------------')unparsed = vars(ops) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典for key in unparsed.keys():print('{} : {}'.format(key,unparsed[key]))#---------------------------------------------------------------------------os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUStest_path =  ops.test_path # 测试图片文件夹路径#---------------------------------------------------------------- 构建模型print('use model : %s'%(ops.model))if ops.model == 'resnet_50':model_ = resnet50(img_size=ops.img_size[0])elif ops.model == 'resnet_18':model_ = resnet18(img_size=ops.img_size[0])use_cuda = torch.cuda.is_available()device = torch.device("cuda:0" if use_cuda else "cpu")model_ = model_.to(device)model_.eval() # 设置为前向推断模式# 加载测试模型if os.access(ops.model_path,os.F_OK):# checkpointchkpt = torch.load(ops.model_path, map_location=device)model_.load_state_dict(chkpt)print('load test model : {}'.format(ops.model_path))input_size = ops.img_size[0]batch_size = 1  #批处理大小input_shape = (3, input_size,input_size)   #输入数据,改成自己的输入shapeprint("input_size : ",input_size)x = torch.randn(batch_size, *input_shape)   # 生成张量x = x.to(device)export_onnx_file = "{}_size-{}.onnx".format(ops.model,input_size)		# 目的ONNX文件名torch.onnx.export(model_,x,export_onnx_file,opset_version=9,do_constant_folding=True,	# 是否执行常量折叠优化input_names=["input"],	# 输入名output_names=["output2d"],	# 输出名#dynamic_axes={"input":{0:"batch_size"},  # 批处理变量#                "output":{0:"batch_size"}})

脚本对应输出结果如下:


/*******************  Project handpose x ******************/----------------------------------
model_path : ckpt\resnet_18_epoch-275-x96.pth
model : resnet_18
GPUS : 0
test_path : ./image/
img_size : (96, 96)
use model : resnet_18
load test model : ckpt\resnet_18_epoch-275-x96.pth
input_size :  96

 ​​​

助力快速掌握数据集的信息和使用方式。

数据可以如此美好!

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com