mxnet下如何查看中间结果
https://blog.csdn.net/disen10/article/details/79376631
固定权重:https://www.cnblogs.com/chenyliang/p/6780019.html
固定权重:https://discuss.gluon.ai/t/topic/1164
查看权重
在训练过程中,有时候我们为了debug而需要查看中间某一步的权重信息,在mxnet中,我们可以很方便的调用get_params()方法来得到权重信息。
- '''
- 查看权重示例代码
- 转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
- '''
- import mxnet as mx
- sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
- mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #创建Module
- mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
- mod.set_params(arg_params,aux_params)
- import numpy as np
- import cv2
- def get_image(filename):
- img = cv2.imread(filename)
- img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
- img = cv2.resize(img,(224,224))
- img = np.swapaxes(img,0,2)
- img = np.swapaxes(img,1,2)
- img = img[np.newaxis,:]
- return img
- from collections import namedtuple
- Batch = namedtuple('Batch',['data'])
- img = get_image('val_1000/0.jpg') #获取图片
- mod.forward(Batch([mx.nd.array(img)])) #预测结果
- ################################################
- #debug模式下,获取权重信息
- keys = mod.get_params()[0].keys() # 列出所有权重名称
- conv_w = mod.get_params()[0]['conv0_weight'] #获取想要查看的权重信息,如conv_weight
- print conv_w.asnumpy() #查看具体数值
- ################################################
- prob = mod.get_outputs()[0].asnumpy()
- y = np.argsort(np.squeeze(prob))[::-1]
- print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
查看中间输出结果
由于mxnet的网络由symbol组成,而symbol又属于符号式编程,所以我们不能像上面查看权重一样直接查看,我们需要把我们想看的输出结果保存下来。
- '''
- 方法一
- 查看中间结果代码
- 转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
- '''
- import mxnet as mx
- net = mx.symbol.Variable('data')
- fc1 = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=128)
- net = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
- net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=64)
- out = mx.symbol.SoftmaxOutput(data=net, name='softmax')
- # 通过把两个输出组成一个group来得到自己需要查看的中间层输出结果
- group = mx.symbol.Group([fc1, out])
- print group.list_outputs()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- '''
- 方法二
- 有时候我们使用别人的模型,所以无法像方法一一样在定义模型的时候就确定需要查看的中间层输出结果,
- 这时候我们使用get_internals()方法来查找自己需要查看的中间层
- 转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
- '''
- import mxnet as mx
- sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
- ########################################################################
- args = sym.get_internals().list_outputs() #获得所有中间输出
- internals = model.symbol.get_internals()
- fc1 = internals['fc1_output']
- conv = internals['stage4_unit3_conv1_output']
- group = mx.symbol.Group([fc1, sym, conv]) #把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
- #########################################################################
- mod = mx.mod.Module(symbol=group,context=mx.gpu()) #创建Module
- mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
- mod.set_params(arg_params,aux_params)
- import numpy as np
- import cv2
- def get_image(filename):
- img = cv2.imread(filename)
- img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
- img = cv2.resize(img,(224,224))
- img = np.swapaxes(img,0,2)
- img = np.swapaxes(img,1,2)
- img = img[np.newaxis,:]
- return img
- from collections import namedtuple
- Batch = namedtuple('Batch',['data'])
- img = get_image('val_1000/0.jpg') #获取图片
- mod.forward(Batch([mx.nd.array(img)])) #预测结果
- prob = mod.get_outputs()[0].asnumpy()
- y = np.argsort(np.squeeze(prob))[::-1]
- print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
mxnet下如何查看中间结果的更多相关文章
- Linux下如何查看版本信息
Linux下如何查看版本信息, 包括位数.版本信息以及CPU内核信息.CPU具体型号等等,整个CPU信息一目了然. 1.# uname -a (Linux查看版本当前操作系统内核信息) L ...
- Linux下怎么查看当前系统的版本
Linux下怎么查看当前系统的版本: uname -r 功能说明:uname用来获取电脑和操作系统的相关信息. 语 法:uname [-amnrsvpio][--help][--version] ...
- 在windows和linux下如何查看80端口占用情况?是被哪个进程占用?如何终止等
一.在windows下如何查看80端口占用情况?是被哪个进程占用?如何终止等 这里主要是用到windows下的DOS工具,点击"开始"--"运行",输入&quo ...
- 在linux下,查看一个运行中的程序, 占用了多少内存
1. 在linux下,查看一个运行中的程序, 占用了多少内存, 一般的命令有 (1). ps aux: 其中 VSZ(或VSS)列 表示,程序占用了多少虚拟内存. RSS列 表示, 程序占用了多少物 ...
- linux下如何查看mysql、apache是否安装,并卸载
--linux下如何查看mysql.apache是否安装,并卸载? http://blog.163.com/dengxiuhua126@126/blog/static/1186077720137311 ...
- Linux 下实时查看日志
Linux 下实时查看日志 cat /var/log/*.log 如果日志在更新,如何实时查看 tail -f /var/log/messages 还可以使用 watch -d -n 1 cat /v ...
- Linux下如何查看tomcat是否安装、启动、文件路径、进程ID
Linux下如何查看tomcat是否安装.启动.文件路径.进程ID 在Linux系统下,Tomcat使用命令的操作! 检测是否有安装了Tomcat: rpm -qa|grep tomcat 查看Tom ...
- Linux下内存查看命令
在Linux下面,我们常用top命令来查看系统进程,top也能显示系统内存.我们常用的Linux下查看内容的专用工具是free命令. Linux下内存查看命令free详解: 在Linux下查看内存我们 ...
- Linux之Ubuntu下如何查看已安装的软件/库文件【摘抄】
本文属于实用性质,且属于摘抄别处,出自:[Ubuntu 下如何查看已安装的软件](http://blog.csdn.net/m1205979825/article/details/40855583) ...
随机推荐
- Oracle下SQL学习笔记
主机字符串:as sysdba alter user scott account unlock;//解锁scott,不会就谷歌检索 DML语句,增.删.查.改 select语句:熟悉表结构 desc ...
- Jmeter知识点
聚合报告说明 https://www.cnblogs.com/duanxz/p/5464993.html JMeter之Ramp-up Period(in seconds)说明(可同时并发) http ...
- Wix制作安装包
Wix制作安装包,找起资料来很费劲,记录一下: Product.wxs,该文件只能制作出msi形式的安装包,不能做到自动检测framework. <?xml version="1.0& ...
- spring + mybatis配置及网络异常设置
Spring引入mybatis <beans xmlns="http://www.springframework.org/schema/beans" xmlns:contex ...
- spring的面向切面实现的两种方式
面向切面:主要应用在日志记录方面.实现业务与日志记录分离开发. spring面向切面有两种实现方式:1.注解 2.xml配置. 1.注解实现如下: (1)配置如下: <?xml version= ...
- testNG中dataprovider使用的两种方式
testNG的参数化测试有两种方式:xml和dataprovider.个人更喜欢dataprovider,因为我喜欢把测试数据放在数据库里. 一.返回类型是Iterator<Object[]&g ...
- c# 利用MailKit.IMap 收取163邮件
最近我要做一个爬虫.这个爬虫需要如下几个步骤: 1 填写注册内容(需要邮箱注册) 2 过拖拽验证码(geetest) 3 注册成功会给邮箱发一封确认邮箱 4 点击确认邮箱中的链接 完成注册 我这里就采 ...
- 编译用到boost相关的东西,问题的解决;以及和googletest库
编译https://github.com/RAttab/reflect, 发现需要gcc4.7以上的版本才行.于是编译安装最新的gcc-6.2.0, 过程算顺利. http://www.linuxfr ...
- for in //for of //forEach //map三种对比
遍历Array可以采用下标循环,遍历Map和Set就无法使用下标.为了统一集合类型,ES6标准引入了新的iterable类型,Array.Map和Set都属于iterable类型. 具有iterabl ...
- LINQ以及LINQ to Object 和LINQ to Entities
LINQ的全称是Language Integrated Query,中文译成“语言集成查询”,是一种查询技术. LINQ查询通过提供一种跨各种数据源和数据格式使用数据的一致模型,简化了查询过程.LIN ...