1.保存

将训练好的模型参数保存起来,以便以后进行验证或测试。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数经常会用到,max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐,如:

saver=tf.train.Saver(max_to_keep=0)

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,‘ckpt/mnist.ckpt',global_step=step)

第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

2.举例

import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver()
# defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b)) 

3.恢复

用saver.restore()方法恢复变量:

saver.restore(sess,'ckpt.model_checkpoint_path')

sess:表示当前会话,之前保存的结果将被加载入这个会话;

ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

转载:

【1】https://www.cnblogs.com/denny402/p/6940134.html

【2】https://blog.csdn.net/u011500062/article/details/51728830

TensorFlow:tf.train.Saver()模型保存与恢复的更多相关文章

  1. tensorflow的tf.train.Saver()模型保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...

  2. TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

      TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...

  3. 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

    save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...

  4. 一份快速完整的Tensorflow模型保存和恢复教程(译)(转载)

    该文章转自https://blog.csdn.net/sinat_34474705/article/details/78995196 我在进行图像识别使用ckpt文件预测的时候,这个文章给我提供了极大 ...

  5. 机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

    1. tf.train.Saver() tf.train.Saver()是一个类,提供了变量.模型(也称图Graph)的保存和恢复模型方法. TensorFlow是通过构造Graph的方式进行深度学习 ...

  6. TensorFlow进阶(六)---模型保存与恢复、自定义命令行参数

    模型保存与恢复.自定义命令行参数. 在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用.模型的保存和恢复也是通过tf.train.Saver类去实现,它主 ...

  7. TensorFlow使用记录 (九): 模型保存与恢复

    模型文件 tensorflow 训练保存的模型注意包含两个部分:网络结构和参数值. .meta .meta 文件以 “protocol buffer”格式保存了整个模型的结构图,模型上定义的操作等信息 ...

  8. 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

    import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...

  9. tensorflow tf.train.Supervisor作用

    tf.train.Supervisor可以简化编程,避免显示地实现restore操作.通过一个例子看. import tensorflow as tf import numpy as np impor ...

随机推荐

  1. 苹果电脑快速安装双系统 Winclone镜像包 Winclone安装Win7/Win8.1/Win10镜像

    原文:https://bbs.feng.com/read-htm-tid-9940193.html by:旋律 2015年及以后的安装win8.1及win10(不支持win7) 请根据自己的机型安装合 ...

  2. 【Linux】crontab 定时启动sh

    crontab -e 20点59分启动脚本 59 20 * * * sh /home/fzuir/xingye4crawl/endXingYe4Crawl.sh >/home/fzuir/xin ...

  3. 【Python爬虫】教务处模拟登陆

    Python2模拟登陆获取cookie import urllib import urllib2 import cookielib filename = 'cookie.txt' #声明一个Mozil ...

  4. maven 添加jetty 支持

    maven jetty run 即可 配置  <plugin> <groupId>org.mortbay.jetty</groupId> <artifactI ...

  5. Python3玩转单链表——逆转单向链表pythonic版

    [本文出自天外归云的博客园] 链表是由节点构成的,一个指针代表一个方向,如果一个构成链表的节点都只包含一个指针,那么这个链表就是单向链表. 单向链表中的节点不光有代表方向的指针变量,也有值变量.所以我 ...

  6. C#学习笔记(22)——C#创建文本文件txt并追加写入数据

    说明(2017-7-31 16:25:06): 1. 有两种办法,第一种是用FileStream创建txt,用StreamWriter写入数据,期间还要加上判断,是否存在这个txt文件,如果不存在就创 ...

  7. iOS 统计Xcode整个工程的代码行数

    小技巧5-iOS 统计Xcode整个工程的代码行数 1.打开终端 2.cd 空格 将工程的文件夹拖到终端上,回车,此时进入到工程的路径 此时已经进入到工程文件夹下 3.运行指令 a. find . - ...

  8. WCF终结点——终结点地址(EndpointAddress)

    终结点的地址的Uri属性作为终结点地址的唯一标示. 包括客户端终结点和服务端终结点. 一.服务端终结点: 服务端的终结点通过宿主的添加方法暴露出来,从而成为可以调用的资源. 下面是将服务绑定到宿主的代 ...

  9. B2C和B2B之间有多大差距

    从产品应用的角度,我们团队经历了企图将B2C系统套用到B2B业务流程上的阶段,对于自营业务这还勉强可以实施,但对于外部用户的实施难度就太大了,用户体验也不好.这个过程中,我只关注了技术范畴的迭代速度. ...

  10. Maven项目编译后classes文件中没有.xml问题

    在做spring+mybatiss时,自动扫描都配置正确了,却在运行时出现了如下错误.后来查看target/classes/.../dao/文件夹下,发现只有mapper的class文件,而没有xml ...