dropout是在训练神经网络模型时,样本数据过少,防止过拟合而采用的trick。那它是怎么做到防止过拟合的呢?

  首先,想象我们现在只训练一个特定的网络,当迭代次数增多的时候,可能出现网络对训练集拟合的很好(在训练集上loss很小),但是对验证集的拟合程度很差的情况。所以,我们有了这样的想法:可不可以让每次跌代随机的去更新网络参数(weights),引入这样的随机性就可以增加网络generalize 的能力。所以就有了dropout 。

  在训练的时候,我们只需要按一定的概率(retaining probability)p 来对weight layer 的参数进行随机采样,将这个子网络作为此次更新的目标网络。可以想象,如果整个网络有n个参数,那么我们可用的子网络个数为 2^n 。 并且,当n很大时,每次迭代更新 使用的子网络基本上不会重复,从而避免了某一个网络被过分的拟合到训练集上。

  那么测试的时候怎么办呢? 一种最naive的方法是,我们把 2^n 个子网络都用来做测试,然后以某种 voting 机制将所有结果结合一下(比如说平均一下下),然后得到最终的结果。但是,由于n实在是太大了,这种方法实际中完全不可行!所以有人提出,那我做一个大致的估计不就得了,我从2^n个网络中随机选取 m 个网络做测试,最后在用某种voting 机制得到最终的预测结果。这种想法当然可行,当m很大时但又远小于2^n时,能够很好的逼近原2^n个网络结合起来的预测结果。但是,有没有更好的办法呢? of course!那就是dropout 自带的功能,能够通过一次测试得到逼近于原2^n个网络组合起来的预测能力!

  虽然训练的时候我们使用了dropout, 但是在测试时,我们不使用dropout (不对网络的参数做任何丢弃,这时dropout layer相当于进来什么就输出什么)。然后,把测试时dropout layer的输出乘以训练时使用的retaining probability  p (这时dropout layer相当于把进来的东东乘以p)。仔细想想这里面的意义在哪里呢??? 事实上,由于我们在测试时不做任何的参数丢弃,如上面所说,dropout layer 把进来的东西原样输出,导致在统计意义下,测试时 每层 dropout layer的输出比训练时的输出多加了【(1 - p)*100】%  units 的输出。 即 【p*100】% 个units 的和  是同训练时随机采样得到的子网络的输出一致,另【(1 - p)*100】%  的units的和  是本来应该扔掉但是又在测试阶段被保留下来的。所以,为了使得dropout layer 下一层的输入和训练时具有相同的“意义”和“数量级”,我们要对测试时的伪dropout layer的输出(即下层的输入)做 rescale: 乘以一个p,表示最后的sum中只有这么大的概率,或者这么多的部分被保留。这样以来,只要一次测试,将原2^n个子网络的参数全部考虑进来了,并且最后的 rescale 保证了后面一层的输入仍然符合相应的物理意义和数量级。

  假设x是dropout layer的输入,y是dropout layer的输出,W是上一层的所有weight parameters, 是以retaining probability 为p 采样得到的weight parameter子集。把上面的东西用公式表示(忽略bias):

    train:  

    test:

  

  但是一般写程序的时候,我们想直接在test时用   , 这种表达式。(where  ) 因此我们就在训练的时候就直接训练  。 所以训练时,第一个公式修正为    。 即把dropout的输入乘以p 再进行训练,这样得到的训练得到的weight 参数就是  ,测试的时候除了不使用dropout外,不需要再做任何rescale。Caffe 和Lasagne 里面的代码就是这样写的。

转自http://blog.csdn.net/u012702874/article/details/45030991

