转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/7839263.html

目前使用的torch模型转pytorch模型的程序为:

https://github.com/clcarwin/convert_torch_to_pytorch

该程序中,常见的模型都可以转换,但是对于torch中为BatchNormalization的则会提示出错:

Not Implement BatchNormalization

torch中的SpatialBatchNormalization对应于输入为4d的特征(batchsize*featdim*featHeight*featWidth),对应于pytorch中的nn.BatchNorm2d。

而torch中的BatchNormalization对应于输入为2d的特征(batchsize*featdim),对应于pytorch中的nn.BatchNorm1d。

因而修改方法很简单:

1. 在convert_torch.py的行(elif name == 'ReLU':)之前添加:

elif name == 'BatchNormalization':
n = nn.BatchNorm1d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
copy_param(m,n)
add_submodule(seq,n)

2. 在convert_torch.py的(未修改前的)行(elif name == 'ReLU':)之前添加:

elif name == 'BatchNormalization':
s += ['nn.BatchNorm1d({},{},{},{}),#BatchNorm1d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]

3. 在convert_torch.py的(未修改前的)行(s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s))之前添加:

s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm1d',')'),s)
s = map(lambda x: x.replace('),#BatchNorm1d',')'),s)

经过上述修改后,torch模型中含有BatchNormalization,转换到pytorch后的模型性能和转换前的模型性能一致。

顺便说一下,2天前更新的该程序,添加了BatchNorm3d的支持,但是在243、244行之后,并没有增加BatchNorm3d的相关代码,不清楚是否会有问题。我这边没有用到BatchNorm3d,因而没有测试。

另一方面,上面的3步中,我是根据BatchNorm2d去修改,没有测试如果不修改某一步(如第3步),程序是否会有问题。反正都改了,模型没有问题。。。

(原)torch模型转pytorch模型的更多相关文章

  1. 生产与学术之Pytorch模型导出为安卓Apk尝试记录

    生产与学术 写于 2019-01-08 的旧文, 当时是针对一个比赛的探索. 觉得可能对其他人有用, 就放出来分享一下 生产与学术, 真实的对立... 这是我这两天对pytorch深度学习->a ...

  2. 将Pytorch模型从CPU转换成GPU

    1. 如何进行迁移 对模型和相应的数据进行.cuda()处理.通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去.从而可以通过GPU来进行运算了. 1.1 判定使用GPU 下载了对应的GPU ...

  3. 使用C++调用pytorch模型(Linux)

    前言 模型转换思路通常为: Pytorch -> ONNX -> TensorRT Pytorch -> ONNX -> TVM Pytorch -> 转换工具 -> ...

  4. 使用C++调用并部署pytorch模型

    1.背景(Background) 上图显示了目前深度学习模型在生产环境中的方法,本文仅探讨如何部署pytorch模型! 至于为什么要用C++调用pytorch模型,其目的在于:使用C++及多线程可以加 ...

  5. DEX-6-caffe模型转成pytorch模型办法

    在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...

  6. PyTorch模型加载与保存的最佳实践

    一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...

  7. 从零搭建Pytorch模型教程(三)搭建Transformer网络

    ​ 前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍. ...

  8. Pytorch模型量化

    在深度学习中,量化指的是使用更少的bit来存储原本以浮点数存储的tensor,以及使用更少的bit来完成原本以浮点数完成的计算.这么做的好处主要有如下几点: 更少的模型体积,接近4倍的减少: 可以更快 ...

  9. 计算机网络原理和OSI模型与TCP模型

    计算机网络原理和OSI模型与TCP模型 一.计算机网络的概述 1.计算机网络的定义 计算机网络是一组自治计算机的互连的集合 2.计算机网络的基本功能 a.资源共享 b.分布式处理与负载均衡 c.综合信 ...

随机推荐

  1. 浅谈提升C#正则表达式效率

     摘要:说到C#的Regex,谈到最多的应该就是RegexOptions.Compiled这个东西,传说中在匹配速度方面,RegexOptions.Compiled是可以提升匹配速度的,但在启动速度上 ...

  2. Android教材 | 第三章 Android界面事件处理(一)—— 杰瑞教育原创教材试读

      前  言 JRedu Android应用开发中,除了界面编程外,另一个重要的内容就是组件的事件处理.在Android系统中,存在多种界面事件,比如触摸事件.按键事件.点击事件等.在用户交互过程中, ...

  3. 【Java】Java-正则匹配-性能优化

    Java-正则匹配-性能优化 Java 正则 点_百度搜索 在Java类中如何用正则表达式表示小数点啊?_百度知道 使用Jakarta-ORO库的几个例子 - 小橡树 - ITeye博客 正则表达式以 ...

  4. android学习四(Activity的生命周期)

    要学好活动(Activity).就必需要了解android中Activity的声明周期.灵活的使用生命周期.能够开发出更好的程序,在android中是使用任务来管理活动的,一个任务就是一组存放在栈里的 ...

  5. Ubuntu 突然上不去网了怎么办

    到家了也想看看程序.打开WIN8上的虚拟机VM,然后启动Ubuntu.................................... 像往常一样等待着界面,输入password,然后改动程序. ...

  6. 【Python】由host得到IP

    代码: import socket host='www.163.com' ip=socket.gethostbyname(host) print('Ip of {} is {}'.format(hos ...

  7. 极域电子教室卸载或安装软件后windows7无法启用触摸板、键盘

    我今天在win7上装了个极域电子教室,卸载后重启触摸板,键盘都不能用了?连口令都是用屏幕键盘来输入的.进去后看设备管理器,键盘和触摸板,前面都有黄色的告警,而且就是出现了鼠标代码为10的情况?不过吧鼠 ...

  8. Android权限判断checkPermission

    判断本程序是否拥有某权限的方法: private static final String EXTERNAL_STORAGE_PERMISSION = "android.permission. ...

  9. springboot微信sdk方式进行微信支付

    https://blog.csdn.net/xsg6509/article/details/80342744

  10. poj2689 Prime Distance 有难度 埃拉托斯尼斯筛法的运用

    我承认这道很难(对我来说),搞脑子啊,搞了好久,数论刚开始没多久,还不是很强大,思路有点死,主要是我 天赋太差,太菜了,希望多做做有所改善 开始解析: 首先要将在 [ l,u]内的所有素数找出来,还好 ...