torch 深度学习 (2)

torch
ConvNet

前面我们完成了数据的下载和预处理,接下来就该搭建网络模型了,CNN网络的东西可以参考博主 zouxy09的系列文章Deep Learning (深度学习) 学习笔记整理系列之 (七)

  1. 加载包

  1. require 'torch' 

  2. require 'image' 

  3. require 'nn' 

  1. 函数运行参数的设置

  1. if not opt then 

  2. print "==> processing options" 

  3. cmd = torch.CmdLine() 

  4. cmd:text() 

  5. cmd:text('options:') 

  6. -- 选择构建何种结构:线性|MLP|ConvNet。默认:convnet 

  7. cmd:option('-model','convnet','type of model to construct: linear | mlp | convnet') 

  8. -- 是否需要可视化 

  9. cmd:option('-visualize',true,'visualize input data and weights during training') 

  10. -- 参数 

  11. opt = cmd:parse(arg or {}) 

  12. end 

  1. 设置网络模型用到的一些参数

  1. -- 输出类别数,也就是输出节点个数 

  2. noutputs =10 

  3. -- 输入节点的个数 

  4. nfeats = 3 -- YUV三个通道,可以认为是3个features map 

  5. width =32 

  6. height =32 

  7. -- Linear 和 mlp model下的输入节点个数,就是将输入图像拉成列向量 

  8. ninputs = nfeats*width*height 


  9. -- 为mlp定义隐层节点的个数 

  10. nhiddens = ninputs/2  


  11. -- 为convnet定义隐层feature maps的个数以及滤波器的尺寸 

  12. nstates = {16,256,128} --第一个隐层有16个feature map,第二个隐层有256个特征图,第三个隐层有128个节点 

  13. fanin = {1,4} -- 定义了卷积层的输入和输出对应关系,以fanin[2]举例,表示该卷积层有16个map输入,256个map输出,每个输出map是有fanin[2]个输入map对应filters卷积得到的结果 

  14. filtsize =5 --滤波器的大小,方形滤波器 

  15. poolsize = 2 -- 池化池尺寸 

  16. normkernel = image.gaussian1D(7) --长度为7的一维高斯模板,用来local contrast normalization 

  1. 构建模型

  1. if opt.model == linear then  

  2. -- 线性模型 

  3. model = nn.Sequntial() 

  4. model:add(nn.Reshape(ninputs)) -- 输入层 

  5. model:add(nn.Linear(ninputs,noutputs)) -- 线性模型 y=Wx+b 

  6. elseif opt.model == mlp then  

  7. -- 多层感知器 

  8. model = nn.Sequential() 

  9. model:add(nn.Reshape(ninputs)) --输入层 

  10. model:add(nn.Linear(ninputs,nhiddens)) --线性层 

  11. model:add(nn.Tanh()) -- 非线性层 

  12. model:add(nn.Linear(nhiddens,noutputs)) -- 线性层 

  13. -- MLP 目标: `!$y=W_2 f(W_1X+b) + b $` 这里的激活函数采用的是Tanh(),MLP后面还可以接一层输出层Tanh() 

  14. elseif opt.model == convnet then 

  15. -- 卷积神经网络 

  16. model = nn.Sequential() 

  17. -- 第一阶段 

  18. model:add(nn.SpatialConvolutionMap(nn.tables.random(nfeats,nstates[1],fanin[1]),filtsize,filtsize)) 

  19. -- 这一步直接输入的是图像进行卷积,所以没有了 nn.Reshape(ninputs)输入层。 参数:nn.tables.random(nfeats,nstates[1],fanin[1])指定了卷积层中输入maps和输出maps之间的对应关系,这里表示bstates[1]个输出maps的每一map都是由fanin[1]个输入maps得到的。filtsize则是卷积算子的大小 

  20. -- 所以该层的连接个数为(filtsize*filtsize*fanin[1]+1)*nstates[1],1是偏置。这里的fanin[1]连接是随机的,也可以采用全连接 nn.tables.full(nfeats,nstates[1]), 当输入maps和输出maps个数相同时,还可以采用一对一连接 nn.tables.oneToOne(nfeats). 

  21. -- 参见解释文档 [Convolutional layers](https://github.com/torch/nn/blob/master/doc/concolution.md#nn.convlayers.dok) 


  22. model:add(nn.Tanh()) --非线性变换层 

  23. model:SpatialLPPooling(nstates[1],2,poolsize,poolsize,poolsize,poolsize) 

  24. -- 参数(feature maps个数,Lp范数,池化尺寸大小(w,h), 滑动窗步长(dw,dh)) 

  25. model:SpatialSubtractiveNormalization(nstates[1],normalkernel) 

  26. -- local contrast normalization 

  27. -- 具体操作是先在每个map的local邻域进行减法归一化,然后在不同的feature map上进行除法归一化。类似与图像点的均值化和方差归一化。参考[1^x][Nonlinear Image Representation Using Divisive Normalization], [Gaussian Scale Mixtures](stats.stackexchange.com/174502/what-are gaussian-scale-mixtures-and-how-to-generate-samples-of-gaussian-scale),还有解释文档 [Convolutional layers](https://github.com/torch/nn/blob/master/doc/concolution.md#nn.convlayers.dok) 


  28. --[[ 

  29. 这里需要说的一点是传统的CNN一般是先卷积再池化再非线性变换,但这两年CNN一般都是先非线性变换再池化了 

  30. --]] 

  31. -- 第二阶段 

  32. model:add(nn.SpatialConvolutionMap(nn.tables.random(nstates[1],nstates[2],fanin[2]),filtsize,filtsize)) 

  33. model:add(nn.Tanh()) 

  34. model:add(nn.SpatialLPPooling(nstates[2],2,poolsize,poolsize)) 

  35. model:add(nn.SpatialSubtractiveNormalization(nstates[2],kernel)) 


  36. --第三阶段 

  37. model:add(nn.Reshape(nstates[2]*filtsize*filtsize)) --矢量化,全连接 

  38. model:add(nn.Linear(nstates[2]*filtsize*filtsize,nstates[3])) 

  39. model:add(nn.Tanh()) 

  40. model:add(nn.Linear(nstates[3],noutputs)) 

  41. else 

  42. error('unknown -model') 

  43. end 

  1. 显示网络结构以及参数

  1. print('==> here is the model') 

  2. print(model) 

结果如下图

model.png

可以发现,可训练参数分别在1,5部分,所以可以观察权重矩阵的大小

  1. print('==> 权重矩阵的大小 ') 

  2. print(model:get(1).weight:size()) 

  3. print('==> 偏置的大小') 

  4. print(model:get(1).bias:numel()) 

weights numel.png
  1. 参数的可视化

  1. if opt.visualize then 

  2. image.display(image=model:get(1).weight, padding=2,zoom=4,legend='filters@ layer 1') 

  3. image.diaplay(image=model:get(5).weight,padding=2,zoom=4,legend='filters @ layer 2') 

  4. end 

weights visualization.png

torch 深度学习 (2)的更多相关文章

  1. torch 深度学习(5)

    torch 深度学习(5) mnist torch siamese deep-learning 这篇文章主要是想使用torch学习并理解如何构建siamese network. siamese net ...

  2. torch 深度学习(4)

    torch 深度学习(4) test doall files 经过数据的预处理.模型创建.损失函数定义以及模型的训练,现在可以使用训练好的模型对测试集进行测试了.测试模块比训练模块简单的多,只需调用模 ...

  3. torch 深度学习(3)

    torch 深度学习(3) 损失函数,模型训练 前面我们已经完成对数据的预处理和模型的构建,那么接下来为了训练模型应该定义模型的损失函数,然后使用BP算法对模型参数进行调整 损失函数 Criterio ...

  4. 深度学习菜鸟的信仰地︱Supervessel超能云服务器、深度学习环境全配置

    并非广告~实在是太良心了,所以费时间给他们点赞一下~ SuperVessel云平台是IBM中国研究院和中国系统与技术中心基于POWER架构和OpenStack技术共同构建的, 支持开发者远程开发的免费 ...

  5. 深度学习框架caffe/CNTK/Tensorflow/Theano/Torch的对比

    在单GPU下,所有这些工具集都调用cuDNN,因此只要外层的计算或者内存分配差异不大其性能表现都差不多. Caffe: 1)主流工业级深度学习工具,具有出色的卷积神经网络实现.在计算机视觉领域Caff ...

  6. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  7. [深度学习] Pytorch学习(一)—— torch tensor

    [深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...

  8. 【深度学习Deep Learning】资料大全

    最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books  by Yoshua Bengio, Ian Goodfellow and Aaron C ...

  9. [深度学习大讲堂]从NNVM看2016年深度学习框架发展趋势

    本文为微信公众号[深度学习大讲堂]特约稿,转载请注明出处 虚拟框架杀入 从发现问题到解决问题 半年前的这时候,暑假,我在SIAT MMLAB实习. 看着同事一会儿跑Torch,一会儿跑MXNet,一会 ...

随机推荐

  1. 深入理解JS对象和原型链

    函数在整个js中是最复杂也是最重要的知识 一个函数中存在多面性: 1.它本身就是一个普通的函数,执行的时候形成的私有作用域(闭包),形参赋值,预解释,代码执行,执行完 成后栈内存销毁/不销毁. 2.& ...

  2. LinQ高级查询、组合查询、IQueryable集合类型

    LinQ高级查询: 1.模糊查询(包含) Repeater1.DataSource = con.car.Where(r =>r.name.Contains(s)).ToList(); 2.开头 ...

  3. eclipse 创建jsp报错

  4. 获取Json字符串中某个key对应的value

    JSONObject jsonObj= JSONObject.fromObject(jsonStr); String value= jsonObj.getString(key);

  5. playbook实现nginx安装

    1. 先在一台机器上编译安装好nginx,然后打包 tar -zcvf nginx.tar.gz /usr/local/nginx --exclude=conf/nginx.conf --exclud ...

  6. doc命令下查看java安装路径

    在doc窗口下使用命令:set  java_home 即可查看.

  7. maven-surefire-plugin

    本文参考自:https://www.cnblogs.com/qyf404/p/5013694.html surefire是maven里执行测试用例(包括testNG,Junit,pojo)的插件,他能 ...

  8. 20145302张薇《Java程序设计》实验二报告

    20145302张薇<Java程序设计>实验二:Java面向对象程序设计 使用TDD的方式设计实现复数类:Complex 测试代码 import org.junit.Test; publi ...

  9. 20135320赵瀚青LINUX第五章读书笔记

    第五章--系统调用 5.1 与内核通信 作用 1.为用户空间提供一种硬件的抽象接口 2.保证系统稳定和安全 3.除异常和陷入,是内核唯一的合法入口. API.POSIX和C库 关于Unix接口设计:提 ...

  10. XML常用标签的介绍

    1.引言 在使用Java时经常遇到使用XML的情况,而因为对XML不太了解,经常配置时粘贴复制,现在对它进行总结,以备以后使用. 2.XML常见的定义 (1)XML(Extensible Markup ...