tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练。
在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-22 22:34:16
############################ import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
print(v1)
result = v1 + v2
print(result) saver = tf.train.Saver([v1])#只有变量v1会被加载 with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))
执行上面的代码,会报错:tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2
字典一般方便实现变量重命名,因为在某些场景下,模型中变量的命名和当前需要加载的变量名称并不相同而且有时候对于那些TF自动生成的变量的名称太长不好表示,那么为了不导致加载模型时找不到变量的问题。
在上一篇博文中,保存的两个变量的名称为v1和v2。
import tensorflow as tf
#保存或加载时给变量重命名
a1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
a2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = a1 + a2 #使用字典来重命名变量就可以加载原模型中的相应变量.如下指定了原来名称为v1的变量现在加载到变量a1中,原来名称为v2的变量现在加载到变量a2中
saver = tf.train.Saver({"v1":a1, "v2":a2})
#因为有时候模型保存时的变量名称和加载时的变量名称不一致,为了解决这个问题,TF可以通过字典将模型保存时的变量名和需要加载的变量关联起来. with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))
在使用滑动平均模型时,tf.train.ExponentialMovingAverage对每一个变量会维护一个影子变量(shadow variable),这个影子变量是TF自动生成的,那么为了方便加载使用影子变量,就可以使用变量重命名。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: saver_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:02:23
############################
import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v")
v2 = tf.Variable(0, dtype=tf.float32, name="v2")
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0 #在声明滑动平均模型后,TF会自动生成一个影子变量
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0
#v/ExponentialMovingAverage:0
#v2/ExponentialMovingAverage:0 print(ema.variables_to_restore())
#{'v2/ExponentialMovingAverage': <tf.Variable 'v2:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} saver = tf.train.Saver() with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op) sess.run(tf.assign(v, 10))
sess.run(tf.assign(v2, 10))
sess.run(maintain_averages_op) saver.save(sess, "moving_average.ckpt")
print(sess.run([v, ema.average(v)]))
#[10.0, 0.099999905]
滑动平均模型主要作用是为了增加模型的泛化性,可针对所有参数进行优化。
在TF中,每一个变量的滑动平均值是通过影子变量维护的,所以要获得变量的滑动平均值实际上就是获取这个变量的影子变量的值。如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时,就不需要再进行相应变量的滑动平均值的计算。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:51:31
############################ import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") saver = tf.train.Saver({"v/ExponentialMovingAverage":v})#通过变量重命名将原来的变量v的滑动平均值直接赋给变量v with tf.Session() as sess:
saver.restore(sess, "moving_average.ckpt")
print(sess.run(v))
源码路径:
https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/saver_ema.py
https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/restore_ema.py
tensorflow 之模型的保存与加载(二)的更多相关文章
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
- (sklearn)机器学习模型的保存与加载
需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...
- pytorch_模型参数-保存,加载,打印
1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...
- pytorch 中模型的保存与加载,增量训练
让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...
- fashion_mnist多分类训练,两种模型的保存与加载
from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python ...
- tensorflow1.0 模型的保存与加载
import tensorflow as tf import numpy as np # ##Save to file # W = tf.Variable([[4,5,6],[7,8,9]],dtyp ...
随机推荐
- RRGGBBAA或者RRGGBB转换成rgba()
//十六进制颜色值的正则表达式 var reg = /^#([0-9a-fA-f]{3}|[0-9a-fA-f]{6})$/; /*16进制颜色转为RGB格式*/ var colorRgb = fun ...
- Eclipse 生成WebService客户端代码
1. 打开Eclipse,新建一个普通的Javaproject,然后在新建的项目上右键点击项目,New---->other---->Web Services -------->Web ...
- Kubernetes下的Redis主从配置架构
文章看了一大堆,但都是直接从各种地方直接拉master,slave镜像,没有交代这些镜像如何构建出来的 好把,我这篇就讲讲这些master,slave镜像如何做成. 先得找到一个标准的redis镜像, ...
- Android源代码编译apk导入第三方包报错
报错内容例如以下: make: *** 没有规则能够创建"out/target/common/obj/APPS/ AndroidWFS_intermediates/classes-full- ...
- Ftp上传文件
package net.util.common; import java.io.File; import java.io.FileInputStream; import java.io.FileOut ...
- 2018.1.9 博客迁移至csdn
http://blog.csdn.net/liyuhui195134?ref=toolbar
- java中Xml、json之间的相互转换
旁白: 最近关于xml与json之间的转换都搞蒙了,这里写一个demo,以后备用. 正题: project格式是: jar包是一个一个检出来的,还算干净了. 代码: 工具类: package exer ...
- 近期写的一个控件——Well Swipe beta 1.0
原文地址:http://blog.csdn.net/u013045971/article/details/51119507 近期花了大概一个半月的业余时间写的.从没有到有,中间也碰到了非常多的坑,一点 ...
- C#秘密武器之表达式树
一.表达式树入门 Lambda表达式树很复杂,从概念上很难理解清楚,一句话,表达式树是一种数据结构!这里我们通过下面的这个例子来理解一下表达式树,你就能看个大概: lambda表达式树动态创建方法 s ...
- java中的super限定
super的用法: (1)如果需要在子类中调用父类中被覆盖的实例方法,可以用super限定来调用父类中被覆盖的方法.当然,也可以调用从父类继承的实例变量. public void callOverri ...