『TensorFlow』分布式训练_其一_逻辑梳理
1,PS-worker架构
将模型维护和训练计算解耦合,将模型训练分为两个作业(job):
- 模型相关作业,模型参数存储、分发、汇总、更新,有由PS执行
- 训练相关作业,包含推理计算、梯度计算(正向/反向传播),由worker执行
该架构下,所有的woker共享PS上的参数,并按照相同的数据流图传播不同batch的数据,计算出不同的梯度,交由PS汇总、更新新的模型参数,大体逻辑如下:
- pull:各个woker根据数据流图拓扑结构从PS获取最新的模型参数
- feed:各个worker根据定义的规则填充各自batch的数据
- compute:各个worker使用第一步的模型参数计算各自的batch数据,求出各自batch的梯度
- push:各个worker将各自的梯度推送到PS
- update:PS汇总来自n个worker的n份梯度,求出平均值后更新模型参数
分布式经典架构PS-worker会重复上面步骤,直到损失到达阈值或者轮数到达阈值。
2,数据并行模式分类
根据数据流图构建模式分类:
- 图内复制:单进程、‘单机多卡’的数据并行训练,需要用户自己实现梯度汇总和均值计算。实例,models/tutorials/image/cifar10/cifer10_multi_gpu-train.py(见下节)
- 图间复制:多进程、跨多机的分布式训练,使用同步优化器(SyncReplicasOptimizer)实现分布式梯度计算和模型参数更新。实例,tensorflow/tools/dist_test/python/mnist_replica.py(分布式同步训练实践,见下下节)
根据参数更新机制分类:
- 异步训练:各个worker独立训练,计算出梯度后即刻更新参数,不需要等待其他worker完成计算
- 同步训练:所有worker完成本轮计算后,汇总梯度,更新模型,计算能力强的worker需要阻塞等待其他worker
两种训练机制同时支持上面两周数据流图构建模式。一般来说同步机制收敛快,异步单步计算快,但易受单批数据影响,不稳定。
3,同步优化器
tensorflow进行同步(同步训练模式专用)各个worker梯度并进行优化时,会使用特殊的优化器即同步优化器,tf.train.SyncReplicasOptimizer,其第一个参数为普通优化器,我们可以定义一个普通的优化器传入,后续参数如下:
| 参数名称 | 功能说明 | 默认值 |
| replicas_to_aggragate | 并行副本数 | num_workers |
| total_num_replicas | 实际副本数(worker数目) | num_workers |
并行副本数指期望的每一步中并行的batch数据数目,实际副本数指参与的workers数目,
- 并行=实际:全民参与,一个worker领取一个batch数据
- 并行>实际:能者多劳,先完成自己batch的worker会继续领取未训练数据,PS会等到梯度份数到达并行数后进行模型参数计算
- 并行<实际:替补等位,存在空闲的worker,取代可能出现的异常worker,确保训练过程高可用
运算过程
- 计算梯度过程同普通优化器,调用基类的Optimizer的compute_gradients成员方法
- 更新参数时重写了Optimizer的apply_gradients方法,见tensorflow/python/training/sync_replicas_optimizer.py
讲解同步优化器工作逻辑之前,介绍两个概念,
梯度聚合器
每一个模型参数有一个自己队列,收集来自不同worker的梯度值,梯度聚合器包含M个队列对应M个模型参数,每个队列收集来自N个worker计算出来的N个梯度值。
同步标记队列
存储同步标记,实际上就是N个global_step值,每个worker领取一个,用于控制同步
以全民参与模式为例
worker工作模式如下:
- 从同步标记队列领取一个global_step,表示全局训练步数的同步标记
- 将同步标记值赋予worker的本地训练步数local_step
- 从PS获取最新模型参数
- 计算出M个梯度值
- 将M个梯度值推送到PS上的M个梯度队列中
PS工作模式如下:
- 从梯度聚合器上收集worker推送过来的梯度值,每个队列收集N份(对应N个global_step下训练值)后,计算均值,收集齐M个均值后,得到M对{模型参数,梯度值}的聚合元组
- 更新模型参数
- 向同步标记队列推送N个global_step+1标记
聚合器收集梯度值并校验local_step是否符合global_step,是则接收梯度值,计算能力强的worker提交梯度后由于没有同步标记可以领取所以被阻塞,PS集齐N份后更新参数,发布下次N个同步标记,开始下一步训练。
由于初始PS不会更新参数发布同步标记,所以需要初始化同步标记队列——sync_init_op,直接向队列注入N个0标记。
分布式模型训练需要的主要初始化操作如下(opt指tf.train.SyncReplicasOptimizer):
| 操作名称 | 常用变量名 | 功能说明 |
| opt.local_step_init_op | local_init_op | loacl_step初始值 |
| pot.chief_init_op | local_init_op | gobal_step初始值 |
| opt.ready_for_local_init_op | ready_for_local_init_op | 为未初始化的Variable设置初始值 |
| opt.get_chief_queue_runner | chief_queue_runner | 同步标记队列启动QueueRunner实例 |
| opt.get_init_tockens_op | sync_init_op | 同步标记队列初始化 |
| tf.global_variables_initializer | init_op | 全局Variable设置初始值 |
如果使用模型管理类Supervsor,可以将大部分工作交由其代劳。
以能者多劳模式对比
模型参数个数M,worker个数N,并行副本数R(R>N),此时
梯度聚合器仍然有M个参数收集队列,每一个队列要收集R份才进行汇总,R>N所以会存在某个worker领取多份数据的情况。
同步标记队列存储R个同步标记,以确保每一步中梯度聚合器可以收集到R份数据。
4,异步优化器
异步优化器没有很多附加参量,和单机训练几乎一致,只是每个worker获取参数需要从另一个进程PS中得到而已。
5,模型管理类Supervsor
本质上是对Saver(模型参数存储恢复)、Coordinator(多线程服务生命周期管理)、SessionManager(单机以及分布式会话管理)三个类的封装
Coordinator会监测程序的线程是否运行正常,任何异常的出现都会向Supervisor报告,此时Coordinator讲程序的停止条件设置为True,Supervisor停止训练并清理工作(关闭会话、回收内存等),其他服务检测到True后会各自关闭服务,终止线程。
SessionManager帮助用户创建管理单机或是分布式会话,以便简化数据流图的生命周期和维护逻辑,同事负责将checkpoint文件中加载出的数据恢复到数据流图中。
流程逻辑如下:
- 创建Supervisor实例,构造方法需要传入checkpoint文件和summary文件存储目录(Supervisor的logdir参数)
- 调用tf.train.Supervisor.managed_session,从Supervisor实例获取会话实例
- 使用该会话执行训练,训练中需要检查停止条件,保证训练正确性
获取managed_session时,Supervisor会通过QueueRunner同时启动一下三个服务:
- 检查点服务:将数据流图中的参数定期保存,默认10min保存一次,且会识别global_step(Supervisor的global_step参数)
- 汇总服务:默认2min一次
- 步数计数器服务:向汇总添加global_step/sec,2min一次
使用managed_session创建会话时,会自动恢复上一次的结果并继续训练。
『TensorFlow』分布式训练_其一_逻辑梳理的更多相关文章
- 『TensorFlow』分布式训练_其三_多机分布式
本节中的代码大量使用『TensorFlow』分布式训练_其一_逻辑梳理中介绍的概念,是成熟的多机分布式训练样例 一.基本概念 Cluster.Job.task概念:三者可以简单的看成是层次关系,tas ...
- 『TensorFlow』分布式训练_其二_单机多GPU并行&GPU模式设定
建议比对『MXNet』第七弹_多GPU并行程序设计 一.tensorflow GPU设置 GPU指定占用 gpu_options = tf.GPUOptions(per_process_gpu_mem ...
- 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍
一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...
- 『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理
Fork版本项目地址:SSD 一.TFR数据读取 创建slim.dataset.Dataset对象 在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dat ...
- 『TensorFlow』SSD源码学习_其八:网络训练
Fork版本项目地址:SSD 作者使用了分布式训练的写法,这使得训练部分代码异常臃肿,我给出了部分注释.我对于多机分布式并不很熟,而且不是重点,所以不过多介绍,简单的给出一点训练中作者的优化手段,包含 ...
- 『TensorFlow』SSD源码学习_其四:数据介绍及TFR文件生成
Fork版本项目地址:SSD 一.数据格式介绍 数据文件夹命名为VOC2012,内部有5个子文件夹,如下, 我们的检测任务中使用JPEGImages文件夹和Annotations文件夹. JPEGIm ...
- 『TensorFlow』网络操作API_中_损失函数及分类器
一.误差值 度量两个张量或者一个张量和零之间的损失误差,这个可用于在一个回归任务或者用于正则的目的(权重衰减). l2_loss tf.nn.l2_loss(t, name=None) 解释:这个函数 ...
- 『TensorFlow』SSD源码学习_其二:基于VGG的SSD网络前向架构
Fork版本项目地址:SSD 参考自集智专栏 一.SSD基础 在分类器基础之上想要识别物体,实质就是 用分类器扫描整张图像,定位特征位置 .这里的关键就是用什么算法扫描,比如可以将图片分成若干网格,用 ...
- 『TensorFlow』SSD源码学习_其三:锚框生成
Fork版本项目地址:SSD 上一节中我们定义了vgg_300的网络结构,实际使用中还需要匹配SSD另一关键组件:被选取特征层的搜索网格.在项目中,vgg_300网络和网格生成都被统一进一个class ...
随机推荐
- Lazarus分体式改成一体式窗口
安装包 anchordocking和Sparta_DockedFormEditor 然后点选保存并重新编译IDE即可
- BZOJ 3622 已经没有什么好怕的了
扯淡 看到题目想到二项式反演 然后忘了给求阶乘的时候取模,调了一晚上 真令人窒息 思路 二项式反演 首先二项式反演还有另一种形式(不会证) 设\(G_i\)为有至少i个的方案数量,\(F_i\)为恰好 ...
- 一文读懂 深度强化学习算法 A3C (Actor-Critic Algorithm)
一文读懂 深度强化学习算法 A3C (Actor-Critic Algorithm) 2017-12-25 16:29:19 对于 A3C 算法感觉自己总是一知半解,现将其梳理一下,记录在此,也 ...
- Redis,Memcache比较
简单比较: Redis不仅仅支持简单的k/v类型的数据,同时还提供list,set,hash等数据结构的存储.memcache只支持简单的K/V类型数据, 不过memcache可以缓存其他东西如图片, ...
- 【NOIP 2018】Day2 T3 保卫王国
Problem Description Z 国有\(n\)座城市,\(n - 1\)条双向道路,每条双向道路连接两座城市,且任意两座城市 都能通过若干条道路相互到达. Z 国的国防部长小 Z 要在城市 ...
- sqlserver 中常见的函数 数学函数
create table testnum( ID int identity(1,1), num float) insert testnum values (1) insert testnum valu ...
- 3.git、TortoiseGit的安装、仓库的配置教程
参考:https://blog.csdn.net/hc_ttxs/article/details/79375788 引言: Git: 就是最原始的分布式版本控制系统,是开源的. GitHub:与Git ...
- Python3 数据库连接
PyMySQL是在Python3.x版本中用于连接MySQL服务器的一个库,Python2中使用mysqldb. 数据库连接 连接数据库前,请先确认一下事项: 已经创建数据库testdb. 在test ...
- Token和SessionStorage(会话存储对象)
sessionStorage数据只在当前标签页共享 存在本地 关闭浏览器后会清除数据(关闭标签页不会清楚) localStorage数据会存在浏览器中 浏览器关了数据也还在 只有清除缓存才会消失 ...
- linux网络常用命令
1,显示网桥 brctl show2,显示ip ip a3,查看openvswitch的配置信息 ovs-vsctl show4,显示网络命名空间 ip netns5,显示DHCP信息 ps -ef ...