tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练。
在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-22 22:34:16
############################ import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
print(v1)
result = v1 + v2
print(result) saver = tf.train.Saver([v1])#只有变量v1会被加载 with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))
执行上面的代码,会报错:tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2
字典一般方便实现变量重命名,因为在某些场景下,模型中变量的命名和当前需要加载的变量名称并不相同而且有时候对于那些TF自动生成的变量的名称太长不好表示,那么为了不导致加载模型时找不到变量的问题。
在上一篇博文中,保存的两个变量的名称为v1和v2。
import tensorflow as tf
#保存或加载时给变量重命名
a1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
a2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = a1 + a2 #使用字典来重命名变量就可以加载原模型中的相应变量.如下指定了原来名称为v1的变量现在加载到变量a1中,原来名称为v2的变量现在加载到变量a2中
saver = tf.train.Saver({"v1":a1, "v2":a2})
#因为有时候模型保存时的变量名称和加载时的变量名称不一致,为了解决这个问题,TF可以通过字典将模型保存时的变量名和需要加载的变量关联起来. with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))
在使用滑动平均模型时,tf.train.ExponentialMovingAverage对每一个变量会维护一个影子变量(shadow variable),这个影子变量是TF自动生成的,那么为了方便加载使用影子变量,就可以使用变量重命名。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: saver_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:02:23
############################
import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v")
v2 = tf.Variable(0, dtype=tf.float32, name="v2")
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0 #在声明滑动平均模型后,TF会自动生成一个影子变量
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0
#v/ExponentialMovingAverage:0
#v2/ExponentialMovingAverage:0 print(ema.variables_to_restore())
#{'v2/ExponentialMovingAverage': <tf.Variable 'v2:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} saver = tf.train.Saver() with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op) sess.run(tf.assign(v, 10))
sess.run(tf.assign(v2, 10))
sess.run(maintain_averages_op) saver.save(sess, "moving_average.ckpt")
print(sess.run([v, ema.average(v)]))
#[10.0, 0.099999905]
滑动平均模型主要作用是为了增加模型的泛化性,可针对所有参数进行优化。
在TF中,每一个变量的滑动平均值是通过影子变量维护的,所以要获得变量的滑动平均值实际上就是获取这个变量的影子变量的值。如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时,就不需要再进行相应变量的滑动平均值的计算。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:51:31
############################ import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") saver = tf.train.Saver({"v/ExponentialMovingAverage":v})#通过变量重命名将原来的变量v的滑动平均值直接赋给变量v with tf.Session() as sess:
saver.restore(sess, "moving_average.ckpt")
print(sess.run(v))
源码路径:
https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/saver_ema.py
https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/restore_ema.py
tensorflow 之模型的保存与加载(二)的更多相关文章
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
- (sklearn)机器学习模型的保存与加载
需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...
- pytorch_模型参数-保存,加载,打印
1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...
- pytorch 中模型的保存与加载,增量训练
让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...
- fashion_mnist多分类训练,两种模型的保存与加载
from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python ...
- tensorflow1.0 模型的保存与加载
import tensorflow as tf import numpy as np # ##Save to file # W = tf.Variable([[4,5,6],[7,8,9]],dtyp ...
随机推荐
- 关于单例模式的N种实现方式
在开发中经常用到单例模式,单例模式也算是设计模式中最容易理解,也是最容易手写代码的模式,所以也常作为面试题来考.所以想总结一下单例模式的理论知识,方便同学们面试使用. 单例模式实现的方式只有两种类型, ...
- Windows命令行的使用
在介绍Windows批处命令前,我们首先来介绍Windows命令行的使用. Windows shell提供了一个黑色的框框界面,即命令行操作界面,关于命令行的作用和好处,我就不费口舌了,下面仅窥见一斑 ...
- 【LaTeX】E喵的LaTeX新手入门教程(6)中文
假期玩得有点凶 ._.前情回顾[LaTeX]E喵的LaTeX新手入门教程(1)准备篇 [LaTeX]E喵的LaTeX新手入门教程(2)基础排版 [LaTeX]E喵的LaTeX新手入门教程(3)数学公式 ...
- WCF和Socket
WCF的全称是:Windows Communication Foundation.它是建立在Web Service架构上的一个全新的通信平台.它使用相同的基础结构和 API 来创建应用程序,这些应用程 ...
- FXC Define的使用方法
https://docs.microsoft.com/en-us/windows/desktop/direct3dtools/dx-graphics-tools-fxc-syntax https:// ...
- Newtonsoft.Json.4.5.11使用方法总结---反序列化json字符串
写在开头: 最近项目需求,需要在C#中处理json字符串,毫不犹豫的下载了Newtonsoft.Json 4.5.11(2012.12.17)http://json.codeplex.com/,然后百 ...
- 使用FreeMarker的Web Project例子
1 创建一个名为FreemarkerDemo的Web Project 2 删除index.jsp,新建index.html,index.html中的内容为: <html> <head ...
- PHP变量引用赋值与变量赋值变量的区别
变量默认总是传值赋值.那也就是说,当将一个表达式的值赋予一个变量时,整个原始表达式的值被赋值到目标变量.这意味着,例如,当一个变量的值赋予另外一个变量时,改变其中一个变量的值,将不会影响到另外一个变量 ...
- unicode 编码在线转换工具
字符串 unideo的16进制值
- Linux命令计算文件中某一列的平均值
例如每秒执行一次top命令,把结果输出到某个文件中保存,现在需要统计这段时间内某个进程的平均CPU占用率,可使用以下命令 | grep "GameServer_r" | awk ' ...