摘要:本文旨在分享Pytorch->Caffe->om模型转换流程。

标准网络

BaselinePytorchToCaffe

主要功能代码在:

PytorchToCaffe
+-- Caffe
| +-- caffe.proto
| +-- layer_param.py
+-- example
| +-- resnet_pytorch_2_caffe.py
+-- pytorch_to_caffe.py

直接使用可以参考resnet_pytorch_2_caffe.py,如果网络中的操作Baseline中都已经实现,则可以直接转换到Caffe模型。

添加自定义操作

如果遇到没有实现的操作,则要分为两种情况来考虑。

Caffe中有对应操作

以arg_max为例分享一下添加操作的方式。

首先要查看Caffe中对应层的参数:caffe.proto为对应版本caffe层与参数的定义,可以看到ArgMax定义了out_max_val、top_k、axis三个参数:

message ArgMaxParameter {
// If true produce pairs (argmax, maxval)
optional bool out_max_val = 1 [default = false];
optional uint32 top_k = 2 [default = 1];
// The axis along which to maximise -- may be negative to index from the
// end (e.g., -1 for the last axis).
// By default ArgMaxLayer maximizes over the flattened trailing dimensions
// for each index of the first / num dimension.
optional int32 axis = 3;
}

Caffe算子边界中的参数是一致的。

layer_param.py构建了具体转换时参数类的实例,实现了操作参数从Pytorch到Caffe的传递:

def argmax_param(self, out_max_val=None, top_k=None, dim=1):
argmax_param = pb.ArgMaxParameter()
if out_max_val is not None:
argmax_param.out_max_val = out_max_val
if top_k is not None:
argmax_param.top_k = top_k
if dim is not None:
argmax_param.axis = dim
self.param.argmax_param.CopyFrom(argmax_param)

pytorch_to_caffe.py中定义了Rp类,用来实现Pytorch操作到Caffe操作的变换:

class Rp(object):
def __init__(self, raw, replace, **kwargs):
self.obj = replace
self.raw = raw

def __call__(self, *args, **kwargs):
if not NET_INITTED:
return self.raw(*args, **kwargs)
for stack in traceback.walk_stack(None):
if 'self' in stack[0].f_locals:
layer = stack[0].f_locals['self']
if layer in layer_names:
log.pytorch_layer_name = layer_names[layer]
print('984', layer_names[layer])
break
out = self.obj(self.raw, *args, **kwargs)
return out

在添加操作时,要使用Rp类替换操作:

torch.argmax = Rp(torch.argmax, torch_argmax)

接下来,要具体实现该操作:

def torch_argmax(raw, input, dim=1):
x = raw(input, dim=dim)
layer_name = log.add_layer(name='argmax')
top_blobs = log.add_blobs([x], name='argmax_blob'.format(type))
layer = caffe_net.Layer_param(name=layer_name, type='ArgMax',
bottom=[log.blobs(input)], top=top_blobs)
layer.argmax_param(dim=dim)
log.cnet.add_layer(layer)
return x

即实现了argmax操作Pytorch到Caffe的转换。

Caffe中无直接对应操作

如果要转换的操作在Caffe中无直接对应的层实现,解决思路主要有两个:

1)在Pytorch中将不支持的操作分解为支持的操作:

如nn.InstanceNorm2d,实例归一化在转换时是用BatchNorm做的,不支持 affine=True 或者track_running_stats=True,默认use_global_stats:false,但om转换时use_global_stats必须为true,所以可以转到Caffe,但再转om不友好。

InstanceNorm是在featuremap的每个Channel上进行归一化操作,因此,可以实现nn.InstanceNorm2d为:

class InstanceNormalization(nn.Module):
def __init__(self, dim, eps=1e-5):
super(InstanceNormalization, self).__init__()
self.gamma = nn.Parameter(torch.FloatTensor(dim))
self.beta = nn.Parameter(torch.FloatTensor(dim))
self.eps = eps
self._reset_parameters()

def _reset_parameters(self):
self.gamma.data.uniform_()
self.beta.data.zero_()

def __call__(self, x):
n = x.size(2) * x.size(3)
t = x.view(x.size(0), x.size(1), n)
mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
gamma_broadcast = self.gamma.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x)
beta_broadcast = self.beta.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x)
out = (x - mean) / torch.sqrt(var + self.eps)
out = out * gamma_broadcast + beta_broadcast
return out

但在验证HiLens Caffe算子边界中发现,om模型转换不支持Channle维度之外的求和或求均值操作,为了规避这个操作,我们可以通过支持的算子重新实现nn.InstanceNorm2d:

class InstanceNormalization(nn.Module):
def __init__(self, dim, eps=1e-5):
super(InstanceNormalization, self).__init__()
self.gamma = torch.FloatTensor(dim)
self.beta = torch.FloatTensor(dim)
self.eps = eps
self.adavg = nn.AdaptiveAvgPool2d(1)

