TensorFlow 便捷的实现机器学习 三

MNIST
卷积神经网络
Fly

Overview


Iris花瓣分类中,运行结果后,我们最后只是知道一个最终的结果:

Accuracy: 0.933333
Predictions: [1 2]

我们并不能知道tensorflow执行的过程中发生了写什么。
有一种方式是通过在训练过程中通过多次fit来一步一步的得到结果,但是这种方式会大大的影响执行效率。我们可以使用==tf.contrib.learn提供的Monitor API工具来实现监控。下面主要学习的时候启动Logging一级TensorBoard来对实现过程做一个监控。

Enabling Logging with TensorFlow


TensorFlow提供了五个等级的日记记录,分别是:

  1. DEBUG
  2. INFO
  3. WARN
  4. ERROR
  5. FATAL

默认的情况下,TensorFlow主要设置为WARN等级。我们可以自行调整结果

tf.logging.set_verbosity(tf.logging.INFO)

这样子在运行程序的话,就会看到如下信息:

INFO:tensorflow:Training steps [0,200)
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Step 1: loss_1:0 = 1.48073
INFO:tensorflow:training step 100, loss = 0.19847 (0.001 sec/batch).
INFO:tensorflow:Step 101: loss_1:0 = 0.192693
INFO:tensorflow:Step 200: loss_1:0 = 0.0958682
INFO:tensorflow:training step 200, loss = 0.09587 (0.003 sec/batch).

Configuring a ValidationMonitor for Streaming Evaluation


记录训练损失有助于了解你的模型是否融合,但如果你想进一步了解培训期间发生了什么,你有该怎么办?tf.contrib.learn提供了几个高级的Monitor,您可以附加到您的合适的操作,以进一步跟踪指标和/或调试较低级别的TensorFlow操作在模型训练,主要包括:

Monitor Description
CaptureVariable Saves a specified variable's values into a collection at every n steps of training
PrintTensor Logs a specified tensor's values at every n steps of training
SummarySaver Saves Summary protocol buffers for a given tensor using a SummaryWriter at every n steps of training
ValidationMonitor Logs a specified set of evaluation metrics at every n steps of training, and, if desired, implements early stopping under certain conditions

Evaluating Every N Steps

在设置校验ValidationMonitor的时候,你也许想看看这个模型的泛化程度,这个时候你就可以通过设置(test_set.data and test_set.target),以及显示的频率来查看:

validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
test_set.data,
test_set.target,
every_n_steps=50)

然后将这个行代码放在实例化的classifier之前。ValidationMonitor依赖保存的检查点来执行评估操作,因此您需要修改分类器的实例化以添加包含save_checkpoints_secs的RunConfig,它指定在训练期间在检查点保存之间应该经过多少秒。
classifier可是如下设置:

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model",
config=tf.contrib.learn.RunConfig(
save_checkpoints_secs=1))

然后,再讲设置好的validation_monitor放进去

classifier.fit(x=training_set.data,
y=training_set.target,
steps=2000,
monitors=[validation_monitor])

到此,就可以运行代码,然后就能看到:

INFO:tensorflow:Validation (step 50): loss = 1.71139, global_step = 0, accuracy = 0.266667
...
INFO:tensorflow:Validation (step 300): loss = 0.0714158, global_step = 268, accuracy = 0.966667
...
INFO:tensorflow:Validation (step 1750): loss = 0.0574449, global_step = 1729, accuracy = 0.966667

Customizing the Evaluation Metrics

默认情况下,如果未指定评估指标,ValidationMonitor将同时记录损失和精确度,但您可以自定义每隔50个步骤运行的指标列表。tf.contrib.metrics模块为您可以与ValidationMonitor一起使用的分类模型提供各种其他度量功能,包括streaming_precision和streaming_recall。要指定要在每个评估传递中运行的确切指标,请向ValidationMonitor构造函数中添加一个指标参数。指标采用键/值对的dict,其中每个键是您要为该指标记录的名称,相应的值是计算它的函数。

按照如下方式修改ValidationMonitor构造函数,以添加精度和回调的记录,以及精度(损失总是记录,不需要明确指定):

validation_metrics = {"accuracy": tf.contrib.metrics.streaming_accuracy,
"precision": tf.contrib.metrics.streaming_precision,
"recall": tf.contrib.metrics.streaming_recall}
validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
test_set.data,
test_set.target,
every_n_steps=50,
metrics=validation_metrics)

Early Stopping with ValidationMonitor

注意,在上述对数输出中,通过步骤150,模型已经实现了1.0的精确度和召回率。这提出了一个问题,即模型训练是否可以从早期停止中受益。除了记录eval指标,ValidationMonitor使得在满足指定条件时容易实现提前停止,通过如下参数:

Param Description
early_stopping_metric Metric that triggers early stopping (e.g., loss or accuracy) under conditions specified in early_stopping_rounds and early_stopping_metric_minimize. Default is "loss".
early_stopping_metric_minimize True if desired model behavior is to minimize the value of early_stopping_metric; False if desired model behavior is to maximize the value of early_stopping_metric. Default is True.
early_stopping_rounds Sets a number of steps during which if the early_stopping_metric does not decrease (if early_stopping_metric_minimize is True) or increase (if early_stopping_metric_minimize is False), training will be stopped. Default is None, which means early stopping will never occur.

我们可以做如下设置:

validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
test_set.data,
test_set.target,
every_n_steps=50,
metrics=validation_metrics,
early_stopping_metric="loss",
early_stopping_metric_minimize=True,
early_stopping_rounds=200)

这样就会提前停止,而不需要到2000步,结果如下:

...
INFO:tensorflow:Validation (step 1450): recall = 1.0, accuracy = 0.966667, global_step = 1431, precision = 1.0, loss = 0.0550445
INFO:tensorflow:Stopping. Best step: 1150 with loss = 0.0506100878119.

