tf.train.Supervisor可以简化编程,避免显示地实现restore操作.通过一个例子看.

import tensorflow as tf
import numpy as np
import os
log_path = r"D:\Source\model\linear"
log_name = "linear.ckpt"
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3 # Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b # Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) # Before starting, initialize the variables. We will 'run' this first.
saver = tf.train.Saver()
init = tf.global_variables_initializer() # Launch the graph.
sess = tf.Session()
sess.run(init) if len(os.listdir(log_path)) != 0: # 已经有模型直接读取
saver.restore(sess, os.path.join(log_path, log_name))
for step in range(201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
saver.save(sess, os.path.join(log_path, log_name))

这段代码是对tensorflow官网上的demo做一个微小的改动.如果模型已经存在,就先读取模型接着训练.tf.train.Supervisor可以简化这个步骤.看下面的代码.

import tensorflow as tf
import numpy as np
import os
log_path = r"D:\Source\model\supervisor"
log_name = "linear.ckpt"
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3 W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) saver = tf.train.Saver()
init = tf.global_variables_initializer() sv = tf.train.Supervisor(logdir=log_path, init_op=init) # logdir用来保存checkpoint和summary
saver = sv.saver # 创建saver
with sv.managed_session() as sess: # 会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
for i in range(201):
sess.run(train)
if i % 20 == 0:
print(i, sess.run(W), sess.run(b))
saver.save(sess, os.path.join(log_path, log_name))

sv = tf.train.Supervisor(logdir=log_path, init_op=init)会判断模型是否存在.如果存在,会自动读取模型.不用显式地调用restore.

参考资料

  1. tensorflow官方文档
  2. tensorflow学习笔记(二十二):Supervisor

tensorflow tf.train.Supervisor作用的更多相关文章

  1. tensorflow|tf.train.slice_input_producer|tf.train.Coordinator|tf.train.start_queue_runners

    #### ''' tf.train.slice_input_producer :定义样本放入文件名队列的方式[迭代次数,是否乱序],但此时文件名队列还没有真正写入数据 slice_input_prod ...

  2. tensorflow中的Supervisor

    tf.train.Supervisor()可以帮我们简化一些事情,可以保存模型参数和Summary,它有以下的作用: 1)自动去checkpoint加载数据或初始化数据 ,因此我们就不需要手动初始化或 ...

  3. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  4. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  5. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  6. TensorFlow 实战(二)—— tf.train(优化算法)

    Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...

  7. 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...

  8. tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程 ...

  9. 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

    import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...

随机推荐

  1. PHP队列的实现详细操作步骤

    队列是一种特殊的线性表,它只允许在表的前端,可以称之为front,进行删除操作:而在表的后端,可以称之为rear进行插入操作.队列和堆栈一样,是一种操作受限制的线性表,和堆栈不同之处在于:队列是遵循“ ...

  2. 三、netcore跨平台之 Linux配置nginx负载均衡

    前面两章讲了netcore在linux上部署以及配置nginx,并让nginx代理webapi. 这一章主要讲如何配置负载均衡,有些步骤在前两章讲的很详细了,所以这一章我就不会一个个截图了. 因为本人 ...

  3. 力扣(LeetCode)二进制求和 个人题解

    给定两个二进制字符串,返回他们的和(用二进制表示). 输入为非空字符串且只包含数字 1 和 0. 示例 1: 输入: a = "11", b = "1" 输出: ...

  4. CentOS7和Ubuntu下安装Docker & Docker-Compose

    本篇介绍如何在CentOS 7.6和Ubuntu 16.04下安装Docker & Docker-Compose. CentOS篇 安装Docker # cat /etc/redhat-rel ...

  5. oracle实现"limit"功能

    转载于http://blog.sina.com.cn/s/blog_67e2758d0100s3oc.html oracle数据库不支持mysql中limit功能,但可以通过rownum来限制返回的结 ...

  6. PowerMock学习(四)之Mock static的使用

    我们编写代码的时候,总会写一些工具类,为了方便调用喜欢使用static关键字来修饰对应方法. 那么现在举例说明,还是准备两个接口,第一个是查询学生总数,第二个是新增学生两个接口,具体示例代码如下: p ...

  7. Redux中间件Redux-thunk的配置

    当做固定写法吧 截图里少一个括号,已代码为主 import {createStore,applyMiddleware,compose} from 'redux' import thunk from ' ...

  8. 24 道 shell 脚本面试题

    想要成为中高级phper, shell 脚本是需要掌握的,它有助于你在工作环境中自动完成很多任务. 如下是一些面试过程中,经常会遇到的 shell 脚本面试问题及解答: Q:1 Shell脚本是什么. ...

  9. 看了这篇Redis,我以大专生的身份,进入了阿里,定级P7

    摘要: 前几天讲了Redis的面试知识点,当然那只是一部分,我相信各位在面试,或者实际开发过程中对缓存雪崩,穿透,击穿也不陌生吧,就算没遇到过但是你肯定听过,那三者到底有什么区别,我们又应该怎么去防止 ...

  10. mysql--时区问题(时间差8个小时?修改Mysql 时区)

    发现评论时间比本地时间晚8小时,原因:mysql默认时区选择了CST 解决办法: Ubuntu系统环境下: 1.检查mysql系统时区 进入mysql:mysql -u root -p mysql&g ...