【转】tf.SessionRunHook使用方法
原文地址:https://blog.csdn.net/mrr1ght/article/details/81011280 。本文有删减。
tf.train.SessionRunHook()是一个类;用来定义Hooks;
Hooks是什么,官方文档中关于training hooks的定义是:
Hooks are tools that run in the process of training/evaluation of the model.
Hooks是在模型训练/测试过程中的工具。Pytorch中也经常会有这个概念出现,其实也就跟keras里的callbacks一样,hook和callback都是在训练过程中执行特定的任务。
例如判断是否需要停止训练的EarlyStopping;改变学习率的LearningRateScheduler,他们都有一个共性,就是在每个step开始/结束或者每个epoch开始/结束时需要执行某个操作。如每个epoch结束都保存一次checkpoint;每个epoch结束时都判断一次loss有没有下降,如果loss没有下降的轮数大于提取设定的阈值,就终止训练。当然以上的功能我们都可以自己完全重头实现。但是这些keras和tersorflow提供了更好的工具就是hook和callback,并且一些常用的功能都已经实现好了。说到底每个hook和callback都是按照固定格式定义了在每个step开始/结束要执行的操作,每个epoch开始/结束执行的操作。
Hooks都是继承自父类tf.train.SessionRunHook(),首先看一下这个父类的定义源码;
tf.train.SessionRunHook()定义
tf.train.SessionRunHook()类定义在tensorflow/python/training/session_run_hook.py,类中每个函数的作用与什么时候调用都已加入函数注释中;
class SessionRunHook(object):
"""Hook to extend calls to MonitoredSession.run()."""
def begin(self):
"""再创建会话之前调用
调用begin()时,default graph会被创建,
可在此处向default graph增加新op,begin()调用后,default graph不能再被修改
"""
pass
def after_create_session(self, session, coord): # pylint: disable=unused-argument
"""tf.Session被创建后调用
调用后会指示所有的Hooks有一个新的会话被创建
Args:
session: A TensorFlow Session that has been created.
coord: A Coordinator object which keeps track of all threads.
"""
pass
def before_run(self, run_context): # pylint: disable=unused-argument
"""调用在每个sess.run()执行之前
可以返回一个tf.train.SessRunArgs(op/tensor),在即将运行的会话中加入这些op/tensor;
加入的op/tensor会和sess.run()中已定义的op/tensor合并,然后一起执行;
Args:
run_context: A `SessionRunContext` object.
Returns:
None or a `SessionRunArgs` object.
"""
return None
def after_run(self,
run_context, # pylint: disable=unused-argument
run_values): # pylint: disable=unused-argument
"""调用在每个sess.run()之后
参数run_values是befor_run()中要求的op/tensor的返回值;
可以调用run_context.qeruest_stop()用于停止迭代
sess.run抛出任何异常after_run不会被调用
Args:
run_context: A `SessionRunContext` object.
run_values: A SessionRunValues object.
"""
pass
def end(self, session): # pylint: disable=unused-argument
"""在会话结束时调用
end()常被用于Hook想要执行最后的操作,如保存最后一个checkpoint
如果sess.run()抛出除了代表迭代结束的OutOfRange/StopIteration异常外,
end()不会被调用
Args:
session: A TensorFlow Session that will be soon closed.
"""
pass
tf.train.SessionRunHook()类中定义的方法的参数run_context,run_values,run_args,包含sess.run()会话运行所需的一切信息,
run_context:类tf.train.SessRunContext的实例run_values:类tf.train.SessRunValues的实例run_args:类tf.train.SessRunArgs的实例.
这三个类会在下面详细介绍
tf.train.SessionRunHook()的使用
(1)可以使用tf中已经预定义好的Hook,其都是tf.train.SessionRunHook()的子类;如
- StopAtStepHook:设置用于停止迭代的max_step或num_step,两者只能设置其一
- NanTensorHook:如果loss的值为Nan,则停止训练;
- tensorflow中有许多预定义的Hook,想了解更多的同学可以去官方文档tf.train.下查看
(2)也可用tf.train.SessionRunHook()定义自己的Hook,并重写类中的方法;然后把想要使用的Hook(预定义好的或者自己定义的)放到tf.train.MonitorTrainingSession()参数[Hook]列表中;
关于tf.train.MonitorTrainingSession()参见tf.train.MonitoredTrainingSession()解析。
给一个定义自己Hook的栗子,来自cifar10
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time#duration持续的时间
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
SessRunContext/SessRunValues/SessRunArgs
这三个类都服务于sess.run(),区别如下:
- tf.train.SessRunContext和tf.train.SessRunArgs提供会话运行所需的信息,
- tf.train.SessRunValues保存会话运行的结果
(1) tf.train.SessRunArgs类
提供给会话运行的参数,与sess.run()参数定义一样:
fethes,feeds,option
(2) tf.train.SessRunValues
用于保存sess.run()的结果,其中resluts是sess.run()返回值中对应于SessRunArgs()的返回值,
(3) tf.train.SessRunContext
SessRunContext包含sess.run()所需的一切信息
属性:
- original_args:sess.run所需的参数,是一个tf.train.SessRunArgs实例
- session:指定要运行的会话
- stop_request:返回一个bool值,用于判断是否停止迭代;
方法:
equest_stop(): 设置_stop_request值为True
cifar10 中的运用实例
tf.train.SessionRunHook()和tf.train.MonitorTrainingSession()一般一起使用,下面是cifar10中的使用实例
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time#duration持续的时间
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
#monitored 被监控的
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
【转】tf.SessionRunHook使用方法的更多相关文章
- tf.truncated_normal和tf.random_normal使用方法的区别
1.tf.truncated_normal使用方法 tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=No ...
- Android5.0以后版本把应用移动到SD或者TF卡的方法
由于手机内存较小,才8G,用的时间一久,内部存储就满了,天天删垃圾,WIFI还老断线,终于忍无可忍了,决定把应用移动到SD卡,实践下来,只有少部分App默认支持移动到SD卡,大部分程序不支持只能装在内 ...
- TensorFlow使用记录 (二): 理解tf.nn.conv2d方法
方法定义 tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC&quo ...
- 【TensorFlow】理解tf.nn.conv2d方法 ( 附代码详解注释 )
最近在研究学习TensorFlow,在做识别手写数字的demo时,遇到了tf.nn.conv2d这个方法,查阅了官网的API 发现讲得比较简略,还是没理解.google了一下,参考了网上一些朋友写得博 ...
- [tensorflow] tf.gather使用方法
tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来 import tensorflow as tf a = tf.Variable([[1,2,3,4,5], [6,7,8,9, ...
- ValueError:GraphDef cannot be larger than 2GB.解决办法
在使用TensorFlow 1.X版本的estimator的时候经常会碰到类似于ValueError:GraphDef cannot be larger than 2GB的报错信息,可能的原因是数据太 ...
- 机器学习笔记5-Tensorflow高级API之tf.estimator
前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...
- 文本分类学习(三) 特征权重(TF/IDF)和特征提取
上一篇中,主要说的就是词袋模型.回顾一下,在进行文本分类之前,我们需要把待分类文本先用词袋模型进行文本表示.首先是将训练集中的所有单词经过去停用词之后组合成一个词袋,或者叫做字典,实际上一个维度很大的 ...
- ROS tf
一.节点中使用(cpp,python) 1. ros wiki 提供的tutorials 2. https://blog.csdn.net/start_from_scratch/article/det ...
随机推荐
- Xcode一个project多个target
project添加target https://blog.csdn.net/vbirdbest/article/details/53466009 https://www.cnblogs.com/Bob ...
- 用Python 绘制分布(折线)图
用Python 绘制分布(折线)图,使用的是 plot()函数. 一个简单的例子: # encoding=utf-8 import matplotlib.pyplot as plt from pyla ...
- TP-LINK WR703N OpenWrt 无线配网历程
① 创建了两个 Interfaces,名字分别为 lan.wlan0 (可自行设定),一个负责连接 PPPoE,一个负责提供 AP 热点. ② 配置 wlan0 相关 ip 地址,该地址为无线网内网地 ...
- mke2fs和mkfs命令使用
1.mke2fs命令 在Linux系统下,mke2fs命令可用于创建磁盘分区上的”ext2/ext3”文件系统. (1)语法 mke2fs(选项)(参数) (2)常用选项 -b<区块大小> ...
- Win10 系统直接在目录下打开cmd
每次用cmd命令,就要定位到当前文件夹,很麻烦 这里介绍一种直接定位到要操作的文件夹的方法: 操作步骤: (1)选择要cmd的文件夹,按住Shift键,鼠标右键快捷方式,先打开Powershell窗口 ...
- 分享一个Linux C++消息通信框架TCPSHM
由于本人从事行业关系,Linux环境下的低延迟通信是我关注的技术之一.要达到极端的低延迟,当然同机器内IPC比网络通信快,而Linux IPC方式中无疑是共享内存延迟最低.不过相对于TCP这种通用的通 ...
- 啊哈!算法(第一章)C#实现
第1节 最简单的排序--桶排序 期末考试完了老师要将同学们的分数按照从高到低排序. 小哼的班上只有 5 个同学,这 5 个同学分别考了 5 分.3 分.5 分.2 分和 8 分,考得真是惨不忍 ...
- django开发_七牛云图片管理
七牛云注册 https://www.qiniu.com/ 实名认证成功之后,赠送10G存储空间 复制粘贴AK和SK 创建存储空间,填写空间名称,选择存储区域.访问控制选择位公开空间 获取测试域名 七牛 ...
- Dubbo面试踩坑
1.Dubbo支持哪些协议,每种协议的应用场景,优缺点? dubbo: 单一长连接和NIO异步通讯,适合大并发小数据量的服务调用,以及消费者远大于提供者.传输协议TCP,异步,Hessian序列化: ...
- 《MySQL实战45讲》学习笔记2——MySQL的日志系统
一.日志类型 逻辑日志:存储了逻辑SQL修改语句 物理日志:存储了数据被修改的值 二.binlog 1.定义 binlog 是 MySQL 的逻辑日志,也叫二进制日志.归档日志,由 MySQL Ser ...