https://blog.csdn.net/disen10/article/details/79376631

固定权重:https://www.cnblogs.com/chenyliang/p/6780019.html

固定权重:https://discuss.gluon.ai/t/topic/1164

查看权重

在训练过程中,有时候我们为了debug而需要查看中间某一步的权重信息,在mxnet中,我们可以很方便的调用get_params()方法来得到权重信息。

  1.  
    '''
  2.  
    查看权重示例代码
  3.  
    转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
  4.  
    '''
  5.  
    import mxnet as mx
  6.  
    sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
  7.  
    mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #创建Module
  8.  
    mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
  9.  
    mod.set_params(arg_params,aux_params)
  10.  
    import numpy as np
  11.  
    import cv2
  12.  
    def get_image(filename):
  13.  
    img = cv2.imread(filename)
  14.  
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
  15.  
    img = cv2.resize(img,(224,224))
  16.  
    img = np.swapaxes(img,0,2)
  17.  
    img = np.swapaxes(img,1,2)
  18.  
    img = img[np.newaxis,:]
  19.  
    return img
  20.  
    from collections import namedtuple
  21.  
    Batch = namedtuple('Batch',['data'])
  22.  
    img = get_image('val_1000/0.jpg') #获取图片
  23.  
    mod.forward(Batch([mx.nd.array(img)])) #预测结果
  24.  
    ################################################
  25.  
    #debug模式下,获取权重信息
  26.  
    keys = mod.get_params()[0].keys() # 列出所有权重名称
  27.  
    conv_w = mod.get_params()[0]['conv0_weight'] #获取想要查看的权重信息,如conv_weight
  28.  
    print conv_w.asnumpy() #查看具体数值
  29.  
    ################################################
  30.  
    prob = mod.get_outputs()[0].asnumpy()
  31.  
    y = np.argsort(np.squeeze(prob))[::-1]
  32.  
    print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

查看中间输出结果

由于mxnet的网络由symbol组成,而symbol又属于符号式编程,所以我们不能像上面查看权重一样直接查看,我们需要把我们想看的输出结果保存下来。

  1.  
    '''
  2.  
    方法一
  3.  
    查看中间结果代码
  4.  
    转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
  5.  
    '''
  6.  
    import mxnet as mx
  7.  
    net = mx.symbol.Variable('data')
  8.  
    fc1 = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=128)
  9.  
    net = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
  10.  
    net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=64)
  11.  
    out = mx.symbol.SoftmaxOutput(data=net, name='softmax')
  12.  
    # 通过把两个输出组成一个group来得到自己需要查看的中间层输出结果
  13.  
    group = mx.symbol.Group([fc1, out])
  14.  
    print group.list_outputs()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  1.  
    '''
  2.  
    方法二
  3.  
    有时候我们使用别人的模型,所以无法像方法一一样在定义模型的时候就确定需要查看的中间层输出结果,
  4.  
    这时候我们使用get_internals()方法来查找自己需要查看的中间层
  5.  
    转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
  6.  
    '''
  7.  
    import mxnet as mx
  8.  
    sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
  9.  
    ########################################################################
  10.  
    args = sym.get_internals().list_outputs() #获得所有中间输出
  11.  
    internals = model.symbol.get_internals()
  12.  
    fc1 = internals['fc1_output']
  13.  
    conv = internals['stage4_unit3_conv1_output']
  14.  
    group = mx.symbol.Group([fc1, sym, conv]) #把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
  15.  
    #########################################################################
  16.  
    mod = mx.mod.Module(symbol=group,context=mx.gpu()) #创建Module
  17.  
    mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
  18.  
    mod.set_params(arg_params,aux_params)
  19.  
    import numpy as np
  20.  
    import cv2
  21.  
    def get_image(filename):
  22.  
    img = cv2.imread(filename)
  23.  
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
  24.  
    img = cv2.resize(img,(224,224))
  25.  
    img = np.swapaxes(img,0,2)
  26.  
    img = np.swapaxes(img,1,2)
  27.  
    img = img[np.newaxis,:]
  28.  
    return img
  29.  
    from collections import namedtuple
  30.  
    Batch = namedtuple('Batch',['data'])
  31.  
    img = get_image('val_1000/0.jpg') #获取图片
  32.  
    mod.forward(Batch([mx.nd.array(img)])) #预测结果
  33.  
    prob = mod.get_outputs()[0].asnumpy()
  34.  
    y = np.argsort(np.squeeze(prob))[::-1]
  35.  
    print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))

