Neural Networks: Learning

内容较多,故分成上下两篇文章。

一、内容概要

  • Cost Function and Backpropagation

    • Cost Function
    • Backpropagation Algorithm
    • Backpropagation Intuition
  • Backpropagation in Practice

    • Implementation Note:Unroll Parameters
    • Gradient Checking
    • Random Initialization
    • Putting it Together
  • Application of Neural Networks

    • Autonomous Driving

二、重点&难点

1.Cost Function and Backpropagation

1) Cost Function

首先定义一下后面会提到的变量

L: 神经网络总层数

Sl:l层单元个数(不包括bias unit)

k:输出层个数

回顾正则化逻辑回归中的损失函数:

\[J(\theta) = - \frac{1}{m} \sum_{i=1}^m [ y^{(i)}\ \log (h_\theta (x^{(i)})) + (1 - y^{(i)})\ \log (1 - h_\theta(x^{(i)}))] + \frac{\lambda}{2m}\sum_{j=1}^n \theta_j^2\]

在神经网络中损失函数略微复杂了些,但是也比较好理解,就是把所有层都算进去了。

\[
\begin{gather*} J(\Theta) = - \frac{1}{m} \sum_{i=1}^m \sum_{k=1}^K \left[y^{(i)}_k \log ((h_\Theta (x^{(i)}))_k) + (1 - y^{(i)}_k)\log (1 - (h_\Theta(x^{(i)}))_k)\right] + \frac{\lambda}{2m}\sum_{l=1}^{L-1} \sum_{i=1}^{s_l} \sum_{j=1}^{s_{l+1}} ( \Theta_{j,i}^{(l)})^2\end{gather*}
\]

2)BackPropagation反向传播

更详细的公式推导可以参考http://ufldl.stanford.edu--反向传导算法

下面给出我自己对BP算法的理解以及ufldl上的推导:

假设神经网络结构如下

- 1. FP

  1. 利用前向传导公式(FP)计算\(2,3……\) 直到 \({n_l}\)层(输出层)的激活值。

    计算过程如下:

- 2. BP

  • 权值更新

首先需要知道的是BP算法是干嘛的?它是用来让神经网络自动更新权重\(W\)的。

这里权重\(W\)与之前线性回归权值更新形式上是一样:

那现在要做的工作就是求出后面的偏导,在求之前进一步变形:

注意\(J(W,b;x^{(i)},y^{(i)})\)表示的是单个样例的代价函数,而\(J(W,b)\)表示的是整体的代价函数。

所以接下来的工作就是求出\(\frac{∂J(W,b;x,y)}{∂W_{ij^{(l)}}}\),求解这个需要用到微积分中的链式法则,即

\[
\begin{align*}
\frac{∂J(W,b;x,y)}{∂W_{ij^{(l)}}} = \frac{∂J(W,b;x,y)}{∂a_{i^{(l)}}} \frac{∂a_{i^{(l)}}}{∂z_{i^{(l)}}} \frac{∂z_{i^{(l)}}}{∂w_{ij^{(l)}}} = a_j^{(l)}δ_i^{(l+1)}
\end{align*}
\]

更加详细运算过程可以参考[一文弄懂神经网络中的反向传播法——BackPropagation],这篇文章详细的介绍了BP算法的每一步骤。

上面的公式中出现了\(δ\)(误差error),所以后续的目的就是求出每层每个node的\(δ\),具体过程如下:

  • 计算δ

对于第 \(n_l\)层(输出层)的每个输出单元\(i\),我们根据以下公式计算残差:

对 \(l = n_l-1, n_l-2, ……,3,2\)的各个层,第 \(l\) 层的第 \(i\) 个节点的残差计算方法如下:

将上面的结果带入权值更新的表达式中便可顺利的执行BackPropagation啦~~~