def forward(self, x):
n, c, h, w = x.shape
mean = nn.Upsample(scale_factor=h)(self.adavg(x))
var = nn.Upsample(scale_factor=h)(self.adavg((x - mean).pow(2)))
gamma_broadcast = self.gamma.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x)
beta_broadcast = self.beta.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x)
out = (x - mean) / torch.sqrt(var + self.eps)
out = out * gamma_broadcast + beta_broadcast
return out

经过验证,与原操作等价,可以转为Caffe模型

2)在Caffe中通过利用现有操作实现:

在Pytorch转Caffe的过程中发现,如果存在featuremap + 6这种涉及到常数的操作,转换过程中会出现找不到blob的问题。我们首先查看pytorch_to_caffe.py中add操作的具体转换方法:

def _add(input, *args):
x = raw__add__(input, *args)
if not NET_INITTED:
return x
layer_name = log.add_layer(name='add')
top_blobs = log.add_blobs([x], name='add_blob')
if log.blobs(args[0]) == None:
log.add_blobs([args[0]], name='extra_blob')
else:
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
bottom=[log.blobs(input),log.blobs(args[0])], top=top_blobs)
layer.param.eltwise_param.operation = 1 # sum is 1
log.cnet.add_layer(layer)
return x

可以看到对于blob不存在的情况进行了判断,我们只需要在log.blobs(args[0]) == None条件下进行修改,一个自然的想法是利用Scale层实现add操作:

def _add(input, *args):
x = raw__add__(input, *args)
if not NET_INITTED:
return x
layer_name = log.add_layer(name='add')
top_blobs = log.add_blobs([x], name='add_blob')
if log.blobs(args[0]) == None:
layer = caffe_net.Layer_param(name=layer_name, type='Scale',
bottom=[log.blobs(input)], top=top_blobs)
layer.param.scale_param.bias_term = True
weight = torch.ones((input.shape[1]))
bias = torch.tensor(args[0]).squeeze().expand_as(weight)
layer.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())
log.cnet.add_layer(layer)
else:
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
layer.param.eltwise_param.operation = 1 # sum is 1
log.cnet.add_layer(layer)
return x

类似的,featuremap * 6这种简单乘法也可以通过同样的方法实现。

踩过的坑

  • Pooling:Pytorch默认 ceil_mode=false,Caffe默认 ceil_mode=true,可能会导致维度变化,如果出现尺寸不匹配的问题可以检查一下Pooling参数是否正确。另外,虽然文档上没有看到,但是 kernel_size > 32 后模型虽然可以转换,但推理会报错,这时可以分两层进行Pooling操作。
  • Upsample :om边界算子中的Upsample 层scale_factor参数必须是int,不能是size。如果已有模型参数为size也会正常跑完Pytorch转Caffe的流程,但此时Upsample参数是空的。参数为size的情况可以考虑转为scale_factor或用Deconvolution来实现。
  • Transpose2d:Pytorch中 output_padding 参数会加在输出的大小上,但Caffe不会,输出特征图相对会变小,此时反卷积之后的featuremap会变大一点,可以通过Crop层进行裁剪,使其大小与Pytorch对应层一致。另外,om中反卷积推理速度较慢,最好是不要使用,可以用Upsample+Convolution替代。
  • Pad:Pytorch中Pad操作很多样,但Caffe中只能进行H与W维度上的对称pad,如果Pytorch网络中有h = F.pad(x, (1, 2, 1, 2), "constant", 0)这种不对称的pad操作,解决思路为:
  1. 如果不对称pad的层不存在后续的维度不匹配的问题,可以先判断一下pad对结果的影响,一些任务受pad的影响很小,那么就不需要修改。
  2. 如果存在维度不匹配的问题,可以考虑按照较大的参数充分pad之后进行Crop,或是将前后两个(0, 0, 1, 1)与(1, 1, 0, 0)的pad合为一个(1, 1, 1, 1),这要看具体的网络结构确定。
  3. 如果是Channel维度上的pad如F.pad(x, (0, 0, 0, 0, 0, channel_pad), "constant", 0),可以考虑零卷积后cat到featuremap上:
zero = nn.Conv2d(in_channels, self.channel_pad, kernel_size=3, padding=1, bias=False)
nn.init.constant(self.zero.weight, 0)
pad_tensor = zero(x)
x = torch.cat([x, pad_tensor], dim=1)
  • 一些操作可以转到Caffe,但om并不支持标准Caffe的所有操作,如果要再转到om要对照文档确认好边界算子。

本文分享自华为云社区《Pytorch->Caffe模型转换》,原文作者:杜甫盖房子 。

点击关注,第一时间了解华为云新鲜技术~

