TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training.

The BestCheckpointSaver is a wrapper around a tf.train.Saver.

The BestCheckpointSaver provides the ability to save the best n checkpoints, whereas the tf.train.Saver can only save the last n checkpoints.

Features

  • Save only best n checkpoints
  • Compares checkpoints based on a user-provided value
  • Can rank checkpoints by highest or lowest values
  • Automatically delete outdated checkpoints
  • Provide at a glance record of each checkpoint's associated value (the user-provided value obtained from that checkpoint)

Using the BestCheckpointSaver

from checkmate import BestCheckpointSaver

# ...build model...

best_ckpt_saver = BestCheckpointSaver(
save_dir=best_checkpoint_dir,
num_to_keep=3,
maximize=True
) # train and evaluate
for train_step in range(max_steps):
sess.run(train_op)
if train_step % evaluation_interval == 0:
accuracy = sess.run(eval_op, feed_dict=validation_data)
best_ckpt_saver.handle(accuracy, sess, global_step_tensor)

Loading the best checkpoint

import checkmate

# ...build model...

saver = tf.train.Saver()
saver.restore(sess, checkmate.get_best_checkpoint(best_checkpoint_dir, select_maximum_value=True))

At this stage, the module is no-frills with limited documentation. It is not intended to work in distributed settings or with complex Session/Graph management (i.e. the tf.Estimator framework). Contributions are welcome.

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model的更多相关文章

  1. 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

    save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...

  2. TensorFlow:tf.train.Saver()模型保存与恢复

    1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...

  3. tensorflow的tf.train.Saver()模型保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...

  4. 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

    import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...

  5. 机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

    1. tf.train.Saver() tf.train.Saver()是一个类,提供了变量.模型(也称图Graph)的保存和恢复模型方法. TensorFlow是通过构造Graph的方式进行深度学习 ...

  6. tf.train.Saver()

    1. 实例化对象 saver = tf.train.Saver(max_to_keep=1) max_to_keep: 表明保存的最大checkpoint文件数.当一个新文件创建的时候,旧文件就会被删 ...

  7. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  8. tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage

    在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于提高模型在测试数据上的健壮性(robustness). tenso ...

  9. TensorFlow 实战(二)—— tf.train(优化算法)

    Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...

随机推荐

  1. redhat_6.5下载地址

    redhat官方下载(需要注册帐号+订阅/申请试用方可下载) https://access.redhat.com/downloads/ 网络资源:附RHEL 6.5安装文件MD5及SHA-256:一. ...

  2. Linux CPU实时监控工具

    注:ubuntu需要安装sysstat: sudo apt install sysstat [root@testDb ~]# mpstat 1 10    ---显示操作系统内核版本 以及硬件配置   ...

  3. Android Design Support Library——Navigation View

    前沿 Android 从5.0开始引入了Material design元素的设计,这种新的设计语言让整个安卓的用户体验焕然一新,google在Android Design Support Librar ...

  4. 删除qq互联

    1.删除source/plugin 下的文件qqconnect文件夹 2.删除数据: DELETE FROM `dn_common_plugin` WHERE `identifier` = 'qqco ...

  5. hadoop家族成员

    1.概述 使用hadoop已经有一段时间了,从最开始懵懂到迷茫,再到各种阅读与写作,再到如今各种组合应用,逐渐已经离不开hadoop了,hadoop在大数据行业的成功,加速了它本身的发展,各大社区都能 ...

  6. spring boot 与 thymeleaf (2): 常用表达式

    在asp.net mvc 中, 有一个视图解析器, 可以支持Razor语法. 使用起来, 是非常的方便, 并且, 写在前台页面的后台方法, 是可调试的. 但是在java中, 目前我还没有接触到, 像. ...

  7. Solidity的delete操作

    Solidity中有个特殊的操作符delete用于释放空间,因为区块链技术做为一种公用资源,为避免大家滥用.且鼓励主动对空间的回收,释放空间将会返还一些gas. delete关键字的作用是对某个类型值 ...

  8. Tomcat学习总结(13)—— Tomcat常用参数配置说明

    1.修改端口号 Tomcat端口配置在server.xml文件的Connector标签中,默认为8080,可根据实际情况修改. 修改端口号 2.解决URL中文参数乱码 在server.xml文件的Co ...

  9. 了解MySQL联表查询中的驱动表,优化查询,以小表驱动大表

    一.为什么要用小表驱动大表 1.驱动表的定义 当进行多表连接查询时, [驱动表] 的定义为: 1)指定了联接条件时,满足查询条件的记录行数少的表为[驱动表] 2)未指定联接条件时,行数少的表为[驱动表 ...

  10. Spring Aop 注解方式参数传递

    v\:* {behavior:url(#default#VML);} o\:* {behavior:url(#default#VML);} w\:* {behavior:url(#default#VM ...