本文首发于微信公众号「对白的算法屋」

大家好,我是对白。

目前,越来越多的互联网公司内部都有自己的一套框架去训练模型,而模型训练时需要的数据则都保存在分布式文件系统(HDFS)上。Hive作为构建在HDFS上的一个数据仓库,它本质上可以看作是一个翻译器,可以将HiveSQL语句翻译成MapReduce程序或Spark程序,因此模型需要的数据例如csv/libsvm文件都会保存成Hive表并存放在HDFS上,那么问题就来了,如何大规模地把HDFS中的数据直接喂到Tensorflow中呢?Tensorflow提供了一种解决方法:spark-tensorflow-connector,支持将spark DataFrame格式数据直接保存为TFRecords格式数据,接下来就带大家了解一下TFRecord的原理、构成和如何生成TFRecords文件。


TFRecord介绍

TFRecord是Tensorflow训练和推断标准的数据存储格式之一,将数据存储为二进制文件(二进制存储具有占用空间少,拷贝和读取(from disk)更加高效的特点),而且不需要单独的标签文件了,其本质是一行行字节字符串构成的样本数据。

一条TFRecord数据代表一个Example,一个Example就是一个样本数据,每个Example内部由一个字典构成,每个key对应一个Feature,key为字段名,Feature为字段名所对应的数据,Feature有三种数据类型:ByteList、FloatList,Int64List。

TFRecord构成

它实质上是由protobuf定义的一种数据协议,其中tensorflow提供了两种Example表示形式 Example和SequenceExample。它的定义代码位于[tensroflow/core/example/example.proto & feature.proto]。

Example和SequenceExample的定义:

message Example {
Features features = 1;
};
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
}; message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
}; // Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
}; // Containers for sequential data.
//
// A FeatureList contains lists of Features. These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name. The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
repeated Feature feature = 1;
}; message FeatureLists {
// Map from feature name to feature list.
map<string, FeatureList> feature_list = 1;
};

我们这里以最常用的Example来进行解释。从图中可以看到,在样本生产环节,每个Example内部由一个dict构成,每个key(string)对应着一个Feature结构,这个Feature结构有三种具体形式,分别是ByteList,FloatList,Int64List三种。这三种形式便可以承载string,bytes,float,double,int,long等多种样本结构,并且基于list的表示,使得我们既可以表达scalar,也可以表达vector类型的数据(注意如果想要将一个matrix保存到到一个Feature内,其值需要时按照Row-Major拍平的1-D array, 行列数据需使用额外字段保存,方便反序列化)。这里需要注意的是,我们在序列化的时候,并未将格式信息序列化进去,实质上,序列化后的,每条tfrecord中的数据,只具有以下数据:

TFRecord中每条数据的格式:

uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data

因此我们可以看出来,TFRecord并不是一个self-describing的格式,也就是说,tfrecord的write和read都需要额外指明schema。从上图我们也能看出来,在实际训练的时候,样本都需要经过一个知晓了Schema的Parser来进行解析,然后才能传递给Tensorflow进行实际的训练。

注:这里只展示了CTR场景常使用的Example,当然也有图像等场景需要使用SequenceExample进行一些样本的结构化表达,这里不做展开。根据官方文档来看,SequenceExample主要是使用在时序特征和视频特征。其中context字段描述的是和当期时间和特征不相关的共性数据,而feature_list则持有和时间或者视频帧相关的数据。感兴趣可以参考youtube-8M这个数据集中关于样本数据的表示。

TFRecord的生成(小规模)

TFRecord的生成=Example序列化+写入TFRecord文件

构建Example时需要指定格式信息(字典)key是特征,value是BytesList/FloatList/Int64List值,但Example序列化时并未将格式信息序列化进去,因此读取TFRecord文件需要额外指明schema。

每个Example会序列化成字节字符串并写入TFRecord文件中,代码如下:

import tensorflow as tf

