tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练。
在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-22 22:34:16
############################ import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
print(v1)
result = v1 + v2
print(result) saver = tf.train.Saver([v1])#只有变量v1会被加载 with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))
执行上面的代码,会报错:tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2
字典一般方便实现变量重命名,因为在某些场景下,模型中变量的命名和当前需要加载的变量名称并不相同而且有时候对于那些TF自动生成的变量的名称太长不好表示,那么为了不导致加载模型时找不到变量的问题。
在上一篇博文中,保存的两个变量的名称为v1和v2。
import tensorflow as tf
#保存或加载时给变量重命名
a1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
a2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = a1 + a2 #使用字典来重命名变量就可以加载原模型中的相应变量.如下指定了原来名称为v1的变量现在加载到变量a1中,原来名称为v2的变量现在加载到变量a2中
saver = tf.train.Saver({"v1":a1, "v2":a2})
#因为有时候模型保存时的变量名称和加载时的变量名称不一致,为了解决这个问题,TF可以通过字典将模型保存时的变量名和需要加载的变量关联起来. with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))
在使用滑动平均模型时,tf.train.ExponentialMovingAverage对每一个变量会维护一个影子变量(shadow variable),这个影子变量是TF自动生成的,那么为了方便加载使用影子变量,就可以使用变量重命名。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: saver_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:02:23
############################
import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v")
v2 = tf.Variable(0, dtype=tf.float32, name="v2")
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0 #在声明滑动平均模型后,TF会自动生成一个影子变量
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0
#v/ExponentialMovingAverage:0
#v2/ExponentialMovingAverage:0 print(ema.variables_to_restore())
#{'v2/ExponentialMovingAverage': <tf.Variable 'v2:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} saver = tf.train.Saver() with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op) sess.run(tf.assign(v, 10))
sess.run(tf.assign(v2, 10))
sess.run(maintain_averages_op) saver.save(sess, "moving_average.ckpt")
print(sess.run([v, ema.average(v)]))
#[10.0, 0.099999905]
滑动平均模型主要作用是为了增加模型的泛化性,可针对所有参数进行优化。
在TF中,每一个变量的滑动平均值是通过影子变量维护的,所以要获得变量的滑动平均值实际上就是获取这个变量的影子变量的值。如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时,就不需要再进行相应变量的滑动平均值的计算。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:51:31
############################ import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") saver = tf.train.Saver({"v/ExponentialMovingAverage":v})#通过变量重命名将原来的变量v的滑动平均值直接赋给变量v with tf.Session() as sess:
saver.restore(sess, "moving_average.ckpt")
print(sess.run(v))
源码路径:
https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/saver_ema.py
https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/restore_ema.py
tensorflow 之模型的保存与加载(二)的更多相关文章
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
- (sklearn)机器学习模型的保存与加载
需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...
- pytorch_模型参数-保存,加载,打印
1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...
- pytorch 中模型的保存与加载,增量训练
让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...
- fashion_mnist多分类训练,两种模型的保存与加载
from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python ...
- tensorflow1.0 模型的保存与加载
import tensorflow as tf import numpy as np # ##Save to file # W = tf.Variable([[4,5,6],[7,8,9]],dtyp ...
随机推荐
- Android内存优化4 了解java GC 垃圾回收机制2 GC执行finalize的过程
1. finalize的作用 finalize()是Object的protected方法,子类可以覆盖该方法以实现资源清理工作,GC在回收对象之前调用该方法. finalize()与C++中的析构函数 ...
- CSS margin-top 属性
1.margin-top 属性设置元素的上外边距. 注意:允许使用负值. 2.html 文件 <html> <head> <style type="text/c ...
- mongo 误操作恢复数据
场景:我往同一个集合里面插入 三条数据 aa:aa bb:bb cc:cc .后来我后悔了,不想插入 bb:bb,通过oplog重放过滤好 bb:bb这条数据. 原理: 1.通过 oplog.r ...
- python部署工具fabric
两台机器:10.1.6.186.10.1.6.159.fabric部署在10.1.6.186上面 1 执行和1相同的任务,不过排除掉10.1.6.159这台机器 1 #!/usr/bin/pytho ...
- git 统计代码量 shell脚本
#!/bin/bash # 统计代码量 # 使用方法: sh gitstat.sh "2017-11-01" "2017-11-30" "JamKon ...
- JSP和Servlet中的几个编码的作用及原理
首先,说说JSP和Servlet中的几个编码的作用. 在JSP和Servlet中主要有以下几个地方可以设置编码,pageEncoding="UTF-8".contentType=& ...
- 用BETTERCAP和RASPBERRY PI ZERO W制作迷你WiFi干扰器
我并不是一个特别勤快的人,几天前我终于开始将我几周以来的一些想法付诸于实践,即使用Raspberry Pi Zero W制作一个可随身携带的迷你WiFi干扰器.有了它,我就可以随时随地的收集附近无线接 ...
- 流畅的python第十六章协程学习记录
从句法上看,协程与生成器类似,都是定义体中包含 yield 关键字的函数.可是,在协程中,yield 通常出现在表达式的右边(例如,datum = yield),可以产出值,也可以不产出——如果 yi ...
- screen space shadowmap unity
unity用到了screen space shadow map 1.camera 在light pos 生成depth1 2.screen space depth2 3.根据depth1 depth2 ...
- Vuex内容解析和vue cli项目中使用状态管理模式Vuex
中文文档:vuex官方中文网站 一.vuex里面都有些什么内容? const store = new Vuex.Store({ state: { name: 'weish', age: }, gett ...