【导读】神经网络的初始化是训练流程的重要基础环节,会对模型的性能、收敛性、收敛速度等产生重要的影响。本文是deeplearning.ai的一篇技术博客,文章指出,对初始化值的大小选取不当,  可能造成梯度爆炸或梯度消失等问题,并提出了针对性的解决方法。

初始化会对深度神经网络模型的训练时间和收敛性产生重大影响。简单的初始化方法可以加速训练,但使用这些方法需要注意小心常见的陷阱。本文将解释如何有效地对神经网络参数进行初始化。

有效的初始化对构建模型至关重要

要构建机器学习算法,通常要定义一个体系结构(例如逻辑回归,支持向量机,神经网络)并对其进行训练来学习参数。下面是训练神经网络的一些常见流程:

初始化参数

选择优化算法

然后重复以下步骤:

1、向前传播输入

2、计算成本函数

3、使用反向传播计算与参数相关的成本梯度

4、根据优化算法,利用梯度更新每个参数

然后,给定一个新的数据点,使用模型来预测其类型。

初始化值太大\太小会导致梯度爆炸或梯度消失

初始化这一步对于模型的最终性能至关重要,需要采用正确的方法。比如对于下面的三层神经网络。可以尝试使用不同的方法初始化此网络,并观察对学习的影响。

在优化循环的每次迭代(前向,成本,后向,更新)中,我们观察到当从输出层向输入层移动时,反向传播的梯度要么被放大,要么被最小化。

假设所有激活函数都是线性的(恒等函数)。 则输出激活为:

其中 L=10 ,且W[1]、W[2]…W[L-1]都是2*2矩阵,因为从第1层到L-1层都是2个神经元,接收2个输入。为了方便分析,如果假设W[1]=W[2]=…=W[L-1]=W,那么输出预测为

如果初始化值太大或太小会造成什么结果?

情况1:初始化值过大会导致梯度爆炸

如果每个权重的初始化值都比单位矩阵稍大,即:

可简化表示为

且a[l]的值随l值呈指数级增长。当这些激活用于向后传播时,会导致梯度爆炸。也就是说,与参数相关的成本梯度太大。 这导致成本围绕其最小值振荡。

初始化值太大导致成本围绕其最小值震荡

情况2:初始化值过小会导致梯度消失

类似地,如果每个权重的初始化值都比单位矩阵稍小,即:

可简化表示为

且a[l]的值随l值减少呈指数级下降。当这些激活用于后向传播时,可能会导致梯度消失。也就是说,与参数相关的成本梯度太小。这会导致成本在达到最小值之前收敛。

初始化值太小导致模型过早收敛

总而言之,使用大小不合适的值对权重进行将导致神经网络的发散或训练速度下降。 虽然我们用的是简单的对称权重矩阵来说明梯度爆炸/消失的问题,但这一现象可以推广到任何不合适的初始化值。

如何确定合适的初始化值

为了防止以上问题的出现,我们可以坚持以下经验原则:

1.激活的平均值应为零。

2.激活的方差应该在每一层保持不变。

在这两个假设下,反向传播的梯度信号不应该在任何层中乘以太小或太大的值。梯度应该可以移动到输入层,而不会爆炸或消失。

更具体地说,对于层l,其前向传播是:

我们想让下式成立:

确保均值为零,并保持每层输入方差值不变,可以保证信号不会爆炸或消失。该方法既适用于前向传播(用于激活),也适用于向后传播(用于关于激活的成本梯度)。这里建议使用Xavier初始化(或其派生初始化方法),对于每个层l,有:

层l中的所有权重均自正态分布中随机挑选,其中均值 μ=0 ,方差E= 1/( n[l−1]),其中n[l−1] 是第l-1层网络中的神经元数量。偏差已初始化为零。

下图说明了Xavier初始化对五层全连接神经网络的影响。数据集为MNIST中选取的10000个手写数字,分类结果的红色方框表示错误分类,蓝色表示正确分类。

结果显示,Xavier初始化的模型性能显著高于uniform和标准正态分布(从上至下分别为uniform、标准正态分布、Xavier)。

结论

在实践中,使用Xavier初始化的机器学习工程师会将权重初始化为N(0,1/( n[l−1]))或N(0,2/(n[l-1]+n[1])),其中后一个分布的方差是n[l-1]和n[1]的调和平均。

Xavier初始化可以与tanh激活一起使用。此外,还有大量其他初始化方法。 例如,如果你正在使用ReLU,则通常的初始化是He初始化,其初始化权重通过乘以Xavier初始化的方差2来初始化。 虽然这种初始化证明稍微复杂一些,但其思路与tanh是相同的。

参考链接:

https://www.deeplearning.ai/ai-notes/initialization/

