.简介
torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现
Variable和tensor的区别和联系
Variable是篮子,而tensor是鸡蛋,鸡蛋应该放在篮子里才能方便拿走(定义variable时一个参数就是tensor)
Variable这个篮子里除了装了tensor外还有requires_grad参数,表示是否需要对其求导,默认为False
Variable这个篮子呢,自身有一些属性
比如grad,梯度variable.grad是d(y)/d(variable)保存的是变量y对variable变量的梯度值,如果requires_grad参数为False,所以variable.grad返回值为None,如果为True,返回值就为对variable的梯度值
比如grad_fn,对于用户自己创建的变量(Variable())grad_fn是为none的,也就是不能调用backward函数,但对于由计算生成的变量,如果存在一个生成中间变量的requires_grad为true,那其的grad_fn不为none,反则为none
比如data,这个就很简单,这个属性就是装的鸡蛋(tensor)
Varibale包含三个属性:
data:存储了Tensor,是本体的数据
grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
grad_fn:指向Function对象,用于反向传播的梯度计算之用
 
具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。
那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式)。
如果用Variable计算的话,那返回的也是一个同类型的Variable。
【tensor 是一个多维矩阵】
用一个例子说明,Variable的定义:
import torch
from torch.autograd import Variable # torch 中 Variable 模块
tensor = torch.FloatTensor([[,],[,]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True) print(tensor)
""" [torch.FloatTensor of size 2x2]
""" print(variable)
"""
Variable containing: [torch.FloatTensor of size 2x2]
"""

 注:tensor不能反向传播,variable可以反向传播

二、Variable求梯度

Variable计算时,它会逐渐地生成计算图。这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力。

v_out.backward()    # 模拟 v_out 的误差反向传递

print(variable.grad)    # 初始 Variable 的梯度
'''
0.5000 1.0000
1.5000 2.0000
'''

 

三、获取Variable里面的数据

直接print(Variable) 只会输出Variable形式的数据,在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。

print(variable)     #  Variable 形式
"""
Variable containing: [torch.FloatTensor of size 2x2]
""" print(variable.data) # 将variable形式转为tensor 形式
""" [torch.FloatTensor of size 2x2]
"""
print(variable.data.numpy()) # numpy 形式
"""
[[ . .]
[ . .]]
"""

四:关于require_grad对variable的作用

代码一:

import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(,),requires_grad = False)
temp = Variable(torch.zeros(,),requires_grad = True)
y = x + temp +
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)

none

(因为requires_grad=False)

代码二:

import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(,),requires_grad = False)
temp = Variable(torch.zeros(,),requires_grad = True)
y = x + temp +
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(temp.grad) # d(y)/d(temp)

tensor([[0.2500, 0.2500],
        [0.2500, 0.2500]])

代码三:

import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(,),requires_grad = False)
temp = Variable(torch.zeros(,),requires_grad = True)
y = x +
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)
Traceback (most recent call last):
  File "path", line 12, in <module>
    y.backward()
(报错了,因为生成变量y的中间变量只有x,而x的requires_grad是False,所以y的grad_fn是none)
 
代码四:
import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(,),requires_grad = False)
temp = Variable(torch.zeros(,),requires_grad = True)
y = x +
y = y.mean() #求平均数
#y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(y.grad_fn) # d(y)/d(x)

none

五:grad属性

在每次backward后,grad值是会累加的,所以利用BP算法,每次迭代是需要将grad清零的。

x.grad.data.zero_()

(in-place操作需要加上_,即zero_)

六:扩展
在PyTorch中计算图的特点总结如下:

autograd根据用户对Variable的操作来构建其计算图。
requires_grad
variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。
volatile
variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。
retain_graph
多次反向传播(多层监督)时,梯度是累加的。一般来说,单次反向传播后,计算图会free掉,也就是反向传播的中间缓存会被清空【这就是动态度的特点】。为进行多次反向传播需指定retain_graph=True来保存这些缓存。
.backward()
反向传播,求解Variable的梯度。放在中间缓存中。

莫烦pytorch学习笔记(二)——variable的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )

    CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...

  2. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  5. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  6. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  9. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

随机推荐

  1. Python分布式爬虫必学框架scrapy打造搜索引擎✍✍✍

    Python分布式爬虫必学框架scrapy打造搜索引擎  整个课程都看完了,这个课程的分享可以往下看,下面有链接,之前做java开发也做了一些年头,也分享下自己看这个视频的感受,单论单个知识点课程本身 ...

  2. Haar分类器方法

    一.Haar分类器的前世今生 二.人脸检测属于计算机视觉的范畴,早期人们的主要研究方向是人脸识别,即根据人脸来识别人物的身份,后来在复杂背景下的人脸检测需求越来越大,人脸检测也逐渐作为一个单独的研究方 ...

  3. Android Button.getWidth()为0的问题

    View在onCreate的时候,没有渲染组件,所以获取到的宽度和高度为0, 需要添加一个观察者,在layout渲染后再去取宽高.代码如下: private Button btn_icon; @Ove ...

  4. PHP算法之盛最多水的容器

    给定 n 个非负整数 a1,a2,...,an,每个数代表坐标中的一个点 (i, ai) .在坐标内画 n 条垂直线,垂直线 i 的两个端点分别为 (i, ai) 和 (i, 0).找出其中的两条线, ...

  5. apache 80 端口 反向代理 tomcat 8080端口

    最近有个jsp的项目要放到服务上,但服务器上已经有了XAMPP(apache + mysql + php), 已占用了80端口.但http默认是访问80端口的. 先把tomcat 环境搭建起来, 发现 ...

  6. Java Collection - 遍历map的几种方式

    作者:zhaoguhong(赵孤鸿) 出处:http://www.cnblogs.com/zhaoguhong/ 本文版权归作者和博客园共有,转载请注明出处 ---------------- 总结 如 ...

  7. 【JZOJ4811】排队

    description analysis 堆\(+\)树上倍增 考虑后序遍历搞出\(dfs\)序,那么要填肯定是从\(dfs\)序开始填 把每个点是序里第几位看成优先级,用小根堆来维护当前空着的优先级 ...

  8. 牛客多校第六场 B Shorten IPv6 Address 模拟

    题意: 给你一个二进制表示的IPv6地址,让你把它转换成8组4位的16进制,用冒号分组的表示法.单组的前导0可以省略,连续多组为0的可以用两个冒号替换,但是只允许替换一次.把这个地址通过这几种省略方式 ...

  9. Codeforces-GYM101873 G Water Testing 皮克定理

    题意: 给定一个多边形,这个多边形的点都在格点上,问你这个多边形里面包含了几个格点. 题解: 对于格点多边形有一个非常有趣的定理: 多边形的面积S,内部的格点数a和边界上的格点数b,满足如下结论: 2 ...

  10. winform的datagridview控件滚动更新数据

    范例源码下载地址:http://files.cnblogs.com/files/luoxiaozhao/PrintDemo.rar