本文目的

在介绍estimator分布式的时候,官方文档由于版本更新导致与接口不一致。具体是:在estimator分布式当中,使用dataset作为数据输入,在1.12版本中,数据训练只是dataset的数据,就是所有设备加起来,跑一遍数据。

而在2.0版本中,训练数据是dataset的数据乘以分

布式的设备数。也就是说,在每个设备当中都会完整地跑一遍dataset的所有数据。

1.12版本读取

1. 在主线程当中创建图

下面这段代码中,在client中调用了input function,得到迭代器。这是属于estimator distribute train调用的代码

with ops.Graph().as_default() as g:
# We want to create the iterations variable outside the distribution scope
# as that is just stored on the host and mainly used to drive the loop
# and doesn't need to be a Mirrored/Device variable.
if is_tpu_strategy:
steps_per_run_variable = training.get_or_create_steps_per_run_variable()
with self._train_distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
iterator, input_hooks = self._get_iterator_from_input_fn(
input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
  • _get_iterator_from_input_fn * 这个函数会生成迭代器供后续训练读取数据。
  def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
if distribution is not None:
result = distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
else:
result = self._call_input_fn(input_fn, mode) iterator = result.make_initializable_iterator()
input_hooks = [estimator_util._DatasetInitializerHook(iterator)] # pylint: disable=protected-access
return iterator, input_hooks

这里会调用distribute_dataset生成dataset。

再点进去看以后可看到会创建这样一个PerDeviceDataset

class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices # Default to using prefetching in graph mode, unless specified.
# TODO(priyag): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently") if self._prefetch_on_device:
self._dataset = dataset.apply(
prefetching_ops_v2.prefetch_to_devices(self._devices))
else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
self._dataset = dataset.batch(len(devices), drop_remainder=True)

最后一行代码可以看到,在原dataset上又封装了一层batch。将数据根据设备数切分。

后面创建迭代器也是封装为PerDeviceDataIterator,形成一个字典映射,不同设备不同数据,根据batch 的index切分。

分布式训练

在1.12版本中的训练比较简单。对于MirroredStrategy来说,会给每个一个device创建一个线程,

有一个缺点就是,每一次run都会创建线程,在todo里看到,后续会优化掉应该。

下面是在client中从迭代器获取数据,传递给每个device去运算的代码,

self._train_distribution.call_for_each_tower

features, labels = estimator_util.parse_iterator_result(
iterator.get_next())
grouped_estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
labels, # although this will be None it seems
model_fn_lib.ModeKeys.TRAIN,
self.config)
loss = self._train_distribution.unwrap(
self._train_distribution.reduce(
distribute_lib.get_loss_reduction(),
grouped_estimator_spec.loss,
destinations='/device:CPU:0'))[0]
distributed_train_op = grouped_estimator_spec.train_op

call_for_each_tower是每个设备训练的接口

