0201-PyTorch0.4.0迁移指南以及代码兼容

pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

一、概要

  • TensorsVariables已合并
  • Tensors支持0维(标量)
  • 弃用volatile标签
  • dtypesdevicesNumPy风格的创作功能
  • 编写device-agnostic代码
  • nn.Module中子模块名称,参数和缓冲区中的新边界约束

二、合并Tensor和Variable和类

torch.autograd.Variabletorch.Tensor现在类相同。确切地说,torch.Tensor能够像Variable一样自动求导; Variable继续像以前一样工作但返回一个torch.Tensor类型的对象。意味着你在代码中不再需要Variable包装器。

2.1 Tensor中的type()改变了

type()不再反映张量的数据类型。使用isinstance()x.type()替代:

>>> x = torch.DoubleTensor([1, 1, 1])
>>> print(type(x)) # was torch.DoubleTensor
"<class 'torch.Tensor'>"
>>> print(x.type()) # OK: 'torch.DoubleTensor'
'torch.DoubleTensor'
>>> print(isinstance(x, torch.DoubleTensor)) # OK: True
True

2.2 什么时候autograd开始自动求导?

equires_gradautograd的核心标志,现在是Tensors上的一个属性。让我们看看在代码中如何体现的。

autograd使用以前用于Variables的相同规则。当张量定义了requires_grad=True就可以自动求导了。例如,

>>> x = torch.ones(1)  # create a tensor with requires_grad=False (default)
>>> x.requires_grad
False
>>> y = torch.ones(1) # another tensor with requires_grad=False
>>> z = x + y
>>> # both inputs have requires_grad=False. so does the output
>>> z.requires_grad
False
>>> # then autograd won't track this computation. let's verify!
>>> z.backward()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>> # now create a tensor with requires_grad=True
>>> w = torch.ones(1, requires_grad=True)
>>> w.requires_grad
True
>>> # add to the previous result that has require_grad=False
>>> total = w + z
>>> # the total sum now requires grad!
>>> total.requires_grad
True
>>> # autograd can compute the gradients as well
>>> total.backward()
>>> w.grad
tensor([ 1.])
>>> # and no computation is wasted to compute gradients for x, y and z, which don't require grad
>>> z.grad == x.grad == y.grad == None
True

2.3 操作requires_grad标志

除了直接设置属性之外,您可以使用my_tensor.requiresgrad(requires_grad=True)直接修改此标志,或者如上例所示,在创建时将其作为参数传递(默认为False),例如:

>>> existing_tensor.requires_grad_()
>>> existing_tensor.requires_grad
True
>>> my_tensor = torch.zeros(3, 4, requires_grad=True)
>>> my_tensor.requires_grad
True

2.4 关于.data

.data是从Variable中获取Tensor的方法。合并后,调用y = x.data仍然具有类似的语义。因此y将是与x共享同的Tensor相数据,x与计算历史无关,并具有requires_grad=False

但是,.data在某些情况下可能不安全。x.data上的任何变化都不会被autograd跟踪,并且x在向后传递中计算梯度将不正确。一种更安全的替代方法是使用x.detach(),它也返回一个Tensorrequires_grad=False共享数据的数据,但是如果x需要反向传播那就会使用autograd直接改变报告。

下面是一个.datax.detach()(以及为什么我们建议detach一般使用)之间的区别的例子。

如果你使用Tensor.detach(),保证梯度计算是正确的。

>>> a = torch.tensor([1,2,3.], requires_grad = True)
>>> out = a.sigmoid()
>>> c = out.detach()
>>> c.zero_()
tensor([ 0., 0., 0.]) >>> out # modified by c.zero_() !!
tensor([ 0., 0., 0.]) >>> out.sum().backward() # Requires the original value of out, but that was overwritten by c.zero_()
RuntimeError: one of the variables needed for gradient computation has been modified by an

然而,使用Tensor.data可能是不安全的,并且当梯度计算需要张量但直接修改时可能容易导致不正确的梯度。

>>> a = torch.tensor([1,2,3.], requires_grad = True)
>>> out = a.sigmoid()
>>> c = out.data
>>> c.zero_()
tensor([ 0., 0., 0.]) >>> out # out was modified by c.zero_()
tensor([ 0., 0., 0.]) >>> out.sum().backward()
>>> a.grad # The result is very, very wrong because `out` changed!
tensor([ 0., 0., 0.])

三、现在一些操作返回0维(标量)Tensors

以前,Tensor向量(1维张量)的索引返回一个Python数字,但是Variable的索引向量返回一个(1,)的向量!即tensor.sum()返回一个Python数字,但variable.sum()会返回一个大小为(1,)的向量。

幸运的是,此版本在PyTorch中引入了适当的标量(0维张量)支持!可以使用新torch.tensor函数来创建标量(稍后会对其进行更详细的解释;现在只需将它看作PyTorch中与numpy.array的等价物)。现在你可以做这样的事情:

