在PyTorch中,集中于所有神经网络的是autograd包。首先,我们简要地看一下此工具包,然后我们将训练第一个神经网络。

autograd包为张量的所有操作提供了自动微分。它是一个运行式定义的框架,这意味着你的后向传播是由你的代码运行方式来定义的,并且每一个迭代都可以是不同的。

下面,让我们使用一些更简单的术语和例子来解释这个问题。

0x01 变量(Variable)

autograd.Variableautograd包的核心类,它封装了一个张量,并支持几乎所有在该张量上定义的操作。一旦完成了你的计算,你可以调用.backward(),它会自动计算所有梯度。

你可以通过.data属性访问原始的张量,而梯度w.r.t.这个变量被累积到.grad

还有一个类对于autograd的实现非常重要——一个函数。

变量和函数是相互联系的,并建立一个非循环图,它编码了计算的一个完整历史。每个变量都有一个.grad_fn属性,该属性引用了一个创建了该变量的函数(除了由用户创建的变量之外,它们的grad_fnNone)。

如果你想计算导数,你可以在一个变量上调用.backward()。如果变量是一个标量(也就是说它包含一个元素数据),那么你不需要为backward()指定任何参数,但是如果它有更多元素,那么你需要指定一个grad_output参数,该参数是一个匹配形状的张量。

import torch
from torch.autograd import Variable

创建一个变量:

x = Variable(torch.ones(2, 2), requires_grad=True)
print(x)

输出结果:

Variable containing:
1 1
1 1
[torch.FloatTensor of size 2x2]

做一个变量操作:

y = x + 2
print(y)

输出结果:

Variable containing:
3 3
3 3
[torch.FloatTensor of size 2x2]

y是由于操作而创建的,所以它有一个grad_fn

print(y.grad_fn)

输出结果:

<AddBackward0 object at 0x7ff91b4f0908>

y做更多操作:

z = y * y * 3
out = z.mean() print(z, out)

输出结果:

Variable containing:
27 27
27 27
[torch.FloatTensor of size 2x2]
Variable containing:
27
[torch.FloatTensor of size 1]

0x02 梯度(Gradients)

现在我们介绍后向传播,out.backward()等效于做out.backward(torch.Tensor([1.0]))

out.backward()

打印梯度d(out)/dx:

print(x.grad)

输出结果:

Variable containing:
4.5000 4.5000
4.5000 4.5000
[torch.FloatTensor of size 2x2]

你应该得到一个元素为4.5的矩阵。我们将这个变量叫做"o"。此时,我们有:

你可以利用梯度做很多疯狂的事情!

x = torch.randn(3)
x = Variable(x, requires_grad=True) y = x * 2
while y.data.norm() < 1000:
y = y * 2 print(y)

输出结果:

Variable containing:
164.9539
-511.5981
-1356.4794
[torch.FloatTensor of size 3]
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients) print(x.grad)

输出结果:

Variable containing:
204.8000
2048.0000
0.2048
[torch.FloatTensor of size 3]

扩展阅读: 变量和函数的文档在这里http://pytorch.org/docs/autograd

以上脚本的总运行时间为:0分0.009秒。

 

本文中所使用的Python代码:autograd_tutorial.py

