AMP:Automatic mixed precision,自动混合精度,可以在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的。

在Pytorch 1.5版本及以前,通过NVIDIA出品的插件apex,可以实现amp功能。

从Pytorch 1.6版本以后,Pytorch将amp的功能吸收入官方库,位于torch.cuda.amp模块下。

本文为针对官方文档主要内容的简要翻译和自己的理解。

1. Introduction

torch.cuda.amp提供了对混合精度的支持。为实现自动混合精度训练,需要结合使用如下两个模块:

2. Typical Mixed Precision Training

一个典型的amp应用示例如下:

# 定义模型和优化器
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...) # 在训练最开始定义GradScalar的实例
scaler = GradScaler() for epoch in epochs:
for input, target in data:
optimizer.zero_grad() # 利用with语句,在autocast实例的上下文范围内,进行模型的前向推理和loss计算
with autocast():
output = model(input)
loss = loss_fn(output, target) # 对loss进行缩放,针对缩放后的loss进行反向传播
# (此部分计算在autocast()作用范围以外)
scaler.scale(loss).backward() # 将梯度值缩放回原尺度后,优化器进行一步优化
scaler.step(optimizer) # 更新scalar的缩放信息
scaler.update()

3. Working with Unscaled Gradients

待更新

4. Working with Scaled Gradients

待更新

5. Working with Multiple Models, Losses, and Optimizers

如果模型的Loss计算部分输出多个loss,需要对每一个loss值执行scaler.scale

如果网络具有多个优化器,对任一个优化器执行scaler.unscale_,并对每一个优化器执行scaler.step

scaler.update只在最后执行一次。

应用示例如下:

scaler = torch.cuda.amp.GradScaler()

for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast():
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target) scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward() # 选择其中一个优化器执行显式的unscaling
scaler.unscale_(optimizer0)
# 对每一个优化器执行scaler.step
scaler.step(optimizer0)
scaler.step(optimizer1)
# 完成所有梯度更新后,执行一次scaler.update
scaler.update()

6. Working with Multiple GPUs

针对多卡训练的情况,只影响autocast的使用方法,GradScaler的用法与之前一致。

6.1 DataParallel in a single process

在每一个不同的cuda设备上,torch.nn.DataParallel在不同的进程中执行前向推理,而autocast只在当前进程中生效,因此,如下方式的调用是不生效的:

model = MyModel()
dp_model = nn.DataParallel(model) # 在主进程中设置autocast
with autocast():
# dp_model的内部进程并不会对autocast生效
output = dp_model(input)
# loss的计算在主进程中执行,autocast可以生效,但由于前面执行推理时已经失效,因此整体上是不正确的
loss = loss_fn(output)

有效的调用方式如下所示:

# 方法1:在模型构建中,定义forwar函数时,采用装饰器方式
MyModel(nn.Module):
...
@autocast()
def forward(self, input):
... # 方法2:在模型构建中,定义forwar函数时,采用上下文管理器方式
MyModel(nn.Module):
...
def forward(self, input):
with autocast():
... # DataParallel的使用方式不变
model = MyModel().cuda()
dp_model = nn.DataParallel(model) # 在模型执行推理时,由于前面模型定义时的修改,在各cuda设备上的子进程中autocast生效
# 在执行loss计算是,在主进程中,autocast生效
with autocast():
output = dp_model(input)
loss = loss_fn(output)

6.2 DistributedDataParallel, one GPU per process

torch.nn.parallel.DistributedDataParallel在官方文档中推荐每个GPU执行一个实例的方法,以达到最好的性能表现。

在这种模式下,DistributedDataParallel内部并不会再启动子进程,因此对于autocastGradScaler的使用都没有影响,与典型示例保持一致。

6.3 DistributedDataParallel, multiple GPUs per process

DataParallel 的使用相同,在模型构建时,对forward函数的定义方式进行修改,保证autocast在进程内部生效。

