为了账号安全,请及时绑定邮箱和手机立即绑定

PyTorch - 将 ProGAN 代理从 pth 转换为 onnx

PyTorch - 将 ProGAN 代理从 pth 转换为 onnx

人到中年有点甜 2022-09-13 10:01:46
我使用此 PyTorch 重新实现训练了一个 ProGAN 代理,并将该代理另存为 .现在我需要将代理转换为格式,我正在使用此scipt执行此操作:.pth.onnxfrom torch.autograd import Variableimport torch.onnximport torchvisionimport torchdevice = torch.device("cuda")dummy_input = torch.randn(1, 3, 64, 64)state_dict = torch.load("GAN_agent.pth", map_location = device)torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")一旦我运行它,我得到错误(下面的完整提示)。据我所知,问题在于将代理转换为.onnx需要更多信息。我错过了什么吗?AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)<ipython-input-2-c64481d4eddd> in <module>     10 state_dict = torch.load("GAN_agent.pth", map_location = device)     11 ---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)    146                         operator_export_type, opset_version, _retain_param_name,    147                         do_constant_folding, example_outputs,--> 148                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs)    149     150 
查看完整描述

1 回答

?
慕婉清6462132

TA贡献1804条经验 获得超2个赞

您拥有的文件是 ,它们只是图层名称到权重偏差和类似值的映射(有关更全面的介绍,请参阅此处)。state_dicttensor

这意味着你需要一个模型,以便可以映射那些节省的权重和偏差,但首先要做的事情是:

1. 模型准备

克隆模型定义所在的存储库并打开文件 。我们需要进行一些修改才能使其与 . 导出器需要仅作为(或/个)传递,而类需要和参数)。/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.pyonnxonnxinputtorch.tensorlistdictGeneratorintfloat

简单的解决方案是稍微修改一下函数(文件中的行,可以在GitHub上验证它)到以下内容:forward80

def forward(self, x, depth, alpha):

    """

    forward pass of the Generator

    :param x: input noise

    :param depth: current depth from where output is required

    :param alpha: value of alpha for fade-in effect

    :return: y => output

    """


    # THOSE TWO LINES WERE ADDED

    # We will pas tensors but unpack them here to `int` and `float`

    depth = depth.item()

    alpha = alpha.item()

    # THOSE TWO LINES WERE ADDED

    assert depth < self.depth, "Requested output depth cannot be produced"


    y = self.initial_block(x)


    if depth > 0:

        for block in self.layers[: depth - 1]:

            y = block(y)


        residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y))

        straight = self.rgb_converters[depth](self.layers[depth - 1](y))


        out = (alpha * straight) + ((1 - alpha) * residual)


    else:

        out = self.rgb_converters[0](y)


    return out

此处仅添加了解包方式。每个不属于类型的输入都应在函数定义中打包为一个,并在函数顶部尽快解压缩。它不会破坏您创建的检查点,所以不用担心,因为它只是映射。item()Tensorlayer-weight


2. 模型导出

将此脚本放在 (位置也位于):/pro_gan_pytorchREADME.md


import torch


from pro_gan_pytorch import PRO_GAN as pg


gen = torch.nn.DataParallel(pg.Generator(depth=9))

gen.load_state_dict(torch.load("GAN_GEN_SHADOW_8.pth"))


module = gen.module.to("cpu")


# Arguments like depth and alpha may need to be changed

dummy_inputs = (torch.randn(1, 512), torch.tensor([5]), torch.tensor([0.1]))

torch.onnx.export(module, dummy_inputs, "GAN_GEN8.onnx", verbose=True)

请注意以下几点:

  • 我们必须在加载权重之前创建模型,因为它是唯一的。state_dict

  • torch.nn.DataParallel是必需的,因为这是模型的训练对象(不确定您的情况,请相应地进行调整)。加载后,我们可以通过属性获取模块本身。module

  • 一切都被扔到了,我认为没有必要在这里。如果你坚持的话,你可以把一切都扔到。CPUGPUGPU

  • 生成器的虚拟输入不能是图像(我使用了存储库作者在其Google云端硬盘上提供的文件),它必须是带有元素的噪音。512

运行它,你的文件应该在那里。.onnx

哦,由于您遵循不同的检查点,您可能希望遵循类似的过程,尽管不能保证一切都会正常工作(尽管它看起来确实如此)。


查看完整回答
反对 回复 2022-09-13
  • 1 回答
  • 0 关注
  • 92 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号