# 回忆上一小节介绍的,每个Example内部实际有若干种Feature表达,下面
# 的四个工具方法方便我们进行Feature的构造
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _int64list_feature(value_list):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list)) # Example序列化成字节字符串
def serialize_example(user_id, city_id, app_type, viewd_pois, avg_paid, comment):
# 注意我们需要按照格式来进行数据的组装,这里的dict便按照指定Schema构造了一条Example
feature = {
'user_id': _int64_feature(user_id),
'city_id': _int64_feature(city_id),
'app_type': _int64_feature(app_type),
'viewd_pois': _int64list_feature(viewd_pois),
'avg_paid': _float_feature(avg_paid),
'comment': _bytes_feature(comment),
}
# 调用相关api将Example序列化为字节字符串
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString() # 样本的生产,这里展示了将2条样本数据写入到了TFRecord文件中
def write_demo(filepath):
with tf.python_io.TFRecordWriter(filepath) as writer:
writer.write(serialize_example(1, 10, 1, [658, 325], 36.3, "yummy food."))
writer.write(serialize_example(2, 20, 2, [897, 568, 126], 89.6, "nice place to have dinner."))
print "write demo data done." filepath = "testdata.tfrecord"
write_demo(filepath)

由以上代码可知,TFRecord的原理是:将每个样本传给serialize_example函数并输出字节字符串,再通过TFRecordWriter类写入TFRecord文件中,有多少个样本就会生成多少个字节字符串。

TFRecord的生成(大规模)

TFRecord的生成=spark DataFrame格式数据保存为tfrecords格式数据

from pyspark.sql.types import *
def main():
#从hive表中读取数据
df=spark.sql("""
select * from experiment.table""")
#tfrecords保存路径
path = "viewfs:///user/hadoop-hdp/ml/demo/tensorflow/data/tfrecord"
#将spark DataFrame格式数据转换为tfrecords格式数据
df.repartition(file_num).write \
.mode("overwrite") \
.format("tfrecords") \
.option("recordType", "Example")\
.save(path)
if __name__ == "__main__":
main()

TFRecord的读取

在模型训练的时候需要读取TFRecord文件,有三个步骤:

1、首先通过tf.data.TFRecordDataset() API读取TFRecord文件并创建dataset;

2、定义schema;

3、使用tf.parse_single_example() 按照schema解析dataset中每个样本;

schema的意义在于指定每个样本的每一列数据应该用哪一种特征解析函数去解析。

Tensorflow提供了三种解析函数:

1、tf.FixedLenFeature(shape,dtype,default_value):解析定长特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;

2、tf.VarLenFeature(dtype):解析变长特征,dtype:输入数据类型;

3、tf.FixedSequenceFeature(shape,dtype,default_value):解析定长序列特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;

代码如下:

def read_demo(filepath):
# 定义schema
schema = {
'user_id': tf.FixedLenFeature([], tf.int64),
'city_id': tf.FixedLenFeature([], tf.int64),
'app_type': tf.FixedLenFeature([], tf.int64),
'viewed_pois': tf.VarLenFeature(tf.int64),
'avg_paid': tf.FixedLenFeature([], tf.float32, default_value=0.0),
'comment': tf.FixedLenFeature([], tf.string, default_value=''),
} # 使用相关api,按照schema解析dataset中的样本
def _parse_function(example_proto):
return tf.parse_single_example(example_proto, schema) # 读取TFRecord文件来创建dataset
dataset = tf.data.TFRecordDataset(filepath)
#按照schema解析dataset中的每个样本
parsed_dataset = dataset.map(_parse_function)
#创建Iterator并迭代Iterator即可访问dataset中的样本
next = parsed_dataset.make_one_shot_iterator().get_next() # 这里直接利用session,打印dataset中的样本
with tf.Session() as sess:
while True:
try:
print sess.run(next)
except:
print "out of data"
break

其中,

tf.parse_single_example(
serialized,
features,
name=None,
example_names=None
)

参数:

  • serialized:序列化的Example。
  • features:一个字典,key是特征,value是FixedLenFeature/VarLenFeature/FixedSequenceFeature值。
  • name:此操作的名称(可选)。
  • example_names:(可选)标量字符串张量,关联的名称。

返回:

一个字典,key是特征,value是Tensor或Sparse Tensor值。

最后欢迎大家关注我的微信公众号:对白的算法屋(duibainotes),跟踪NLP、推荐系统和对比学习等机器学习领域前沿。

想进一步交流的同学也可以通过公众号加我的微信一同探讨技术问题,谢谢。

推荐阅读

TFRecord&tf.Example

tensorflow/ecosystem

Linkedin Spark-TFRecord

