这篇文章介绍tf.estimator,一个高级TensorFlow API,可以极大简化机器学习编程。Estimators封装了下面几个活动。

  • 训练
  • 评估
  • 预测
  • 出口服务(export for serving)

可以使用tensorflow中自带的Estimators,也可以自定义Estimators。所有的Estimators,都继承自tf.estimator.Estimator类。

1. Estimators的优点

  • 可以在分布式多服务器环境下,无需修改代码运行基于Estimator的模型。可以运行Estimator-based模型在CPUs,GPUs, TPUs,无需重新编码模型。
  • 简化了模型开发者间的共享实现。
  • 开发高级的直观的代码。比使用低级API容易。
  • Estimators建立在tf.keras.layers,简化了自定义。
  • Estimators为你建立图
  • Estimators提供一个安全的分布式训练循环,控制如何和何时 1)建立图  2)初始变量  3)加载数据  4)处理异常  5)生成检查点文件和从错误中恢复   6)保存TensorBoard需要的summaries

当使用Estimators时,必须将数据输入管道和模型分开。这种分离简化了不同数据集上的实验。

2.  Pre-made Estimators

pre-made Estimators生成和管理tf.graph和tf.Session。并且只需要作出很小的代码改动,就能实验各种模型结构。下面以一个基于全链接,前馈神经网络训练分类模型为例。

pre-made Estimators 程序的结构

包含下面四个步骤:

1.     写一个或多个数据集导入函数。你肯呢个会生成一个函数用来导入训练集,另一个函数导入测试集。每个数据集导入函数必须包含两个对象:1)一个字典,keys是特征名,values是Tensors(或 Sparse Tensors)包含对应的特征数据。2)一个Tensor,包含一个或多个标签

代码的基本骨架如下:

def input_fn(dataset):
... # manipulate dataset, extracting the feature dict and the label
return feature_dict, label

2.定义特征列。每个tf.feature_column识别一个特征名,它的类型和任何输入预处理。下面的代码段生成三个特征列。前来那个个特征那个列简单地识别特征名和类型。第三个特征那个列指定一个lambda程序,伸缩原始数据。

# Define three numeric feature columns.
population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column('median_education',
normalizer_fn=lambda x: x - global_education_mean)

3.实例化相关pre-made Estimator.

# Instantiate an estimator, passing the feature columns.
estimator = tf.estimator.LinearClassifier(
feature_columns=[population, crime_rate, median_education],
)

4.调用训练,评估,或预测方法。

# my_training_set is the function created in Step 1
estimator.train(input_fn=my_training_set, steps=2000)

3. 自定义Estimators

无论是pre-made还是自定义Estimators,核心都是模型函数(建立图,用于训练,评估和预测),pre-made Estimators已经实现了这些。自定义Estimators需要自己实现。

推荐工作流

1.假设一个合适的pre-made Estimator存在,使用它建立你的地一个模型,使用结果建立baseline.

2.建立和测试你的整个管道,包括使用pre-made Estimator的整个代码的完整性和可靠性。

3.如果合适的可供替代的pre-made Estimators存在,运行实验,决定哪个pre-made Estimator产生最好的结果。

4.改善你的代码,建立自己的自定义Estimator.

4.  从Keras模型中生成Estimators

可以将存在的Keras模型转换成Estimators。这样可以使你的Keras模型拥有Estimator的优势,比如分布式训练。调用tf.keras.estimator.model_to_estimator.

# Instantiate a Keras inception v3 model.
keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3) # Treat the derived Estimator as you would with any other Estimator.
# First, recover the input name(s) of Keras model, so we can use them as the
# feature column name(s) of the Estimator input function:
keras_inception_v3.input_names # print out: ['input_1']
# Once we have the input name(s), we can create the input function, for example,
# for input(s) in the format of numpy ndarray:
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"input_1": train_data},
y=train_labels,
num_epochs=1,
shuffle=False)
# To train, we call Estimator's train function:
est_inception_v3.train(input_fn=train_input_fn, steps=2000)

