『PyTorch』第三弹_自动求导

torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现

Varibale包含三个属性:

  • data:存储了Tensor,是本体的数据
  • grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
  • grad_fn:指向Function对象,用于反向传播的梯度计算之用

data

import torch as t
from torch.autograd import Variable x = Variable(t.ones(2, 2), requires_grad = True)
x # 实际查询的是.data,是个Tensor

实际上查询x和查询x.data返回结果一致,

Variable containing:

1 1

1 1

[torch.FloatTensor of size 2x2]

梯度求解

构建一个简单的方程:y = x[0,0] + x[0,1] + x[1,0] + x[1,1],Variable的运算结果也是Variable,但是,中间结果反向传播中不会被求导()

这和TensorFlow不太一致,TensorFlow中中间运算果数据结构均是Tensor,

y = x.sum()

y
"""
  Variable containing:
   4
  [torch.FloatTensor of size 1]
"""

可以查看目标函数的.grad_fn方法,它用来求梯度,

y.grad_fn
"""
    <SumBackward0 at 0x18bcbfcdd30>
""" y.backward() # 反向传播
x.grad # Variable的梯度保存在Variable.grad中
"""
  Variable containing:
   1 1
   1 1
  [torch.FloatTensor of size 2x2]
"""

grad属性保存在Variable中,新的梯度下来会进行累加,可以看到再次求导后结果变成了2,

y.backward()
x.grad # 可以看到变量梯度是累加的
"""
Variable containing:
2 2
2 2
[torch.FloatTensor of size 2x2]
"""

所以要归零,

x.grad.data.zero_()  # 归零梯度,注意,在torch中所有的inplace操作都是要带下划线的,虽然就没有.data.zero()方法

"""
0 0
0 0
[torch.FloatTensor of size 2x2]
"""

对比Variable和Tensor的接口,相差无两,

Variable和Tensor的接口近乎一致,可以无缝切换

x = Variable(t.ones(4, 5))

y = t.cos(x)                         # 传入Variable
x_tensor_cos = t.cos(x.data) # 传入Tensor print(y)
print(x_tensor_cos) """
Variable containing:
0.5403 0.5403 0.5403 0.5403 0.5403
0.5403 0.5403 0.5403 0.5403 0.5403
0.5403 0.5403 0.5403 0.5403 0.5403
0.5403 0.5403 0.5403 0.5403 0.5403
[torch.FloatTensor of size 4x5] 0.5403 0.5403 0.5403 0.5403 0.5403
0.5403 0.5403 0.5403 0.5403 0.5403
0.5403 0.5403 0.5403 0.5403 0.5403
0.5403 0.5403 0.5403 0.5403 0.5403
[torch.FloatTensor of size 4x5]
"""

『PyTorch』第三弹重置_Variable对象的更多相关文章

  1. 『PyTorch』第三弹_自动求导

    torch.autograd 包提供Tensor所有操作的自动求导方法. 数据结构介绍 autograd.Variable 这是这个包中最核心的类. 它包装了一个Tensor,并且几乎支持所有的定义在 ...

  2. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  3. 关于『HTML』:第三弹

    关于『HTML』:第三弹 建议缩放90%食用 盼望着, 盼望着, 第三弹来了, HTML基础系列完结了!! 一切都像刚睡醒的样子(包括我), 欣欣然张开了眼(我没有) 敬请期待Markdown语法系列 ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数

    一.封装新的PyTorch函数 继承Function类 forward:输入Variable->中间计算Tensor->输出Variable backward:均使用Variable 线性 ...

  6. 『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor

    Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arang ...

  7. 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法

    在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...

  8. 『MXNet』第三弹_Gluon模型参数

    MXNet中含有init包,它包含了多种模型初始化方法. from mxnet import init, nd from mxnet.gluon import nn net = nn.Sequenti ...

  9. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...

随机推荐

  1. 防DNS劫持教程,手动修复本地DNS教程

    防DNS劫持教程,手动修复本地DNS教程 该如何避免DNS劫持的问题呢?1. 请不要轻易连接陌生网络.2. 可以通过手动指定DNS(DNS用于将域名正确转换为您想访问的网站的作用),修改后你的网络应用 ...

  2. 精力管理 | 迅速恢复精力的N个技巧,四个关键词以及自我管理的方法和工具列表

    精力管理 | 迅速恢复精力的N个技巧,所谓坚持,是坚定的“持有”,这个“持”字很值得琢磨——不是扛.不是顶,而是“持”这样一个半放松的状态.如果你没做好自己该做的事情,如果你自己没有成长起来,随着年龄 ...

  3. Python Web学习笔记之TCP、UDP、ICMP、IGMP的解释和区别

    TCP与UDP解释 TCP---传输控制协议,提供的是面向连接.可靠的字节流服务.当客户和服务器彼此交换数据前,必须先在双方之间建立一个TCP连接,之后才能传输数据.TCP提供超时重发,丢弃重复数据, ...

  4. pythoy的configparser模块

    生成配置文件的模块 DEFAULT块,在以块为单位取块的值时,都会出现 import configparser config = configparser.ConfigParser() #相当于生成了 ...

  5. c++标准库多线程入门

    从c++ 11开始,语言核心和标准库开始引入了对多线程的原生支持.如下所示: int doSth(char c) { default_random_engine dre(c); uniform_int ...

  6. htm5之视频音频(shit IE10都不支持)

    <p style="color: red; background-color: black;"> 视频<br /> autoplay autoplay 如果 ...

  7. cogs 1962. [HAOI2015]树上染色

    ★★☆   输入文件:haoi2015_t1.in   输出文件:haoi2015_t1.out   简单对比 时间限制:1 s   内存限制:256 MB [题目描述] 有一棵点数为N的树,树边有边 ...

  8. Metasploit应用举例

    本篇文章包含以下几方面内容: 1.Metasploit端口扫描: 2.用其他模块 3.metasploit smb获取系统信息 4.Metsploit服务识别 5.ftp识别: 6.metasploi ...

  9. Https流程,openssl本地自建证书,抓包

    HTTPS:超文本安全传输协议,和HTTP相比,多了一个SSL/TSL的认证过程,端口为443在http(超文本传输协议)基础上提出的一种安全的http协议,因此可以称为安全的超文本传输协议.http ...

  10. chromedriver下载安装

    博主开发平台是win10,Python版本是3.6.最近需要用到chromedriver+selenium,下载好selenium后,pip install chromedriver,直接安装到pyt ...