转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/7839263.html

目前使用的torch模型转pytorch模型的程序为:

https://github.com/clcarwin/convert_torch_to_pytorch

该程序中,常见的模型都可以转换,但是对于torch中为BatchNormalization的则会提示出错:

Not Implement BatchNormalization

torch中的SpatialBatchNormalization对应于输入为4d的特征(batchsize*featdim*featHeight*featWidth),对应于pytorch中的nn.BatchNorm2d。

而torch中的BatchNormalization对应于输入为2d的特征(batchsize*featdim),对应于pytorch中的nn.BatchNorm1d。

因而修改方法很简单:

1. 在convert_torch.py的行(elif name == 'ReLU':)之前添加:

elif name == 'BatchNormalization':
n = nn.BatchNorm1d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
copy_param(m,n)
add_submodule(seq,n)

2. 在convert_torch.py的(未修改前的)行(elif name == 'ReLU':)之前添加:

elif name == 'BatchNormalization':
s += ['nn.BatchNorm1d({},{},{},{}),#BatchNorm1d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]

3. 在convert_torch.py的(未修改前的)行(s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s))之前添加:

s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm1d',')'),s)
s = map(lambda x: x.replace('),#BatchNorm1d',')'),s)

经过上述修改后,torch模型中含有BatchNormalization,转换到pytorch后的模型性能和转换前的模型性能一致。

顺便说一下,2天前更新的该程序,添加了BatchNorm3d的支持,但是在243、244行之后,并没有增加BatchNorm3d的相关代码,不清楚是否会有问题。我这边没有用到BatchNorm3d,因而没有测试。

另一方面,上面的3步中,我是根据BatchNorm2d去修改,没有测试如果不修改某一步(如第3步),程序是否会有问题。反正都改了,模型没有问题。。。

(原)torch模型转pytorch模型的更多相关文章

  1. 生产与学术之Pytorch模型导出为安卓Apk尝试记录

    生产与学术 写于 2019-01-08 的旧文, 当时是针对一个比赛的探索. 觉得可能对其他人有用, 就放出来分享一下 生产与学术, 真实的对立... 这是我这两天对pytorch深度学习->a ...

  2. 将Pytorch模型从CPU转换成GPU

    1. 如何进行迁移 对模型和相应的数据进行.cuda()处理.通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去.从而可以通过GPU来进行运算了. 1.1 判定使用GPU 下载了对应的GPU ...

  3. 使用C++调用pytorch模型(Linux)

    前言 模型转换思路通常为: Pytorch -> ONNX -> TensorRT Pytorch -> ONNX -> TVM Pytorch -> 转换工具 -> ...

  4. 使用C++调用并部署pytorch模型

    1.背景(Background) 上图显示了目前深度学习模型在生产环境中的方法,本文仅探讨如何部署pytorch模型! 至于为什么要用C++调用pytorch模型,其目的在于:使用C++及多线程可以加 ...

  5. DEX-6-caffe模型转成pytorch模型办法

    在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...

  6. PyTorch模型加载与保存的最佳实践

    一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...

  7. 从零搭建Pytorch模型教程(三)搭建Transformer网络

    ​ 前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍. ...

  8. Pytorch模型量化

    在深度学习中,量化指的是使用更少的bit来存储原本以浮点数存储的tensor,以及使用更少的bit来完成原本以浮点数完成的计算.这么做的好处主要有如下几点: 更少的模型体积,接近4倍的减少: 可以更快 ...

  9. 计算机网络原理和OSI模型与TCP模型

    计算机网络原理和OSI模型与TCP模型 一.计算机网络的概述 1.计算机网络的定义 计算机网络是一组自治计算机的互连的集合 2.计算机网络的基本功能 a.资源共享 b.分布式处理与负载均衡 c.综合信 ...

随机推荐

  1. 2018.08.28 ali 梯度下降法实现最小二乘

    - 要理解梯度下降和牛顿迭代法的区别 #include<stdio.h> // 1. 线性多维函数原型是 y = f(x1,x2,x3) = a * x1 + b * x2 + c * x ...

  2. Pytorch 0.3加载0.4模型及其之间版本的变化

    1. 0.4中使用设备:.to(device) 2. 0.4中删除了Variable,直接tensor就可以 3. with torch.no_grad():的使用代替volatile:弃用volat ...

  3. 修改IP地址的PowerShell

    $wmi = Get-WmiObject win32_networkadapterconfiguration -filter "ipenabled = 'true'" $wmi.E ...

  4. Construct Binary Tree from Preorder and Inorder Traversal leetcode java

    题目: Given preorder and inorder traversal of a tree, construct the binary tree. Note: You may assume ...

  5. 【Spark】Spark Streaming 动态更新filter关注的内容

    Spark Streaming 动态更新filter关注的内容 spark streaming new thread on driver_百度搜索 (1 封私信)Spark Streaming 动态更 ...

  6. 【ElasticSearch】ES5新特性-keyword-text类型-查询区别

    ES5新特性-keyword-text类型-查询区别 elasticsearch-head Elasticsearch-sql client junneyang (JunneYang) es keyw ...

  7. (文档)流媒体资源 Streaming Assets

    Most assets in Unity are combined into the project when it is built. However, it is sometimes useful ...

  8. gson ajax 数字精度丢失

    ajax传输的json,gson会发生丢失,long > 15的时候会丢失0 解决方案:直接把属性为long的属性自动加上双引号成为js的字符串,这样就不会发生丢失了,ajax自动识别为字符串. ...

  9. GOOD BLOG URL

    1TEST http://www.cnblogs.com/Javame/p/3653509.html 综合 http://shiyanjun.cn/

  10. Persistent Netcat Backdoor

    In this example, instead of looking up information on the remote system, we will be installing a net ...