tf.train.Supervisor可以简化编程,避免显示地实现restore操作.通过一个例子看.

import tensorflow as tf
import numpy as np
import os
log_path = r"D:\Source\model\linear"
log_name = "linear.ckpt"
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3 # Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b # Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) # Before starting, initialize the variables. We will 'run' this first.
saver = tf.train.Saver()
init = tf.global_variables_initializer() # Launch the graph.
sess = tf.Session()
sess.run(init) if len(os.listdir(log_path)) != 0: # 已经有模型直接读取
saver.restore(sess, os.path.join(log_path, log_name))
for step in range(201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
saver.save(sess, os.path.join(log_path, log_name))

这段代码是对tensorflow官网上的demo做一个微小的改动.如果模型已经存在,就先读取模型接着训练.tf.train.Supervisor可以简化这个步骤.看下面的代码.

import tensorflow as tf
import numpy as np
import os
log_path = r"D:\Source\model\supervisor"
log_name = "linear.ckpt"
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3 W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) saver = tf.train.Saver()
init = tf.global_variables_initializer() sv = tf.train.Supervisor(logdir=log_path, init_op=init) # logdir用来保存checkpoint和summary
saver = sv.saver # 创建saver
with sv.managed_session() as sess: # 会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
for i in range(201):
sess.run(train)
if i % 20 == 0:
print(i, sess.run(W), sess.run(b))
saver.save(sess, os.path.join(log_path, log_name))

sv = tf.train.Supervisor(logdir=log_path, init_op=init)会判断模型是否存在.如果存在,会自动读取模型.不用显式地调用restore.

参考资料

  1. tensorflow官方文档
  2. tensorflow学习笔记(二十二):Supervisor

tensorflow tf.train.Supervisor作用的更多相关文章

  1. tensorflow|tf.train.slice_input_producer|tf.train.Coordinator|tf.train.start_queue_runners

    #### ''' tf.train.slice_input_producer :定义样本放入文件名队列的方式[迭代次数,是否乱序],但此时文件名队列还没有真正写入数据 slice_input_prod ...

  2. tensorflow中的Supervisor

    tf.train.Supervisor()可以帮我们简化一些事情,可以保存模型参数和Summary,它有以下的作用: 1)自动去checkpoint加载数据或初始化数据 ,因此我们就不需要手动初始化或 ...

  3. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  4. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  5. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  6. TensorFlow 实战(二)—— tf.train(优化算法)

    Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...

  7. 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...

  8. tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程 ...

  9. 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

    import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...

随机推荐

  1. C++图像加Lidar点云转写rosbag

    近期需要处理一批Lidar+image的数据,拿到的是其他格式,但要转存成rosbag使用,参考部分网上做法,完成并记录. 1.Lidar处理 主要是将Lidar点云信息按点转为pcl::PointX ...

  2. 微信小程序api封装(promise)

    顺带这是我平时公司切换改变网络环境 直接上代码,我相信就可以懂了, //app.js function fetchApi(url, type, params, method) { return new ...

  3. [LC]35题 Search Insert Position (搜索插入位置)

    ①英文题目 Given a sorted array and a target value, return the index if the target is found. If not, retu ...

  4. suseoj 1210: 会场安排问题 (贪心)

    1210: 会场安排问题 时间限制: 1 Sec  内存限制: 128 MB提交: 1  解决: 1[提交][状态][讨论版][命题人:liyuansong] 题目描述 假设要在足够多的会场里安排一批 ...

  5. 【计算机网络】TCP基础知识详解

    1. TCP概念相关 [!NOTE] TCP(Transmission Control Protocol),又叫传输控制协议. TCP协议是面向连接的,可靠的,基于字节流的传输协议.在基于 TCP 进 ...

  6. TestNg练习001

    15分钟入门TestNG 阅读目录 TestNG介绍 在Eclipse中在线安装TestNG 在Eclipse中离线安装TestNg TestNG最简单的测试 TestNG的基本注解 TestNG中如 ...

  7. SpringSecurity退出功能实现的正确方式

    本文将介绍在Spring Security框架下如何实现用户的"退出"logout的功能.其实这是一个非常简单的功能,我见过很多的程序员在使用了Spring Security之后, ...

  8. 3个例子详解C++ 11 中push_back 和 emplace_back差异

    本文首发于个人博客https://kezunlin.me/post/b83bc460/,欢迎阅读最新内容! cpp11 push_back and emplace_back Guide case1 # ...

  9. windows 10上源码编译libjpeg-turbo和使用教程 | compile and use libjpeg-turbo on windows 10

    本文首发于个人博客https://kezunlin.me/post/83828674/,欢迎阅读! compile and use libjpeg-turbo on windows 10 Series ...

  10. numpy和matplotlib下载中出现的问题

    在安装numpy的时候遇到如下所示的错误: 经过几个小时的查找,最终发现是pygame的路径不对导致.将pygame的具体路径加上后,问题解决.实施如下:得出一个结论:路径很重要,千万得小心哦. 报错 ...