但是!!!需要注意的是上面式子中反复出现的 \(f '(z_i^{(l)})\) ,表示激活函数的导数。这个在刚开始的确困惑到我了,因为视频里老师在演示计算\(δ\)的时候根本就乘以这一项,难道老师错了?其实不是的,解释如下:

常用的激活函数有好几种,但使用是分情况的:

  • 线性情况下:f(z) = z
  • 非线性情况下:(只举一些我知道的例子)
    • sigmoid
    • tanh
    • relu

所以这就是为什么老师在视频中没有乘以 \(f '(z_i^{(l)})\) 的原因了,就是因为是线性的,求导后为1,直接省略了。

另外sigmoid函数表达式为\(f(z)=\frac{1}{1+e^{-z}}\),很容易知道\(f'(z)=\frac{-e^{-z}}{ (1+e^{-z}) ^2 } = f(z)·(1-f(z))\)这也就解释了Coursera网站上讲义的公式是这样的了:


所以现在总结一下BP算法步骤

  1. 进行前馈传导计算,利用前向传导公式,得到\(L_2, L_3, \ldots\)直到输出层 \(\textstyle L_{n_l}\)的激活值。
  2. 对输出层(第 \(\textstyle n_l\)层),计算:

    \(\delta^{(n_l)}= - (y - a^{(n_l)}) \bullet f'(z^{(n_l)})\)
  3. 对于 \(\textstyle l = n_l-1, n_l-2, n_l-3, \ldots, 2\) 的各层,计算:

    \(\delta^{(l)} = \left((W^{(l)})^T \delta^{(l+1)}\right) \bullet f'(z^{(l)})\)
  4. 计算最终需要的偏导数值:

    \[
    \begin{align}
    \nabla_{W^{(l)}} J(W,b;x,y) &= \delta^{(l+1)} (a^{(l)})^T, \\
    \nabla_{b^{(l)}} J(W,b;x,y) &= \delta^{(l+1)}.
    \end{align}
    \]

使用批量梯度下降一次迭代过程:

  1. 对于所有\(\textstyle l\),令 \(\textstyle \Delta W^{(l)} := 0 , \textstyle \Delta b^{(l)} := 0\) (设置为全零矩阵或全零向量)
  2. 对于\(\textstyle i = 1\) 到\(\textstyle m\) ,

    使用反向传播算法计算\(\textstyle \nabla_{W^{(l)}} J(W,b;x,y)\) 和\(\textstyle \nabla_{b^{(l)}} J(W,b;x,y)\) 。

    计算\(\textstyle \Delta W^{(l)} := \Delta W^{(l)} + \nabla_{W^{(l)}} J(W,b;x,y)\) 。

    计算\(\textstyle \Delta b^{(l)} := \Delta b^{(l)} + \nabla_{b^{(l)}} J(W,b;x,y)\) 。
  3. 更新权重参数:

    \[
    \begin{align}
    W^{(l)} &= W^{(l)} - \alpha \left[ \left(\frac{1}{m} \Delta W^{(l)} \right) + \lambda W^{(l)}\right] \\
    b^{(l)} &= b^{(l)} - \alpha \left[\frac{1}{m} \Delta b^{(l)}\right]
    \end{align}
    \]

3) Backpropagation Intuition

本小节演示了具体如何操作BP,不再赘述。

具体可参考Coursera讲义


MARSGGBO♥原创







2017-8-3

Andrew Ng机器学习课程笔记--week5(上)的更多相关文章

  1. Andrew Ng机器学习课程笔记--week5(下)

    Neural Networks: Learning 内容较多,故分成上下两篇文章. 一.内容概要 Cost Function and Backpropagation Cost Function Bac ...

  2. Andrew Ng机器学习课程笔记--week9(上)(异常检测&推荐系统)

    本周内容较多,故分为上下两篇文章. 一.内容概要 1. Anomaly Detection Density Estimation Problem Motivation Gaussian Distrib ...

  3. Andrew Ng机器学习课程笔记--汇总

    笔记总结,各章节主要内容已总结在标题之中 Andrew Ng机器学习课程笔记–week1(机器学习简介&线性回归模型) Andrew Ng机器学习课程笔记--week2(多元线性回归& ...

  4. Andrew Ng机器学习课程笔记(五)之应用机器学习的建议

    Andrew Ng机器学习课程笔记(五)之 应用机器学习的建议 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7368472.h ...

  5. Andrew Ng机器学习课程笔记(六)之 机器学习系统的设计

    Andrew Ng机器学习课程笔记(六)之 机器学习系统的设计 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7392408.h ...

  6. Andrew Ng机器学习课程笔记(四)之神经网络

    Andrew Ng机器学习课程笔记(四)之神经网络 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7365730.html 前言 ...

  7. Andrew Ng机器学习课程笔记(三)之正则化

    Andrew Ng机器学习课程笔记(三)之正则化 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7365475.html 前言 ...

  8. Andrew Ng机器学习课程笔记(二)之逻辑回归

    Andrew Ng机器学习课程笔记(二)之逻辑回归 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7364636.html 前言 ...

  9. Andrew Ng机器学习课程笔记--week1(机器学习介绍及线性回归)

    title: Andrew Ng机器学习课程笔记--week1(机器学习介绍及线性回归) tags: 机器学习, 学习笔记 grammar_cjkRuby: true --- 之前看过一遍,但是总是模 ...

随机推荐

  1. Linux常用命令及shell技巧

    这里列出一些个人在工作中常使用的各种linux命令,每一个不详细讲参数,只写经常用的参数.希望快速获得在linux命令行工作的能力的朋友可以看看.本人一直觉的,不使用linux 图形界面,以xshel ...

  2. Python项目实战:福布斯系列之数据采集

    1 数据采集概述 开始一个数据分析项目,首先需要做的就是get到原始数据,获得原始数据的方法有多种途径.比如: 获取数据集(dataset)文件 使用爬虫采集数据 直接获得excel.csv及其他数据 ...

  3. linux shell变量$#,$@,$0,$1,$2的含义解释

    变量说明: $$ Shell本身的PID(ProcessID) $! Shell最后运行的后台Process的PID $? 最后运行的命令的结束代码(返回值) $- 使用Set命令设定的Flag一览  ...

  4. 自动化运维—tomcat服务起停(mysql+shell+django+bootstrap+jquery)

    项目简介: 项目介绍:自动化运维是未来的趋势,最近学了不少东西,正好通过这个小项目把这些学的东西串起来,练练手. 基础架构: 服务器端:web框架-Django 前端:html css jQuery ...

  5. C语言基础 - 实现动态数组并增加内存管理

    用C语言实现一个动态数组,并对外暴露出对数组的增.删.改.查函数 (可以存储任意类型的元素并实现内存管理) 这里我的编译器就是xcode 分析: 模拟存放 一个 People类 有2个属性 字符串类型 ...

  6. EasyUi+Spring Data 实现按条件分页查询

    Spring data 介绍 Spring data 出现目的 为了简化.统一 持久层 各种实现技术 API ,所以 spring data 提供一套标准 API 和 不同持久层整合技术实现 . 自己 ...

  7. ServletContext对象统计在线人数

    package com.zdsofe.servlet1; import java.io.IOException; import java.io.PrintWriter; import javax.se ...

  8. iOS 配置

    1.git的配置 使用Github,也许大家觉得比较麻烦的就是在每次push的时候,都需要输入用户名和密码.如果使用SSH,就可以记住用户名,并创建属于自己的密码来保证安全操作,还有神奇的一招可以“不 ...

  9. CSS3新增文本属性实现图片点击切换效果

    <!doctype html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  10. mybatis深入理解之 # 与 $ 区别以及 sql 预编译

    mybatis 中使用 sqlMap 进行 sql 查询时,经常需要动态传递参数,例如我们需要根据用户的姓名来筛选用户时,sql 如下: select * from user where name = ...