使用TensorFlow高级别的API进行编程
这里涉及到的高级别API主要是使用Estimator类来编写机器学习的程序,此外你还需要用到一些数据导入的知识。
为什么使用Estimator
Estimator类是定义在tf.estimator.Estimator中的,你可以使用其中已经有的Estimator,叫做预创建的Estimator,也可以自定义Estimator。Estimator已经封装了训练(train),评估(evaluate),预测(predict),导出以供使用等方法。
此外,Estimator会为我们提供诸如图构建、创建session等管道工作,不用我们再做这些重复的工作。它还提供了安全的分布式训练循环。相比于低级的API,我们可以把大部分的时间和精力放在处理数据、训练模型、调整参数上面,而不是创建张量、构建图、使用session运行张量上面。
使用Estimator的步骤
1:需要编写一个数据输入的函数input_fn
input_fn是输入函数,这个函数的作用在于对数据进行预处理,并且在模型train,predict,evaluate的时候给模型送进去数据。所以input_fn主要作用的时机在模型训练、预测和评估的时候,在模型定义的时候不需要传入输入函数,而是传入一个预定义的特征列。可以使用系统自带的函数,可以编写自定义的输入函数。
使用系统自带的数据输入函数:
系统自带的输入函数为tf.estimator.inputs.numpy_input_fn,它的输入参数如下:
def numpy_input_fn(x,
y=None,
batch_size=128,
num_epochs=1,
shuffle=None,
queue_capacity=1000,
num_threads=1)
x为numpy数组或者numpy数组的字典,当为numpy数组的时候,这个数组被当做单一的特征对待。
一个例子如下,这个例子是tf.estimator Quickstart tutorial中的一段代码:
import numpy as np training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array(training_set.data)},
y=np.array(training_set.target),
num_epochs=None,
shuffle=True) classifier.train(input_fn=train_input_fn, steps=2000)
自定义导入数据的函数:
要自定义导入函数,要知道tensorflow中关于数据的概念,以及知道自定义的函数应该返回的值,下面我将梳理一下这里面的概念:
自定义函数的基本框架以及返回值
def my_input_fn():
# 在这里进行数据的预处理...
# ...返回两个值 1) 一个由特征列和包含特征的Tensors组成的映射(字典) 2) 一个包含labels的Tensor
return feature_cols, labels
自定义函数需要返回两个值,一个值是feature_cols,是一个字典,其中字典的key为特征的列名称,字典的value为包含特征值的Tensor对象。labels是一个包含标签值的Tensor对象。
tf.data.API对于数据的两个抽象:
使用tf.data.API来构建数据输入的管道,帮助我们导入数据,无论是图像,文本还是分布式的数据,都可以用它来完成。
一个抽象的概念是tf.data.Dataset,一个Dataset是一个数据集,它是由一系列的元素组成的,每个元素的类型都是相同的。其中每个元素包含一个或者多个Tensor对象。我们可以以两种方式来创建Dataset对象,一种方式是创建它的来源,比如使用Dataset.from_tensor_slices(),可以使用张量来创建Dataset对象,另外一种方式是运用转换的方式,可以将一个Dataset来变成另外一个Dataset,比如Dataset.batch()。
另外一个抽象的概念是tf.data.Iterator,它代表的是迭代器。表示的是如何从数据集里面取出元素,最简答的迭代器是单次迭代器,Dataset.make_one_shot_iterator()可以创建单次迭代器。创建迭代器以后,可以使用Iterator.get_next()来获取下一个元素。
其它的创建数据集的方法:
Dataset.from_tensor()创建一个Dataset,并将传入的Tensor当做一个元素。 Dataset.from_tensor_slices()会创建一个Dataset,并且将传入的Tensor在第0维上面切面,分成一些列的元素。还可以使用TFRecordDataset来获得磁盘上面TFRecord格式的数据。
其它的创建迭代器的方法:
除了dataset.make_one_shot_iterator()这种单次迭代器以外,你还可以创建可初始化、可重新初始化、可馈送迭代器。
导入数据集的基本的工作机制:
1:创建Dataset对象 –> 2:将Dataset进行转化 –> 3:创建迭代器 –> 4:用迭代器返回下一个元素。
下面用一个例子来说明一下:
from tensorflow.python.data import Dataset
import numpy as np
def my_input_fn(features, targets, batch_size=1, shuffle=True, num_epochs=None):
"""自定义的输入函数 Args:
features: 使用pandas中的DataFrame对象来表示的features
targets: 使用pandas的taFrame对象表示的targets
batch_size: 批次的大小
shuffle: 是否将数据进行重新打乱
num_epochs: 需要重复的epochs的数量,一个epochs代表一个训练周期. None = repeat indefinitely
Returns:
下一批次数据的元组 (features, labels)
""" # 将pandas对象转换为字典,其中字典的值为numpy的数组
features = {key:np.array(value) for key,value in dict(features).items()} # 创建一个Dataset,并且设置好批次和重复的次数
ds = Dataset.from_tensor_slices((features,targets)) # warning: 2GB limit
ds = ds.batch(batch_size).repeat(num_epochs) # 是否进行数据扰动
if shuffle:
ds = ds.shuffle(10000) # 返回下个批次的数据
features, labels = ds.make_one_shot_iterator().get_next()
return features, labels
上面自定义了数据导入的函数,使用Dataset.from_tensor_slices()来创建Dataset。然后使用batch、repeat、shuffle进行转换。 接着创建迭代器,并且获得下一个元素。
2:定义特征列
使用tf.feature_column来标识特征名称、类型和任何输入预处理。
特征列在原始数据和模型之间起到了连接的作用。在编写模型的时候需要预先确定输入数据的特征列。
比如包含经度和维度两个特征的特征列,它们都是数值类型,这个特征列在模型定义的时候需要传入:
import tensorflow as tf
longitude = tf.feature_column.numeric_column('longitude')
latitude = tf.feature_column.numeric_column('latitude')
feature_column = [longitude, latitude]

