tensorflow的断点续训

2019-09-07

顾名思义,断点续训的意思是因为某些原因模型还没有训练完成就被中断,下一次训练可以在上一次训练的基础上继续训练而不用从头开始;这种方式对于你那些训练时间很长的模型来说非常友好。

如果要进行断点续训,那么得满足两个条件:

(1)本地保存了模型训练中的快照;(即断点数据保存)

(2)可以通过读取快照恢复模型训练的现场环境。(断点数据恢复)

这两个操作都用到了tensorflow中的train.Saver类。

1.tensorflow.trainn.Saver类

__init__(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
这里不对所有参数进行介绍,只介绍常用的参数
max_to_keep:允许保存的模型的个数,默认为5;当保存的个数超过5时,自动删除最旧的模型,以保证最多同时存在5个模型;如果设置为0或者None,则会对所有训练中的模型进行保存,但是这样除了多占硬盘外没什么意义。
其他的参数一般就使用默认值就可以了。
saver = tf.train.Saver(max_to_keep=10)

有机会再补充其他参数的用法。

2.断点数据的保存

使用saver对象的save方法即可保存模型:

save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True,
strip_default_attrs=False,
save_debug_info=False
)

常用参数:

sess:需要保存的会话,一般就是我们程序中的sess;

save_path:保存模型的文件路径以及名称,例如“ckpt/my_model”,注意如果要保存在ckpt文件夹下,那么需要在ckpt后面加个斜杠/;

global_step:训练次数,saver会自动将这个值加入到保存的文件名字中。

saver.save(sess,"my_model",global_step=1)
saver.save(sess,"my_model",global_step=100)
saver.save(sess,"ckpt/my_model",global_step=1)

其中1,2,3行代码分别会:

1:在代码的路径下生成名为“my_model_1文件”;

2:在代码的路径下生成名为“my_model_100文件”;

3:在ckpt文件夹下生成名为“my_model_1文件”。

最常见的用法:

for epoch in range(n_iter):
'''
training process
'''
saver.save(sess,ckpt_dir+"model_name",global_step=epoch)

其中ckpt_dir是断点数据存放的路径。

3.断点数据的恢复

3.1 只加载参数,不加载图

需要先建立一个与之前相同的模型;然后再检查有没有断点数据,如果有,则进行恢复。

'''
模型图创建
'''
ckpt_dir = "ckpt/"
#创建Saver对象
saver = tf.train.Saver()
#如果有断点文件,读取最近的断点文件
ckpt = tf.train.latest_checkpoint(ckpt_dir) if ckpt != None:
saver.restore(sess,ckpt)

不需要提供模型的名字,tf.train.latest_checkpoint(ckpt_dir)会去ckpt_dir文件夹中自动寻找最新的模型文件。

这个方法要求模型图建立好之后才允许创建saver,然后进行变量恢复,否则会报错。

当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化。

3.2 图结构与参数都加载

不需要自己建立模型图了,全部靠加载:

import tensorflow as tf
#获取最新断点数据路径
ckpt = tf.train.latest_checkpoint("./ckpt/")
#加载图结构
saver = tf.train.import_meta_graph(ckpt+".meta") sess = tf.Session()
#加载参数
saver.restore(sess,ckpt)
#运行sess
sess.run(tf.get_default_graph().get_tensor_by_name("x:0"))

可以通过 tf.get_default_graph().get_tensor_by_name("x:0")获取模型节点,其中“x:0”是创建节点的时候节点的name。

4.模型文件解析

在程序训练过程中保存的模型文件如下图所示:

checkpoint文件会记录保存信息,通过它可以定位最新保存的模型;

.meta文件保存了当前图结构

.data文件保存了当前参数名和值

.index文件保存了辅助索引信息

至于文件名后面的数字表示的是模型训练的不同批次,我们一般只需要最新的那个;由于之前设置最多保存5个模型,所以批次号是从6开始的。

4.1 查看checkpoint

ckpt = tf.train.get_checkpoint_state("./ckpt/")
print(ckpt)

结果是文件的断点状态信息:

断点状态信息下有一个“model_checkpoint_path”属性,属性内容是最新的那个模型的路径,用str类型来表示;

ckpt.model_checkpoint_path

这个与tf.train.latest_checkpoint("./ckpt/")得出的结果是相同的,可以通过这个路径来加载模型参数。

4.2 通过data文件查看变量名和变量值

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file("./ckpt/model.ckpt-10",None,True)
print_tensors_in_checkpoint_file中输入的第一个参数即上一节中获取到的模型路径;结果会以字典的形式展现出来。

4.3 通过meta文件加载图结构

saver = tf.train.import_meta_graph('./ckpt/model.ckpt-10.meta')

注意这里的参数是完整的路径加上meta文件的文件名,后面需要加上“.meta”。

返回的是一个saver对象,这个对象中包含了之前模型的图结构。