CNN中dropout层的理解的更多相关文章

  1. 由浅入深:CNN中卷积层与转置卷积层的关系

    欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~ 本文由forrestlin发表于云+社区专栏 导语:转置卷积层(Transpose Convolution Layer)又称反卷积层或分数卷 ...

  2. CNN中卷积层 池化层反向传播

    参考:https://blog.csdn.net/kyang624823/article/details/78633897 卷积层 池化层反向传播: 1,CNN的前向传播 a)对于卷积层,卷积核与输入 ...

  3. CNN中卷积层的计算细节

    原文链接: https://zhuanlan.zhihu.com/p/29119239 卷积层尺寸的计算原理 输入矩阵格式:四个维度,依次为:样本数.图像高度.图像宽度.图像通道数 输出矩阵格式:与输 ...

  4. 深度学习中dropout策略的理解

    现在有空整理一下关于深度学习中怎么加入dropout方法来防止测试过程的过拟合现象. 首先了解一下dropout的实现原理: 这些理论的解释在百度上有很多.... 这里重点记录一下怎么实现这一技术 参 ...

  5. 对faster rcnn 中rpn层的理解

    1.介绍 图为faster rcnn的rpn层,接自conv5-3 图为faster rcnn 论文中关于RPN层的结构示意图 2 关于anchor: 一般是在最末层的 feature map 上再用 ...

  6. 理解CNN中的通道 channel

    在深度学习的算法学习中,都会提到 channels 这个概念.在一般的深度学习框架的 conv2d 中,如 tensorflow .mxnet ,channels 都是必填的一个参数. channel ...

  7. javaEE中关于dao层和services层的理解

    javaEE中关于dao层和services层的理解 入职已经一个多月了,作为刚毕业的新人,除了熟悉公司的项目,学习公司的框架,了解项目的一些业务逻辑之外,也就在没学到什么:因为刚入职, 带我的那个师 ...

  8. caffe中关于(ReLU层,Dropout层,BatchNorm层,Scale层)输入输出层一致的问题

    在卷积神经网络中.常见到的激活函数有Relu层 layer { name: "relu1" type: "ReLU" bottom: "pool1&q ...

  9. 深度学习中Dropout原理解析

    1. Dropout简介 1.1 Dropout出现的原因 在机器学习的模型中,如果模型的参数太多,而训练样本又太少,训练出来的模型很容易产生过拟合的现象. 在训练神经网络的时候经常会遇到过拟合的问题 ...

随机推荐

  1. java反序列化漏洞的检测

    1.首先下载常用的工具ysoserial 这边提供下载地址:https://jitpack.io/com/github/frohoff/ysoserial/master-v0.0.5-gb617b7b ...

  2. 【BZOJ4553】[Tjoi2016&Heoi2016]序列 cdq分治+树状数组

    [BZOJ4553][Tjoi2016&Heoi2016]序列 Description 佳媛姐姐过生日的时候,她的小伙伴从某宝上买了一个有趣的玩具送给他.玩具上有一个数列,数列中某些项的值可能 ...

  3. 【BZOJ2946】[Poi2000]公共串 后缀数组+二分

    [BZOJ2946][Poi2000]公共串 Description        给出几个由小写字母构成的单词,求它们最长的公共子串的长度. 任务: l        读入单词 l        计 ...

  4. Android 通过Socket 和服务器通讯

    Extends:(http://www.cnblogs.com/likwo/p/3641135.html) Android 通过Socket 和服务器通讯,是一种比较常用的通讯方式,时间比较紧,说下大 ...

  5. 微信小程序 --- https请求

    wx.request发起的是 https 请求,而不是 http 请求.一个小程序 同时 只能有 5个 网络请求. 参数: url:开发者服务器接口地址: data:请求的参数: header:设置请 ...

  6. Linux--vim编辑器和文件恢复

    第五章  Vim编辑器和恢复ext4下误删除的文件-Xmanager工具 本节所讲内容: 5.1  vim的使用 5.2  实战:恢复ext4文件系统下误删除的文件 5.3  实战:使用xmanage ...

  7. poj3345 Bribing FIPA【树形DP】【背包】

    Bribing FIPA Time Limit: 2000MS   Memory Limit: 65536K Total Submissions: 5910   Accepted: 1850 Desc ...

  8. 括号匹配问题(区间dp)

    简单的检查括号是否配对正确使用的是栈模拟,这个不必再说,现在将这个问题改变一下:如果给出一个括号序列,问需要把他补全成合法最少需要多少步? 这是一个区间dp问题,我们可以利用区间dp来解决,直接看代码 ...

  9. Nmap介绍

    1.Nmap介绍 Nmap用于列举网络主机清单.管理服务升级调度.监控主机或服务运行状况.Nmap可以检测目标机是否在线.端口开放情况.侦测运行的服务类型及版本信息.侦测操作系统与设备类型等信息. 1 ...

  10. php基础:面向对象

    一.public.private.protected访问修饰符 public:任何都可以访问(本类.子类.外部都可以访问) protected:本类.子类都可以访问(本类.子类均可访问) privat ...