反向传播

课程来源:PyTorch深度学习实践——河北工业大学

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

笔记

在之前课程中介绍的线性模型就是一个最简单的神经网络的结构,其内部参数的更新过程如下:

对于简单的模型来说可以直接使用表达式的方式来更新权重,但是如果网络结构比较复杂(如下图),直接使用解析式的方式来更新显然有些复杂且不太可能实现。

反向传播就是为了解决这种问题。反向传播的基本思想就是将网络看成一张图,在图上传播梯度,从而使用链式传播来计算梯度。首先介绍两层的网络的计算图的方式表示如下图所示:

矩阵求导参考书籍链接如下:https://bicmr.pku.edu.cn/~wenzw/bigdata/matrix-cook-book.pdf

如果把式子展开,将会有如下结果:

也就是多层线性模型的叠加是可以用一个线性模型来实现的。因此为了提高模型的复杂程度,对于每一层的输出增加一个非线性的变化函数,如sigmoid等函数,如下图所示:

反向传播的链式求导的过程一个实例如下图所示:

得到相应导数之后就可以对于权重进行更新,如果x也只是一个中间结果,则可以继续向前传导。

接下来可以看一个完整的线性模型的计算图示例,过程就是先进行前馈过程,在前馈到loss之后进行反向传播,从而完成计算:

接下来介绍在PyTorch中如何进行前馈和反馈计算。

首先需要介绍的是Tensor,这是PyTorch中构建动态图的一个重要组成部分,Tensor中主要元素的是Data(数据)和Grad(导数),分别用于保存权重值和损失函数对权重的导数。

使用PyTorch实现上述的线性模型的代码如下:

import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0] w = torch.tensor([1.0]) #初值为1.0
w.requires_grad = True # 需要计算梯度 def forward(x):
return x*w # 返回tensor def loss(x, y):
y_pred = forward(x)
return (y_pred - y)**2 print("predict (before training)", 4, forward(4).item()) for epoch in range(100):
for x, y in zip(x_data, y_data):
l =loss(x,y) #l是一个张量
l.backward() #将计算链路上需要梯度的地方计算出梯度,这一步之后计算图释放,每一次更新都创建新的计算图
print('\tgrad:', x, y, w.grad.item())#item是为了把梯度中的数值取出为标量
w.data = w.data - 0.01 * w.grad.data # 权重更新时,使用标量,使用data的时候不会建立新的计算图,注意grad也是一个tensor
w.grad.data.zero_() # 更新之后将梯度数据清零
print('progress:', epoch, l.item())
print("predict (after training)", 4, forward(4).item())

作业

1、手动推导线性模型y=w*x,损失函数loss=(ŷ-y)²下,当数据集x=2,y=4的时候,反向传播的过程。

2、手动推导线性模型 y=w*x+b,损失函数loss=(ŷ-y)²下,当数据集x=1,y=2的时候,反向传播的过程。

3、画出二次模型y=w1x²+w2x+b,损失函数loss=(ŷ-y)²的计算图,并且手动推导反向传播的过程,最后用pytorch的代码实现。

代码如下:

import torch
import matplotlib.pyplot as plt
import numpy as np
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
w1=torch.tensor([1.0],requires_grad=True)
w2=torch.tensor([1.0],requires_grad=True)
b=torch.tensor([1.0],requires_grad=True)
epoch_list=[]
loss_list=[]
def forward(x):
return w1*x**2+w2*x+b
def loss(x,y):
y_pred=forward(x)
return (y_pred-y)**2
print('Predict (befortraining)',4,forward(4))
for epoch in range(100):
for x,y in zip(x_data,y_data):
l=loss(x,y)
l.backward()
print('\tgrad:',x,y,w1.grad.item(),w2.grad.item(),b.grad.item())
w1.data=w1.data-0.01*w1.grad.data
w2.data = w2.data - 0.01 * w2.grad.data
b.data = b.data - 0.01 * b.grad.data
w1.grad.data.zero_()
w2.grad.data.zero_()
b.grad.data.zero_()
print('Epoch:', epoch, l.item())
epoch_list.append(epoch)
loss_list.append(l.data)
print('Predict(after training)', 4, forward(4).item())
print('predict (after training)', 4, forward(4))
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

