torch 深度学习 (2)

torch
ConvNet

前面我们完成了数据的下载和预处理,接下来就该搭建网络模型了,CNN网络的东西可以参考博主 zouxy09的系列文章Deep Learning (深度学习) 学习笔记整理系列之 (七)

  1. 加载包

  1. require 'torch' 

  2. require 'image' 

  3. require 'nn' 

  1. 函数运行参数的设置

  1. if not opt then 

  2. print "==> processing options" 

  3. cmd = torch.CmdLine() 

  4. cmd:text() 

  5. cmd:text('options:') 

  6. -- 选择构建何种结构:线性|MLP|ConvNet。默认:convnet 

  7. cmd:option('-model','convnet','type of model to construct: linear | mlp | convnet') 

  8. -- 是否需要可视化 

  9. cmd:option('-visualize',true,'visualize input data and weights during training') 

  10. -- 参数 

  11. opt = cmd:parse(arg or {}) 

  12. end 

  1. 设置网络模型用到的一些参数

  1. -- 输出类别数,也就是输出节点个数 

  2. noutputs =10 

  3. -- 输入节点的个数 

  4. nfeats = 3 -- YUV三个通道,可以认为是3个features map 

  5. width =32 

  6. height =32 

  7. -- Linear 和 mlp model下的输入节点个数,就是将输入图像拉成列向量 

  8. ninputs = nfeats*width*height 


  9. -- 为mlp定义隐层节点的个数 

  10. nhiddens = ninputs/2  


  11. -- 为convnet定义隐层feature maps的个数以及滤波器的尺寸 

  12. nstates = {16,256,128} --第一个隐层有16个feature map,第二个隐层有256个特征图,第三个隐层有128个节点 

  13. fanin = {1,4} -- 定义了卷积层的输入和输出对应关系,以fanin[2]举例,表示该卷积层有16个map输入,256个map输出,每个输出map是有fanin[2]个输入map对应filters卷积得到的结果 

  14. filtsize =5 --滤波器的大小,方形滤波器 

  15. poolsize = 2 -- 池化池尺寸 

  16. normkernel = image.gaussian1D(7) --长度为7的一维高斯模板,用来local contrast normalization 

  1. 构建模型

  1. if opt.model == linear then  

  2. -- 线性模型 

  3. model = nn.Sequntial() 

  4. model:add(nn.Reshape(ninputs)) -- 输入层 

  5. model:add(nn.Linear(ninputs,noutputs)) -- 线性模型 y=Wx+b 

  6. elseif opt.model == mlp then  

  7. -- 多层感知器 

  8. model = nn.Sequential() 

  9. model:add(nn.Reshape(ninputs)) --输入层 

  10. model:add(nn.Linear(ninputs,nhiddens)) --线性层 

  11. model:add(nn.Tanh()) -- 非线性层 

  12. model:add(nn.Linear(nhiddens,noutputs)) -- 线性层 

  13. -- MLP 目标: `!$y=W_2 f(W_1X+b) + b $` 这里的激活函数采用的是Tanh(),MLP后面还可以接一层输出层Tanh() 

  14. elseif opt.model == convnet then 

  15. -- 卷积神经网络 

  16. model = nn.Sequential() 

  17. -- 第一阶段 

  18. model:add(nn.SpatialConvolutionMap(nn.tables.random(nfeats,nstates[1],fanin[1]),filtsize,filtsize)) 

  19. -- 这一步直接输入的是图像进行卷积,所以没有了 nn.Reshape(ninputs)输入层。 参数:nn.tables.random(nfeats,nstates[1],fanin[1])指定了卷积层中输入maps和输出maps之间的对应关系,这里表示bstates[1]个输出maps的每一map都是由fanin[1]个输入maps得到的。filtsize则是卷积算子的大小 

  20. -- 所以该层的连接个数为(filtsize*filtsize*fanin[1]+1)*nstates[1],1是偏置。这里的fanin[1]连接是随机的,也可以采用全连接 nn.tables.full(nfeats,nstates[1]), 当输入maps和输出maps个数相同时,还可以采用一对一连接 nn.tables.oneToOne(nfeats). 

  21. -- 参见解释文档 [Convolutional layers](https://github.com/torch/nn/blob/master/doc/concolution.md#nn.convlayers.dok) 


  22. model:add(nn.Tanh()) --非线性变换层 

  23. model:SpatialLPPooling(nstates[1],2,poolsize,poolsize,poolsize,poolsize) 

  24. -- 参数(feature maps个数,Lp范数,池化尺寸大小(w,h), 滑动窗步长(dw,dh)) 

  25. model:SpatialSubtractiveNormalization(nstates[1],normalkernel) 

  26. -- local contrast normalization 

  27. -- 具体操作是先在每个map的local邻域进行减法归一化,然后在不同的feature map上进行除法归一化。类似与图像点的均值化和方差归一化。参考[1^x][Nonlinear Image Representation Using Divisive Normalization], [Gaussian Scale Mixtures](stats.stackexchange.com/174502/what-are gaussian-scale-mixtures-and-how-to-generate-samples-of-gaussian-scale),还有解释文档 [Convolutional layers](https://github.com/torch/nn/blob/master/doc/concolution.md#nn.convlayers.dok) 


  28. --[[ 

  29. 这里需要说的一点是传统的CNN一般是先卷积再池化再非线性变换,但这两年CNN一般都是先非线性变换再池化了 

  30. --]] 

  31. -- 第二阶段 

  32. model:add(nn.SpatialConvolutionMap(nn.tables.random(nstates[1],nstates[2],fanin[2]),filtsize,filtsize)) 

  33. model:add(nn.Tanh()) 

  34. model:add(nn.SpatialLPPooling(nstates[2],2,poolsize,poolsize)) 

  35. model:add(nn.SpatialSubtractiveNormalization(nstates[2],kernel)) 


  36. --第三阶段 

  37. model:add(nn.Reshape(nstates[2]*filtsize*filtsize)) --矢量化,全连接 

  38. model:add(nn.Linear(nstates[2]*filtsize*filtsize,nstates[3])) 

  39. model:add(nn.Tanh()) 

  40. model:add(nn.Linear(nstates[3],noutputs)) 

  41. else 

  42. error('unknown -model') 

  43. end 

  1. 显示网络结构以及参数

  1. print('==> here is the model') 

  2. print(model) 

结果如下图

model.png

可以发现,可训练参数分别在1,5部分,所以可以观察权重矩阵的大小

  1. print('==> 权重矩阵的大小 ') 

  2. print(model:get(1).weight:size()) 

  3. print('==> 偏置的大小') 

  4. print(model:get(1).bias:numel()) 

weights numel.png
  1. 参数的可视化

  1. if opt.visualize then 

  2. image.display(image=model:get(1).weight, padding=2,zoom=4,legend='filters@ layer 1') 

  3. image.diaplay(image=model:get(5).weight,padding=2,zoom=4,legend='filters @ layer 2') 

  4. end 

weights visualization.png

torch 深度学习 (2)的更多相关文章

  1. torch 深度学习(5)

    torch 深度学习(5) mnist torch siamese deep-learning 这篇文章主要是想使用torch学习并理解如何构建siamese network. siamese net ...

  2. torch 深度学习(4)

    torch 深度学习(4) test doall files 经过数据的预处理.模型创建.损失函数定义以及模型的训练,现在可以使用训练好的模型对测试集进行测试了.测试模块比训练模块简单的多,只需调用模 ...

  3. torch 深度学习(3)

    torch 深度学习(3) 损失函数,模型训练 前面我们已经完成对数据的预处理和模型的构建,那么接下来为了训练模型应该定义模型的损失函数,然后使用BP算法对模型参数进行调整 损失函数 Criterio ...

  4. 深度学习菜鸟的信仰地︱Supervessel超能云服务器、深度学习环境全配置

    并非广告~实在是太良心了,所以费时间给他们点赞一下~ SuperVessel云平台是IBM中国研究院和中国系统与技术中心基于POWER架构和OpenStack技术共同构建的, 支持开发者远程开发的免费 ...

  5. 深度学习框架caffe/CNTK/Tensorflow/Theano/Torch的对比

    在单GPU下,所有这些工具集都调用cuDNN,因此只要外层的计算或者内存分配差异不大其性能表现都差不多. Caffe: 1)主流工业级深度学习工具,具有出色的卷积神经网络实现.在计算机视觉领域Caff ...

  6. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  7. [深度学习] Pytorch学习(一)—— torch tensor

    [深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...

  8. 【深度学习Deep Learning】资料大全

    最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books  by Yoshua Bengio, Ian Goodfellow and Aaron C ...

  9. [深度学习大讲堂]从NNVM看2016年深度学习框架发展趋势

    本文为微信公众号[深度学习大讲堂]特约稿,转载请注明出处 虚拟框架杀入 从发现问题到解决问题 半年前的这时候,暑假,我在SIAT MMLAB实习. 看着同事一会儿跑Torch,一会儿跑MXNet,一会 ...

随机推荐

  1. 转载SQL_trace 和10046使用

    SQL_TRACE是Oracle提供的用于进行SQL跟踪的手段,是强有力的辅助诊断工具.在日常的数据库问题诊断和解决中,SQL_TRACE是非常常用的方法.本文就SQL_TRACE的使用作简单探讨,并 ...

  2. selenium 模块

    介绍 selenium最初是一个自动化测试工具,而爬虫中使用它主要是为了解决requests无法直接执行JavaScript代码的问题 selenium本质是通过驱动浏览器,完全模拟浏览器的操作,比如 ...

  3. vertical-align和line-height的深入应用

    vertical-align和line-height的深入应用 本文的重点是了解vertical-align和line-height的使用 涉及到的名词:基线,底端,行内框,行框,行间距,替换元素及非 ...

  4. php微信支付接口开发程序(流程已通)

    php微信支付接口开发程序(流程已通) 来源:未知    时间:2014-12-11 17:11   阅读数:11843   作者:xxadmin [导读] 微信支付接口现在也慢慢的像支付宝一个可以利 ...

  5. session和token的区别

    session的使用方式是客户端cookie里存id,服务端session存用户数据,客户端访问服务端的时候,根据id找用户数据 而token一般翻译成令牌,一般是用于验证表明身份的数据或是别的口令数 ...

  6. Mybatis 一对一、一对多、多对多

    一对一返回resultType <!-- 查询订单关联查询用户信息 resultType --> <select id="findOrderCustom" res ...

  7. ubuntu 可能的依赖包,安装过程中根据需要安装

    /*************依赖包安装****************/下面是可能的依赖包,安装过程中根据需要安装 build-essential - libglib2.-dev libpng3 li ...

  8. PHP分页及原理

    在看本文之前,请确保你已掌握了PHP的一些知识以及MYSQL的查询操作基础哦. 作为一个Web程序,经常要和不计其数的数据打交道,比如会员的数据,文章数据,假如只有几十个会员那很好办,在一页显示就可以 ...

  9. WCF基本知识

    1.开通WCF调试服务: 须在服务端的行为中作如下配置:includeExceptionDetailInFaults="true" 代码如下: <behaviors> ...

  10. wix toolset 用wixui 默认中文

    light.exe .\test.wixobj -ext WixUIExtension -ext WixUtilExtension -cultures:zh-CN