一文看懂神经网络初始化!吴恩达Deeplearning.ai最新干货的更多相关文章

  1. 吴恩达deepLearning.ai循环神经网络RNN学习笔记_看图就懂了!!!(理论篇)

    前言 目录: RNN提出的背景 - 一个问题 - 为什么不用标准神经网络 - RNN模型怎么解决这个问题 - RNN模型适用的数据特征 - RNN几种类型 RNN模型结构 - RNN block - ...

  2. 吴恩达deepLearning.ai循环神经网络RNN学习笔记_没有复杂数学公式,看图就懂了!!!(理论篇)

    本篇文章被Google中国社区组织人转发,评价: 条理清晰,写的很详细! 被阿里算法工程师点在看! 所以很值得一看! 前言 目录: RNN提出的背景 - 一个问题 - 为什么不用标准神经网络 - RN ...

  3. 用纯Python实现循环神经网络RNN向前传播过程(吴恩达DeepLearning.ai作业)

    Google TensorFlow程序员点赞的文章!   前言 目录: - 向量表示以及它的维度 - rnn cell - rnn 向前传播 重点关注: - 如何把数据向量化的,它们的维度是怎么来的 ...

  4. 吴恩达DeepLearning.ai的Sequence model作业Dinosaurus Island

    目录 1 问题设置 1.1 数据集和预处理 1.2 概览整个模型 2. 创建模型模块 2.1 在优化循环中梯度裁剪 2.2 采样 3. 构建语言模型 3.1 梯度下降 3.2 训练模型 4. 结论   ...

  5. 吴恩达《AI For Everyone》_练习英语翻译_待更新

    AI For Everyone https://www.coursera.org/learn/ai-for-everyone 讲师: Andrew Ng (吴恩达) CEO/Founder Landi ...

  6. 吴恩达deeplearning之CNN—卷积神经网络

    https://blog.csdn.net/ice_actor/article/details/78648780 个人理解: 卷积计算的过程其实是将原始的全连接换成了卷积全连接,每个kernel为对应 ...

  7. 吴恩达DeepLearning 第一课第四周随笔

    第四周 4.1深度神经网络符号约定 L=4______(神经网络层数)   4.2 校正矩阵的维数 校正要点:,, dZ,dA,dW,db都与它们被导数(Z,A,W,b)的维数相同 4.3 为什么使用 ...

  8. 2017年度好视频,吴恩达、李飞飞、Hinton、OpenAI、NIPS、CVPR、CS231n全都在

    我们经常被问:机器翻译迭代了好几轮,专业翻译的饭碗都端不稳了,字幕组到底还能做什么? 对于这个问题,我们自己感受最深,却又来不及解释,就已经边感受边做地冲出去了很远,摸爬滚打了一整年. 其实,现在看来 ...

  9. 一文弄懂神经网络中的反向传播法——BackPropagation【转】

    本文转载自:https://www.cnblogs.com/charlotte77/p/5629865.html 一文弄懂神经网络中的反向传播法——BackPropagation   最近在看深度学习 ...

随机推荐

  1. 这么香的Chrome插件,你都安装了吗?

    工欲善其事必先利其器,今天长话短说,介绍13个敏捷.高效的Chrome插件 根据使用方式,本人将其划分为三大类: 开发者工具 日常效率工具类 浏览器管理类 开发者工具 1. Web Developer ...

  2. 带你入门 CSS Grid 布局

    前言 三月中旬的时候,有一个对于 CSS 开发者来说很重要的消息,最新版的 Firefox 和 Chrome 已经正式支 CSS Grid 这一新特性啦.没错:我们现在就可以在最流行的两大浏览器上玩转 ...

  3. SPA那点事

    前端猿一天不学习就没饭吃了,后端猿三天不学习仍旧有白米饭摆于桌前.IT行业的快速发展一直在推动着前端技术栈在不断地更新换代,前端的发展成了互联网时代的一个缩影.而单页面应用的发展给前端猿分了一杯羹. ...

  4. 前端每日实战:111# 视频演示如何用纯 CSS 创作一只艺术的鸭子

    效果预览 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览. https://codepen.io/comehope/pen/aaoveW 可交互视频 此视频是可 ...

  5. linux 安装Mosquitto

    这篇博客讲的很好:https://www.cnblogs.com/chen1-kerr/p/7258487.html 此处简单摘抄部分内容 1.下载mosquitto安装包 源码地址:http://m ...

  6. 添加谷歌拓展程序 vue.js devtools过程中的问题

    在用vue做项目过程中,需要用到vue.js devtools,在从github上面clone下来代码,然后再npm install ,过程报错,然后更新npm包也是会有问题,以下是install的问 ...

  7. springmvc与swagger2

    首先呢我们导入相关的jar包文件 为了方便copy我copy一份 <!-- 导入java ee jar 包 -->        <dependency>           ...

  8. MySQL中INSERT INTO SELECT的使用

    1. 语法介绍      有三张表a.b.c,现在需要从表b和表c中分别查几个字段的值插入到表a中对应的字段.对于这种情况,可以使用如下的语句来实现: INSERT INTO db1_name (fi ...

  9. 这些MongoDB的隐藏操作你真的都掌握了吗?反正我是刚知道

    背景 最近公司系统还原用户时偶尔会出现部分用户信息未还原成功的问题,最为开发人员,最头疼的不是代码存在bug,而是测试发现了bug,但一旦我去重现,它就不见了.Are you kidding me? ...

  10. AspNetCore3.1源码解析_2_Hsts中间件

    title: "AspNetCore3.1源码解析_2_Hsts中间件" date: 2020-03-16T12:40:46+08:00 draft: false --- 概述 在 ...