mxnet下如何查看中间结果的更多相关文章

  1. Linux下如何查看版本信息

    Linux下如何查看版本信息, 包括位数.版本信息以及CPU内核信息.CPU具体型号等等,整个CPU信息一目了然.   1.# uname -a   (Linux查看版本当前操作系统内核信息)   L ...

  2. Linux下怎么查看当前系统的版本

    Linux下怎么查看当前系统的版本:   uname -r 功能说明:uname用来获取电脑和操作系统的相关信息. 语 法:uname [-amnrsvpio][--help][--version] ...

  3. 在windows和linux下如何查看80端口占用情况?是被哪个进程占用?如何终止等

    一.在windows下如何查看80端口占用情况?是被哪个进程占用?如何终止等 这里主要是用到windows下的DOS工具,点击"开始"--"运行",输入&quo ...

  4. 在linux下,查看一个运行中的程序, 占用了多少内存

    1. 在linux下,查看一个运行中的程序, 占用了多少内存, 一般的命令有 (1). ps aux: 其中  VSZ(或VSS)列 表示,程序占用了多少虚拟内存. RSS列 表示, 程序占用了多少物 ...

  5. linux下如何查看mysql、apache是否安装,并卸载

    --linux下如何查看mysql.apache是否安装,并卸载? http://blog.163.com/dengxiuhua126@126/blog/static/1186077720137311 ...

  6. Linux 下实时查看日志

    Linux 下实时查看日志 cat /var/log/*.log 如果日志在更新,如何实时查看 tail -f /var/log/messages 还可以使用 watch -d -n 1 cat /v ...

  7. Linux下如何查看tomcat是否安装、启动、文件路径、进程ID

    Linux下如何查看tomcat是否安装.启动.文件路径.进程ID 在Linux系统下,Tomcat使用命令的操作! 检测是否有安装了Tomcat: rpm -qa|grep tomcat 查看Tom ...

  8. Linux下内存查看命令

    在Linux下面,我们常用top命令来查看系统进程,top也能显示系统内存.我们常用的Linux下查看内容的专用工具是free命令. Linux下内存查看命令free详解: 在Linux下查看内存我们 ...

  9. Linux之Ubuntu下如何查看已安装的软件/库文件【摘抄】

    本文属于实用性质,且属于摘抄别处,出自:[Ubuntu 下如何查看已安装的软件](http://blog.csdn.net/m1205979825/article/details/40855583) ...

随机推荐

  1. vue在页面嵌入别的页面或者是视频2

    vue在页面嵌入别的页面或者是视频 以下是嵌入页面 <iframe name="myiframe" id="myrame" src="http: ...

  2. 深入理解Lua的闭包一:概念、应用和实现原理

    本文首先通过具体的例子讲解了Lua中闭包的概念,然后总结了闭包的应用场合,最后探讨了Lua中闭包的实现原理.   闭包的概念 在Lua中,闭包(closure)是由一个函数和该函数会访问到的非局部变量 ...

  3. MYSQL: set names utf8是什么意思?

    set names utf8 是用于设置编码,可以再在建数据库的时候设置,也可以在创建表的时候设置,或只是对部分字段进行设置,而且在设置编码的时候,这些地方最好是一致的,这样能最大程度上避免数据记录出 ...

  4. MySQL 基础 DDL和DML

    DDL 数据库定义语句 创建数据库 create table if exits 数据库.表名( field1 数据类型 约束类型 commit 字段注释, field2 数据类型 约束类型 commi ...

  5. inner join, left join, right join 和 full join

    inner join:理解为“有效连接”,两张表中都有的数据才会显示left join:理解为“有左显示”,比如on a.field=b.field,则显示a表中存在的全部数据及a.b中都有的数据,a ...

  6. 对k8s service的一些理解

    服务service service是一个抽象概念,定义了一个服务的多个pod逻辑合集和访问pod的策略,一般把service称为微服务 举个例子一个a服务运行3个pod,b服务怎么访问a服务的pod, ...

  7. 虚拟IP技术

    虚拟IP技术在高可用领域像数据库SQLSERVER.web服务器等场景下使用很多,很疑惑它是怎么实现的,偶然,发现了一种方式可以实现虚拟ip.它的原理在于同一个物理网卡,是可以拥有多个ip地址的,至于 ...

  8. js中两个!!的理解

    在js中经常有两个!!出现,经常让人难以理解 (function () { var a = 10; var b = 20; function add(num1, num2) { var num1 = ...

  9. Web API 入门 二 媒体类型

    还是拿上面 那篇 Web API 入门 一  的那个来讲 在product类中加一个时间属性

  10. 记录一则FGA审计“A用户对B用户某张表的更新操作”需求

    环境:Oracle 11.2.0.4 我这里测试A用户为JINGYU,要审计的表为B用户SCOTT下的EMP表.通过FGA来实现. 1.添加审计策略 2.测试审计效果 3.控制审计策略 1.添加审计策 ...