反向传播

课程来源:PyTorch深度学习实践——河北工业大学

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

笔记

在之前课程中介绍的线性模型就是一个最简单的神经网络的结构,其内部参数的更新过程如下:

对于简单的模型来说可以直接使用表达式的方式来更新权重,但是如果网络结构比较复杂(如下图),直接使用解析式的方式来更新显然有些复杂且不太可能实现。

反向传播就是为了解决这种问题。反向传播的基本思想就是将网络看成一张图,在图上传播梯度,从而使用链式传播来计算梯度。首先介绍两层的网络的计算图的方式表示如下图所示:

矩阵求导参考书籍链接如下:https://bicmr.pku.edu.cn/~wenzw/bigdata/matrix-cook-book.pdf

如果把式子展开,将会有如下结果:

也就是多层线性模型的叠加是可以用一个线性模型来实现的。因此为了提高模型的复杂程度,对于每一层的输出增加一个非线性的变化函数,如sigmoid等函数,如下图所示:

反向传播的链式求导的过程一个实例如下图所示:

得到相应导数之后就可以对于权重进行更新,如果x也只是一个中间结果,则可以继续向前传导。

接下来可以看一个完整的线性模型的计算图示例,过程就是先进行前馈过程,在前馈到loss之后进行反向传播,从而完成计算:

接下来介绍在PyTorch中如何进行前馈和反馈计算。

首先需要介绍的是Tensor,这是PyTorch中构建动态图的一个重要组成部分,Tensor中主要元素的是Data(数据)和Grad(导数),分别用于保存权重值和损失函数对权重的导数。

使用PyTorch实现上述的线性模型的代码如下:

import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0] w = torch.tensor([1.0]) #初值为1.0
w.requires_grad = True # 需要计算梯度 def forward(x):
return x*w # 返回tensor def loss(x, y):
y_pred = forward(x)
return (y_pred - y)**2 print("predict (before training)", 4, forward(4).item()) for epoch in range(100):
for x, y in zip(x_data, y_data):
l =loss(x,y) #l是一个张量
l.backward() #将计算链路上需要梯度的地方计算出梯度,这一步之后计算图释放,每一次更新都创建新的计算图
print('\tgrad:', x, y, w.grad.item())#item是为了把梯度中的数值取出为标量
w.data = w.data - 0.01 * w.grad.data # 权重更新时,使用标量,使用data的时候不会建立新的计算图,注意grad也是一个tensor
w.grad.data.zero_() # 更新之后将梯度数据清零
print('progress:', epoch, l.item())
print("predict (after training)", 4, forward(4).item())

作业

1、手动推导线性模型y=w*x,损失函数loss=(ŷ-y)²下,当数据集x=2,y=4的时候,反向传播的过程。

2、手动推导线性模型 y=w*x+b,损失函数loss=(ŷ-y)²下,当数据集x=1,y=2的时候,反向传播的过程。

3、画出二次模型y=w1x²+w2x+b,损失函数loss=(ŷ-y)²的计算图,并且手动推导反向传播的过程,最后用pytorch的代码实现。

代码如下:

import torch
import matplotlib.pyplot as plt
import numpy as np
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
w1=torch.tensor([1.0],requires_grad=True)
w2=torch.tensor([1.0],requires_grad=True)
b=torch.tensor([1.0],requires_grad=True)
epoch_list=[]
loss_list=[]
def forward(x):
return w1*x**2+w2*x+b
def loss(x,y):
y_pred=forward(x)
return (y_pred-y)**2
print('Predict (befortraining)',4,forward(4))
for epoch in range(100):
for x,y in zip(x_data,y_data):
l=loss(x,y)
l.backward()
print('\tgrad:',x,y,w1.grad.item(),w2.grad.item(),b.grad.item())
w1.data=w1.data-0.01*w1.grad.data
w2.data = w2.data - 0.01 * w2.grad.data
b.data = b.data - 0.01 * b.grad.data
w1.grad.data.zero_()
w2.grad.data.zero_()
b.grad.data.zero_()
print('Epoch:', epoch, l.item())
epoch_list.append(epoch)
loss_list.append(l.data)
print('Predict(after training)', 4, forward(4).item())
print('predict (after training)', 4, forward(4))
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

