神经网络入门篇:神经网络的梯度下降(Gradient descent for neural networks)
神经网络的梯度下降
- 在这篇博客中,讲的是实现反向传播或者说梯度下降算法的方程组
单隐层神经网络会有\(W^{[1]}\),\(b^{[1]}\),\(W^{[2]}\),\(b^{[2]}\)这些参数,还有个\(n_x\)表示输入特征的个数,\(n^{[1]}\)表示隐藏单元个数,\(n^{[2]}\)表示输出单元个数。
在这个例子中,只介绍过的这种情况,那么参数:
矩阵\(W^{[1]}\)的维度就是(\(n^{[1]}, n^{[0]}\)),\(b^{[1]}\)就是\(n^{[1]}\)维向量,可以写成\((n^{[1]}, 1)\),就是一个的列向量。
矩阵\(W^{[2]}\)的维度就是(\(n^{[2]}, n^{[1]}\)),\(b^{[2]}\)的维度就是\((n^{[2]},1)\)维度。
还有一个神经网络的成本函数,假设在做二分类任务,那么的成本函数等于:
Cost function:
公式:
\(J(W^{[1]},b^{[1]},W^{[2]},b^{[2]}) = {\frac{1}{m}}\sum_{i=1}^mL(\hat{y}, y)\)
loss function和之前做logistic回归完全一样。
训练参数需要做梯度下降,在训练神经网络的时候,随机初始化参数很重要,而不是初始化成全零。当参数初始化成某些值后,每次梯度下降都会循环计算以下预测值:
\(\hat{y}^{(i)},(i=1,2,…,m)\)
公式1.28:
\(dW^{[1]} = \frac{dJ}{dW^{[1]}},db^{[1]} = \frac{dJ}{db^{[1]}}\)
公式1.29:
\({d}W^{[2]} = \frac{{dJ}}{dW^{[2]}},{d}b^{[2]} = \frac{dJ}{db^{[2]}}\)
其中
公式1.30:
\(W^{[1]}\implies{W^{[1]} - adW^{[1]}},b^{[1]}\implies{b^{[1]} -adb^{[1]}}\)
公式1.31:
\(W^{[2]}\implies{W^{[2]} - \alpha{\rm d}W^{[2]}},b^{[2]}\implies{b^{[2]} - \alpha{\rm d}b^{[2]}}\)
正向传播方程如下(之前讲过):
forward propagation:
(1)
\(z^{[1]} = W^{[1]}x + b^{[1]}\)
(2)
\(a^{[1]} = \sigma(z^{[1]})\)
(3)
\(z^{[2]} = W^{[2]}a^{[1]} + b^{[2]}\)
(4)
\(a^{[2]} = g^{[2]}(z^{[z]}) = \sigma(z^{[2]})\)
反向传播方程如下:
back propagation:
公式1.32:
$ dz^{[2]} = A^{[2]} - Y , Y = \begin{bmatrix}y^{[1]} & y^{[2]} & \cdots & y^{[m]}\ \end{bmatrix} $
公式1.33:
$ dW^{[2]} = {\frac{1}{m}}dz{[2]}A $
公式1.34:
$ {\rm d}b^{[2]} = {\frac{1}{m}}np.sum({d}z^{[2]},axis=1,keepdims=True)$
公式1.35:
$ dz^{[1]} = \underbrace{W^{[2]T}{\rm d}z{[2]}}_{(n,m)}\quad\underbrace{{g{[1]}}{'}}_{activation ; function ; of ; hidden ; layer}\quad\underbrace{(z{[1]})}_{(n,m)} $
公式1.36:
\(dW^{[1]} = {\frac{1}{m}}dz^{[1]}x^{T}\)
公式1.37:
\({\underbrace{db^{[1]}}_{(n^{[1]},1)}} = {\frac{1}{m}}np.sum(dz^{[1]},axis=1,keepdims=True)\)
上述是反向传播的步骤,注:这些都是针对所有样本进行过向量化,\(Y\)是\(1×m\)的矩阵;这里np.sum是python的numpy命令,axis=1表示水平相加求和,keepdims是防止python输出那些古怪的秩数\((n,)\),加上这个确保阵矩阵\(db^{[2]}\)这个向量输出的维度为\((n,1)\)这样标准的形式。
目前为止,计算的都和Logistic回归十分相似,但当开始计算反向传播时,需要计算,是隐藏层函数的导数,输出在使用sigmoid函数进行二元分类。这里是进行逐个元素乘积,因为\(W^{[2]T}dz^{[2]}\)和\((z^{[1]})\)这两个都为\((n^{[1]},m)\)矩阵;
还有一种防止python输出奇怪的秩数,需要显式地调用reshape把np.sum输出结果写成矩阵形式。
以上就是正向传播的4个方程和反向传播的6个方程,这里是直接给出的。
神经网络入门篇:神经网络的梯度下降(Gradient descent for neural networks)的更多相关文章
- 机器学习(1)之梯度下降(gradient descent)
机器学习(1)之梯度下降(gradient descent) 题记:最近零碎的时间都在学习Andrew Ng的machine learning,因此就有了这些笔记. 梯度下降是线性回归的一种(Line ...
- 梯度下降(Gradient Descent)小结 -2017.7.20
在求解算法的模型函数时,常用到梯度下降(Gradient Descent)和最小二乘法,下面讨论梯度下降的线性模型(linear model). 1.问题引入 给定一组训练集合(training se ...
- 梯度下降(gradient descent)算法简介
梯度下降法是一个最优化算法,通常也称为最速下降法.最速下降法是求解无约束优化问题最简单和最古老的方法之一,虽然现在已经不具有实用性,但是许多有效算法都是以它为基础进行改进和修正而得到的.最速下降法是用 ...
- 梯度下降(Gradient descent)
首先,我们继续上一篇文章中的例子,在这里我们增加一个特征,也即卧室数量,如下表格所示: 因为在上一篇中引入了一些符号,所以这里再次补充说明一下: x‘s:在这里是一个二维的向量,例如:x1(i)第i间 ...
- (二)深入梯度下降(Gradient Descent)算法
一直以来都以为自己对一些算法已经理解了,直到最近才发现,梯度下降都理解的不好. 1 问题的引出 对于上篇中讲到的线性回归,先化一个为一个特征θ1,θ0为偏置项,最后列出的误差函数如下图所示: 手动求解 ...
- 机器学习中的数学(1)-回归(regression)、梯度下降(gradient descent)
版权声明: 本文由LeftNotEasy所有,发布于http://leftnoteasy.cnblogs.com.如果转载,请注明出处,在未经作者同意下将本文用于商业用途,将追究其法律责任. 前言: ...
- CS229 2.深入梯度下降(Gradient Descent)算法
1 问题的引出 对于上篇中讲到的线性回归,先化一个为一个特征θ1,θ0为偏置项,最后列出的误差函数如下图所示: 手动求解 目标是优化J(θ1),得到其最小化,下图中的×为y(i),下面给出TrainS ...
- 回归(regression)、梯度下降(gradient descent)
本文由LeftNotEasy所有,发布于http://leftnoteasy.cnblogs.com.如果转载,请注明出处,在未经作者同意下将本文用于商业用途,将追究其法律责任. 前言: 上次写过一篇 ...
- 吴恩达深度学习:2.3梯度下降Gradient Descent
1.用梯度下降算法来训练或者学习训练集上的参数w和b,如下所示,第一行是logistic回归算法,第二行是成本函数J,它被定义为1/m的损失函数之和,损失函数可以衡量你的算法的效果,每一个训练样例都输 ...
- 梯度下降算法 Gradient Descent
梯度下降算法 Gradient Descent 梯度下降算法是一种被广泛使用的优化算法.在读论文的时候碰到了一种参数优化问题: 在函数\(F\)中有若干参数是不确定的,已知\(n\)组训练数据,期望找 ...
随机推荐
- 论文解读(DWL)《Dynamic Weighted Learning for Unsupervised Domain Adaptation》
[ Wechat:Y466551 | 付费咨询,非诚勿扰 ] 论文信息 论文标题:Dynamic Weighted Learning for Unsupervised Domain Adaptatio ...
- jsp+servlet实战项目
第一步:新建maven项目,项目中添加dao,entity,service,servlet,util包第二步:导入依赖 第三步:数据库建表 第四步:entity实体包(疯转) 第五步:在util工具包 ...
- 2023-08-14:用go语言写算法。给出两个长度相同的字符串 str1 和 str2 请你帮忙判断字符串 str1 能不能在 零次 或 多次 转化 后变成字符串 str2 每一次转化时,你可以将
2023-08-14:用go语言写算法.给出两个长度相同的字符串 str1 和 str2, 请你帮忙判断字符串 str1 能不能在 零次 或 多次 转化 后变成字符串 str2, 每一次转化时,你可以 ...
- Hugging News #0814: Llama 2 学习资源大汇总 🦙
每一周,我们的同事都会向社区的成员们发布一些关于 Hugging Face 相关的更新,包括我们的产品和平台更新.社区活动.学习资源和内容更新.开源库和模型更新等,我们将其称之为「Hugging Ne ...
- 知识图谱(Knowledge Graph)- Neo4j 5.10.0 使用 - CQL - 太极拳传承谱系表
删除数据库中以往的图 MATCH (n) DETACH DELETE n 创建节点 CREATE命令语法 Neo4j CQL"CREATE"命令用于创建没有属性的节点. 它只是创建 ...
- 1.JDK的安装与卸载
1.卸载: 卸载或更改程序,找到相应的JDK程序,删除 2.安装: 官网下载JDK程序:jdk-8u25-windows-i586.exe 双击安装程序,同意协议,更改安装路径:C:\jdk1.8.0 ...
- 在Godot 3.X中添加触屏摇杆
开源项目地址:https://github.com/shinneider/godot_touchJoyPad 效果图: 下载项目 方法一 直接从godot assets lib下载 如图,直接下载自动 ...
- ETL之apache hop系列3-hop Server环境部署与客户端发布管道工作流
前言 该文档主要是apache hop 2.5的 Windows 10和Linux docker环境部署和客户端发布工作流和管道的相关内容 不使用Docker直接使用应用程序包,下载压缩包文件后,需要 ...
- java与es8实战之六:用JSON创建请求对象(比builder pattern更加直观简洁)
欢迎访问我的GitHub 这里分类和汇总了欣宸的全部原创(含配套源码):https://github.com/zq2599/blog_demos 本篇概览 本文是<java与es8实战>系 ...
- grinder使用入门
安装配置 自动安装脚本编写 移步我的github下载脚本,只要修改下配置信息,即可完成安装 环境变量设置: