pytorch中 all_gather 操作是不进行梯度回传的。在计算图构建中如果需要经过all_gather操作后,仍需要将梯度回传给各个进程中的allgather前的对应变量,则需要重新继承torch.autograd.Function

https://pytorch.org/docs/stable/autograd.html 中对torch.autograd.Function进行了介绍

https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd 中举例介绍如何重新实现其子类

 下面代码是为了说明all_gather相关特性及如何实现梯度回传.

\(x,y,z\)都是2x2矩阵,其之间关系为\(y=x+2, z=y*y\)

接下来就需要MPI进行进程间数据传递,将z进行汇总到每个进程即all_gather操作。然后将汇总的矩阵进行相乘,然后求均值。

r对y的导数如下:

\(r=0.25({}_{g_0}y_{11}^2*{}_{g_1}y_{11}^2+{}_{g_0}y_{12}^2*{}_{g_1}y_{12}^2+
{}_{g_0}y_{21}^2*{}_{g_1}y_{21}^2+
{}_{g_0}y_{22}^2*{}_{g_1}y_{22}^2)\)

\(\frac{dr}{d{}_{g_0}y}=
\begin{Bmatrix}
0.5{}_{g_0}y_{11}*{}_{g_1}y_{11}^2 & 0.5{}_{g_0}y_{12}*{}_{g_1}y_{12}^2 \\
0.5{}_{g_0}y_{21}*{}_{g_1}y_{21}^2 & 0.5{}_{g_0}y_{22}*{}_{g_1}y_{22}^2)
\end{Bmatrix}\)

gpu0上x值为\(\begin{Bmatrix} 1 & 1 \\1 & 1 \end{Bmatrix}\),gpu1上x值为\(\begin{Bmatrix} 0 & 0 \\0 & 0 \end{Bmatrix}\).通过公式可以计算出,r关于gpu0上的y的导数为\(\begin{Bmatrix}6 & 6 \\6 & 6\end{Bmatrix}\),r关于gpu1上的y的导数为\(\begin{Bmatrix}9 & 9 \\9 & 9\end{Bmatrix}\)

import os
import torch
from torch import nn
import sys
sys.path.append('./')
import torch.distributed as dist
from torch.autograd import Variable
from utils import GatherLayer def test():
#torch.manual_seed(0)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=True
dist.init_process_group(backend="nccl", init_method="env://")
rank = dist.get_rank()
local_rank = int(os.environ.get('LOCAL_RANK', 0))
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
print('world_size: {}, rank: {}, local_rank: {}'.format(world_size, rank, local_rank)) if local_rank == 0:
x = Variable(torch.ones(2, 2), requires_grad=True).cuda()
else:
x = Variable(torch.zeros(2, 2), requires_grad=True).cuda()
y = x + 2
y.retain_grad()
z = y * y z_gather = [torch.zeros_like(z) for _ in range(world_size)]
dist.all_gather(z_gather, z)
#z_gather = GatherLayer.apply(z)
r = z_gather[0] * z_gather[1] out = r.mean()
out.backward()
if local_rank == 0:
print('rank:0', y.grad)
else:
print('rank:1', y.grad)

(1)上述述代码中,先使用pytorch中提供的all_gather操作,运行代码会提示错误。错误信息如下:

Traceback (most recent call last):
File "test/test_all_gather.py", line 46, in <module>
Traceback (most recent call last):
File "test/test_all_gather.py", line 46, in <module>
test()
File "test/test_all_gather.py", line 36, in test
out.backward()
File "/usr/local/lib/python3.6/dist-packages/torch/tensor.py", line 185, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py", line 127, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
test()

(2)参考https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py, 该函数就是继承torch.autograd.Function,实现了all_gather后,梯度也能回传。

上述代码,启用z_gather = GatherLayer.apply(z),就实现了梯度回传功能,打印对变量y的梯度

world_size: 2, rank: 0, local_rank: 0
world_size: 2, rank: 1, local_rank: 1
rank:0 tensor([[6., 6.],
[6., 6.]], device='cuda:0')
rank:1 tensor([[9., 9.],
[9., 9.]], device='cuda:1')

GatherLayer类实现如下:

class GatherLayer(torch.autograd.Function):
"""Gather tensors from all process, supporting backward propagation.""" @staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output) @staticmethod
def backward(ctx, *grads):
(input,) = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out

下面网址有关all gather梯度传播的讨论

https://discuss.pytorch.org/t/will-dist-all-gather-break-the-auto-gradient-graph/47350