Pytorch原生AMP支持使用方法(1.6版本)的更多相关文章

  1. 原生JS添加节点方法与jQuery添加节点方法的比较及总结

    一.首先构建一个简单布局,来供下边讲解使用 1.HTML部分代码: <div id="div1">div1</div> <div id="d ...

  2. 原生JavaScript支持6种方式获取元素

    一.原生JavaScript支持6种方式获取元素 document.getElementById('id'); document.getElementsByName('name'); document ...

  3. thinkPHP框架中执行原生SQL语句的方法

    这篇文章主要介绍了thinkPHP框架中执行原生SQL语句的方法,结合实例形式分析了thinkPHP中执行原生SQL语句的相关操作技巧,并简单分析了query与execute方法的使用区别,需要的朋友 ...

  4. 现有语言不支持XXX方法

    史上最强大的IDE也会有bug的时候哈,今天遇到这个问题特别郁闷,百度了下,果然也有人遇到过这个问题 解决方法: 1.调用的时候参数和接口声明的参数不一致(检查修改) 2.继承接口中残留一个废弃的方法 ...

  5. 原生JS事件绑定方法以及jQuery绑定事件方法bind、live、on、delegate的区别

    一.原生JS事件绑定方法: 1.通过HTML属性进行事件处理函数的绑定如: <a href="#" onclick="f()"> 2.通过JavaS ...

  6. 原生JS中apply()方法的一个值得注意的用法

    今天在学习vue.js的render时,遇到需要重复构造多个同类型对象的问题,在这里发现原生JS中apply()方法的一个特殊的用法: var ary = Array.apply(null, { &q ...

  7. PHPnow开启PHP扩展里openssl支持的方法

    PHPnow 是 Win32 下绿色的 Apache + PHP + MySQL 环境套件包.简易安装.快速搭建支持虚拟主机的 PHP 环境.更多介绍<PHP服务套件 PHPnow1.5.6&g ...

  8. 扩展原生js的一些方法

    扩展原生js的Array类 Array.prototype.add = function(item){ this.push(item); } Array.prototype.addRange = fu ...

  9. 原生Js 两种方法实现页面关键字高亮显示

    原生Js 两种方法实现页面关键字高亮显示 上网看了看别人写的,不是兼容问题就是代码繁琐,自己琢磨了一下用两种方法都可以实现,各有利弊. 方法一 依靠正则表达式修改 1.获取obj的html2.统一替换 ...

随机推荐

  1. 2020牛客暑期多校训练营 第二场 I Interval 最大流 最小割 平面图对偶图转最短路

    LINK:Interval 赛时连题目都没看. 观察n的范围不大不小 而且建图明显 考虑跑最大流最小割. 图有点稠密dinic不太行. 一个常见的trick就是对偶图转最短路. 建图有点复杂 不过建完 ...

  2. luogu P6088 [JSOI2015]字符串树 可持久化trie 线段树合并 树链剖分 trie树

    LINK:字符串树 先说比较简单的正解.由于我没有从最简单的考虑答案的角度思考 所以... 下次还需要把所有角度都考察到. 求x~y的答案 考虑 求x~根+y~根-2*lca~根的答案. 那么问题变成 ...

  3. 机器学习 | 详解GBDT梯度提升树原理,看完再也不怕面试了

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第30篇文章,我们今天来聊一个机器学习时代可以说是最厉害的模型--GBDT. 虽然文无第一武无第二,在机器学习领域并没有 ...

  4. Java 多态 接口继承等学习笔记

    Super关键字 1.子类可以调用父类声明的构造方法 : 语法:在子类的构造方法中使用super关键字  super(参数列表) 2.操作被隐藏的成员变量(子类的成员变量和父类的成员变量重名的说法)和 ...

  5. 数据结构C++使用邻接表实现图

    定义邻接表存储的图类.[实验要求] (1)创建一个邻接表存储的图:(2)返回图中指定边的权值:(3)插入操作:向图中插入一个顶点,插入一条边:(4)删除操作:从图中删除一个顶点,删除一条边:(5)图的 ...

  6. JavaScript package.json里添加git-cz

    git-cz官网 0.目的 => 替代git commit, 丰富提交的内容 1.安装包 npm install commitizen cz-conventional-changelog --s ...

  7. mysql基础测试题

    mysql基础测试题:https://www.cnblogs.com/wupeiqi/articles/5729934.html 如何创建表? 就这样类推?如何提取我们想要的元素呢? 综合提取呢?

  8. Vue CLI3 移动端适配 【px2rem 或 postcss-plugin-px2rem】

    Vue CLI3 移动端适配 [px2rem 或 postcss-plugin-px2rem] 今天,我们使用Vue CLI3 做一个移动端适配 . 前言 首先确定你的项目是Vue CLI3版本以上的 ...

  9. javaSE总结(转+总结)

    一:java概述: 1,JDK:Java Development Kit,java的开发和运行环境,java的开发工具和jre. 2,JRE:Java Runtime Environment,java ...

  10. 一键打开 jupyter

    一般打开jupyter notebook 是以下步骤: 打开cmd-----输入:jupyter notebook-----按Enter键 为了省事,写了一个.py文件实现上述步骤,代码如下: imp ...