torch 深度学习 (2)
torch 深度学习 (2)
前面我们完成了数据的下载和预处理,接下来就该搭建网络模型了,CNN网络的东西可以参考博主 zouxy09的系列文章Deep Learning (深度学习) 学习笔记整理系列之 (七)
加载包
- require 'torch'
- require 'image'
- require 'nn'
函数运行参数的设置
- if not opt then
- print "==> processing options"
- cmd = torch.CmdLine()
- cmd:text()
- cmd:text('options:')
- -- 选择构建何种结构:线性|MLP|ConvNet。默认:convnet
- cmd:option('-model','convnet','type of model to construct: linear | mlp | convnet')
- -- 是否需要可视化
- cmd:option('-visualize',true,'visualize input data and weights during training')
- -- 参数
- opt = cmd:parse(arg or {})
- end
设置网络模型用到的一些参数
- -- 输出类别数,也就是输出节点个数
- noutputs =10
- -- 输入节点的个数
- nfeats = 3 -- YUV三个通道,可以认为是3个features map
- width =32
- height =32
- -- Linear 和 mlp model下的输入节点个数,就是将输入图像拉成列向量
- ninputs = nfeats*width*height
- -- 为mlp定义隐层节点的个数
- nhiddens = ninputs/2
- -- 为convnet定义隐层feature maps的个数以及滤波器的尺寸
- nstates = {16,256,128} --第一个隐层有16个feature map,第二个隐层有256个特征图,第三个隐层有128个节点
- fanin = {1,4} -- 定义了卷积层的输入和输出对应关系,以fanin[2]举例,表示该卷积层有16个map输入,256个map输出,每个输出map是有fanin[2]个输入map对应filters卷积得到的结果
- filtsize =5 --滤波器的大小,方形滤波器
- poolsize = 2 -- 池化池尺寸
- normkernel = image.gaussian1D(7) --长度为7的一维高斯模板,用来local contrast normalization
构建模型
- if opt.model == linear then
- -- 线性模型
- model = nn.Sequntial()
- model:add(nn.Reshape(ninputs)) -- 输入层
- model:add(nn.Linear(ninputs,noutputs)) -- 线性模型 y=Wx+b
- elseif opt.model == mlp then
- -- 多层感知器
- model = nn.Sequential()
- model:add(nn.Reshape(ninputs)) --输入层
- model:add(nn.Linear(ninputs,nhiddens)) --线性层
- model:add(nn.Tanh()) -- 非线性层
- model:add(nn.Linear(nhiddens,noutputs)) -- 线性层
- -- MLP 目标: `!$y=W_2 f(W_1X+b) + b $` 这里的激活函数采用的是Tanh(),MLP后面还可以接一层输出层Tanh()
- elseif opt.model == convnet then
- -- 卷积神经网络
- model = nn.Sequential()
- -- 第一阶段
- model:add(nn.SpatialConvolutionMap(nn.tables.random(nfeats,nstates[1],fanin[1]),filtsize,filtsize))
- -- 这一步直接输入的是图像进行卷积,所以没有了 nn.Reshape(ninputs)输入层。 参数:nn.tables.random(nfeats,nstates[1],fanin[1])指定了卷积层中输入maps和输出maps之间的对应关系,这里表示bstates[1]个输出maps的每一map都是由fanin[1]个输入maps得到的。filtsize则是卷积算子的大小
- -- 所以该层的连接个数为(filtsize*filtsize*fanin[1]+1)*nstates[1],1是偏置。这里的fanin[1]连接是随机的,也可以采用全连接 nn.tables.full(nfeats,nstates[1]), 当输入maps和输出maps个数相同时,还可以采用一对一连接 nn.tables.oneToOne(nfeats).
- -- 参见解释文档 [Convolutional layers](https://github.com/torch/nn/blob/master/doc/concolution.md#nn.convlayers.dok)
- model:add(nn.Tanh()) --非线性变换层
- model:SpatialLPPooling(nstates[1],2,poolsize,poolsize,poolsize,poolsize)
- -- 参数(feature maps个数,Lp范数,池化尺寸大小(w,h), 滑动窗步长(dw,dh))
- model:SpatialSubtractiveNormalization(nstates[1],normalkernel)
- -- local contrast normalization
- -- 具体操作是先在每个map的local邻域进行减法归一化,然后在不同的feature map上进行除法归一化。类似与图像点的均值化和方差归一化。参考[1^x][Nonlinear Image Representation Using Divisive Normalization], [Gaussian Scale Mixtures](stats.stackexchange.com/174502/what-are gaussian-scale-mixtures-and-how-to-generate-samples-of-gaussian-scale),还有解释文档 [Convolutional layers](https://github.com/torch/nn/blob/master/doc/concolution.md#nn.convlayers.dok)
- --[[
- 这里需要说的一点是传统的CNN一般是先卷积再池化再非线性变换,但这两年CNN一般都是先非线性变换再池化了
- --]]
- -- 第二阶段
- model:add(nn.SpatialConvolutionMap(nn.tables.random(nstates[1],nstates[2],fanin[2]),filtsize,filtsize))
- model:add(nn.Tanh())
- model:add(nn.SpatialLPPooling(nstates[2],2,poolsize,poolsize))
- model:add(nn.SpatialSubtractiveNormalization(nstates[2],kernel))
- --第三阶段
- model:add(nn.Reshape(nstates[2]*filtsize*filtsize)) --矢量化,全连接
- model:add(nn.Linear(nstates[2]*filtsize*filtsize,nstates[3]))
- model:add(nn.Tanh())
- model:add(nn.Linear(nstates[3],noutputs))
- else
- error('unknown -model')
- end
显示网络结构以及参数
- print('==> here is the model')
- print(model)
结果如下图

可以发现,可训练参数分别在1,5部分,所以可以观察权重矩阵的大小
- print('==> 权重矩阵的大小 ')
- print(model:get(1).weight:size())
- print('==> 偏置的大小')
- print(model:get(1).bias:numel())

参数的可视化
- if opt.visualize then
- image.display(image=model:get(1).weight, padding=2,zoom=4,legend='filters@ layer 1')
- image.diaplay(image=model:get(5).weight,padding=2,zoom=4,legend='filters @ layer 2')
- end

torch 深度学习 (2)的更多相关文章
- torch 深度学习(5)
torch 深度学习(5) mnist torch siamese deep-learning 这篇文章主要是想使用torch学习并理解如何构建siamese network. siamese net ...
- torch 深度学习(4)
torch 深度学习(4) test doall files 经过数据的预处理.模型创建.损失函数定义以及模型的训练,现在可以使用训练好的模型对测试集进行测试了.测试模块比训练模块简单的多,只需调用模 ...
- torch 深度学习(3)
torch 深度学习(3) 损失函数,模型训练 前面我们已经完成对数据的预处理和模型的构建,那么接下来为了训练模型应该定义模型的损失函数,然后使用BP算法对模型参数进行调整 损失函数 Criterio ...
- 深度学习菜鸟的信仰地︱Supervessel超能云服务器、深度学习环境全配置
并非广告~实在是太良心了,所以费时间给他们点赞一下~ SuperVessel云平台是IBM中国研究院和中国系统与技术中心基于POWER架构和OpenStack技术共同构建的, 支持开发者远程开发的免费 ...
- 深度学习框架caffe/CNTK/Tensorflow/Theano/Torch的对比
在单GPU下,所有这些工具集都调用cuDNN,因此只要外层的计算或者内存分配差异不大其性能表现都差不多. Caffe: 1)主流工业级深度学习工具,具有出色的卷积神经网络实现.在计算机视觉领域Caff ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- [深度学习] Pytorch学习(一)—— torch tensor
[深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...
- 【深度学习Deep Learning】资料大全
最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books by Yoshua Bengio, Ian Goodfellow and Aaron C ...
- [深度学习大讲堂]从NNVM看2016年深度学习框架发展趋势
本文为微信公众号[深度学习大讲堂]特约稿,转载请注明出处 虚拟框架杀入 从发现问题到解决问题 半年前的这时候,暑假,我在SIAT MMLAB实习. 看着同事一会儿跑Torch,一会儿跑MXNet,一会 ...
随机推荐
- Mac打开应用提示已损坏的解决办法
相信很多升级了最新Mac系统的用户在打开一些应用的时候都会出现“应用XX已损坏”的系统提示,安装这些应用的时候总是提示“已损坏,移至废纸篓”这类信息,根本无法打开应用. Mac打开应用提示已损坏的解决 ...
- 关于Softnet的加密。方式是使用API函数。。关键是开发号
首先是获取 开发号. 类似于这个玩意 http://www.cnblogs.com/wenluderen/p/4853563.html 这个帖子里面有介绍关于开发号的完整资料. ××××××××××× ...
- centos7命令3
查看监听的端口 netstat -lntp 检查端口被哪个进程占用 netstat -lnp|grep 8080 查看当前文件夹大小 du -sh 查看当前文件夹各目录大小 du -sh ./* 查看 ...
- 一步一步学EF系列【5、升级篇 实体与数据库的映射】live writer真坑,第4次补发
前言 之前的几篇文章,被推荐到首页后,又被博客园下了,原因内容太少,那我要写多点呢,还是就按照这种频率进行写呢?本身我的意图这个系列就是想已最简单最容易理解的方式进行,每篇内容也不要太多,这样初学者容 ...
- HackerRank - candies 【贪心】
HackerRank - candies [贪心] Description Alice is a kindergarten teacher. She wants to give some candie ...
- Winter-1-E Let the Balloon Rise 解题报告及测试数据
Time Limit:1000MS Memory Limit:32768KB Description Contest time again! How excited it is to see ...
- Docker+.Net Core 的那些事儿-3.创建容器并运行
1.根据镜像运行容器 上篇文章建立了一个镜像: 我们以此开始,执行以下命令: docker run -d -p 5000:5000 hwapp:latest 如果返回以上结果表示建立成功. 此时如果你 ...
- linux在文件中包含某个关键词的指定行插入内容
1. 在包含某个关键字的行上面插入一行文字 sed -i '/wangzai/i\doubi' 1.txt 把内容doubi插入到包含wangzai关键字的上一行 2. 在包含某个关键字的行下面插入一 ...
- nginx rewrite规则last与break的区别
概要:break和last都能阻止继续执行后面的rewrite指令,last如果在location下的话,对于重写后的URI会重新匹配location,而break不会重新匹配location. 区别 ...
- C++之条形码,windows下zint库的编译及应用(一)
zint库是一个开源的第三方库,提供了生成条形码.二维码等功能.本文主要介绍zint库的生成及简单应用. 工具/原料 vs2012 代码文件下载 1 下载zint包 2 zint依赖另外两个库 ...