大规模人脸分类—allgather操作(1)的更多相关文章

  1. 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...

  2. [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...

  3. 用keras的cnn做人脸分类

    keras介绍 Keras是一个简约,高度模块化的神经网络库.采用Python / Theano开发. 使用Keras如果你需要一个深度学习库: 可以很容易和快速实现原型(通过总模块化,极简主义,和可 ...

  4. wordpress搜索结果排除某个分类如何操作

    我们知道wordpress的搜索结果页search.php和分类页category.php是一样的,但是客户的网站是功能比较多的系统,有新闻又有产品,如果搜索结果只想展示产品要如何操作呢?随ytkah ...

  5. SQL分类-DDL_操作数据库_创建&查询

    SQL分类 1.DDL(Data Definition Language)数据定义语言 用来定义数据库对象:数据库,表,列等.关键字:create , drop, alter 等 2.DML(Data ...

  6. python集合的分类与操作

    如图: 集合的炒作分类: 确定大小 测试项的成员关系 遍历集合 获取一个字符串表示 测试相等性 连接两个集合 转换为另一种类型的集合 插入一项 删除一项 替换一项 访问或获取一项

  7. Python函数分类及操作

    为什么使用函数? 答:函数的返回值可以确切知道整个函数执行的结果   函数的定义:1.数学意义的函数:两个变量:自变量x和因变量y,二者的关系                      2.Pytho ...

  8. .NET做人脸识别并分类

    .NET做人脸识别并分类 在游乐场.玻璃天桥.滑雪场等娱乐场所,经常能看到有摄影师在拍照片,令这些经营者发愁的一件事就是照片太多了,客户在成千上万张照片中找到自己可不是件容易的事.在一次游玩等活动或家 ...

  9. face recognition[翻译][深度学习理解人脸]

    本文译自<Deep learning for understanding faces: Machines may be just as good, or better, than humans& ...

  10. face recognition[翻译][深度人脸识别:综述]

    这里翻译下<Deep face recognition: a survey v4>. 1 引言 由于它的非侵入性和自然特征,人脸识别已经成为身份识别中重要的生物认证技术,也已经应用到许多领 ...

随机推荐

  1. 肖sir____Apsara Clouder云计算专项技能认证题目收集

    Apsara Clouder云计算专项技能认证: Apsara Clouder云计算专项技能认证:云服务器ECS入门[认证考试真题分享](答案仅供参考) 单选13道题 1.下列哪一个不是重置ECS密码 ...

  2. MAC系统连接Windows共享文件的方法:

    MAC系统连接Windows共享文件的方法: 1.首先先确认Windows系统下已开启共享.并且两台电脑之间局域网已通. 2.苹果MAC系统,点击桌面.打开顶部菜单 "前往". 3 ...

  3. layui伸缩左侧菜单栏,已伸缩成功但是右侧主体部分不动

    <ul class="layui-nav layui-nav-tree" lay-filter="test" style="width:200p ...

  4. winform 中 label透明化

    label1.BackColor = Color.Transparent;//设置背景颜色为透明 label1.Parent = pictureBox1;//将pictureBox1设为标签的父控件, ...

  5. keil调试教程

    点击跳转 如果开启调试就提示弹框错误2k,说明你没有破解你的keil,网上自行下载注册机. 调试一定要对应自己板子的晶振,否则当你测试你的延时实际时间时,keil里的sec会不一样,甚至离谱.

  6. CH32F103C8T6调试口Disable后的修复办法

    1.问题描述 因为软件编程,将CH32F103的 debug disable了,无法通过仿真器下载程序. 2. 修复 2.1 解决思路 利用厂家给的串口ISP进行下载(HUSB或者COM) 2.2 硬 ...

  7. 基于excel的自动化框架

    设定项目文件大致结构 atp/: 项目名 conf/:存放配置文件 data/:存放sql文件 lib/: 存放项目的所有源代码. logs/:存放日志文件 uploads/:存放下载的文件 star ...

  8. kumquat

    今天准备做个解释型编程语言,名字就叫kumquat(金桔)因为我刚刚喝了口金桔柠檬茶,挺甜的 用python写把

  9. vue2 安装swiper

    npm install swiper@5.4.5 -D

  10. 【java数据结构与算法】插入排序

    [插入排序解析]起始:假设第一个元素为已经排好序那么我们就要从数组的第二个元素开始每一轮确定1一个元素的正确位置所以外层循环的控制变量为 [1,arr.length)的左闭右开区间外层循环控制比较轮次 ...