PyTorch深度学习实践——反向传播
反向传播
课程来源:PyTorch深度学习实践——河北工业大学
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
笔记
在之前课程中介绍的线性模型就是一个最简单的神经网络的结构,其内部参数的更新过程如下:

对于简单的模型来说可以直接使用表达式的方式来更新权重,但是如果网络结构比较复杂(如下图),直接使用解析式的方式来更新显然有些复杂且不太可能实现。

反向传播就是为了解决这种问题。反向传播的基本思想就是将网络看成一张图,在图上传播梯度,从而使用链式传播来计算梯度。首先介绍两层的网络的计算图的方式表示如下图所示:

矩阵求导参考书籍链接如下:https://bicmr.pku.edu.cn/~wenzw/bigdata/matrix-cook-book.pdf
如果把式子展开,将会有如下结果:

也就是多层线性模型的叠加是可以用一个线性模型来实现的。因此为了提高模型的复杂程度,对于每一层的输出增加一个非线性的变化函数,如sigmoid等函数,如下图所示:

反向传播的链式求导的过程一个实例如下图所示:

得到相应导数之后就可以对于权重进行更新,如果x也只是一个中间结果,则可以继续向前传导。
接下来可以看一个完整的线性模型的计算图示例,过程就是先进行前馈过程,在前馈到loss之后进行反向传播,从而完成计算:

接下来介绍在PyTorch中如何进行前馈和反馈计算。
首先需要介绍的是Tensor,这是PyTorch中构建动态图的一个重要组成部分,Tensor中主要元素的是Data(数据)和Grad(导数),分别用于保存权重值和损失函数对权重的导数。
使用PyTorch实现上述的线性模型的代码如下:
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.tensor([1.0]) #初值为1.0
w.requires_grad = True # 需要计算梯度
def forward(x):
return x*w # 返回tensor
def loss(x, y):
y_pred = forward(x)
return (y_pred - y)**2
print("predict (before training)", 4, forward(4).item())
for epoch in range(100):
for x, y in zip(x_data, y_data):
l =loss(x,y) #l是一个张量
l.backward() #将计算链路上需要梯度的地方计算出梯度,这一步之后计算图释放,每一次更新都创建新的计算图
print('\tgrad:', x, y, w.grad.item())#item是为了把梯度中的数值取出为标量
w.data = w.data - 0.01 * w.grad.data # 权重更新时,使用标量,使用data的时候不会建立新的计算图,注意grad也是一个tensor
w.grad.data.zero_() # 更新之后将梯度数据清零
print('progress:', epoch, l.item())
print("predict (after training)", 4, forward(4).item())
作业
1、手动推导线性模型y=w*x,损失函数loss=(ŷ-y)²下,当数据集x=2,y=4的时候,反向传播的过程。

2、手动推导线性模型 y=w*x+b,损失函数loss=(ŷ-y)²下,当数据集x=1,y=2的时候,反向传播的过程。

3、画出二次模型y=w1x²+w2x+b,损失函数loss=(ŷ-y)²的计算图,并且手动推导反向传播的过程,最后用pytorch的代码实现。

代码如下:
import torch
import matplotlib.pyplot as plt
import numpy as np
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
w1=torch.tensor([1.0],requires_grad=True)
w2=torch.tensor([1.0],requires_grad=True)
b=torch.tensor([1.0],requires_grad=True)
epoch_list=[]
loss_list=[]
def forward(x):
return w1*x**2+w2*x+b
def loss(x,y):
y_pred=forward(x)
return (y_pred-y)**2
print('Predict (befortraining)',4,forward(4))
for epoch in range(100):
for x,y in zip(x_data,y_data):
l=loss(x,y)
l.backward()
print('\tgrad:',x,y,w1.grad.item(),w2.grad.item(),b.grad.item())
w1.data=w1.data-0.01*w1.grad.data
w2.data = w2.data - 0.01 * w2.grad.data
b.data = b.data - 0.01 * b.grad.data
w1.grad.data.zero_()
w2.grad.data.zero_()
b.grad.data.zero_()
print('Epoch:', epoch, l.item())
epoch_list.append(epoch)
loss_list.append(l.data)
print('Predict(after training)', 4, forward(4).item())
print('predict (after training)', 4, forward(4))
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
可视化loss如下:

