tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练。
在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-22 22:34:16
############################ import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
print(v1)
result = v1 + v2
print(result) saver = tf.train.Saver([v1])#只有变量v1会被加载 with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))
执行上面的代码,会报错:tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2
字典一般方便实现变量重命名,因为在某些场景下,模型中变量的命名和当前需要加载的变量名称并不相同而且有时候对于那些TF自动生成的变量的名称太长不好表示,那么为了不导致加载模型时找不到变量的问题。
在上一篇博文中,保存的两个变量的名称为v1和v2。
import tensorflow as tf
#保存或加载时给变量重命名
a1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
a2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = a1 + a2 #使用字典来重命名变量就可以加载原模型中的相应变量.如下指定了原来名称为v1的变量现在加载到变量a1中,原来名称为v2的变量现在加载到变量a2中
saver = tf.train.Saver({"v1":a1, "v2":a2})
#因为有时候模型保存时的变量名称和加载时的变量名称不一致,为了解决这个问题,TF可以通过字典将模型保存时的变量名和需要加载的变量关联起来. with tf.Session() as sess:
saver.restore(sess, "my_test_model.ckpt")
print(sess.run(result))
在使用滑动平均模型时,tf.train.ExponentialMovingAverage对每一个变量会维护一个影子变量(shadow variable),这个影子变量是TF自动生成的,那么为了方便加载使用影子变量,就可以使用变量重命名。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: saver_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:02:23
############################
import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v")
v2 = tf.Variable(0, dtype=tf.float32, name="v2")
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0 #在声明滑动平均模型后,TF会自动生成一个影子变量
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name)
#v:0
#v2:0
#v/ExponentialMovingAverage:0
#v2/ExponentialMovingAverage:0 print(ema.variables_to_restore())
#{'v2/ExponentialMovingAverage': <tf.Variable 'v2:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} saver = tf.train.Saver() with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op) sess.run(tf.assign(v, 10))
sess.run(tf.assign(v2, 10))
sess.run(maintain_averages_op) saver.save(sess, "moving_average.ckpt")
print(sess.run([v, ema.average(v)]))
#[10.0, 0.099999905]
滑动平均模型主要作用是为了增加模型的泛化性,可针对所有参数进行优化。
在TF中,每一个变量的滑动平均值是通过影子变量维护的,所以要获得变量的滑动平均值实际上就是获取这个变量的影子变量的值。如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时,就不需要再进行相应变量的滑动平均值的计算。
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
############################
#File Name: restore_ema.py
#Brief:
#Author: frank
#Mail: frank0903@aliyun.com
#Created Time:2018-06-25 21:51:31
############################ import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") saver = tf.train.Saver({"v/ExponentialMovingAverage":v})#通过变量重命名将原来的变量v的滑动平均值直接赋给变量v with tf.Session() as sess:
saver.restore(sess, "moving_average.ckpt")
print(sess.run(v))
源码路径:
https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/saver_ema.py
https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/restore_ema.py
tensorflow 之模型的保存与加载(二)的更多相关文章
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
- (sklearn)机器学习模型的保存与加载
需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...
- pytorch_模型参数-保存,加载,打印
1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...
- pytorch 中模型的保存与加载,增量训练
让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...
- fashion_mnist多分类训练,两种模型的保存与加载
from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python ...
- tensorflow1.0 模型的保存与加载
import tensorflow as tf import numpy as np # ##Save to file # W = tf.Variable([[4,5,6],[7,8,9]],dtyp ...
随机推荐
- shell 脚本中执行mysql语句
通过hash建表之后,表的数据量巨大2048,那怎么去验证表是否建成功呢? 逻辑生成表名这部分就不写了.只要能建表成功,这部分的脚本肯定是有的.那么怎么在shell中执行selec查询并返回呢 只要在 ...
- DESede对称加密算法工具类
利用Cipher的核心功能,自己封装了一个加密解密的工具类,可以直接使用.在使用之前需要先下载commons-codec-1.9.jar,并导入项目. 工具类如下: package com.pcict ...
- [Python爬虫] 之十二:Selenium +phantomjs抓取中的url编码问题
最近在抓取活动树网站 (http://www.huodongshu.com/html/find.html) 上数据时发现,在用搜索框输入中文后,点击搜索,phantomjs抓取数据怎么也抓取不到,但是 ...
- 虚拟机、linux系统安装
下载VMWare解压后依据提示正触安装VMWare到硬盘中 (1) 建立虚拟机 A.用鼠标左建双击桌面中的"VMwareworkstation"图标.执行虚拟机 B.建立一台虚拟机 ...
- redis学习笔记——RDB和AOF持久化二
上一篇对RDB的源码分析是比较多的,但是AOF持久化执行进行了一些理论上的分析和概念的说明.本来想自己偷一些懒,将上篇文章中最后所给链接的AOF实现代码随便过一过算了,后来也就是在过的过程中发现自己这 ...
- Office 如何下载网页的视频 JWPlayer的内嵌视频
右击页面空白处,查看页面源代码 在里面搜索mp4或者swf,video,一般网页中的视频都是这些格式,仔细找一定能找到对应的地址 然后复制到迅雷下载即可
- 网页制作,网站制作中put和get的区别
Http定义了与服务器交互的不同方法,最基本的方法有4种,分别是GET,POST,PUT,DELETE.URL全称是资源描述符,我们可以这样认为:一个URL地址,它用于描述一个网络上的资源,而HTTP ...
- C#秘密武器之多线程——基础
多线程概述 什么是进程? 当一个程序开始运行时,它就是一个进程,进程包括运行中的程序和程序所使用到的内存和系统资源.而一个进程又是由多个线程所组成的. 什么是线程? 线程是程序中的一个执行流,每个线程 ...
- SpringBoot环境属性占位符解析和类型转换
前提 前面写过一篇关于Environment属性加载的源码分析和扩展,里面提到属性的占位符解析和类型转换是相对复杂的,这篇文章就是要分析和解读这两个复杂的问题.关于这两个问题,选用一个比较复杂的参数处 ...
- python——关于Python Profilers性能分析器
1. 介绍性能分析器 profiler是一个程序,用来描述运行时的程序性能,并且从不同方面提供统计数据加以表述.Python中含有3个模块提供这样的功能,分别是cProfile, profile和ps ...