tensorflow的断点续训的更多相关文章

  1. Keras模型训练的断点续训、早停、效果可视化

    训练:model.fit()函数 fit(x=None, y=None, batch_size=None, epochs=, verbose=, callbacks=None, validation_ ...

  2. curl断点续载

    摘自http://blog.csdn.net/zmy12007/article/details/37157297 摘自http://www.linuxidc.com/Linux/2014-10/107 ...

  3. 关于视频断点续播和H5的本地存储

    前段时间,需要在下实现一个视频的断点续播功能,呃,我不会呀,这就很尴尬了.然后呢,在下就想起了一个叫做localStorage的东西.这是个什么东西呢?在网上查阅了一些资料后,在下发现这是webSto ...

  4. scrapy爬虫之断点续爬和多个spider同时爬取

    from scrapy.commands import ScrapyCommand from scrapy.utils.project import get_project_settings #断点续 ...

  5. 迁移学习——使用Tensorflow和VGG16预训模型进行预测

    使用Tensorflow和VGG16预训模型进行预测 from:https://zhuanlan.zhihu.com/p/28997549   fast.ai的入门教程中使用了kaggle: dogs ...

  6. python3.6 单文件爬虫 断点续存 普通版 文件续存方式

    # 导入必备的包 # 本文爬取的是顶点小说中的完美世界为列.文中的aa.text,bb.text为自己创建的text文件 import requests from bs4 import Beautif ...

  7. Electron 的断点续下载

    最近用 Electron 做了个壁纸程序,需要断点续下载,在这里记录一下. HTTP断点下载相关的报文 Accept-Ranges 告诉客户端服务器是否支持断点续传,服务器返回 Content-Ran ...

  8. HTML 5 断点续上传

    断点上传,java里面比较靠谱一点的,一般都会选用Flex.我承认,Flex只是摸了一下,不精通.HTML 5 有个Blob对象(File对象继承它),这个对象有个方法slice方法,可以对一个文件进 ...

  9. scrapy 断点续爬

    第一步:安装berkeleydb数据库 第二部:pip install bsddb3 第三部:pip install scrapy-deltafetch 第四部: settings.py设置 SPID ...

  10. python下载mp4 同步和异步下载支持断点续下

    Range 用于请求头中,指定第一个字节的位置和最后一个字节的位置,一般格式: Range:(unit=first byte pos)-[last byte pos] Range 头部的格式有以下几种 ...

随机推荐

  1. 原创分享 HubbleDotNet 最新绿色版,服务端免安装,基于eaglet 最后V1.2.8.9版本开发,bug修正,支持一键生成同步表

    HubbleDotNet 是一个基于.net framework 的开源免费的全文搜索数据库组件.开源协议是 Apache 2.0.HubbleDotNet提供了基于SQL的全文检索接口,使用者只需会 ...

  2. 网络很慢mtu设置

    [root@db-***** etc]# cat /etc/rc.local #!/bin/sh # # This script will be executed *after* all the ot ...

  3. 微信公众号 H5授权登录

    首先微信公众号 必须是服务号,订阅号没有 "网页授权获取用户基本信息" 没有这个权限.服务号也必须认证后才有这个权限

  4. struts2 显示表格

    <%@ taglib uri="/struts-tags" prefix="s"%> <h3>All Records:</h3&g ...

  5. 项目启动报错:关于modals以及node版本相关

    programme1: 1.代码用master分支的. 2. 删除node_module ,  yarn lock 文件,package-lock文件. 3. 用 npm install 或者 yar ...

  6. ZIP文件操作工具类

    2 3 import lombok.extern.slf4j.Slf4j; 4 import org.apache.commons.io.FilenameUtils; 5 6 import java. ...

  7. 2022-04-11内部群每日三题-清辉PMP

    1.项目经理从制造商那里收到一个更新信息,说一个必要的设备修理可能会导致他们的可交付成果迟八周时间.项目经理应该怎么做? A.确定关键路径 B.实施沟通管理计划 C.执行假设情景分析 D.对项目进度赶 ...

  8. 每日一抄 Go语言封装qsort快速排序函数

    package qsort /* <GO语言高级编程>设计中案例,仅作为笔记进行收藏. qsort快速排序函数是C语⾔的⾼阶函数,⽀持⽤于⾃定义排序⽐较函数,可以对任意类型的数组进⾏排序. ...

  9. Swift中 堆(heap)和栈(stack)的区别

    1.内存空间分为堆空间和栈空间 2.堆->引用类型(对象.函数.闭包)  栈->值类型(结构体.枚举.元组) 3.值类型赋值->深拷贝  引用类型赋值->浅拷贝 let a = ...

  10. 开发谷歌插件--web3钱包(一)

    之前开发了一款谷歌插件,因为很简单没有什么好记录的. 这次记录下一款新的钱包功能的插件,其中遇到的问题,以及解决方案. 首先遇到的问题就是唤醒: 小狐狸钱包应该都用过,点击图标就会唤起登录页面(pop ...