Tensorflow之TFRecord的原理和使用心得的更多相关文章

  1. 4. Tensorflow的Estimator实践原理

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  2. 3. Tensorflow生成TFRecord

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  3. tensorflow核心概念和原理介绍

    关于 TensorFlow TensorFlow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库. 节点(Nodes)在图中表示数学操作,图中的线(edges)则表示 ...

  4. Tensorflow 读写 tfrecord 文件(Python3)

    TensorFlow笔记博客:https://blog.csdn.net/xierhacker/article/category/6511974 写入tfrecord文件 import tensorf ...

  5. tensorflow制作tfrecord格式数据

    tf.Example msg tensorflow提供了一种统一的格式.tfrecord来存储图像数据.用的是自家的google protobuf.就是把图像数据序列化成自定义格式的二进制数据. To ...

  6. 《转》从系统和代码实现角度解析TensorFlow的内部实现原理 | 深度

    from https://www.leiphone.com/news/201702/n0uj58iHaNpW9RJG.html?viewType=weixin 摘要 2015年11月9日,Google ...

  7. tensorflow的tfrecord操作代码与数据协议规范

    tensorflow的数据集可以说是非常重要的部分,我认为人工智能就是数据加算法,数据没处理好哪来的算法? 对此tensorflow有一个专门管理数据集的方式tfrecord·在训练数据时提取图片与标 ...

  8. 移动端开发rem布局之less+媒体查询布局的原理步骤和心得

    rem即是以html文件中font-size的大小的倍数rem布局的原理:通过媒体查询设置不同屏幕宽度下的html的font-size大小,然后在css布局时用rem单位取代px,从而实现页面元素大小 ...

  9. 分享《机器学习实战基于Scikit-Learn和TensorFlow》中英文PDF源代码+《深度学习之TensorFlow入门原理与进阶实战》PDF+源代码

    下载:https://pan.baidu.com/s/1qKaDd9PSUUGbBQNB3tkDzw <机器学习实战:基于Scikit-Learn和TensorFlow>高清中文版PDF+ ...

随机推荐

  1. python验证码图片生成

    环境:win10(64位)+pycharm2018+pillow5.4+python3.7 对Django的跨站请求保护的有所了解的同学会知道{%csrf_token%}在实际上作用并不是那么大,只要 ...

  2. Python脚本:爬取天气数据并发邮件给心爱的Ta

    第一部分:爬取天气数据 # 在函数调用 get_weather(url = 'https://www.tianqi.com/foshan') 的 url中更改城市,foshan为佛山市 1 impor ...

  3. 犀牛Rhino 7.0中文版安装破解教程

    犀牛Rhino 7.0中文版是一款专业的.功能强大的三维建模软件,利用它可以创建.编辑.分析.提供.渲染.动画与转换NURBS线条.曲面.实体与多边形网格:它能轻易整合3DS MAX 与Softima ...

  4. Hyper-V下Internal vSwitch的配置和Linux虚拟机的SSH连接

    最近工作中要在Windows Server 2016/Hyper-V 10中运行Ubuntu16实例,需要制作出"即插即用"的镜像文件,也就是安装好后即可从外部SSH进去.之前我使 ...

  5. Mysql读写锁保姆级图文教程

    摘要:读锁会阻塞写,但是不会阻塞读,而写锁会把杜希俄都阻塞. 本文分享自华为云社区<Mysql保姆级读写锁图文教程丨[绽放吧!数据库]>,作者:Code皮皮虾 . 准备 创建mylock表 ...

  6. Spring Cloud 专题之六:bus

    书接上回: SpringCloud专题之一:Eureka Spring Cloud专题之二:OpenFeign Spring Cloud专题之三:Hystrix Spring Cloud 专题之四:Z ...

  7. Jarvis OJ部分逆向

    Jarvis OJ部分逆向题解 很久没有写博客了,前天上Jarvis OJ刷了几道逆向,保持了一下感觉.都是简单题目,写个writeup记录一下. easycrackme int __cdecl ma ...

  8. Solidity

    起因是Xenc师傅给我截了张图,我日 居然看不懂 ,一搜才知道,之前学的版本有些老了.. 这次学下新一点的记录下 HelloWorld pragma solidity ^0.6.0; // versi ...

  9. Vue2.x响应式原理

    一.回顾Vue响应式用法 ​ vue响应式,我们都很熟悉了.当我们修改vue中data对象中的属性时,页面中引用该属性的地方就会发生相应的改变.避免了我们再去操作dom,进行数据绑定. 二.Vue响应 ...

  10. 特殊回文数 BASIC-9

    特殊回文数 代码 import java.util.Scanner; /*123321是一个非常特殊的数,它从左边读和从右边读是一样的. 输入一个正整数n, 编程求所有这样的五位和六位十进制数, 满足 ...