#mxnet# 权值共享
https://www.cnblogs.com/chenyliang/p/6847744.html
Note:后记
此权值共享非彼卷积共享。说的是layer实体间的参数共享。
Introduction
想将两幅图像”同时“经过同一模型,似乎之前有些听闻的shared model没有找到确凿的痕迹,单个构建Variable然后每层设置,对debug阶段(甚至使用阶段)来说是场噩梦。能够可行的只想到了,在set_params阶段进行指定,如果简单的将两个load的symbol进行Group,然后进行bind会提示出现多个名称。于是问题就是:如何生成同一结构内含指定符号名的symbol?
Exploration
此类非标准操作,更别指望mxnet的doc了,只有从dir()和src查起。
Change the name
首先想到的自然是改名:
本来是
a=mx.sym.Variable('x')
要改成与
a=mx.sym.Variable('y')
相同的效果。
关于名称的接口:
import mxnet as mx
d=mx.sym.Variable('data')
conv1_w=mx.sym.Variable('kw')
conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
conv1.name
#'conv1'
How to change it
怎么改呢?看起来只有*_set_attr*靠谱些,先看看都有那些属性:
conv1.list_attr()
#{'no_bias': 'True', 'kernel': '(3, 3)', 'num_filter': '1'}
。。。并没有什么好结果出现,看起来还有一个接口:
conv1.attr_dict()
#{'conv1': {'no_bias': 'True', 'kernel': '(3, 3)', 'num_filter': '1'}}
那就试试,'conv1'?
>>>conv1._set_attr(conv1='yy')
>>>conv1.name # 有戏?!赶紧看看
'conv1' # 那刚才的是什么?
>>> conv1.list_attr()
{'no_bias': 'True', 'kernel': '(3, 3)', 'conv1': 'yy', 'num_filter': '1'} # 呵呵,被骗了...
Check the Src
来看看名字是到哪取的(~当然是家里取的...)
# python/mxnet/symbol.py
@property
def name(self):
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.MXSymbolGetName(
self.handle, ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None
于是追寻MXSymbolGetName,虽然直觉告诉我很有可能不会有python接口了(很有可能是通过底层实现的名字获取),但还是得看看。
//src/c_api/c_api_symbolic.cc
int MXSymbolGetName(SymbolHandle symbol,
const char** out,
int* success) {
return NNSymbolGetAttr(symbol, "name", out, success);
}
这不禁让人浮想起来。。。赶紧试试:
>>> conv1._set_attr(name='yy')
>>> conv1.name
'yy'
被我发现了吧 :)
失败
失败的原因是,上述的操作只改变了node,但参数的名称并没有改变(可以.list_arguments()进行查看)。我当时想的是将参数名称保持相同,然后在set_params的时候就可以直接调用,然而实际调用时,会报错,提示检测出了多个相同的名称,所以此路基本封死。
从json入手
这是一个当时认为最惨的办法——每次都要先对文件进行操作(非常粗野)。但今早发现symbol中还有操作json的接口(当然说的不是save,laod之类的):
sn_epoch_load=0
model_prefix='nin'
sym1, arg_params, aux_params = mx.mod.module.load_checkpoint(model_prefix, n_epoch_load)
sym=sym1.get_internals()['conv4_1024_output'].__copy__()
ss=sym.__getstate__()['handle']
ss1=ss.replace('\"name\": \"','\"name\": \"sha-')
sym2 = sym.__copy__()
h={'handle':ss1}
sym2.__setstate__(h)
>>> sym2.list_arguments()
['sha-data', 'sha-conv1_weight', 'sha-conv1_bias', 'sha-cccp1_weight', 'sha-cccp1_bias', 'sha-cccp2_weight', 'sha-cccp2_bias', 'sha-conv2_weight', 'sha-conv2_bias', 'sha-cccp3_weight', 'sha-cccp3_bias', 'sha-cccp4_weight', 'sha-cccp4_bias', 'sha-conv3_weight', 'sha-conv3_bias', 'sha-cccp5_weight', 'sha-cccp5_bias', 'sha-cccp6_weight', 'sha-cccp6_bias', 'sha-conv4_1024_weight', 'sha-conv4_1024_bias']
>>> sym2.attr_dict()
{'sha-cccp3': {'no_bias': 'False', 'kernel': '(1,1)', 'num_group': '1', 'dilate': '(1,1)', 'num_filter': '256', 'stride': '(1,1)', 'cudnn_off': 'False', 'pad': '(0,0)', 'workspace': '1024', 'cudnn_tune': 'off'}, 'sha-cccp2': {'no_bias': 'False', 'kernel': '(1,1)', 'num_group': '1', 'dilate': '(1,1)', 'num_filter': '96', 'stride': '(1,1)', 'cudnn_off': 'False', 'pad': '(0,0)', 'workspace': '1024', 'cudnn_tune': 'off'}, 'sha-drop': {'p': '0.5'}, 'sha-conv2': {'no_bias': 'False', 'kernel': '(5,5)', 'num_group': '1', 'dilate': '(1,1)', 'num_filter': '256', 'stri
# 示意一下就可
这样看上去问题被解决了。
Solution
于是我们的答案就是:
import mxnet as mx
M,N=3,3
num_filter=1
kernel=mx.nd.array([ [1,2,3],[1,2,3],[1,2,3] ])
d=mx.sym.Variable('data')
conv1=mx.sym.Convolution(data=d,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
loss=mx.sym.MakeLoss(data=conv1)
bch_kernel=kernel.reshape((1,1,M,N))
arg_params={'conv1_weight': bch_kernel}
def shareParams(sym,params):
sym1 = sym.__copy__()
new_params= {}
ss=sym1.__getstate__()['handle']
ss1=ss.replace('\"name\": \"','\"name\": \"sha-')
h={'handle':ss1}
sym1.__setstate__(h)
for i in params:
new_params['sha-'+i] = params[i]
new_params[i] = params[i]
return mx.sym.Group([sym,sym1]),new_params
sym,params = shareParams(loss,arg_params)
mod=mx.mod.Module(symbol=sym,data_names=('data','sha-data',))
mod.bind(data_shapes=[ ('data',[1,1,M,N]), ('sha-data',[1,1,M,N]),])
mod.init_params()
mod.set_params(arg_params=params, aux_params=[],allow_missing=True)
mod.init_optimizer()
mod.forward(mx.io.DataBatch([bch_kernel,bch_kernel],[]))
mod.get_outputs()[0].asnumpy()
#array([[[[ 42.]]]], dtype=float32)
mod.get_outputs()[1].asnumpy()
#array([[[[ 42.]]]], dtype=float32)
mod.backward()
mod.update()
mod.forward(mx.io.DataBatch([bch_kernel,bch_kernel],[]))
mod.get_outputs()[0].asnumpy()
#array([[[[ 41.57999802]]]], dtype=float32)
mod.get_outputs()[1].asnumpy()
#array([[[[ 41.57999802]]]], dtype=float32)
搞定 :)
22 Jul, 2017 记
关于这个问题,我后面还曾设想找段空闲时期,试着用mxnet内部机制进行封装。最近发现,自己也是傻得可以。。。
两张图先进行batch维的拼接,通过所需段后再拆分 (⊙﹏⊙)b
#mxnet# 权值共享的更多相关文章
- CNN中的局部连接(Sparse Connectivity)和权值共享
局部连接与权值共享 下图是一个很经典的图示,左边是全连接,右边是局部连接. 对于一个1000 × 1000的输入图像而言,如果下一个隐藏层的神经元数目为10^6个,采用全连接则有1000 × 1000 ...
- tensorflow-参数、超参数、卷积核权值共享
根据网上查询到的说法,参数就是在卷积神经网络中可以被训练的参数,比如卷积核的权值和偏移等等,而超参数是一些预先设定好并且无法改变的,比如说是卷积核的个数等. 另外还有一个最最基础的概念,就是卷积核的权 ...
- CARS: 华为提出基于进化算法和权值共享的神经网络结构搜索,CIFAR-10上仅需单卡半天 | CVPR 2020
为了优化进化算法在神经网络结构搜索时候选网络训练过长的问题,参考ENAS和NSGA-III,论文提出连续进化结构搜索方法(continuous evolution architecture searc ...
- 神经网络权值初始化方法-Xavier
https://blog.csdn.net/u011534057/article/details/51673458 https://blog.csdn.net/qq_34784753/article/ ...
- 51nod1459(带权值的dijkstra)
题目链接:https://www.51nod.com/onlineJudge/questionCode.html#!problemId=1459 题意:中文题诶- 思路:带权值的最短路,这道题数据也没 ...
- caffe中权值初始化方法
首先说明:在caffe/include/caffe中的 filer.hpp文件中有它的源文件,如果想看,可以看看哦,反正我是不想看,代码细节吧,现在不想知道太多,有个宏观的idea就可以啦,如果想看代 ...
- [NOIP2014]联合权值 题解
题目大意: 有一棵树,求距离为2的点权的乘积的和以及最大值. 思路: 枚举每一个点,则与其相邻的点互为距离为2的点.该部分的最大值为点权最大的两个点的积,和为点的权值和的平方减去每个点的平方,这样每条 ...
- Codevs 3728 联合权值
问题描述 无向连通图G有n个点,n-1条边.点从1到n依次编号,编号为i的点的权值为Wi ,每 条边的长度均为1.图上两点(u,v)的距离定义为u点到v点的最短距离.对于图G上的点 对(u,v),若它 ...
- css权值计算
外部样式表<内部样式表<内联样式: HTML 标签选择器的权值为 1: Class 类选择器的权值为 10: ID 选择器的权值为 100: 内联样式表的权值最高 1000: !impor ...
随机推荐
- python认知及六大标准数据类型
--- typora-root-url: assets --- ### -python的认知 ``` 89年开发的语言,创始人范罗苏姆(Guido van Rossum),别称:龟叔(Guido). ...
- [vue]基础篇stepbystep案例实践(废弃)
去看这个就好了 总结: 1.子组件可以触发父组件的方法,this.$emit() //(通知父组件干活) 2.父组件可以调用子组件的方法() // ref 如果放在组件上 获取的是组件的实例 并不是组 ...
- nginx反向代理 支持WebSocket
WebSocket(简称WS)协议的握手和HTTP是兼容的,通过HTTP/1.1中协议转换机制,客户端可以传递名为“Upgrade” 头部信息将连接从HTTP连接升级到WebSocket连接 那么反向 ...
- Shiro权限管理框架详解
1 权限管理1.1 什么是权限管理 基本上涉及到用户参与的系统都要进行权限管理,权限管理属于系统安全的范畴,权限管理实现对用户访问系统的控制,按照安全规则或者安全策略控制用户可以访问而且只能访问自己被 ...
- [LeetCode] 607. Sales Person_Easy tag: SQL
Description Given three tables: salesperson, company, orders.Output all the names in the table sales ...
- Hybrid设计--账号体系的建设
前后端分离:开发效率高,没有SEO 现在是重客户端设计:交互和业务逻辑是前端来写,适合做前后端分离.对前端更友好,提高了效率. 传统模式开发:整个业务逻辑是server端写,不适合做前后端分离.ser ...
- webpack使用五
一切皆模块 Webpack有一个不可不说的优点,它把所有的文件都都当做模块处理,JavaScript代码,CSS和fonts以及图片等等通过合适的loader都可以被处理. CSS webpack提供 ...
- SQL中的replace函数
REPLACE 用第三个表达式替换第一个字符串表达式中出现的所有第二个给定字符串表达式. 语法 REPLACE ( 'string_expression1' , 'string_expression2 ...
- python中安装并使用redis
数据缓存系统:1:mongodb:是直接持久化,直接存储于硬盘的缓存系统2:redis: 半持久化,存储于内存和硬盘3:memcache:数据只能存储在内存里的缓存系统 redis是一个key-val ...
- 深入解析Java反射(1) - 基础
深入解析Java反射(1) - 基础 最近正筹备Samsara框架的开发,而其中的IOC部分非常依靠反射,因此趁这个机会来总结一下关于Java反射的一些知识.本篇为基本篇,基于JDK 1.8. 一.回 ...