day-20 tensorflow持久化之入门学习
如果不对模型参数进行保存,当训练结束以后,模型也在内存中被释放,下一轮又需要对模型进行重新训练,有没有一种方法,可以利用之前已经训练好的模型参数值,直接进行模型推理或者继续训练?这里需要引入一个数据之久化的概念,其通用定义就是将内存中的数据模型转换为存储模型,以及将存储模型转换为内存中的数据模型的统称。
OK,在tensorflow中,持久化可以是我们训练好的神经网络权重值和biase值写入到文件中,下一次直接从文件中进行读取,而不需要重新对模型进行训练。
用tensorflow写一个简单的示例:求两个变量v1和v2的和,然后将其保存result变量中,然后将其保存到文件中,下一次训练时直接读取文件。
先看保存程序:
import tensorflow as tf # 定义两个变量,并对其进行求和
v1 = tf.Variable(tf.constant(value=1.0,dtype=tf.float32,shape=[1],name="v1"))
v2 = tf.Variable(tf.constant(value=2.0,dtype=tf.float32,shape=[1],name="v2"))
result = v1 + v2 # 将求和操作加到result集合中
tf.add_to_collection('result',result) # 新建一个持久化对象
saver = tf.train.Saver() # 运行会话,并持久化模型
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# 如下操作执行完以后,会在sample_test目录下生成四个文件:
# checkpoint:所有模型文件列表
# model.data-00000-of-00001:
# model-index:
# model.meta:计算图的结构
saver.save(sess=sess,save_path="sample_test/model")
如果要重新加载模型,新的代码可以这么写:
import tensorflow as tf # 从之前保存的点新建一个持久化对象
saver = tf.train.import_meta_graph(meta_graph_or_file="sample_test/model.meta") with tf.Session() as sess:
# 重新加载保存的参数值
saver.restore(sess=sess,save_path="sample_test/model")
# 注意get_collection返回一个列表,如果直接运行,结果也是一个List,注意比较下面的区别:
print(tf.get_collection(key="result"))
print(sess.run(tf.get_collection(key='result')))
'''
[<tf.Tensor 'add:0' shape=(1,) dtype=float32>]
[array([3.], dtype=float32)]
''' print(tf.get_collection(key="result")[0])
print(sess.run(tf.get_collection(key='result')[0]))
'''
tf.Tensor 'add:0' shape=(1,)
[3.]
'''
进一步,如果我们的网络结构加入了滑动平均模型,重新加载模型时,我们往往是希望用其进行验证,需要使用滑动平均模型参数的值,一个完整的示例如下:
训练时:
# 导入库
import tensorflow as tf # 定义一个变量
v = tf.Variable(initial_value=0,dtype=tf.float32,name='v') # 显示当前有哪些变量
# <tf.Variable 'v:0' shape=() dtype=float32_ref>
for variable in tf.global_variables():
print(variable) # 定义一个滑动平均模型,和变量应用模型的操作
ema = tf.train.ExponentialMovingAverage(0.999)
maintain_average_op = ema.apply(tf.global_variables()) # 显示当前有哪些变量
# <tf.Variable 'v:0' shape=() dtype=float32_ref>
# <tf.Variable 'v/ExponentialMovingAverage:0' shape=() dtype=float32_ref>
for variable in tf.global_variables():
print(variable) saver = tf.train.Saver() # 执行会话,并持久化
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
sess.run(tf.assign(ref=v,value=10))
sess.run(maintain_average_op)
saver.save(sess=sess,save_path="sample_test/model")
print(sess.run([v,ema.average(v)]))
重新加载时:
import tensorflow as tf v = tf.Variable(0,dtype=tf.float32,name='v') ema = tf.train.ExponentialMovingAverage(0.999)
saver = tf.train.Saver(ema.variables_to_restore()) # {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
print(ema.variables_to_restore()) with tf.Session() as sess:
saver.restore(sess,save_path="sample_test/model")
# 自动加载滑动平均值来代替变量的值
# 0.009999871
print(sess.run(v))
day-20 tensorflow持久化之入门学习的更多相关文章
- 人工智能新手入门学习路线和学习资源合集(含AI综述/python/机器学习/深度学习/tensorflow)
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] 1. 分享个人对于人工智能领域的算法综述:如果你想开始学习算法,不妨先了解人工 ...
- TensorFlow.NET机器学习入门【8】采用GPU进行学习
随着网络越来约复杂,训练难度越来越大,有条件的可以采用GPU进行学习.本文介绍如何在GPU环境下使用TensorFlow.NET. TensorFlow.NET使用GPU非常的简单,代码不用做任何修改 ...
- msp430入门学习20
msp430的USART模式 msp430入门学习
- TensorFlow.NET机器学习入门【4】采用神经网络处理分类问题
上一篇文章我们介绍了通过神经网络来处理一个非线性回归的问题,这次我们将采用神经网络来处理一个多元分类的问题. 这次我们解决这样一个问题:输入一个人的身高和体重的数据,程序判断出这个人的身材状况,一共三 ...
- MyBatis入门学习教程-使用MyBatis对表执行CRUD操作
上一篇MyBatis学习总结(一)--MyBatis快速入门中我们讲了如何使用Mybatis查询users表中的数据,算是对MyBatis有一个初步的入门了,今天讲解一下如何使用MyBatis对use ...
- opengl入门学习
OpenGL入门学习 说起编程作图,大概还有很多人想起TC的#include <graphics.h>吧? 但是各位是否想过,那些画面绚丽的PC游戏是如何编写出来的?就靠TC那可怜的640 ...
- MyBatis入门学习(二)
在MyBatis入门学习(一)中我们完成了对MyBatis简要的介绍以及简单的入门小项目测试,主要完成对一个用户信息的查询.这一节我们主要来简要的介绍MyBatis框架的增删改查操作,加深对该框架的了 ...
- OpenGL入门学习(转)
OpenGL入门学习 http://www.cppblog.com/doing5552/archive/2009/01/08/71532.html 说起编程作图,大概还有很多人想起TC的#includ ...
- Bootstrap3.0入门学习系列
Bootstrap3.0入门学习系列规划[持续更新] 前言 首先在此多谢博友们在前几篇博文当中给与的支持和鼓励,以及在回复中提出的问题.意见和看法. 在此先声明一下,之前在下小菜所有的随笔文章中, ...
随机推荐
- mysql事务隔离
一.事务的特性 原子性.一致性.隔离性.持久性 二.事务的隔离级别 1.未提交读 (Read Uncommitpeatableted) 臭名昭著的脏读 ,事务A读到事务B未提交的数据 2.提交读RC( ...
- SQL 一
1.所有表都必须在模式中.2.SYS模式不是默认模式3.虽然有概念用户PUBLIC,但它根本没有模式.4.索引有自己的名称空间,存储过程.同义词.表和视图都在同一名称空间里.5.堆是可变长度行的表,这 ...
- update更新修改数据
update ---整表更新数据 update 表名 set 需要调整字段1= '值1' ,需要调整字段2= '值2' …… ---更新条件数据 update 表名 set 需要调整字段 ...
- 为什么我们需要DTO?
最近在写代码时突然产生了这个疑惑,我们为什么需要DTO进行数据传输呢? 要了解DTO首先我们要知道什么是DAO,DAO就是数据库的一个数据模型,是一个类文件里面存储着数据库的字段及其getter&am ...
- CSS动画详解及transform、transition、translate的区别
刚看完一节慕课网的css动画,在此总结下 1. 先说下 transform.transition.translate的区别 transform 和 transition是css的2个属性,transl ...
- rem和em的用法
1.rem转化为向素值的方法 rem单位转化为像素大小取决于根元素的字体大小,即HTML元素的字体大小,根元素字体大小乘以rem. 例:根元素的字体大小 16px,10rem 将等同于 160px,即 ...
- SST-超级简单任务调度器结构分析
SST(Super Simple Task) 是一个基于任务优先级.抢占式.事件驱动.RTC.单堆栈的超级简单任务调度器,它基于Rober Ward一篇论文的思想,Miro Samek用C重新编程实现 ...
- ruby 爬虫爬取拉钩网职位信息,产生词云报告
思路:1.获取拉勾网搜索到职位的页数 2.调用接口获取职位id 3.根据职位id访问页面,匹配出关键字 url访问采用unirest,由于拉钩反爬虫,短时间内频繁访问会被限制访问,所以没有采用多线程, ...
- QOS-配置拥塞避免机制
QOS-配置拥塞避免机制 2018年7月7日 20:29 尾丢弃及其导致的问题: 队列满时路由器进行尾丢弃,即新到的所有数据包都全部丢弃 丢弃的结果造成高延迟.高抖动.丧失服务保证.TCP全局同步.T ...
- MongoDB入门---数据库&&&集合的基本操作
MongoDB作为一种nosql的数据库,它自己本身的增伤改查还有数据库集合的创建和展示与一般的数据库较之是有一部分差别的.我们今天就来看一下MongoDB的一些基本操作. 首先呢,就是先来数据 ...