实际上,这里的训练在步骤1450停止,指示对于过去200个步骤,损失没有减少,并且总体来说,步骤1150针对测试数据集产生最小损失值。这表明通过减少步数来额外校准超参数可以进一步改善模型。

Visualizing Log Data with TensorBoard


通过阅读ValidationMonitor生成的日志,可以在训练期间提供大量有关模型性能的原始数据,但也可以查看此数据的可视化,以便进一步了解趋势,例如,精确度如何更改步数。您可以使用TensorBoard(与TensorFlow一起打包的单独程序)通过将logdir命令行参数设置为保存模型训练数据的目录(此处为/ tmp / iris_model)来绘制这样的图。在命令行上运行以下命令:

$ tensorboard --logdir=/tmp/iris_model/
Starting TensorBoard 22 on port 6006
(You can navigate to http://0.0.0.0:6006)

然后在浏览器中加载提供的URL(此处为http://0.0.0.0:6006)。就可以可视化查看结果了。

reference

[1] https://www.tensorflow.org/tutorials/monitors/

TensorFlow 便捷的实现机器学习 三的更多相关文章

  1. [译]与TensorFlow的第一次接触(三)之聚类

    转自 [译]与TensorFlow的第一次接触(三)之聚类 2016.08.09 16:58* 字数 4316 阅读 7916评论 5喜欢 18 前一章节中介绍的线性回归是一种监督学习算法,我们使用数 ...

  2. 《转载》python/人工智能/Tensorflow/自然语言处理/计算机视觉/机器学习学习资源分享

    本次分享一部分python/人工智能/Tensorflow/自然语言处理/计算机视觉/机器学习的学习资源,也是一些比较基础的,如果大家有看过网易云课堂的吴恩达的入门课程,在看这些视频还是一个很不错的提 ...

  3. TensorFlow框架(5)之机器学习实践

    1. Iris data set Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理.Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集.数据集包含150个数据集,分为3类, ...

  4. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  5. 机器学习与Tensorflow(1)——机器学习基本概念、tensorflow实现简单线性回归

    一.机器学习基本概念 1.训练集和测试集 训练集(training set/data)/训练样例(training examples): 用来进行训练,也就是产生模型或者算法的数据集 测试集(test ...

  6. Hands on Machine Learning with sklearn and TensorFlow —— 一个完整的机器学习项目(加州房地产)

    数据集地址:https://github.com/ageron/handson-ml/tree/master/datasets 先行知识准备:NumPy,Pandas,Matplotlib的模块使用 ...

  7. ubuntu16.04安装tensorflow官方教程与机器学习资料【学习笔记】

    tensorflow官网有官方的安装教程:https://www.tensorflow.org/install/install_linux google的机器学习官方快速入门教程:https://de ...

  8. 机器学习 (三) 逻辑回归 Logistic Regression

    文章内容均来自斯坦福大学的Andrew Ng教授讲解的Machine Learning课程,本文是针对该课程的个人学习笔记,如有疏漏,请以原课程所讲述内容为准.感谢博主Rachel Zhang 的个人 ...

  9. 机器学习(三) Jupyter Notebook, numpy和matplotlib的详细使用 (上)

    工欲善其事,必先利其器.在本章,我们将学习和机器学习相关的基础工具的使用:Jupyter Notebook, numpy和matplotlib.大多数教程在讲解机器学习的时候,大量使用这些工具,却不对 ...

随机推荐

  1. 洛谷 P3515 [ POI 2011 ] Lightning Conductor —— 决策单调性DP

    题目:https://www.luogu.org/problemnew/show/P3515 决策单调性... 参考TJ:https://www.cnblogs.com/CQzhangyu/p/725 ...

  2. python 内存泄露的诊断

    对于一个用 python 实现的,长期运行的后台服务进程来说,如果内存持续增长,那么很可能是有了"内存泄露" 一.内存泄露的原因 对于 python 这种支持垃圾回收的语言来说,怎 ...

  3. Coursera Algorithms week2 基础排序 练习测验: Intersection of two sets

    题目原文: Given two arrays a[] and b[], each containing n distinct 2D points in the plane, design a subq ...

  4. HDU3085 Nightmare Ⅱ

    题目: Last night, little erriyue had a horrible nightmare. He dreamed that he and his girl friend were ...

  5. POJ 1523 Tarjan求割点

    SPF Description Consider the two networks shown below. Assuming that data moves around these network ...

  6. Linux系统下通过命令行对mysql数据进行备份和还原

    一.备份 1.进入mysql目录 cd /var/lib/mysql (进入mysql目录,根据安装情况会有差别) 2.备份 mysqldump -u root -p密码 数据库名 数据表名 > ...

  7. 各个数据库中,查询前n条记录的方法

    SQL查询前10条的方法为: 1.select top X *  from table_name --查询前X条记录,可以改成需要的数字,比如前10条. 2.select top X *  from  ...

  8. 【sqli-labs】 less17 POST - Update Query- Error Based - String (基于错误的更新查询POST注入)

    这是一个重置密码界面,查看源码可以看到username作了防注入处理 逻辑是先通过用户名查出数据,在进行密码的update操作 所以要先知道用户名,实际情况中可以注册用户然后实行攻击,这里先用admi ...

  9. (转)Arcgis for Js之GeometryService实现测量距离和面积

    http://blog.csdn.net/gisshixisheng/article/details/40540601 距离和面积的测量时GIS常见的功能,在本节,讲述的是通过GeometryServ ...

  10. Ubuntu 16.04安装和卸载软件命令

    安装软件 apt-get install softname1 softname2 softname3…… 卸载软件 apt-get remove softname1 softname2 softnam ...