Tensorflow Estimators的更多相关文章

  1. Convolutional Neural Network in TensorFlow

    翻译自Build a Convolutional Neural Network using Estimators TensorFlow的layer模块提供了一个轻松构建神经网络的高端API,它提供了创 ...

  2. Awesome TensorFlow

    Awesome TensorFlow  A curated list of awesome TensorFlow experiments, libraries, and projects. Inspi ...

  3. [TensorFlow] Creating Custom Estimators in TensorFlow

    Welcome to Part 3 of a blog series that introduces TensorFlow Datasets and Estimators. Part 1 focuse ...

  4. [TensorFlow] Introduction to TensorFlow Datasets and Estimators

    Datasets and Estimators are two key TensorFlow features you should use: Datasets: The best practice ...

  5. TensorFlow框架(5)之机器学习实践

    1. Iris data set Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理.Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集.数据集包含150个数据集,分为3类, ...

  6. TensorFlow 中文资源全集,官方网站,安装教程,入门教程,实战项目,学习路径。

    Awesome-TensorFlow-Chinese TensorFlow 中文资源全集,学习路径推荐: 官方网站,初步了解. 安装教程,安装之后跑起来. 入门教程,简单的模型学习和运行. 实战项目, ...

  7. TensorFlow tutorial

    代码示例来自https://github.com/aymericdamien/TensorFlow-Examples tensorflow先定义运算图,在run的时候才会进行真正的运算. run之前需 ...

  8. TensorFlow.org教程笔记(一)Tensorflow初上手

    本文同时也发布在自建博客地址. 本文翻译自www.tensorflow.org的英文教程. 本文档介绍了TensorFlow编程环境,并向您展示了如何使用Tensorflow解决鸢尾花分类问题. 先决 ...

  9. tensorflow estimator API小栗子

    TensorFlow的高级机器学习API(tf.estimator)可以轻松配置,训练和评估各种机器学习模型. 在本教程中,您将使用tf.estimator构建一个神经网络分类器,并在Iris数据集上 ...

随机推荐

  1. Android学习笔记_21_ViewFlipper使用详解 手势识别器

    一.介绍ViewFilpper类 1.1 屏幕切换 屏幕切换指的是在同一个Activity内屏幕见的切换,最长见的情况就是在一个FrameLayout内有多个页面,比如一个系统设置页面:一个个性化设置 ...

  2. 读取静态的json文件

    <!DOCTYPE html><html><head><meta http-equiv="Content-Type" content=&q ...

  3. 第一次写C语言小程序,可以初步理解学生成绩管理系统的概念

    1 成绩管理系统概述 1.1  管理信息系统的概念  管理信息系统(Management Information Systems,简称MIS),是一个不断发展的新型学科,MIS的定义随着科技的进步也在 ...

  4. Python基础—15-正则表达式

    正则表达式 应用场景 特定规律字符串的查找替换切割等 邮箱格式.URL.IP等的校验 爬虫项目中,特定内容的提取 使用原则 只要是能够使用字符串函数解决的问题,就不要使用正则 正则的效率较低,还会降低 ...

  5. js/jquery 禁用点击事件

    前言 工作中经常遇到这种情况:验证邮箱页面的重新发送需要在3分钟后才可以点击触发请求,所以在这之前需要禁用他的点击. 网上查了后有以下几种实现方法 1.css禁用鼠标点击事件 .disabled { ...

  6. 洛谷P1731 [NOI1999]生日蛋糕(爆搜)

    题目背景 7月17日是Mr.W的生日,ACM-THU为此要制作一个体积为Nπ的M层 生日蛋糕,每层都是一个圆柱体. 设从下往上数第i(1<=i<=M)层蛋糕是半径为Ri, 高度为Hi的圆柱 ...

  7. 最长公共子序列Lcs (51Nod - 1006)

    20180604   11:28   给出两个字符串A B,求A与B的最长公共子序列(子序列不要求是连续的).   比如两个串为:   abcicba abdkscab   ab是两个串的子序列,ab ...

  8. Mac系统配置php环境

    [写在前面——叨叨叨] -_-#急着配环境的同志们可以绕道.最近学校的实验室里接了一个小项目——考勤刷卡系统,利用RFID在硬件层获取学生卡的ID,通过wifi传输至服务器,进行考勤信息存储,手机端获 ...

  9. hibernate连接oracle数据库进行查询

    按主键查询 dao层 public Emp get(Serializable id){ //通过session的get方法根据加载指定对象 return (Emp)HibernateUtil.curr ...

  10. JS高级. 06 缓存、分析解决递归斐波那契数列、jQuery缓存、沙箱、函数的四种调用方式、call和apply修改函数调用方法

    缓存 cache 作用就是将一些常用的数据存储起来 提升性能 cdn //-----------------分析解决递归斐波那契数列<script> //定义一个缓存数组,存储已经计算出来 ...