为何Keras中的CNN是有问题的,如何修复它们?
在训练了 50 个 epoch 之后,本文作者惊讶地发现模型什么都没学到,于是开始深挖背后的问题,并最终从恺明大神论文中得到的知识解决了问题。
上个星期我做了一些实验,用了在 CIFAR10 数据集上训练的 VGG16。我需要从零开始训练模型,所以没有使用在 ImageNet 上预训练的版本。
我开始了 50 个 epoch 的训练,然后去喝了个咖啡,回来就看到了这些学习曲线:
模型什么都没学到!
我见过网络收敛得极其缓慢、振荡、过拟合、发散,但这是我第一次发现这种行为——模型根本就没有起任何作用。
因此我就深挖了一下,看看究竟发生了什么。
实验
这是我创建模型的方法。它遵循了 VGG16 的原始结构,但是,大多数全连接层被移除了,所以只留下了相当多的卷积层。
现在让我们了解一下是什么导致了我在文章开头展示的训练曲线。
学习模型过程中出现错误时,检查一下梯度的表现通常是一个好主意。我们可以使用下面的方法得到每层梯度的平均值和标准差:
然后将它们画出来,我们就得到了以下内容:
使用 Glorot 函数初始化的 VGG16 梯度的统计值
呀... 我的模型中根本就没有梯度,或许应该检查一下激活值是如何逐层变化的。我们可以试用下面的方法得到激活值的平均值和标准差:
然后将它们画出来:
使用 Glorot 函数进行初始化的 VGG16 模型的激活值
这就是问题所在!
提醒一下,每个卷积层的梯度是通过以下公式计算的:
其中Δx 和Δy 用来表示梯度∂L/∂x 和∂L/∂y。梯度是通过和链式法则计算的,这意味着我们是从最后一层开始,反向传递到较浅的层。但当最后一层的激活值接近零时会发生什么呢?这正是我们面临的情况,梯度到处都是零,所以不能反向传播,导致网络什么都学不到。
由于我的网络是相当简约的:没有,没有 Dropout,没有数据增强,所以我猜问题可能来源于比较糟糕的初始化,因此我拜读了何恺明的论文——《Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification》
论文链接:https://arxiv.org/pdf/1502.01852.pdf
下面简要描述一下论文内容。
初始化方法
初始化始终是深度学习研究中的一个重要领域,尤其是结构和非线性经常变化的时候。实际上一个好的初始化是我们能够训练深度神经网络的原因。
以下是何恺明论文中的关键思想,他们展示了初始化应该具备的条件,以便使用 ReLU 激活函数正确初始化 CNN。这里会需要一些数学知识,但是不必担心,你只需抓住整体思路。
我们将一个卷积层 l 的输出写成下面的形式:
接下来,如果偏置值被初始化为 0,再假设权重 w 和元素 x 相互独立并且共享相同的分布,我们就得到了:
其中 n 是一层的权重数目(例如 n=k²c)。通过独立变量的乘积的方差公式:
它变成了:
然后,如果我们让权重 w 的均值为 0,就会得到:
通过 König-Huygens 性质:
最终得到:
然而,由于我们使用的是 ReLU 激活函数,所以就有了:
因此:
这就是一个单独卷积层的输出的方差,到那时如果我们想考虑所有层的情况,就必须将它们乘起来,这就得到了:
由于我们做了乘积,所以现在很容易看到如果每一层的方差不接近于 1,网络就会快速衰减。实际上,如果它比 1 小,就会快速地朝着零消散,如果比 1 大,激活的值就会急剧增长,甚至变成一个你的计算机都无法表示的数字(NaN)。因此,为了拥有表现良好的 ReLU CNN,下面的问题必须被重视:
作者比较了使用标准初始化(Xavier/Glorot)[2] 和使用它们自己的解初始化深度 CNN 时的情况:
在一个 22 层的 ReLU CNN 上使用 Glorot(蓝色)初始化和 Kaiming 的初始化方法进行训练时的对比。使用 Glorot 初始化的模型没有学到任何东西。
这幅图是不是很熟悉?这就是我在文章开始向你们展示的图形!使用 Xavier/Glorot 初始化训练的网络没有学到任何东西。
现在猜一下 Keras 中默认的初始化是哪一种?
没错!在 Keras 中,卷积层默认是以 Glorot Uniform 分布进行初始化的:
所以如果我们将初始化方法改成 Kaiming Uniform 分布会怎么样呢?
使用 Kaiming 的初始化方法
现在来创建我们的 VGG16 模型,但是这次将初始化改成 he_uniform。
在训练模型之前,让我们来检查一下激活值和梯度。
所以现在,使用 Kaiming 的初始化方法时,我们的激活拥有 0.5 左右的均值,以及 0.8 左右的标准差。
可以看到,现在我们有一些梯度,如果希望模型能够学到一些东西,这种梯度就是一种好现象了。
现在,如果我们训练一个新的模型,就会得到下面的学习曲线:
我们可能需要增加一些正则化,但是现在,哈哈,已经比之前好很多了,不是吗?
结论
在这篇文章中,我们证明,初始化是模型中特别重要的一件事情,这一点你可能经常忽略。此外,文章还证明,即便像 Keras 这种卓越的库中的默认设置,也不能想当然拿来就用。
参考文献和扩展阅读:
[1]: Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification:https://arxiv.org/pdf/1502.01852.pdf
[2]: Understanding the difficulty of training deep feedforward neural networks:http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
[3]: 吴恩达课程:https://www.youtube.com/watch?v=s2coXdufOzE
原文地址:https://towardsdatascience.com/why-default-cnn-are-broken-in-keras-and-how-to-fix-them-ce295e5e5f2
欢迎关注磐创博客资源汇总站:
http://docs.panchuang.net/
欢迎关注PyTorch官方中文教程站:
http://pytorch.panchuang.net/
为何Keras中的CNN是有问题的,如何修复它们?的更多相关文章
- [知乎作答]·关于在Keras中多标签分类器训练准确率问题
[知乎作答]·关于在Keras中多标签分类器训练准确率问题 本文来自知乎问题 关于在CNN中文本预测sigmoid分类器训练准确率的问题?中笔者的作答,来作为Keras中多标签分类器的使用解析教程. ...
- 在Keras中可视化LSTM
作者|Praneet Bomma 编译|VK 来源|https://towardsdatascience.com/visualising-lstm-activations-in-keras-b5020 ...
- 探索学习率设置技巧以提高Keras中模型性能 | 炼丹技巧
学习率是一个控制每次更新模型权重时响应估计误差而调整模型程度的超参数.学习率选取是一项具有挑战性的工作,学习率设置的非常小可能导致训练过程过长甚至训练进程被卡住,而设置的非常大可能会导致过快学习到 ...
- keras中VGG19预训练模型的使用
keras提供了VGG19在ImageNet上的预训练权重模型文件,其他可用的模型还有VGG16.Xception.ResNet50.InceptionV3 4个. VGG19在keras中的定义: ...
- keras中的mini-batch gradient descent (转)
深度学习的优化算法,说白了就是梯度下降.每次的参数更新有两种方式. 一. 第一种,遍历全部数据集算一次损失函数,然后算函数对各个参数的梯度,更新梯度.这种方法每更新一次参数都要把数据集里的所有样本都看 ...
- keras中的模型保存和加载
tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...
- keras中的loss、optimizer、metrics
用keras搭好模型架构之后的下一步,就是执行编译操作.在编译时,经常需要指定三个参数 loss optimizer metrics 这三个参数有两类选择: 使用字符串 使用标识符,如keras.lo ...
- Keras中RNN不定长输入的处理--padding and masking
在使用RNN based model处理序列的应用中,如果使用并行运算batch sample,我们几乎一定会遇到变长序列的问题. 通常解决变长的方法主要是将过长的序列截断,将过短序列用0补齐到一个固 ...
- 深度学习基础系列(十一)| Keras中图像增强技术详解
在深度学习中,数据短缺是我们经常面临的一个问题,虽然现在有不少公开数据集,但跟大公司掌握的海量数据集相比,数量上仍然偏少,而某些特定领域的数据采集更是非常困难.根据之前的学习可知,数据量少带来的最直接 ...
随机推荐
- Typora+PicGo+Gitee笔记方案
前言:需要学习的知识太多,从一开始就在寻找一款能让我完全满意的编辑器,然而一直都没有令我满意的.在前两天Typora新版本更新后,总算是拥有了一套我认为很完美的笔记方案:使用Typora编写markd ...
- 性能测试之Mysql数据库调优
一.前言 性能调优前提:无监控不调优,对于mysql性能的监控前几天有文章提到过,有兴趣的朋友可以去看一下 二.Mysql性能指标及问题分析和定位 1.我们在监控图表中关注的性能指标大概有这么几个:C ...
- LeetCode--链表1-单链表
LeetCode--链表1-单链表 单链表模板 初始化 头部插入 尾部插入 删除节点 Index插入 Index返回对应的节点指针和val值 class MyLinkedList { private: ...
- 配置VSCode的C/C++语言功能
0. 前言 主要是在网上找的方法都没试成功过,在各种机缘巧合下终于成功了. 这篇文章基于个人经验,而且没有走寻常路. 1. 需要的软件和插件 软件: VSCode (https://code.visu ...
- session和el表达式
2015/1/21 ## 回顾昨天案例 ## # 模拟购物车: >> 基本步骤: |-- 显示所有的书籍: |-- 制作书记列表/模仿数据库: |-- 参见昨天示例: |-- 制作查看详情 ...
- echart 之实现温度计
百度这个图表支持不是很好,有的需要自己写,看大神们实现温度计都是用 水球特效实现的我这里雕虫小计啊但是满足我了我的项目需求特此分享出来,可惜自己不是专业的前端 这是我的实现结果 好了上代码html: ...
- vue 实现 裁切图片 同时有放大、缩小、旋转功能
实现效果: 裁切指定区域内的图片 旋转图片 放大图片 输出bolb 格式数据 提供给 formData 对象 效果图 大概原理: 利用h5 FileReader 对象, 获取 <input ty ...
- JAVA Integer值的范围
原文出处:http://hi.baidu.com/eduask%C9%BD%C8%AA/blog/item/227bf4d81c71ebf538012f53.html package com.test ...
- 对两个有序数组重新去重合并排序js实现
这里主要是要利用两个数组有序这个条件,所以只需两个指针分别指向两个数组,当其中一个小于另外一个就移动该指针,反之则移动另外一个指针,如果相等则均向后移动. 结束条件是,当任意一个数组的指针移到末尾则跳 ...
- form里面文件上传并预览
其实form里面是不能嵌套form的,如果form里面有图片上传和其他input框,我们希望上传图片并预览图片,然后将其他input框填写完毕,再提交整个表单的话,有两种方式! 方式一:点击上传按钮的 ...