一文带你熟悉Pytorch->Caffe->om模型转换流程的更多相关文章

  1. 一文带你熟悉JAVA IO这个看似很高冷的菇凉

    Java IO 是一个庞大的知识体系,很多人学着学着就会学懵了,包括我在内也是如此,所以本文将会从 Java 的 BIO 开始,一步一步深入学习,引出 JDK1.4 之后出现的 NIO 技术,对比 N ...

  2. 一文带你熟悉SpringIOC

    Spring的IOC: IOC是Spring的一个核心组件,理解IOC是迈向Spring大门的重要一步 现实生活中,我们写字用的笔会有多种颜色,为了做不同的标记,需要用不同颜色的笔.如果只是使用一两种 ...

  3. 数据可视化之powerBI基础(七)一文带你熟悉PowerBI建模视图中的功能

    https://zhuanlan.zhihu.com/p/67316729 PowerBI 3月的更新,正式发布了建模视图,而之前只是预览功能.新的建模视图到底有什么用,下面带你认识一下它的主要功能. ...

  4. Istio是啥?一文带你彻底了解!

    原标题:Istio是啥?一文带你彻底了解! " 如果你比较关注新兴技术的话,那么很可能在不同的地方听说过 Istio,并且知道它和 Service Mesh 有着牵扯. 这篇文章可以作为了解 ...

  5. 【转帖】Istio是啥?一文带你彻底了解!

    Istio是啥?一文带你彻底了解! http://www.sohu.com/a/270131876_463994 原始位置来源: https://cizixs.com 如果你比较关注新兴技术的话,那么 ...

  6. 【项目实践】一文带你搞定Spring Security + JWT

    以项目驱动学习,以实践检验真知 前言 关于认证和授权,R之前已经写了两篇文章: [项目实践]在用安全框架前,我想先让你手撸一个登陆认证 [项目实践]一文带你搞定页面权限.按钮权限以及数据权限 在这两篇 ...

  7. 一文带你读懂什么是vxlan网络

    一个执着于技术的公众号 一.背景 随着云计算.虚拟化相关技术的发展,传统网络无法满足大规模.灵活性要求高的云数据中心的要求,于是便有了overlay网络的概念.overlay网络中被广泛应用的就是vx ...

  8. 一文带你读懂zookeeper在大数据生态的应用

    一个执着于技术的公众号 一.简述 在一群动物掌管的世界中,动物没有人类聪明的思想,为了保持动物世界的生态平衡,这时,动物管理员-zookeeper诞生了. 打开Apache zookeeper的官网, ...

  9. 一文带您了解5G的价值与应用

    一文带您了解5G的价值与应用 5G最有趣的一点是:大多数产品都是先有明确应用场景而后千呼万唤始出来.而5G则不同,即将到来的5G不仅再一次印证了科学技术是第一生产力还给不少用户带来了迷茫——我们为什么 ...

  10. 一文带你了解elasticsearch

    一文带你了解elasticsearch cxf2102100人评论160人阅读2019-07-02 21:31:36   elasticsearch es基本概念 es术语介绍 文档Document ...

随机推荐

  1. JDK21的虚拟线程是什么?和平台线程什么关系?

    虚拟线程(Virtual Thread)是 JDK 而不是 OS 实现的轻量级线程(Lightweight Process,LWP),由 JVM 调度.许多虚拟线程共享同一个操作系统线程,虚拟线程的数 ...

  2. ELK+ filebeat

    ELK 企业级日志分析系统 ELK 概述 1.ELK 简介 ELK平台是一套完整的日志集中处理解决方案,将 ElasticSearch.Logstash 和 Kiabana 三个开源工具配合使用, 完 ...

  3. 【Unity3D】UI Toolkit容器

    1 前言 ​ UI Toolkit简介 中介绍了 UI Builder.样式属性.UQuery.Debugger,UI Toolkit元素 中介绍了 Label.Button.TextField.To ...

  4. 一款简单漂亮的WPF UI - AduSkin

    前言 经常会有同学会问,有没有好看简单的WPF UI库推荐的.今天就给大家推荐一款简单漂亮的WPF UI,融合多个开源框架组件:AduSkin. WPF是什么? WPF 是一个强大的桌面应用程序框架, ...

  5. Chromium Command Buffer原理解析

    Command Buffer 是支撑 Chromium 多进程硬件加速渲染的核心技术之一.它基于 OpenGLES2.0 定义了一套序列化协议,这套协议规定了所有 OpenGLES2.0 命令的序列化 ...

  6. QT线程问题

    QT线程问题 (一)QThread (二)QMutex和QMutexLocker (end)后面会更新 (一)QThread 文章 (二)QMutex和QMutexLocker 通俗理解 QMutex ...

  7. QLabel自己总结(常用接口)

    继承关系:QLabel->QFrame->QWidget void setText(const QString &) [slots] 设置标签显示的内容,也可以使用构造函数设置. ...

  8. JVM Stack and Frame

    Overview Sharing a single thread within the district: PC Register/JVM Stack/Native Method Stack.All ...

  9. Jail 【Python沙箱逃逸问题合集】

    借助NSS平台题目,以2022年HNCTF为例展开分析 背景: 由于目前很多赛事有时候会出现一些pyjail的题目,因此在这里总结一下以便以后遇见可以轻松应对. 注:由于Python3中的unicod ...

  10. 微服务系列-使用WebFlux的WebClient进行Spring Boot 微服务通信示例

    公众号「架构成长指南」,专注于生产实践.云原生.分布式系统.大数据技术分享. 概述 在之前的教程中,我们看到了使用 RestTemplate 的 Spring Boot 微服务通信示例. 从 5.0 ...