【PyTorch深度学习60分钟快速入门 】Part2:Autograd自动化微分的更多相关文章

  1. 【PyTorch深度学习60分钟快速入门 】Part1:PyTorch是什么?

      0x00 PyTorch是什么? PyTorch是一个基于Python的科学计算工具包,它主要面向两种场景: 用于替代NumPy,可以使用GPU的计算力 一种深度学习研究平台,可以提供最大的灵活性 ...

  2. 【PyTorch深度学习60分钟快速入门 】Part0:系列介绍

      说明:本系列教程翻译自PyTorch官方教程<Deep Learning with PyTorch: A 60 Minute Blitz>,基于PyTorch 0.3.0.post4 ...

  3. 【PyTorch深度学习60分钟快速入门 】Part4:训练一个分类器

      太棒啦!到目前为止,你已经了解了如何定义神经网络.计算损失,以及更新网络权重.不过,现在你可能会思考以下几个方面: 0x01 数据集 通常,当你需要处理图像.文本.音频或视频数据时,你可以使用标准 ...

  4. 【PyTorch深度学习60分钟快速入门 】Part5:数据并行化

      在本节中,我们将学习如何利用DataParallel使用多个GPU. 在PyTorch中使用多个GPU非常容易,你可以使用下面代码将模型放在GPU上: model.gpu() 然后,你可以将所有张 ...

  5. 【PyTorch深度学习60分钟快速入门 】Part3:神经网络

      神经网络可以通过使用torch.nn包来构建. 既然你已经了解了autograd,而nn依赖于autograd来定义模型并对其求微分.一个nn.Module包含多个网络层,以及一个返回输出的方法f ...

  6. pytorch深度学习60分钟闪电战

    https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 官方推荐的一篇教程 Tensors #Construct a ...

  7. Vue.js 60 分钟快速入门

    Vue.js 60 分钟快速入门 转载 作者:keepfool 链接:http://www.cnblogs.com/keepfool/p/5619070.html Vue.js介绍 Vue.js是当下 ...

  8. 不会几个框架,都不好意思说搞过前端: Vue.js - 60分钟快速入门

    Vue.js——60分钟快速入门   Vue.js是当下很火的一个JavaScript MVVM库,它是以数据驱动和组件化的思想构建的.相比于Angular.js,Vue.js提供了更加简洁.更易于理 ...

  9. Vue.js——60分钟快速入门(转)

    vue:Vue.js——60分钟快速入门 <!doctype html> <html lang="en"> <head> <meta ch ...

随机推荐

  1. python 包和模块

    一. 模块 使用内置函数vars()可以查看当前环境下有哪些对象(变量.函数.类) from 模块 import *: 不会导入以下划线开头的对象 只会导入__all__中定义了的对象(__all__ ...

  2. JQuery复习心得

    this === event.currentTarget    event.stopPropagation  阻止冒泡  http:www.css88.com JQ和原生JS入口函数的区别: 书写个数 ...

  3. mysql启动错误,提示crash 错误

    :: mysqld_safe Starting mysqld daemon with databases from /data/mysql_data -- :: [Note] Plugin 'FEDE ...

  4. boost学习 泛型编程之traits 学习

    traits使用的场景一般有三种  分发到不同处理流程 解决C++代码中某些无法编译的问题 比如一个图书馆的代码,接受书籍并收入到不同类别中 template<class T> // T表 ...

  5. Java:内部接口

    1.什么是内部接口 内部接口也称为嵌套接口,即在一个接口内部定义另一个接口.举个例子,Entry接口定义在Map接口里面,如下代码: public interface Map { interface ...

  6. (PMP)第12章-----项目采购管理

    B D 12.1 规划采购管理 输入 工具与技术 输出 1.项目章程 2.商业文件 (商业文件, 效益管理计划) 3.项目管理计划 (范围,质量,资源管理计划, 范围基准) 4.项目文件 (里程碑清单 ...

  7. Win7 VS2017编译bgfx图形API

    官方的编译指南在这个页面 https://bkaradzic.github.io/bgfx/build.html#quick-start 目前的版本编译比较简单,下载3个项目,放于同级目录下 http ...

  8. 异常与Final

    Throwable 类是 Java 语言中所有错误或异常的超类(这就是一切皆可抛的东西).它有两个子类:Error和Exception.Error:用于指示合理的应用程序不应该试图捕获的严重问题.这种 ...

  9. 23.HashMap

    HashMap也是我们使用非常多的Collection,它是基于哈希表的 Map 接口的实现,以key-value的形式存在.在HashMap中,key-value总是会当做一个整体来处理,系统会根据 ...

  10. RxSwift学习笔记8:filter/distinctUntilChanged/single/elementAt/ignoreElements/take/takeLast/skip/sample/debounce

    //filter:该操作符就是用来过滤掉某些不符合要求的事件. Observable.of(1,2,3,4,5,8).filter({ $0 % 2 == 0 }).subscribe { (even ...