3年前曾经写过关于分布式环境下batch normalization是否需要特殊实现的讨论:

batch normalization的multi-GPU版本该怎么实现? 【Tensorflow 分布式PS/Worker模式下异步更新的情况】

 

==============================================

 

当时我给出的观点就是在多卡环境下batch normalization使用每个step内的各显卡batch上的相关值进行同步的话会和单卡情况取得相似的结果,因此我给出的结论就是多卡情况下是没有必要针对batch normalization算子开发什么高深的替代版本,你不论是同步更新还是异步更新的情况下对每个显卡上运行得到的batch normalization算子中的参数进行同样的update就可以了,因为从我之前做的仿真使用中可以看出不论是单机情况还是多卡同/异步更新情况下都是对batch normalization算子中参数的估计,而这几种方法之间的差别其实不大,可以说极为相近,也正是如此在几年前我就得出了没有必要为多卡/分布式环境下设计特殊的batch normalization算子,不管是同步更新还是异步更新同时对batch normalization算子中的参数进行同样操作就和单卡情况下不会有太大的差距。几年前得到这个结论的时候只是考个人推断和仿真实验获得的,并没有在实际的代码上跑过,当时主要的原因就是省时、省力,同时也是对但是网上的各种针对多卡/分布式环境下开发出的特殊batch normalization算子的一种反对意见,最近看到一篇可以佐证我观点的文章这里给出相关链接并摘录出部分内容:

https://zhuanlan.zhihu.com/p/402198819

=========================================
 
 
在上面的那个文章中给出了讨论和实验:

-------------------------------------------------
假设batch_size=2,每个GPU计算的均值和方差都针对这两个样本而言的。而BN的特性是:batch_size越大,均值和方差越接近与整个数据集的均值和方差,效果越好。使用多块GPU时,会计算每个BN层在所有设备上输入的均值和方差。如果GPU1和GPU2都分别得到两个特征层,那么两块GPU一共计算2     4  个特征层的均值和方差,可以认为batch_size=4。注意:如果不用同步BN,而是每个设备计算自己的批次数据的均值方差,效果与单GPU一致,仅仅能提升训练速度;如果使用同步BN,效果会有一定提升,但是会损失一部分并行速度。

BN如何在不同设备之间同步?

下图为单GPU、以及是否使用同步BN训练的三种情况,可以看到使用同步BN(橙线)比不使用同步BN(蓝线)总体效果要好一些,不过训练时间也会更长。使用单GPU(黑线)和不使用同步BN的效果是差不多的。

-------------------------------------------------
 
 
我三年前的文章指出多卡/分布式情况下使用同步或异步的方式更新batch normalization算子中的参数会和单卡情况下的性能相似,而上面的这篇文章也同样验证了这个观点;甚至从上面的这个文章中可以看到多卡情况下同步更新batch normalization算子中的参数往往会得到更好的效果,当然这个性能相差的也不是十分的明显。
 
 
这里我甚至有个新的观点,那就是多卡情况下即使不对batch normalization算子在训练过程中更新(同步、异步更新都包括),而是在训练结束后再进行取均值的更新方式也不会有太大的性能差距,总结的来说就是我个人认为多卡/分布式环境下batch normalization算子的参数的计算使用下面三种方式都和单卡情况下相差不大:
1. 训练过程中同步更新batch normalization算子参数;
2. 训练过程中异步更新batch normalization算子参数;
3. 训练结束后再更新batch normalization算子参数;

不过这三种方式即使相差不大也必然虽然一个谁优谁劣的问题,而这个回答确实是难以给出的,因为这个定论需要对不同的数据集和任务进行计算,大量的获取各种情况下的最终性能指标才可以有个定论,不过这里也给出我的个人建议,那就是:
对性能要求较为严格的情况下建议使用第一种方式,即训练过程中同步更新batch normalization算子参数;而对性能要求的容忍度较大的情况下可以考虑使用第三种方式,也就是训练结束后再更新batch normalization算子;而对于第二种方式,也就是训练过程中异步更新batch normalization算子其实是要单独分析的,因为pytorch是本身不支持异步更新的,当然你可以自己来进行实现(官方只给了同步更新的code),而TensorFlow由于并不是像pytorch使用MPI而是使用自己公司的protobuffer因此可以完美的支持异步更新(异步更新需要考虑如何处理不同时延下的更新策略,需要单独设计分布式算法来决定何时合并参数何时抛弃参数),所以对于异步更新batch normalization算子参数的方式并不是很建议。
 
 

=========================================

