Pytorch原生AMP支持使用方法(1.6版本)
AMP:Automatic mixed precision,自动混合精度,可以在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的。
在Pytorch 1.5版本及以前,通过NVIDIA出品的插件apex,可以实现amp功能。
从Pytorch 1.6版本以后,Pytorch将amp的功能吸收入官方库,位于torch.cuda.amp模块下。
本文为针对官方文档主要内容的简要翻译和自己的理解。
1. Introduction
torch.cuda.amp提供了对混合精度的支持。为实现自动混合精度训练,需要结合使用如下两个模块:
torch.cuda.amp.autocast:autocast主要用作上下文管理器或者装饰器,来确定使用混合精度的范围。torch.cuda.amp.GradScalar:GradScalar主要用来完成梯度缩放。
2. Typical Mixed Precision Training
一个典型的amp应用示例如下:
# 定义模型和优化器
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始定义GradScalar的实例
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# 利用with语句,在autocast实例的上下文范围内,进行模型的前向推理和loss计算
with autocast():
output = model(input)
loss = loss_fn(output, target)
# 对loss进行缩放,针对缩放后的loss进行反向传播
# (此部分计算在autocast()作用范围以外)
scaler.scale(loss).backward()
# 将梯度值缩放回原尺度后,优化器进行一步优化
scaler.step(optimizer)
# 更新scalar的缩放信息
scaler.update()
3. Working with Unscaled Gradients
待更新
4. Working with Scaled Gradients
待更新
5. Working with Multiple Models, Losses, and Optimizers
如果模型的Loss计算部分输出多个loss,需要对每一个loss值执行scaler.scale。
如果网络具有多个优化器,对任一个优化器执行scaler.unscale_,并对每一个优化器执行scaler.step。
而scaler.update只在最后执行一次。
应用示例如下:
scaler = torch.cuda.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast():
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target)
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
# 选择其中一个优化器执行显式的unscaling
scaler.unscale_(optimizer0)
# 对每一个优化器执行scaler.step
scaler.step(optimizer0)
scaler.step(optimizer1)
# 完成所有梯度更新后,执行一次scaler.update
scaler.update()
6. Working with Multiple GPUs
针对多卡训练的情况,只影响autocast的使用方法,GradScaler的用法与之前一致。
6.1 DataParallel in a single process
在每一个不同的cuda设备上,torch.nn.DataParallel在不同的进程中执行前向推理,而autocast只在当前进程中生效,因此,如下方式的调用是不生效的:
model = MyModel()
dp_model = nn.DataParallel(model)
# 在主进程中设置autocast
with autocast():
# dp_model的内部进程并不会对autocast生效
output = dp_model(input)
# loss的计算在主进程中执行,autocast可以生效,但由于前面执行推理时已经失效,因此整体上是不正确的
loss = loss_fn(output)
有效的调用方式如下所示:
# 方法1:在模型构建中,定义forwar函数时,采用装饰器方式
MyModel(nn.Module):
...
@autocast()
def forward(self, input):
...
# 方法2:在模型构建中,定义forwar函数时,采用上下文管理器方式
MyModel(nn.Module):
...
def forward(self, input):
with autocast():
...
# DataParallel的使用方式不变
model = MyModel().cuda()
dp_model = nn.DataParallel(model)
# 在模型执行推理时,由于前面模型定义时的修改,在各cuda设备上的子进程中autocast生效
# 在执行loss计算是,在主进程中,autocast生效
with autocast():
output = dp_model(input)
loss = loss_fn(output)
6.2 DistributedDataParallel, one GPU per process
torch.nn.parallel.DistributedDataParallel在官方文档中推荐每个GPU执行一个实例的方法,以达到最好的性能表现。
在这种模式下,DistributedDataParallel内部并不会再启动子进程,因此对于autocast和GradScaler的使用都没有影响,与典型示例保持一致。
6.3 DistributedDataParallel, multiple GPUs per process
与DataParallel 的使用相同,在模型构建时,对forward函数的定义方式进行修改,保证autocast在进程内部生效。
Pytorch原生AMP支持使用方法(1.6版本)的更多相关文章
- 原生JS添加节点方法与jQuery添加节点方法的比较及总结
一.首先构建一个简单布局,来供下边讲解使用 1.HTML部分代码: <div id="div1">div1</div> <div id="d ...
- 原生JavaScript支持6种方式获取元素
一.原生JavaScript支持6种方式获取元素 document.getElementById('id'); document.getElementsByName('name'); document ...
- thinkPHP框架中执行原生SQL语句的方法
这篇文章主要介绍了thinkPHP框架中执行原生SQL语句的方法,结合实例形式分析了thinkPHP中执行原生SQL语句的相关操作技巧,并简单分析了query与execute方法的使用区别,需要的朋友 ...
- 现有语言不支持XXX方法
史上最强大的IDE也会有bug的时候哈,今天遇到这个问题特别郁闷,百度了下,果然也有人遇到过这个问题 解决方法: 1.调用的时候参数和接口声明的参数不一致(检查修改) 2.继承接口中残留一个废弃的方法 ...
- 原生JS事件绑定方法以及jQuery绑定事件方法bind、live、on、delegate的区别
一.原生JS事件绑定方法: 1.通过HTML属性进行事件处理函数的绑定如: <a href="#" onclick="f()"> 2.通过JavaS ...
- 原生JS中apply()方法的一个值得注意的用法
今天在学习vue.js的render时,遇到需要重复构造多个同类型对象的问题,在这里发现原生JS中apply()方法的一个特殊的用法: var ary = Array.apply(null, { &q ...
- PHPnow开启PHP扩展里openssl支持的方法
PHPnow 是 Win32 下绿色的 Apache + PHP + MySQL 环境套件包.简易安装.快速搭建支持虚拟主机的 PHP 环境.更多介绍<PHP服务套件 PHPnow1.5.6&g ...
- 扩展原生js的一些方法
扩展原生js的Array类 Array.prototype.add = function(item){ this.push(item); } Array.prototype.addRange = fu ...
- 原生Js 两种方法实现页面关键字高亮显示
原生Js 两种方法实现页面关键字高亮显示 上网看了看别人写的,不是兼容问题就是代码繁琐,自己琢磨了一下用两种方法都可以实现,各有利弊. 方法一 依靠正则表达式修改 1.获取obj的html2.统一替换 ...
随机推荐
- Dynmaics 365 scale group
关于scale Groups的概念,在看Dynamics crm online的时候,一直不理解缩放组scale group的概念,后来查到GP也在用这个概念,想想不就是动态扩展嘛,马上顿悟了,原来如 ...
- Linux的VMWare中Centos7查看文件内容命令 (more-less-head-tail)
一.More分页查看文件 more 命令类似 cat ,不过会以一页一页的形式显示,更方便使用者逐页阅读, 而最基本的指令就是按空白键(space)就往下一页显示, 按 b 键就会往回(back)一页 ...
- 卷积神经网络 part1
[任务一]视频学习心得及问题总结 根据下面三个视频的学习内容,写一个总结,最后列出没有学明白的问题. [任务二]代码练习 在谷歌 Colab 上完成代码练习,关键步骤截图,并附一些自己的想法和解读. ...
- 014_go语言中的变参函数
代码演示 package main import "fmt" func sum(nums ...int) { fmt.Print(nums, " ") toto ...
- 搭建MyBatis开发环境及基本的CURD
目录 一.MyBatis概述 1. MyBatis 解决的主要问题 二.快速开始一个 MyBatis 1. 创建mysql数据库和表 2. 创建maven工程 3. 在pom.xml文件中添加信息 4 ...
- DB2 分组查询语句ROW_NUMBER() OVER() (转载)
说起 DB2 在线分析处理,可以用很好很强大来形容.这项功能特别适用于各种统计查询,这些查询用通常的SQL很难实现,或者根本就无发实现.首先,我们从一个简单的例子开始,来一步一步揭开它神秘的面纱,请看 ...
- 2020-04-24:Object obj = new Object()这句话在内存里占用了多少内存
福哥答案2020-04-25:这道题最好把对象和变量分开说明,否则容易产生误解.以下都是64位环境下.针对对象:压缩状态:MarkWord 8+klass 4+数据0+对齐4=16非压缩状态:Mark ...
- C#算法设计之知识储备
前言 该文章的最新版本已迁移至个人博客[比特飞],单击链接 https://www.byteflying.com/archives/669 访问. 算法的讨论具有一定的规则,其中也包含一些不成文的约定 ...
- NLTK库WordNet的使用方法实例
1.在代码中引入wordnet包 >>>from nltk.corpus import wordnet as wn 2.查询一个词所在的所有词集(synsets) >>& ...
- 手动SQL注入总结
1.基于报错与union的注入 注意:union联合查询注入一般要配合其他注入使用 A.判断是否存在注入,注入是字符型还是数字型,有没过滤了关键字,可否绕过 a.如何判断是否存在注入 一般有一下几种 ...