学习笔记TF065:TensorFlowOnSpark
Hadoop生态大数据系统分为Yam、 HDFS、MapReduce计算框架。TensorFlow分布式相当于MapReduce计算框架,Kubernetes相当于Yam调度系统。TensorFlowOnSpark,利用远程直接内存访问(Remote Direct Memory Access,RDMA)解决存储功能和调度,实现深度学习和大数据融合。TensorFlowOnSpark(TFoS),雅虎开源项目。https://github.com/yahoo/TensorFlowOnSpark 。支持ApacheSpark集群分布式TensorFlow训练、预测。TensorFlowOnSpark提供桥接程序,每个Spark Executor启动一个对应TensorFlow进程,通过远程进程通信(RPC)交互。
TensorFlowOnSpark架构。TensorFlow训练程序用Spark集群运行,管理Spark集群步骤:预留,在Executor执行每个TensorFlow进程保留一个端口,启动数据消息监听器。启动,在Executor启动TensorFlow主函数。数据获取,TensorFlow Readers和QueueRunners机制直接读取HDFS数据文件,Spark不访问数据;Feeding,SparkRDD 数据发送TensorFlow节点,数据通过feed_dict机制传入TensorFlow计算图。关闭,关闭Executor TensorFlow计算节点、参数服务节点。Spark Driver->Spark Executor->参数服务器->TensorFlow Core->gRPC、RDMA->HDFS数据集。http://yahoohadoop.tumblr.com/post/157196317141/open-sourcing-tensorflowonspark-distributed-deep 。
TensorFlowOnSpark MNIST。https://github.com/yahoo/TensorFlowOnSpark/wiki/GetStarted_standalone 。Standalone模式Spark集群,一台计算机。安装 Spark、Hadoop。部署Java 1.8.0 JDK。下载Spark2.1.0版 http://spark.apache.org/downloads.html 。下载Hadoop2.7.3版 http://hadoop.apache.org/#Download+Hadoop 。0.12.1版本支持较好。
修改配置文件,设置环境变量,启动Hadoop:$HADOOP_HOME/sbin/start-all.sh。检出TensorFlowOnSpark源代码:
git clone --recurse-submodules https://github.com/yahoo/TensorFlowOnSpark.git
cd TensorFlowOnSpark
git submodule init
git submodule update --force
git submodule foreach --recursive git clean -dfx
源代码打包,提交任务使用:
cd TensorflowOnSpark/src
zip -r ../tfspark.zip *
设置TensorFlowOnSpark根目录环境变量:
cd TensorFlowOnSpark
export TFoS_HOME=$(pwd)
启动Spark主节点(master):
$(SPARK_HOME)/sbin/start-master.sh
配置两个工作节点(worker)实例,master-spark-URL连接主节点:
export MASTER=spark://$(hostname):7077
export SPARK_WORKER_INSTANCES=2
export CORES_PER_WORKER=1
export TOTAL_CORES=$(($(CORES_PER_WORKER)*$(SPARK_WORKER_INSTANCES)))
$(SPARK_HOME)/sbin/start-slave.sh -c $CORES_PER_WORKER -m 3G $(MASTER)
提交任务,MNIST zip文件转换为HDFS RDD 数据集:
$(SPARK_HOME)/bin/spark-submit \
--master $(MASTER) --conf spark.ui.port=4048 --verbose \
$(TFoS_HOME)/examples/mnist/mnist_data_setup.py \
--output examples/mnist/csv \
--format csv
查看处理过的数据集:
hadoop fs -ls hdfs://localhost:9000/user/libinggen/examples/mnist/csv
查看保存图片、标记向量:
hadoop fs -ls hdfs://localhost:9000/user/libinggen/examples/mnist/csv/train/labels
把训练集、测试集分别保存RDD数据。
https://github.com/yahoo/TensorFlowOnSpark/blob/master/examples/mnist/mnist_data_setup.py 。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
import tensorflow as tf
from array import array
from tensorflow.contrib.learn.python.learn.datasets import mnist
def toTFExample(image, label):
"""Serializes an image/label as a TFExample byte string"""
example = tf.train.Example(
features = tf.train.Features(
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=label.astype("int64"))),
'image': tf.train.Feature(int64_list=tf.train.Int64List(value=image.astype("int64")))
}
)
)
return example.SerializeToString()
def fromTFExample(bytestr):
"""Deserializes a TFExample from a byte string"""
example = tf.train.Example()
example.ParseFromString(bytestr)
return example
def toCSV(vec):
"""Converts a vector/array into a CSV string"""
return ','.join([str(i) for i in vec])
def fromCSV(s):
"""Converts a CSV string to a vector/array"""
return [float(x) for x in s.split(',') if len(s) > 0]
def writeMNIST(sc, input_images, input_labels, output, format, num_partitions):
"""Writes MNIST image/label vectors into parallelized files on HDFS"""
# load MNIST gzip into memory
# MNIST图像、标记向量写入HDFS
with open(input_images, 'rb') as f:
images = numpy.array(mnist.extract_images(f))
with open(input_labels, 'rb') as f:
if format == "csv2":
labels = numpy.array(mnist.extract_labels(f, one_hot=False))
else:
labels = numpy.array(mnist.extract_labels(f, one_hot=True))
shape = images.shape
print("images.shape: {0}".format(shape)) # 60000 x 28 x 28
print("labels.shape: {0}".format(labels.shape)) # 60000 x 10
# create RDDs of vectors
imageRDD = sc.parallelize(images.reshape(shape[0], shape[1] * shape[2]), num_partitions)
labelRDD = sc.parallelize(labels, num_partitions)
output_images = output + "/images"
output_labels = output + "/labels"
# save RDDs as specific format
# RDDs保存特定格式
if format == "pickle":
imageRDD.saveAsPickleFile(output_images)
labelRDD.saveAsPickleFile(output_labels)
elif format == "csv":
imageRDD.map(toCSV).saveAsTextFile(output_images)
labelRDD.map(toCSV).saveAsTextFile(output_labels)
elif format == "csv2":
imageRDD.map(toCSV).zip(labelRDD).map(lambda x: str(x[1]) + "|" + x[0]).saveAsTextFile(output)
else: # format == "tfr":
tfRDD = imageRDD.zip(labelRDD).map(lambda x: (bytearray(toTFExample(x[0], x[1])), None))
# requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
tfRDD.saveAsNewAPIHadoopFile(output, "org.tensorflow.hadoop.io.TFRecordFileOutputFormat",
keyClass="org.apache.hadoop.io.BytesWritable",
valueClass="org.apache.hadoop.io.NullWritable")
# Note: this creates TFRecord files w/o requiring a custom Input/Output format
# else: # format == "tfr":
# def writeTFRecords(index, iter):
# output_path = "{0}/part-{1:05d}".format(output, index)
# writer = tf.python_io.TFRecordWriter(output_path)
# for example in iter:
# writer.write(example)
# return [output_path]
# tfRDD = imageRDD.zip(labelRDD).map(lambda x: toTFExample(x[0], x[1]))
# tfRDD.mapPartitionsWithIndex(writeTFRecords).collect()
def readMNIST(sc, output, format):
"""Reads/verifies previously created output"""
output_images = output + "/images"
output_labels = output + "/labels"
imageRDD = None
labelRDD = None
if format == "pickle":
imageRDD = sc.pickleFile(output_images)
labelRDD = sc.pickleFile(output_labels)
elif format == "csv":
imageRDD = sc.textFile(output_images).map(fromCSV)
labelRDD = sc.textFile(output_labels).map(fromCSV)
else: # format.startswith("tf"):
# requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
tfRDD = sc.newAPIHadoopFile(output, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
keyClass="org.apache.hadoop.io.BytesWritable",
valueClass="org.apache.hadoop.io.NullWritable")
imageRDD = tfRDD.map(lambda x: fromTFExample(str(x[0])))
num_images = imageRDD.count()
num_labels = labelRDD.count() if labelRDD is not None else num_images
samples = imageRDD.take(10)
print("num_images: ", num_images)
print("num_labels: ", num_labels)
print("samples: ", samples)
if __name__ == "__main__":
import argparse
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--format", help="output format", choices=["csv","csv2","pickle","tf","tfr"], default="csv")
parser.add_argument("-n", "--num-partitions", help="Number of output partitions", type=int, default=10)
parser.add_argument("-o", "--output", help="HDFS directory to save examples in parallelized format", default="mnist_data")
parser.add_argument("-r", "--read", help="read previously saved examples", action="store_true")
parser.add_argument("-v", "--verify", help="verify saved examples after writing", action="store_true")
args = parser.parse_args()
print("args:",args)
sc = SparkContext(conf=SparkConf().setAppName("mnist_parallelize"))
if not args.read:
# Note: these files are inside the mnist.zip file
writeMNIST(sc, "mnist/train-images-idx3-ubyte.gz", "mnist/train-labels-idx1-ubyte.gz", args.output + "/train", args.format, args.num_partitions)
writeMNIST(sc, "mnist/t10k-images-idx3-ubyte.gz", "mnist/t10k-labels-idx1-ubyte.gz", args.output + "/test", args.format, args.num_partitions)
if args.read or args.verify:
readMNIST(sc, args.output + "/train", args.format)
提交训练任务,开始训练,在HDFS生成mnist_model,命令:
${SPARK_HOME}/bin/spark-submit \
--master ${MASTER} \
--py-files ${TFoS_HOME}/examples/mnist/spark/mnist_dist.py \
--conf spark.cores.max=${TOTAL_CORES} \
--conf spark.task.cpus=${CORES_PER_WORKER} \
--conf spark.executorEnv.JAVA_HOME="$JAVA_HOME" \
${TFoS_HOME}/examples/mnist/spark/mnist_spark.py \
--cluster_size ${SPARK_WORKER_INSTANCES} \
--images examples/mnist/csv/train/images \
--labels examples/mnist/csv/train/labels \
--format csv \
--mode train \
--model mnist_model
mnist_dist.py 构建TensorFlow 分布式任务,定义分布式任务主函数,启动TensorFlow主函数map_fun,数据获取方式Feeding。获取TensorFlow集群和服务器实例:
cluster, server = TFNode.start_cluster_server(ctx, 1, args.rdma)
TFNode调用tfspark.zip TFNode.py文件。
mnist_spark.py文件是训练主程序,TensorFlowOnSpark部署步骤:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
import argparse
import os
import numpy
import sys
import tensorflow as tf
import threading
import time
from datetime import datetime
from tensorflowonspark import TFCluster
import mnist_dist
sc = SparkContext(conf=SparkConf().setAppName("mnist_spark"))
executors = sc._conf.get("spark.executor.instances")
num_executors = int(executors) if executors is not None else 1
num_ps = 1
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch_size", help="number of records per batch", type=int, default=100)
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("-f", "--format", help="example format: (csv|pickle|tfr)", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
parser.add_argument("-X", "--mode", help="train|inference", default="train")
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
args = parser.parse_args()
print("args:",args)
print("{0} ===== Start".format(datetime.now().isoformat()))
if args.format == "tfr":
images = sc.newAPIHadoopFile(args.images, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
keyClass="org.apache.hadoop.io.BytesWritable",
valueClass="org.apache.hadoop.io.NullWritable")
def toNumpy(bytestr):
example = tf.train.Example()
example.ParseFromString(bytestr)
features = example.features.feature
image = numpy.array(features['image'].int64_list.value)
label = numpy.array(features['label'].int64_list.value)
return (image, label)
dataRDD = images.map(lambda x: toNumpy(str(x[0])))
else:
if args.format == "csv":
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
else: # args.format == "pickle":
images = sc.pickleFile(args.images)
labels = sc.pickleFile(args.labels)
print("zipping images and labels")
dataRDD = images.zip(labels)
#1.为在Executor执行每个TensorFlow进程保留一个端口
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.SPARK)
#2.启动Tensorflow主函数
cluster.start(mnist_dist.map_fun, args)
if args.mode == "train":
#3.训练
cluster.train(dataRDD, args.epochs)
else:
#3.预测
labelRDD = cluster.inference(dataRDD)
labelRDD.saveAsTextFile(args.output)
#4.关闭Executor TensorFlow计算节点、参数服务节点
cluster.shutdown()
print("{0} ===== Stop".format(datetime.now().isoformat()))
预测命令:
${SPARK_HOME}/bin/spark-submit \
--master ${MASTER} \
--py-files ${TFoS_HOME}/examples/mnist/spark/mnist_dist.py \
--conf spark.cores.max=${TOTAL_CORES} \
--conf spark.task.cpus=${CORES_PER_WORKER} \
--conf spark.executorEnv.JAVA_HOME="$JAVA_HOME" \
${TFoS_HOME}/examples/mnist/spark/mnist_spark.py \
--cluster_size ${SPARK_WORKER_INSTANCES} \
--images examples/mnist/csv/test/images \
--labels examples/mnist/csv/test/labels \
--mode inference \
--format csv \
--model mnist_model \
--output predictions
还可以Amazon EC2运行及在Hadoop集群采用YARN模式运行。
参考资料:
《TensorFlow技术解析与实战》
欢迎推荐上海机器学习工作机会,我的微信:qingxingfengzi
学习笔记TF065:TensorFlowOnSpark的更多相关文章
- js学习笔记:webpack基础入门(一)
之前听说过webpack,今天想正式的接触一下,先跟着webpack的官方用户指南走: 在这里有: 如何安装webpack 如何使用webpack 如何使用loader 如何使用webpack的开发者 ...
- PHP-自定义模板-学习笔记
1. 开始 这几天,看了李炎恢老师的<PHP第二季度视频>中的“章节7:创建TPL自定义模板”,做一个学习笔记,通过绘制架构图.UML类图和思维导图,来对加深理解. 2. 整体架构图 ...
- PHP-会员登录与注册例子解析-学习笔记
1.开始 最近开始学习李炎恢老师的<PHP第二季度视频>中的“章节5:使用OOP注册会员”,做一个学习笔记,通过绘制基本页面流程和UML类图,来对加深理解. 2.基本页面流程 3.通过UM ...
- 2014年暑假c#学习笔记目录
2014年暑假c#学习笔记 一.C#编程基础 1. c#编程基础之枚举 2. c#编程基础之函数可变参数 3. c#编程基础之字符串基础 4. c#编程基础之字符串函数 5.c#编程基础之ref.ou ...
- JAVA GUI编程学习笔记目录
2014年暑假JAVA GUI编程学习笔记目录 1.JAVA之GUI编程概述 2.JAVA之GUI编程布局 3.JAVA之GUI编程Frame窗口 4.JAVA之GUI编程事件监听机制 5.JAVA之 ...
- seaJs学习笔记2 – seaJs组建库的使用
原文地址:seaJs学习笔记2 – seaJs组建库的使用 我觉得学习新东西并不是会使用它就够了的,会使用仅仅代表你看懂了,理解了,二不代表你深入了,彻悟了它的精髓. 所以不断的学习将是源源不断. 最 ...
- CSS学习笔记
CSS学习笔记 2016年12月15日整理 CSS基础 Chapter1 在console输入escape("宋体") ENTER 就会出现unicode编码 显示"%u ...
- HTML学习笔记
HTML学习笔记 2016年12月15日整理 Chapter1 URL(scheme://host.domain:port/path/filename) scheme: 定义因特网服务的类型,常见的为 ...
- DirectX Graphics Infrastructure(DXGI):最佳范例 学习笔记
今天要学习的这篇文章写的算是比较早的了,大概在DX11时代就写好了,当时龙书11版看得很潦草,并没有注意这篇文章,现在看12,觉得是跳不过去的一篇文章,地址如下: https://msdn.micro ...
随机推荐
- mybatis源码解析之Configuration加载(二)
概述 上一篇我们讲了configuation.xml中几个标签的解析,例如<properties>,<typeAlises>,<settings>等,今天我们来介绍 ...
- numpy中的stack操作:hstack()、vstack()、stack()、dstack()、vsplit()、concatenate()
stack():沿着新的轴加入一系列数组. vstack():堆栈数组垂直顺序(行) hstack():堆栈数组水平顺序(列). dstack():堆栈数组按顺序深入(沿第三维). concatena ...
- 004dayPython学习输入并输出用户名和密码
在python 2.7中,捕获用户输入用raw_input 一.捕获并打印用户名和密码 要求: 输入用户名和密码都可见 # -*- coding:utf-8 -*-userName = raw_inp ...
- 第六次作业———numpy数据集练习
1. 安装scipy,numpy,sklearn包 2. 从sklearn包自带的数据集中读出鸢尾花数据集data 3.查看data类型,包含哪些数据 4.取出鸢尾花特征和鸢尾花类别数据,查看其形状及 ...
- [Deep Learning] mini-batch
转自 http://hp.stuhome.net/index.php/2016/09/20/tensorflow_batch_minibatch/ 深度学习的优化算法,说白了就是梯度下降.每次的参数更 ...
- 使用datepicker和uploadify的冲突解决(IE双击才能打开附件上传对话框)
在开发的过程当中,IE的兼容无疑是我们的一块绊脚石,在我们使用的如期的datepicker插件和使用上传附件的uploadify插件的时候,两者就产生冲突,只要点击过时间的插件,uploadify上传 ...
- MyBatis通过Mapper动态代理来实现curd操作
MyBatis官方推荐使用mapper代理方法开发mapper接口,程序员不需要编写mapper实现类,使用mapper代理方法时,输入参数可以使用pojo包装对象或者map对象,保证dao的通用性 ...
- [转载]Fiddler为所欲为第四篇 直播源抓取与接口分析 [四]
今天的教程,主要是教大家如何进行“封包逆向”,关键词跳转,接口分析.(怎么样,是不是感觉和OD很像~~~)今天的教程我们以[麻花影视]为例,当然,其他APP的逻辑也是一样,通用的哦~ 首先需要做好准备 ...
- Delphi编程之好用的三方控件
Delphi的强大与其庞大的组件库息息相关,目前的XE10.1版本已自带FastReport和GDI+等,下面我们来看一下几个非常强大且实用的组件库 一.DevExpress套件 Dev包含Grid. ...
- 配置3层交换机VLAN间通信
SW2 Switch>en Switch#conf t Enter configuration commands, one per line. End with CNTL/Z. Switch(c ...