multi-GPU环境下的batch normalization需要特殊实现吗?的更多相关文章

  1. Batch Normalization的算法本质是在网络每一层的输入前增加一层BN层(也即归一化层),对数据进行归一化处理,然后再进入网络下一层,但是BN并不是简单的对数据进行求归一化,而是引入了两个参数λ和β去进行数据重构

    Batch Normalization Batch Normalization是深度学习领域在2015年非常热门的一个算法,许多网络应用该方法进行训练,并且取得了非常好的效果. 众所周知,深度学习是应 ...

  2. 手把手教你在win10下搭建pytorch GPU环境(Anaconda+Pycharm)

    Anaconda指的是一个开源的Python发行版本,其主要优点如下: Anaconda默认安装了常见的科学计算包,用它搭建起Python环境后不用再费时费力安装这些包: Anaconda可以创建互相 ...

  3. 从Bayesian角度浅析Batch Normalization

    前置阅读:http://blog.csdn.net/happynear/article/details/44238541——Batch Norm阅读笔记与实现 前置阅读:http://www.zhih ...

  4. 《RECURRENT BATCH NORMALIZATION》

    原文链接 https://arxiv.org/pdf/1603.09025.pdf Covariate 协变量:在实验的设计中,协变量是一个独立变量(解释变量),不为实验者所操纵,但仍影响实验结果. ...

  5. How Does Batch Normalization Help Optimization?

    1. 摘要 BN 是一个广泛应用的用于快速稳定地训练深度神经网络的技术,但是我们对其有效性的真正原因仍然所知甚少. 输入分布的稳定性和 BN 的成功之间关系很小,BN 对训练过程更根本的影响是:它让优 ...

  6. [C2W3] Improving Deep Neural Networks : Hyperparameter tuning, Batch Normalization and Programming Frameworks

    第三周:Hyperparameter tuning, Batch Normalization and Programming Frameworks 调试处理(Tuning process) 目前为止, ...

  7. 深度解析Droupout与Batch Normalization

    Droupout与Batch Normalization都是深度学习常用且基础的训练技巧了.本文将从理论和实践两个角度分布其特点和细节. Droupout 2012年,Hinton在其论文中提出Dro ...

  8. Win10环境下YOLO5 快速配置与测试

    目录 一.更换官方源 二.安装Pytorch+CUDA(python版本) 三.YOLO V5 配置与验证 四.数据集测试 五.小结 不想看前面,可以直接跳到标题: 一.更换官方源 在 YOLO V5 ...

  9. WIN7环境下CUDA7.5的安装、配置和测试(Visual Studio 2010)

    以下基于"WIN7(64位)+Visual Studio 2010+CUDA7.5". 系统:WIN7,64位 开发平台:Visual Studio 2010 显卡:NVIDIA ...

  10. [CS231n-CNN] Training Neural Networks Part 1 : activation functions, weight initialization, gradient flow, batch normalization | babysitting the learning process, hyperparameter optimization

    课程主页:http://cs231n.stanford.edu/   Introduction to neural networks -Training Neural Network ________ ...

随机推荐

  1. java8 lambda Predicate示例

    import java.util.Arrays; import java.util.List; import java.util.function.Predicate; public class Pr ...

  2. skywalking需要引入的背景(查询调用链),传统的日志查询方法, 引入EFK日志搜索重要性

    1.根据两次请求日志的关键点来截取日志,缩小日志的范围.tail -f orderApi.log | grep "orderKeyWordSubmit"     确定两次异常请求的 ...

  3. 将静态文件打包进nuget里 Net Core

    我之前写了一个.net core 生成验证码的小工具 需要使用者先单独下载字体文件到本地在 install-package 感觉这样很捞也很不方便,但当时忙着做其他需求现在更新下. 其实很简单 vis ...

  4. MYSQL8.0-JSON函数简单示例-JSON_EXTRACT|JSON_VALUE|JSON_TABLE

    JSON类型在日常应用开发中,用得很少,个人通常用于存储常常变化的配置参数. 它适用于什么业务场景,不好说.就好像许多年前读到的一篇文章,说有个国外公司利用ORACLE的CLOB/BLOB管理一些信息 ...

  5. springboot使用mail提示没有该类型的bean

    @Autowired private JavaMailSenderImpl javaMailSender; 自动注入时提示没有该类型的Bean. 原因 没有配置邮件发送相关的配置信息. spring: ...

  6. NLP与深度学习(三)Seq2Seq模型与Attention机制

    1.   Attention与Transformer模型 Attention机制与Transformer模型,以及基于Transformer模型的预训练模型BERT的出现,对NLP领域产生了变革性提升 ...

  7. SpringBoot 对接美团闪购,检验签名,获取推送订单参数,text转json

    接口文档地址 订单推送(已确定订单):https://open-shangou.meituan.com/home/docDetail/177 签名算法:https://opendj.meituan.c ...

  8. c++临时对象导致的生命周期问题

    对象的生命周期是c++中非常重要的概念,它直接决定了你的程序是否正确以及是否存在安全问题. 今天要说的临时变量导致的生命周期问题是非常常见的,很多时候没有一定经验甚至没法识别出来.光是我自己写.rev ...

  9. 如何优雅地使用Mybatis逆向工程生成类

    文/朱季谦 1.环境:SpringBoot 2.在pom.xml文件里引入相关依赖: 1 <plugin> 2 <groupId>org.mybatis.generator&l ...

  10. oeasy教您玩转vim - 86 - # 外部命令external Command

    ​ 外部命令 external 回忆 上次研究的是global :[range]global/{pattern}/{command} range 是执行的范围 pattern 是搜索的模式 comma ...