可视化loss如下:

PyTorch深度学习实践——反向传播的更多相关文章

  1. PyTorch深度学习实践——多分类问题

    多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...

  2. PyTorch深度学习实践——处理多维特征的输入

    处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...

  3. 深度学习梯度反向传播出现Nan值的原因归类

    症状:前向计算一切正常.梯度反向传播的时候就出现异常,梯度从某一层开始出现Nan值(Nan: Not a number缩写,在numpy中,np.nan != np.nan,是唯一个不等于自身的数). ...

  4. 深度学习之反向传播算法(BP)代码实现

    反向传播算法实战 本文仅仅是反向传播算法的实现,不涉及公式推导,如果对反向传播算法公式推导不熟悉,强烈建议查看另一篇文章神经网络之反向传播算法(BP)公式推导(超详细) 我们将实现一个 4 层的全连接 ...

  5. PyTorch深度学习实践-Overview

    Overview 1.PyTorch简介 ​ PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强 ...

  6. 深度学习实践系列(2)- 搭建notMNIST的深度神经网络

    如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...

  7. 使用PyTorch构建神经网络以及反向传播计算

    使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...

  8. 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络

    前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...

  9. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

随机推荐

  1. 什么是VPC

    1 什么是私有网络(VPC) 私有网络是一块可用户自定义的网络空间,您可以在私有网络内部署云主机.负载均衡.数据库.Nosql快存储等云服务资源.您可自由划分网段.制定路由策略.私有网络可以配置公网网 ...

  2. Web安全防护(二)

    点击劫持 点击劫持,也称UI覆盖攻击 1.1 iframe覆盖攻击 黑客创建一个网页,用iframe包含了目标网站,并且把它隐藏起来.做一个伪装的页面或图片盖上去,且按钮与目标网站一致,诱导用户去点击 ...

  3. MySQL存储引擎(最全面的概括)

    目录 一:MySQL存储引擎 1.什么是存储引擎? 2.查看存储引擎信息 二:MySQL支持的存储引擎 1.存储引擎 三:innoDB存储引擎 1.特性 2.存储结构 3.优缺点.适用场景 四:MyI ...

  4. Asp-Net-Core学习笔记:身份认证入门

    前言 过年前我又来更新了~ 我就说了最近不是在偷懒吧,其实这段时间还是有积累一些东西的,不过还没去整理-- 所以只能发以前没写完的一些笔记出来 就当做是温习一下啦 PS:之前说的红包封面我还没搞,得抓 ...

  5. ApacheCN 大数据译文集 20211206 更新

    PySpark 大数据分析实用指南 零.前言 一.安装 Pyspark 并设置您的开发环境 二.使用 RDD 将您的大数据带入 Spark 环境 三.Spark 笔记本的大数据清理和整理 四.将数据汇 ...

  6. Nginx实现跨域配置详解

    主要给大家介绍了关于Nginx跨域使用字体文件的相关内容,分享出来供大家参考学习,下面来一起看看详细的介绍: 问题描述 今天在使用子域名访问根域名的CSS时,发现字体无法显示,在确保CSS和Font字 ...

  7. Spring学习七:ComponentScan注解

    今天主要从以下几个方面来介绍一下@ComponentScan注解: @ComponentScan注解是什么 @ComponentScan注解的详细使用 1.ComponentScan注解是什么 其实很 ...

  8. 入门-k8s查看Pods/Nodes (四)

    目标 了解Kubernetes Pods(容器组) 了解Kubernetes Nodes(节点) 排查故障 Kubernetes Pods 在 部署第一个应用程序 中创建 Deployment 后,k ...

  9. pageX的兼容性处理2

    <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  10. tomcat访问所有的资源,都是用Servlet来实现的

    感谢大佬:https://www.zhihu.com/question/57400909 tomcat访问所有的资源,都是用Servlet来实现的. 在Tomcat看来,资源分3种 静态资源,如css ...