特征列在原始数据与模型所需的数据之间架起了桥梁。
3:实例化相关的预创建的Estimator
这个步骤就简单了,以深度学习模型为例,运用上面创建的经纬度特征列,使用10*10的隐层创建一个深度神经网络的回归模型:
hidden_units = [10, 10]
dnn_regressor = tf.estimator.DNNRegressor(
feature_columns=feature_columns,
hidden_units=hidden_units,
)
4:调用训练、评估或推理方法
使用上述创建的模型进行train、evaluate、predict操作。首先需要定理训练的输入函数,将训练集的特征和标签都传进去,然后开始训练,例子如下:
training_input_fn = lambda:my_input_fn(train_df, train_target_df)
dnn_regressor.train(
input_fn=training_input_fn,
steps=300
)
参考:
Estimator 高级的API,介绍了创建estimator的流程
导入数据 介绍了数据集,还有迭代器的知识
Building Input Functions with tf.estimator 讲解了如何定义输入函数
特征列 详细介绍了特征列,里面有9中特征列可以学习
google机器学习速成课程的神经网络简介 ,完整的机器学习过程
使用TensorFlow高级别的API进行编程的更多相关文章
- 使用TensorFlow低级别的API进行编程
Tensorflow的低级API要使用张量(Tensor).图(Graph).会话(Session)等来进行编程.虽然从一定程度上来看使用低级的API非常的繁重,但是它能够帮助我们更好的理解Tenso ...
- 谷歌开源的TensorFlow Object Detection API视频物体识别系统实现教程
视频中的物体识别 摘要 物体识别(Object Recognition)在计算机视觉领域里指的是在一张图像或一组视频序列中找到给定的物体.本文主要是利用谷歌开源TensorFlow Object De ...
- [Tensorflow] Object Detection API - predict through your exclusive model
开始预测 一.训练结果 From: Testing Custom Object Detector - TensorFlow Object Detection API Tutorial p.6 训练结果 ...
- 使用Tensorflow object detection API——环境搭建与测试
[软件环境搭建] 操作系统:windows 10 64位 内存:8G CPU:I7-6700 Tensorflow: 1.4 Python:3.5 Anaconda3 (64-bit) 以上环境搭建请 ...
- 【翻译】Keras.NET简介 - 高级神经网络API in C#
Keras.NET是一个高级神经网络API,它使用C#编写,并带有Python绑定,可以在Tensorflow.CNTK或Theano上运行.其关注点是实现快速实验.因为做好研究的关键是:能在尽可能短 ...
- Tensorflow object detection API(1)---环境搭建与测试
参考: https://blog.csdn.net/dy_guox/article/details/79081499 https://blog.csdn.net/u010103202/article/ ...
- 使用TensorFlow Object Detection API+Google ML Engine训练自己的手掌识别器
上次使用Google ML Engine跑了一下TensorFlow Object Detection API中的Quick Start(http://www.cnblogs.com/take-fet ...
- TensorFlow object detection API
cloud执行:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pet ...
- Tensorflow object detection API 搭建属于自己的物体识别模型
一.下载Tensorflow object detection API工程源码 网址:https://github.com/tensorflow/models,可通过Git下载,打开Git Bash, ...
随机推荐
- ajax局部刷新后里面的jquery事件失效的解决方法
live() 与bind()作用基本一样. 最重要区别:live()可以将事件绑定到当前和将来的元素(eg:为id=zy元素绑定点击事件,而当你用js动态生成一个节点并插入到dom文档结构中时,如果你 ...
- SPOJ 16549 - QTREE6 - Query on a tree VI 「一种维护树上颜色连通块的操作」
题意 有操作 $0$ $u$:询问有多少个节点 $v$ 满足路径 $u$ 到 $v$ 上所有节点(包括)都拥有相同的颜色$1$ $u$:翻转 $u$ 的颜色 题解 直接用一个 $LCT$ 去暴力删边连 ...
- git内部原理
Git 内部原理 无论是从之前的章节直接跳到本章,还是读完了其余章节一直到这——你都将在本章见识到 Git 的内部工作原理 和实现方式. 我们发现学习这部分内容对于理解 Git 的用途和强大至关重要. ...
- css-css背景
CSS 允许应用纯色作为背景,也允许使用背景图像创建相当复杂的效果 一:背景色background-color 属性 p {background-color: gray;} 二:背景图像 backgr ...
- opencv之dft及mat类型转换
跑实验时用到dft这个函数,根据教程,需要先将其扩充到最优尺寸,但我用逆变换后发现得到的mat的维数竟然不一样.因此还是不要扩展尺寸了. 参考:http://www.xpc-yx.com/2014/1 ...
- 应用服务器中对JDK的epoll空转bug的处理
原文链接:应用服务器中对JDK的epoll空转bug的处理 前面讲到了epoll的一些机制,与select和poll等传统古老的IO多路复用机制的一些区别,这些区别实质可以总结为一句话, 就是epol ...
- R语言学习笔记:choose、factorial、combn排列组合函数
一.总结 组合数:choose(n,k) —— 从n个中选出k个 阶乘:factorial(k) —— k! 排列数:choose(n,k) * factorial(k) 幂:^ 余数:%% 整数商: ...
- 【BZOJ】4671: 异或图
题解 写完之后开始TTTTTTT--懵逼 这道题我们考虑一个东西叫容斥系数啊>< 这个是什么东西呢 也就是\(\sum_{i = 1}^{m}\binom{m}{i}f_{i} = [m ...
- bzoj 1237 [SCOI2008]配对 贪心+dp
思路:dp[ i ] 表示 排序后前 i 个元素匹配的最小值, 我们可以发现每个点和它匹配的点的距离不会超过2,这样就能转移啦. #include<bits/stdc++.h> #defi ...
- 终端(terminal)、tty、shell、控制台(console)、bash之间的区别与联系
1.终端(terminal) 终端(termimal)= tty(Teletypewriter, 电传打印机),作用是提供一个命令的输入输出环境,在linux下使用组合键ctrl+alt+T打开的就是 ...