可视化loss如下:

PyTorch深度学习实践——反向传播的更多相关文章

  1. PyTorch深度学习实践——多分类问题

    多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...

  2. PyTorch深度学习实践——处理多维特征的输入

    处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...

  3. 深度学习梯度反向传播出现Nan值的原因归类

    症状:前向计算一切正常.梯度反向传播的时候就出现异常,梯度从某一层开始出现Nan值(Nan: Not a number缩写,在numpy中,np.nan != np.nan,是唯一个不等于自身的数). ...

  4. 深度学习之反向传播算法(BP)代码实现

    反向传播算法实战 本文仅仅是反向传播算法的实现,不涉及公式推导,如果对反向传播算法公式推导不熟悉,强烈建议查看另一篇文章神经网络之反向传播算法(BP)公式推导(超详细) 我们将实现一个 4 层的全连接 ...

  5. PyTorch深度学习实践-Overview

    Overview 1.PyTorch简介 ​ PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强 ...

  6. 深度学习实践系列(2)- 搭建notMNIST的深度神经网络

    如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...

  7. 使用PyTorch构建神经网络以及反向传播计算

    使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...

  8. 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络

    前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...

  9. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

随机推荐

  1. 手把手教你用Strace诊断问题

    手把手教你用Strace诊断问题 发表于2015-10-16 早些年,如果你知道有个 strace 命令,就很牛了,而现在大家基本都知道 strace 了,如果你遇到性能问题求助别人,十有八九会建议你 ...

  2. 字体替换 re.sub

    dic={'hqo3r': '迎', 'hq6ic': '名', 'hq7yw': '头', 'hq1lk': '新', 'hqpe1': '肇'} content=''' 总体hqo3r则,错的注( ...

  3. linux 常用命令。

    /* Linux常用命令? 1 查看 ls             展示当前目录下的可见文件   ls -a         展示当前目录下所有的文件(包括隐藏的文件)   ls -l(ll)     ...

  4. RTSP实例解析

    以下是某地IPTV的RTSP协商过程: 1.DESCRIBE 请求: //方法和媒体URL DESCRIBE rtsp://118.122.89.27:554/live/ch1008312159479 ...

  5. .net core部署到ubuntu 上传文件超过30MB

    默认的上传文件不能超过30MB,需要修改几个地方 一.web.config中添加配置 <requestLimits maxAllowedContentLength="214748364 ...

  6. Idea中新建package包,却变成了Directory

    问题描述 今天在IdeaJava工程src下新建了一个名字叫包implements的包,右键在里面新建类时,却发现根本就没有Class这个选项,然后发现implements这个包的图标也和其他包的图标 ...

  7. jar包冲突时怎么办

    因为项目中会依赖许多jar包,免不得就会有冲突,那怎么解决呢? 使用 mvn dependency:tree 可以看到各个包的依赖关系 [INFO] | +- commons-cli:commons- ...

  8. PHP版的猴子选大王算法

    猴子选大王 这个算法可能是目前我看到的最简洁都算法吧,而且很好理解.它不同于其他算法,其他算法都是判断这个猴子能不能被选中,而他只是找出不能被选中的猴子,然后将其塞到数组模拟的环状队列中,参与下次选. ...

  9. lua语言:string

    转载请注明来源:https://www.cnblogs.com/hookjc/ 字符串库函数string.len(s)          返回字符串s的长度:string.rep(s, n)      ...

  10. undefined index: php中提示Undefined ...

    我们经常接收表单POST过来的数据时报Undefined index错误,如下:$act=$_POST['action'];用以上代码总是提示Notice: Undefined index: act ...