tensorflow tf.train.Supervisor作用
tf.train.Supervisor可以简化编程,避免显示地实现restore操作.通过一个例子看.
import tensorflow as tf
import numpy as np
import os
log_path = r"D:\Source\model\linear"
log_name = "linear.ckpt"
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# Before starting, initialize the variables. We will 'run' this first.
saver = tf.train.Saver()
init = tf.global_variables_initializer()
# Launch the graph.
sess = tf.Session()
sess.run(init)
if len(os.listdir(log_path)) != 0: # 已经有模型直接读取
saver.restore(sess, os.path.join(log_path, log_name))
for step in range(201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
saver.save(sess, os.path.join(log_path, log_name))
这段代码是对tensorflow官网上的demo做一个微小的改动.如果模型已经存在,就先读取模型接着训练.tf.train.Supervisor可以简化这个步骤.看下面的代码.
import tensorflow as tf
import numpy as np
import os
log_path = r"D:\Source\model\supervisor"
log_name = "linear.ckpt"
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
sv = tf.train.Supervisor(logdir=log_path, init_op=init) # logdir用来保存checkpoint和summary
saver = sv.saver # 创建saver
with sv.managed_session() as sess: # 会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
for i in range(201):
sess.run(train)
if i % 20 == 0:
print(i, sess.run(W), sess.run(b))
saver.save(sess, os.path.join(log_path, log_name))
sv = tf.train.Supervisor(logdir=log_path, init_op=init)会判断模型是否存在.如果存在,会自动读取模型.不用显式地调用restore.
参考资料
tensorflow tf.train.Supervisor作用的更多相关文章
- tensorflow|tf.train.slice_input_producer|tf.train.Coordinator|tf.train.start_queue_runners
#### ''' tf.train.slice_input_producer :定义样本放入文件名队列的方式[迭代次数,是否乱序],但此时文件名队列还没有真正写入数据 slice_input_prod ...
- tensorflow中的Supervisor
tf.train.Supervisor()可以帮我们简化一些事情,可以保存模型参数和Summary,它有以下的作用: 1)自动去checkpoint加载数据或初始化数据 ,因此我们就不需要手动初始化或 ...
- Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析
觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...
- tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)
tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...
- tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数
tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...
- TensorFlow 实战(二)—— tf.train(优化算法)
Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...
- 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数
原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...
- tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数
tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程 ...
- 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑
import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...
随机推荐
- PHP经典算法题
1.百钱买百鸡 公鸡5文钱一只,母鸡3文钱一只,小鸡3只一文钱,用100文钱买一百只鸡,其中公鸡,母鸡,小鸡都必须要有,问公鸡,母鸡,小鸡要买多少只刚好凑足100文钱. 分析:估计现在小学生都能手工推 ...
- 【原创】使用批处理脚本生成包并自动上传到nuget
Hello 大家好,我是TANZAME,我们又见面了. NuGet 是什么这里就不再重复啰嗦,园子里一搜一大把.今天要跟大家分享的是,在日常开发过程中如何统一管理我们的包,如何通过批处理脚本生成包并自 ...
- JS简单循环遍历json数组的方法
例如数据库里面的json字符串是这样的 1 2 3 4 5 var str = '[{"name":"宗2瓜","num":"1& ...
- 获取jar包内部的资源文件
通常获取一个资源文件很简单,问题是对于jar包内的资源文件,可能会发生意外.假如这里有一个文件操作的类: public class FileLoader { public boolean exists ...
- 初识JVM内存模型
计算机内存模型 在程序运行时,CPU通过访问主存获取数据,但随着CPU的快速发展,CPU访问速度越来越高,硬件无法满足CPU的条件下,大多内存加入了高速缓存机制,不同CPU都有对应的多级(一般为三)缓 ...
- GitHub注册失败,卡在第一步
同事说他无法注册GitHub,我一开始以为GitHub又无法登录进去,我就登录了自己的GitHub账号,没有问题,可以登录啊,见第一个标签页.同一局域网,不可能我能登录,你无法完成注册啊.于是,我就在 ...
- Vue注册组件命名时不能用大写的原因浅析
命名使用注意事项: https://www.jb51.net/article/160227.htm
- 记一次LDAP主从同步配置
LDAP主从同步 OpenLDAP在2.3版本之前的同步复制带有一系列缺点如只支持一主多从模式等,在此缺点就不多说,下文着重介绍一下OpenLDAP V2.4以后的同步负复制功能 同步功能 2.4版最 ...
- 【绝对有收获】看看?必须告诉你为什么要使用MQ消息中间件(图解版)
欢迎关注文章系列 ,关注我 <提升能力,涨薪可待> <面试知识,工作可待> <实战演练,拒绝996> 也欢迎关注微信公众号[Ccww笔记],原创技术文章第一时间推出 ...
- ggforce|绘制区域轮廓-区域放大-寻找你的“onepiece”
首发于“生信补给站” https://mp.weixin.qq.com/s/fm69bw-3cww1YEW_kBcTHQ 更多关于R语言,ggplot2绘图,生信分析的内容,关注有惊喜