>>> torch.tensor(3.1416)         # create a scalar directly
tensor(3.1416)
>>> torch.tensor(3.1416).size() # scalar is 0-dimensional
torch.Size([])
>>> torch.tensor([3]).size() # compare to a vector of size 1
torch.Size([1])
>>>
>>> vector = torch.arange(2, 6) # this is a vector
>>> vector
tensor([ 2., 3., 4., 5.])
>>> vector.size()
torch.Size([4])
>>> vector[3] # indexing into a vector gives a scalar
tensor(5.)
>>> vector[3].item() # .item() gives the value as a Python number
5.0
>>> mysum = torch.tensor([2, 3]).sum()
>>> mysum
tensor(5)
>>> mysum.size()
torch.Size([])

3.1 积累损失

考虑到经常使用的total_loss += loss.data[0]0.4.0之前。loss(1,)张量的Variable包装器,但在0.4.0loss现在是一个0尺寸标量。标量索引是没有意义的(目前只提出一个警告,但在0.5.0中将会报错)。loss.item()用于从标量中获取Python数字。

请注意,如果您在累积损失时未将其转换为Python数字,则可能会发现程序中的内存使用量增加。这是因为上面表达式的右侧曾经是一个Python浮点数,而现在它是一个0的张量。因此,总损失累积了张量和它们的历史梯度,可能导致巨大的autograd图形不必要的保存大量时间。

四、弃用volatile标签

volatile标签现在已被弃用,不起作用。以前,任何涉及Variablewith的计算volatile=True都不会被跟踪autograd。这已经被换成了一套更加灵活的上下文管理的,包括torch.no_grad()torch.set_grad_enabled(grad_mode)及其他。

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>>
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

五、dtypes,devices和NumPy风格的创作功能

在以前的PyTorch版本中,我们用来指定的数据类型(例如float vs double),设备类型(cpu vs cuda)和layoutdense vs sparse)作为"张量类型"。例如,torch.cuda.sparse.DoubleTensorTensordouble数据类型,在CUDA设备只能够,以及配备COO稀疏张量layout

在此版本中,我们引入torch.dtypetorch.device以及torch.layout类,允许通过NumPy的风格创建这些属性的功能进行更好的管理。

具体内容参考:pytorch使用torch.dtype、torch.device和torch.layout管理数据类型属性

5.1 创建 Tensors

创造一个方法Tensor,现在也可使用dtypedevicelayout,和requires_grad选项来指定返回所需的Tensor属性。例如:

