上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练。

在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量。

 #!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-22 22:34:16
############################ import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
print(v1)
result = v1 + v2
print(result) saver = tf.train.Saver([v1])#只有变量v1会被加载 with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))

执行上面的代码,会报错:tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2

字典一般方便实现变量重命名,因为在某些场景下,模型中变量的命名和当前需要加载的变量名称并不相同而且有时候对于那些TF自动生成的变量的名称太长不好表示,那么为了不导致加载模型时找不到变量的问题。

在上一篇博文中,保存的两个变量的名称为v1和v2。

 import tensorflow as tf
#保存或加载时给变量重命名
a1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
a2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = a1 + a2 #使用字典来重命名变量就可以加载原模型中的相应变量.如下指定了原来名称为v1的变量现在加载到变量a1中,原来名称为v2的变量现在加载到变量a2中
saver = tf.train.Saver({"v1":a1, "v2":a2})
#因为有时候模型保存时的变量名称和加载时的变量名称不一致,为了解决这个问题,TF可以通过字典将模型保存时的变量名和需要加载的变量关联起来. with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))

在使用滑动平均模型时,tf.train.ExponentialMovingAverage对每一个变量会维护一个影子变量(shadow variable),这个影子变量是TF自动生成的,那么为了方便加载使用影子变量,就可以使用变量重命名。

 #!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: saver_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:02:23
############################
import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v")
v2 = tf.Variable(0, dtype=tf.float32, name="v2")
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0 #在声明滑动平均模型后,TF会自动生成一个影子变量
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0
#v/ExponentialMovingAverage:0
#v2/ExponentialMovingAverage:0 print(ema.variables_to_restore())
#{'v2/ExponentialMovingAverage': <tf.Variable 'v2:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} saver = tf.train.Saver() with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op) sess.run(tf.assign(v, 10))
sess.run(tf.assign(v2, 10))
sess.run(maintain_averages_op) saver.save(sess, "moving_average.ckpt")
print(sess.run([v, ema.average(v)]))
#[10.0, 0.099999905]

滑动平均模型主要作用是为了增加模型的泛化性,可针对所有参数进行优化。

在TF中,每一个变量的滑动平均值是通过影子变量维护的,所以要获得变量的滑动平均值实际上就是获取这个变量的影子变量的值。如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时,就不需要再进行相应变量的滑动平均值的计算。

 #!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:51:31
############################ import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") saver = tf.train.Saver({"v/ExponentialMovingAverage":v})#通过变量重命名将原来的变量v的滑动平均值直接赋给变量v with tf.Session() as sess:
saver.restore(sess, "moving_average.ckpt")
print(sess.run(v))

源码路径:

https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/saver_ema.py

https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/restore_ema.py

tensorflow 之模型的保存与加载(二)的更多相关文章

  1. tensorflow 之模型的保存与加载(三)

    前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...

  2. tensorflow 之模型的保存与加载(一)

    怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...

  3. Python之TensorFlow的模型训练保存与加载-3

    一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...

  4. tensorflow模型的保存与加载

    模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...

  5. (sklearn)机器学习模型的保存与加载

    需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...

  6. pytorch_模型参数-保存,加载,打印

    1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...

  7. pytorch 中模型的保存与加载,增量训练

     让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...

  8. fashion_mnist多分类训练,两种模型的保存与加载

    from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python ...

  9. tensorflow1.0 模型的保存与加载

    import tensorflow as tf import numpy as np # ##Save to file # W = tf.Variable([[4,5,6],[7,8,9]],dtyp ...

随机推荐

  1. [Linux] ubuntu 安装 Wireshark

    Wireshark是一款非常流行的协议分析软件.自然可以网络抓包的需求. sudo apt-get install wireshark 出于安全方面的考虑,普通用户不能够打开网卡设备进行抓包,wire ...

  2. javascript 中contentWindow和 frames和iframe之间通信

    iframe父子兄弟之间通过jquery传值(contentWindow && parent),iframe的调用包括以下几个方面:(调用包含html dom,js全局变量,js方法) ...

  3. linux自定义开机启动服务和chkconfig使用方法

    linux自定义开机启动服务和chkconfig使用方法 1. 服务概述在linux操作系统下,经常需要创建一些服务,这些服务被做成shell脚本,这些服务需要在系统启动的时候自动启动,关闭的时候自动 ...

  4. [转载]Delphi 版 everything、光速搜索代码

    近日没啥事情,研究了一下 everything.光速搜索原理.花了一个礼拜时间,终于搞定. 废话不多说,直接上代码: unit uMFTSearchFile; { dbyoung@sina.com 2 ...

  5. SQL盲注测试高级技巧

    写在前面: 这篇文章主要写了一些加快盲注速度的技巧和盲注中比较精巧的语句,虽然注入并不是什么新技术了.但是数据库注入漏洞依然困扰着每一个安全厂商,也鞭策着每一个安全从业者不断前进. 正文: 首先来简单 ...

  6. 流畅的python第十七章使用期物处理并发

    从 Python 3.4 起,标准库中有两个名为 Future 的类:concurrent.futures.Future 和asyncio.Future.这两个类的作用相同:两个 Future 类的实 ...

  7. sqlserver中事务总结:begin tran,rollback tran,commit tran

     第1个相关用法:摘自:https://shiyousan.com/post/f13d29b7-0d87-4168-bd8b-8b28b0991b5a 以下是出现错误的SQL部分语句: 此错误的原因是 ...

  8. Silverlight 安装失败 提示 消息 ID 1603 的解决方法

    消息 ID: 1603 安装过程中出现错误.请执行以下步骤 原因是在以前安装过silverlight,没有安装成功或者没有彻底卸载干净,遗留了一些文件,尤其是安装时突然中断的时候会出现这个问题. 解决 ...

  9. 新闻焦点切换flash应用

    pixviewer.zip <!-- pixviewer.swf使用--> <script language="javascript" type="te ...

  10. PHP和数据访问之(插入。删除。和更新数据)

    插入: <?php $conn=@new mysqli('localhost','root','123','mytestdb'); $q_str=<<<EOM insert i ...