def _call_for_each_tower(distribution, fn, *args, **kwargs):
"""Run `fn` in separate threads, once per tower/worker device.
run_concurrently = kwargs.pop("run_concurrently", True)
if not context.executing_eagerly():
# Lots of TF library code isn't thread-safe in graph mode, and
# there is little to be gained by turning on multithreading when
# constructing a graph.
run_concurrently = False
# Needed for per-thread device, etc. contexts in graph mode.
ops.get_default_graph().switch_to_thread_local()
elif run_concurrently is None:
run_concurrently = True coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run()
# call.
threads = []
for index, d in enumerate(distribution.worker_devices):
variable_creator_fn = shared_variable_creator.make_fn(
shared_variable_store, index)
t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access
distribution, coord, d, variable_creator_fn, fn,
*values.select_device(d, args), **values.select_device(d, kwargs))
threads.append(t) for t in threads:
t.start()

其中,select_device就是取对应设备key对应的值。完成整个分布式训练。

TensorFlow Distribution(分布式中的数据读取和训练)的更多相关文章

  1. DataTable to Excel(使用NPOI、EPPlus将数据表中的数据读取到excel格式内存中)

    /// <summary> /// DataTable to Excel(将数据表中的数据读取到excel格式内存中) /// </summary> /// <param ...

  2. TensorFlow走过的坑之---数据读取和tf中batch的使用方法

    首先介绍数据读取问题,现在TensorFlow官方推荐的数据读取方法是使用tf.data.Dataset,具体的细节不在这里赘述,看官方文档更清楚,这里主要记录一下官方文档没有提到的坑,以示" ...

  3. oracle中的数据读取与查找

    数据读取 首先数据块读入到Buffer Cache中,并将其放在LRU(Last Recently Used)链表的MRU(Most Recently Used)端,当需要再次访问该块时可以直接从bu ...

  4. c#中使用数据读取器读取查询结果

    今天有时间了. 在看<c#数据库入门经典> ,总结数据读取器查询结果. 针对单个结果集使用读取器,有3中方法: String connString =..; String sql =@&q ...

  5. 如何在ADO中使用数据读取器(DataReader)读取数据

    DbDataReader类型(实现IDataReader接口)是从数据源获取信息最简单也最快速的方法. 数据读取器是只读向前的效据流.井且一次返回一条记录.因此.只有当你向数据源提交 Select 查 ...

  6. 《TensorFlow实战》中AlexNet卷积神经网络的训练中

    TensorFlow实战中AlexNet卷积神经网络的训练 01 出错 TypeError: as_default() missing 1 required positional argument: ...

  7. Android中Json数据读取与创建

    一:  Json的特性和在数据交互中的地位就不用说了,直接看案例. 首先在android studio中创建assets文件目录,用于存放Json数据文件,android studio 1.3 默认项 ...

  8. Android中Json数据读取与创建的方法

    转自:http://www.jb51.net/article/70875.htm 首先介绍下JSON的定义,JSON是JavaScript Object Notation的缩写. 一种轻量级的数据交换 ...

  9. TensorFlow实践笔记(一):数据读取

    本文整理了TensorFlow中的数据读取方法,在TensorFlow中主要有三种方法读取数据: Feeding:由Python提供数据. Preloaded data:预加载数据. Reading ...

随机推荐

  1. .net core开发从未如此简单,比abp更接地气

    在谈起java一家独大的时候,dotnet人员总是一边嘲笑大量滥竽充数的java从业者,一边羡慕人家的生态.以前是只能羡慕,现在dotnet core开源了,我们都可以为dotnet core的开原生 ...

  2. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  3. Mybatis与Spring集成时都做了什么?

    Mybatis是java开发者非常熟悉的ORM框架,Spring集成Mybatis更是我们的日常开发姿势. 本篇主要讲Mybatis与Spring集成所做的事情,让读过本文的开发者对Mybatis和S ...

  4. Selenium+java - 调用JavaScript操作

    前言 在做web自动化时,有些情况selenium的api无法完成,需要通过第三方手段比如js来完成实现,比如去改变某些元素对象的属性或者进行一些特殊的操作,本文将来讲解怎样来调用JavaScript ...

  5. 佳木斯集训Day1

    23333第一次写博客 其实在佳木斯集训之前我都已经两三个月没打代码了 在佳木斯的时候前几天真心手生,导致了前几次考试考的很差... D1的考试还是比较良心的,T1是一道大模拟,直接枚举最后几位是00 ...

  6. 为什么我们不用JIRA

    很多人问我,缺陷管理工具,为什么不用jira?而去自己造轮子开发一款bug记录系统 缄默如我,原因众多.如果只是3-5分钟就能讲的请的时候,我会先列出什么糟点呢? 1. 收费,一个人一个月的费用差不多 ...

  7. hadoop hdfs 分布式存储

    1.克隆前的工作 1.配置好网络nat  需要设置静态ip并能通过主机上网 ssh   和  rsync  是必须下载的 2.yum install vim wget  rsync  ssh   并配 ...

  8. 不相交路径[BZOJ1471] 容斥原理 拓扑排序

    最近学容斥的时候又碰到一道类似的题目,所以想分享一个套路,拿这题来举例 [题目描述] 给出一个\(N(N\leq 150)\)个结点的有向无环简单图.给出4个不同的点\(a,b,c,d\),定义不相交 ...

  9. Linux--shell重定向与文件处理命令--02

    一.IO重定向 1.数据输入:键盘---标准输入,但并不是唯一输入方式 ” | passwd –stdin username #同时添加用户和密码 while line;do 循环体...$line ...

  10. HelloDjango 第 08 篇:开发博客文章详情页

    作者:HelloGitHub-追梦人物 文中涉及的示例代码,已同步更新到 HelloGitHub-Team 仓库 首页展示的是所有文章的列表,当用户看到感兴趣的文章时,他点击文章的标题或者继续阅读的按 ...