PyTorch深度学习实践——反向传播的更多相关文章
- PyTorch深度学习实践——多分类问题
多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...
- PyTorch深度学习实践——处理多维特征的输入
处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...
- 深度学习梯度反向传播出现Nan值的原因归类
症状:前向计算一切正常.梯度反向传播的时候就出现异常,梯度从某一层开始出现Nan值(Nan: Not a number缩写,在numpy中,np.nan != np.nan,是唯一个不等于自身的数). ...
- 深度学习之反向传播算法(BP)代码实现
反向传播算法实战 本文仅仅是反向传播算法的实现,不涉及公式推导,如果对反向传播算法公式推导不熟悉,强烈建议查看另一篇文章神经网络之反向传播算法(BP)公式推导(超详细) 我们将实现一个 4 层的全连接 ...
- PyTorch深度学习实践-Overview
Overview 1.PyTorch简介 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强 ...
- 深度学习实践系列(2)- 搭建notMNIST的深度神经网络
如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...
- 使用PyTorch构建神经网络以及反向传播计算
使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...
- 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络
前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...
- 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码
PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...
随机推荐
- Git配置用户信息和SSH免密
一.配置用户信息 1.查看配置信息 # 查看所有配置 $ git config -l/--list # 查看系统配置 $ git config --system -l/--list # 查看用户配置 ...
- 如何在pyqt中自定义SwitchButton
前言 网上有很多 SwitchButton 的实现方式,大部分是通过重写 paintEvent() 来实现的,感觉灵活性不是很好.所以希望实现一个可以联合使用 qss 来更换样式的 SwitchBut ...
- Codeforces Round #746 Div. 2
掉分快乐qwq C题代码以及分析(在注释里) /* * @Author: Nan97 * @Date: 2021-10-04 22:37:18 * @Last Modified by: Nan97 * ...
- Spring系列13:bean的生命周期
本文内容 bean的完整的生命周期 生命周期回调接口 Aware接口详解 Spring Bean的生命周期 面试热题:请描述下Spring的生命周期? 4大生命周期 从源码角度来说,简单分为4大阶段: ...
- Spring Boot配置多个DataSource (转)
使用Spring Boot时,默认情况下,配置DataSource非常容易.Spring Boot会自动为我们配置好一个DataSource. 如果在application.yml中指定了spring ...
- 为什么要配置path环境变量
因为在jdk下bin文件夹中有很多我们在开发中要使用的工具,如java.exe,javac.exe,jar.ex等,那么我们在使用时,想要在电脑的任意位置下使用这些java开发工具,那么我们就需有把这 ...
- NSPredicate类,指定过滤器的条件---董鑫
/* 比较和逻辑运算符 就像前面的例子中使用了==操作符,NSPredicate还支持>, >=, <, <=, !=, <>,还支持AND, OR, NOT(或写 ...
- ios 类别和扩展-赵小波
类别 @interface ClassName ( CategoryName ) // method declarations @end Category在iOS开发中使用非常频繁.尤其是在为系统类进 ...
- Zookeeper、Kafka集群与Filebeat+Kafka+ELK架构
Zookeeper.Kafka集群与Filebeat+Kafka+ELK架构 目录 Zookeeper.Kafka集群与Filebeat+Kafka+ELK架构 一.Zookeeper 1. Zook ...
- netty系列之:真正的平等–UDT中的Rendezvous
目录 简介 建立支持Rendezvous的服务器 处理不同的消息 节点之间的交互 总结 简介 在我们之前提到的所有netty知识中,netty好像都被分为客户端和服务器端两部分.服务器端监听连接,并对 ...