>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344, 0.8562, -1.2758],
[ 0.8414, 1.7962, 1.0589],
[-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

具体可以参考:Pytorch0.4.0 中文文档 Torch

六、编写device-agnostic代码

以前版本的PyTorch编写device-agnostic代码非常困难(即,在不修改代码的情况下在CUDA可以使用或者只能使用CPU的设备上运行)。

参考:Pytorch使用To方法编写代码在不同设备(CUDA/CPU)上兼容(device-agnostic)

七、nn.Module中子模块名称,参数和缓冲区中的新边界约束

name这是一个空字符串或包含"."不再被允许进入module.add_module(name, value)module.add_parameter(name, value)或者module.add_buffer(name, value)因为这些名称可能会在state_dict中导致数据丢失。如果您为包含这些名称的模块加载checkpoint,请在加载之前更新模块定义并进行修补state_dict

7.1 代码示例(将它们放在一起)

为了方便对比0.4.0中整体推荐的变化的特征,我们来看一个0.3.10.4.0中常见代码模式的简单例子:

0.3.1(旧):

model = MyRNN()
if use_cuda:
model = model.cuda() # train
total_loss = 0
for input, target in train_loader:
input, target = Variable(input), Variable(target)
hidden = Variable(torch.zeros(*h_shape)) # init hidden
if use_cuda:
input, target, hidden = input.cuda(), target.cuda(), hidden.cuda()
... # get loss and optimize
total_loss += loss.data[0] # evaluate
for input, target in test_loader:
input = Variable(input, volatile=True)
if use_cuda:
...
...

0.4.0(新):

# torch.device object used throughout this script
device = torch.device("cuda" if use_cuda else "cpu") model = MyRNN().to(device) # train
total_loss = 0
for input, target in train_loader:
input, target = input.to(device), target.to(device)
hidden = input.new_zeros(*h_shape) # has the same device &amp; dtype as `input`
... # get loss and optimize
total_loss += loss.item() # get Python number from 1-element Tensor # evaluate
with torch.no_grad(): # operations inside don't track history
for input, target in test_loader:
...

感谢您的阅读!有关更多详细信息,请参阅我们的文档和发行说明,Pytorch 0.4.0中文文档

快乐的PyTorch-ing!

原创文章,转载请注明 :PyTorch 0.4.0迁移指南以及代码兼容 - pytorch中文网

原文出处: https://ptorch.com/news/190.html

0201-PyTorch0.4.0迁移指南以及代码兼容的更多相关文章

  1. AFNetworking 3.0迁移指南

    AFNetworking是一款在OS X和iOS下都令人喜爱的网络库.为了迎合iOS新版本的升级, AFNetworking在3.0版本中删除了基于 NSURLConnection API的所有支持. ...

  2. [转]AFNetworking 3.0迁移指南

    http://www.jianshu.com/p/047463a7ce9b?utm_campaign=hugo&utm_medium=reader_share&utm_content= ...

  3. Spring Boot 2.0 迁移指南

    ![img](https://mmbiz.qpic.cn/mmbiz_jpg/1flHOHZw6Rs7yEJ6ItV43JZMS7AJWoMSZtxicnG0iaE0AvpUHI8oM7lxz1rRs ...

  4. pytorch 0.4.0迁移指南

    总说 由于pytorch 0.4版本更新实在太大了, 以前版本的代码必须有一定程度的更新. 主要的更新在于 Variable和Tensor的合并., 当然还有Windows的支持, 其他一些就是支持s ...

  5. Alamofire 4.0 迁移指南

    Alamofire 4.0 是 Alamofire 最新的一个大版本更新, 一个基于 Swift 的 iOS, tvOS, macOS, watchOS 的 HTTP 网络库. 作为一个大版本更新, ...

  6. Spring Cloud Alibaba迁移指南(二):零代码替换 Eureka

    自 Spring Cloud 官方宣布 Spring Cloud Netflix 进入维护状态后,我们开始制作<Spring Cloud Alibaba迁移指南>系列文章,向开发者提供更多 ...

  7. Spring Cloud Alibaba迁移指南(一):一行代码从 Hystrix 迁移到 Sentinel

    摘要: 本文对Hystrix.Resilience4j.Sentinel进行对比,并探讨如何使用一行代码这种极简的方式,将Hystrix迁移到Sentinel. Hystrix 自从前段时间 宣布停止 ...

  8. HTML5 语义元素、迁移、样式指南和代码约定

    语义元素是拥有语义的元素. 什么是语义元素? 语义元素清楚地向浏览器和开发者描述其意义. 非语义元素的例子:<div> 和 <span> - 无法提供关于其内容的信息. 语义元 ...

  9. ROS_Kinetic_02 ROS Kinetic 迁移指南及中文wiki指南(Migration guide)

    ROS_Kinetic_02 ROS Kinetic 迁移指南(Migration guide) 对于ROS Kinetic Kame有些功能包已经更新改变,提供关于这些包的迁移注意或教程.主要针对于 ...

  10. Spring Boot 2.0 升级指南

    Spring Boot 2.0 升级指南 前言 Spring Boot已经发布2.0有5个月多,多了很多新特性,一些坑也慢慢被填上,最近有空,就把项目中Spring Boot 版本做了升级,顺便整理下 ...

随机推荐

  1. 人形机器人(具身智能,Embodied Intelligence)—— 抓取动作(上半身动作规划)的各大公司技术路线

    视频地址: https://www.youtube.com/watch?v=UZBSXzNKB1Q

  2. 训练人形机器人时如何收集人类行为数据 —— 通过人来训练机器人(真人实际演示动作)or 仿真环境自动生成 —— 哪种方式更优、更可行呢

    特斯拉的老马,搞的optimus人形机器人就是通过人来训练机器人(真人实际演示动作),但是未来使用仿真环境自动生成数据是否可行呢,NVIDIA的老黄在2024 GTC上是大力推出自家的GROOT平台, ...

  3. Jax的加速层的伪代码/中间层代码的生成和查看

    地址: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-st ...

  4. python3.9的nogil版本编译pytorch2.0.1源码报错——失败

    关于python3.9的nogil版本参看: PEP 703作者给出的一种no-GIL的实现--python3.9的nogil版本 ================================== ...

  5. Apache DolphinScheduler 1.3.4升级至3.1.2版本过程中的踩坑记录

    因为在工作中需要推动Apache DolphinScheduler的升级,经过预研,从1.3.4到3.1.2有的体验了很大的提升,在性能和功能性有了很多的改善,推荐升级. 查看官方的升级文档,可知有提 ...

  6. 【VMware vCenter】一次性说清楚 vCenter Server 的 CLI 部署方式。

    VMware vCenter Server 是 VMware vSphere 解决方案的核心组件,用于管理多个 ESXi 主机并实现更多高级功能特性(如 HA.DRS 以及 FT 等),相信大家已经非 ...

  7. python增删查改实例

    本文介绍一个实例,即删除数据库中原有的表格TEST1,新建一个表格TEST2,并在TEST2中插入3行数据.插入数据以后,查询出ID=3的数据,读出,最后将其删除. 结果: 代码: ''' impor ...

  8. 前端界面显示当前时间的Vue代码

    <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title> ...

  9. games101 作业4及作业5 详解光线追踪框架

    games101 作业4及作业5 详解光线追踪框架 作业4 代码分析 作业四的代码整体比较简单 主要流程就是 通过鼠标事件 获取四个控制点的坐标 然后绘制贝塞尔曲线的内容就由我们来完成 理论分析 贝塞 ...

  10. Java——计算1~N之间所有奇数之和

    2024/07/15 1.题目 2.解题 1.题目 2.解题 import java.util.Scanner; public class Main { public static void main ...