TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结
写在前面
我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度快一些。在使用tf.nn.rnn_cell.BasicLSTMCell()进行建模的时候,遇到了模型保存、加载的问题。
查找了一些博主的经验,再加上自己摸索,在这里做个笔记,总结经验。其中关键要素有以下3点:
1.需要保存哪些变量(tensor),就要给哪些变量取名字(即name='XXXXX')。
2.将tf.train.Saver()与需要保存的变量(tensor)定义在一个函数里,否则保存会出错。
3.加载模型的时候,先加载图,再加载变量(tensor)。
下面通过实例进行描述。
模型保存
tf.train.Saver()可以自动保存变量和计算图。
保存前注意!!!需要对要保存的变量命名,即属性中的name=XXX
下面是使用tf.nn.rnn_cell.BasicLSTMCell()自建的一个LSTM_Cel
class LSTM_Cell(object):
# train_data 格式示例,batch_size*num_steps*input_dim 批大小*时间窗口长度*单时间节点输入维度
# train_label格式示例,batch_size*1 # TODO 该模型紧输出一维结果。
# input_dim 格式 int, 输入数据在单时间节点上的维度
# num_nodes 神经元数目/维度
def __init__(self, train_data, train_label, input_dim, batch_size=10, num_nodes=64):
tf.reset_default_graph()
self.num_nodes = num_nodes
self.input_dim = input_dim
self.train_data = train_data
self.train_label = train_label
self.batch_size = batch_size
gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.666)
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) def loss_func(self,lr=0.001):
self.w = tf.Variable(tf.truncated_normal([self.num_nodes, 1], -0.1, 0.1),name='w') # 1 是指输出维度,这里预测一个值,因此维度是1
self.b = tf.Variable(tf.zeros([1]),name='b')
self.batch_in = tf.placeholder(tf.float32, [None, self.train_data.shape[1], self.input_dim],name='batch_in')
self.batch_out = tf.placeholder(tf.float32, [None, 1],name='batch_out')
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_nodes,forget_bias=1.0,state_is_tuple=True)
# init_state = lstm_cell.zero_state(self.batch_in[0],dtype=tf.float32)
output, final_state = tf.nn.dynamic_rnn(lstm_cell, self.batch_in, time_major=False, dtype=tf.float32) # initial_state=init_state,
self.y_pre = tf.nn.sigmoid(tf.matmul(final_state[1], self.w) + self.b,name="y_pre")
self.mse = tf.sqrt(tf.reduce_mean(tf.square(self.y_pre-self.batch_out)),name='mse')
self.cross_entropy = -tf.reduce_mean(self.batch_out * tf.log(self.y_pre),name='cross_entropy')
self.train_op = tf.train.GradientDescentOptimizer(lr).minimize(self.mse)
self.saver = tf.train.Saver() def train_model(self,savepath,epochs=1000):
self.sess.run(tf.global_variables_initializer())
for i in range(epochs):
for j in range(int(len(self.train_data)/self.batch_size)):
batch_i = self.train_data[j*self.batch_size:(j+1)*self.batch_size]
batch_o = self.train_label[j*self.batch_size:(j+1)*self.batch_size]
self.sess.run(self.train_op, feed_dict={self.batch_in:batch_i, \
self.batch_out:batch_o.reshape(self.batch_size,1)})
if (i+1)%200==0:
print('epoch:%d'%(i+1),self.sess.run(self.mse,feed_dict={self.batch_in:batch_i, \
self.batch_out:batch_o.reshape(self.batch_size,1)}))
save_path = self.saver.save(self.sess, savepath)
print("模型保存于: ", save_path)
在LSTM_Cell类中,构造函数定义了一些固定参数以及TensorFlow会话(tf.Session()),而我们所要保存的变量(tensor)都在loss_func()函数中定义。包括:
①最后一个全连接层的w和b;
②输入、输出变量的占位符batch_in,batch_out;
③LSTM单元的计算过程;
④计算最终计算结果y_pre,均方根误差mse,交叉熵计算结果cross_entropy,使用随机梯度下降的训练步骤train_op;
⑤存储器tf.train.Saver()。
在本例中,只有变量’w’, ’b’, ’batch_in’, ’batch_out’, ’y_pre’, ’mse’, ’cross_entropy’在属性中有过命名,会被保存下来。
这里,tf.train.Saver()只能保存本函数(即loss_func)中定义的变量(tensor)。
train_model()函数实现训练过程,并调用self.saver.save(self.sess, savepath)来对模型及命名了的变量(tensor)进行保存。
下面是调用LSTM_Cell类进行训练并保存模型的代码:
# 初始化LSTM类
lstm_obj = LSTM_Cell(sample_input,sample_output,input_dim=1,batch_size=_batch_size,num_nodes=hidden_size)
lstm_obj.loss_func(lr) # 构建计算图
# TODO 训练
lstm_obj.train_model(savepath=saved_path,epochs=epochs)
其中,保存路径为
saved_path = "./standard_LSTM/models/Basic_LSTM_TF_models/59model.ckpt"
最终得到的保存结果为下方4个文件(暂时无视两个png图片)
.meta文档是计算图保存的位置,.data是参数数据,后面的00000-of-00001是模型的版本号。
模型加载
加载困扰了我很久,后面经过摸索才知道有两个关键部分,一个是计算图的加载,一个是变量的加载,两者缺一不可。
LSTM_Cell类中,加载函数(load_model)定义如下。
def load_model(self,savepath):
len_last = len(savepath.split('/')[-1])
self.saver = tf.train.import_meta_graph(savepath+'.meta')
self.saver.restore(self.sess,tf.train.latest_checkpoint(savepath[:-len_last])) # 加载最后一个模型
self.graph = tf.get_default_graph()
tensor_name_list = [tensor.name for tensor in self.graph.as_graph_def().node]
self.w = self.graph.get_tensor_by_name('w:0')
self.b = self.graph.get_tensor_by_name('b:0')
self.batch_in = self.graph.get_tensor_by_name("batch_in:0")
self.batch_out = self.graph.get_tensor_by_name("batch_out:0")
self.y_pre = self.graph.get_tensor_by_name('y_pre:0')
self.mse = self.graph.get_tensor_by_name('mse:0')
self.cross_entropy = self.graph.get_tensor_by_name('cross_entropy:0')
首先定义一个self.saver,来辅助加载图及变量。
第一步加载图,即tf.train.import_meta_graph(savepath+'.meta'),就是加载上图中的 59model.ckpt.meta
saver.restore()函数将模型参数进行加载,savepath[:-len_last]是指保存模型的文件夹路径,即"./standard_LSTM/models/Basic_LSTM_TF_models/" ,将模型加载到默认的计算图中(default_graph)。
此时,各变量(即tensor)已经在计算图中了,但要正常调用,还需要从图中取出并将其设置成变量。
具体方法是先取得默认的计算图self.graph,再通过get_tensor_by_name()方法将tensor实例化,每个tensor的名称与模型保存时name=”XXX”的名称相同,并且后方需要加上:<index>,不重名的情况下这个index一般是0。
也有博主说tensor的名称可以在tensor_name_list中查看到,但我打印出来后发现这个list太长,不大实用。
这时候就加载模型完毕了,可以调用self.sess对self.y_pre、self.mse进行计算。
示例计算如下:
def predict_next_one(self,batch_i): # batch_i长度 为样本时间序列长度
temp = self.sess.run(self.y_pre,feed_dict={self.batch_in:batch_i.reshape(1,len(batch_i),1)})
return temp[0][0]
外部的调用方法如下,(构造函数后就不用使用loss_func构建计算过程了,直接加载模型就行。)
# 初始化LSTM类
lstm_obj = LSTM_Cell(sample_input,sample_output,input_dim=1,batch_size=_batch_size,num_nodes=hidden_size)
# TODO 加载模型
lstm_obj.load_model(savepath=saved_path)
下面两个图是训练完后直接预测以及加载模型再预测的结果,可以看出模型加载后,计算结果与之前一致。
TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结的更多相关文章
- tensorflow学习笔记2:c++程序静态链接tensorflow库加载模型文件
首先需要搞定tensorflow c++库,搜了一遍没有找到现成的包,于是下载tensorflow的源码开始编译: tensorflow的contrib中有一个makefile项目,极大的简化的接下来 ...
- [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题
[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...
- PyTorch保存模型与加载模型+Finetune预训练模型使用
Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...
- 使用Pytorch在多GPU下保存和加载训练模型参数遇到的问题
最近使用Pytorch在学习一个深度学习项目,在模型保存和加载过程中遇到了问题,最终通过在网卡查找资料得已解决,故以此记之,以备忘却. 首先,是在使用多GPU进行模型训练的过程中,在保存模型参数时,应 ...
- 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)
1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...
- 【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- [Pytorch]Pytorch 保存模型与加载模型(转)
转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...
- 132、TensorFlow加载模型
# The tf.train.Saver对象不仅保存变量到checkpoint文件 # 它也恢复变量,当你恢复变量的时候,你就不必须要提前初始化他们 # 列如如下的代码片段解释了如何去调用tf.tra ...
- MindSpore保存与加载模型
技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...
随机推荐
- 使用WireShark进行网络流量安全分析
WireShark的过滤规则 伯克利包过滤(BPF)(应用在wireshark的捕获过滤器上) ** 伯克利包过滤中的限定符有下面的三种:** Type:这种限定符表示指代的对象,例如IP地址,子网或 ...
- linux 配置网卡、远程拷贝文件、建立软硬链接、打包/解包、压缩/解压缩、包操作、yum配置使用、root密码忘记
目录 一.配置网卡 二.xshell连接 三.远程拷贝文件 四.建立软硬连接 五.打包/解包和压缩/解压缩 六.包操作 七.配置yum源 配置yum源 配置阿里云源 常用命令 yum其他命令 八.重置 ...
- 关于浏览器Number.toFixed的错误修复
问题描述如下: var n = 1.255; var fixed = n.toFixed(2); console.log(fixed);//结果:1.25 /* 以上代码运行预期的结果是1.26,但是 ...
- 数据库安装和基本sql语句
数据库概念 文件作为数据进行存储,数据格式千差万别 将保存数据的地方统一起来 MYSQL--------->一款应用软件 用来帮你操作文件的 只要是基于网络通信,底层就是socket 服务端 - ...
- hdu4587 Two Nodes 求图中删除两个结点剩余的连通分量的数量
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4587 题目给了12000ms,对于tarjan这种O(|V|+|E|)复杂度的算法来说,暴力是能狗住的 ...
- hdu6026 dijkstra
题目链接:http://icpc.njust.edu.cn/Problem/Hdu/6026/ 题意大致是:给定一个图,要求删边使他变成树,使得每个点到0的距离就是原图中0到这个点的最短路径.其实就是 ...
- JSP(二)----指令,注释,内置对象
## JSP 1.指令 * 作用:用于配置JSP页面,导入资源文件 * 格式: <%@ 指令名称 属性名1=属性值1 属性名2=属性值2 %> <%@ page con ...
- 配置ssh免密登录遇到的问题——使用VMware多虚拟机搭建Hadoop集群
搭建环境: 虚拟机 VMware12Pro 操作系统 centos6.8 hadoop 1.2.1 1.导入镜像文件,添加java环境 1.查看当前系统中安装的java,ls ...
- ECMAScript 6 基础
ECMAScript 6 基础 ECMAScript 6 简介 JavaScript 三大组成部分 ECMAScript DOM BOM ECMAScript 发展历史 https://develop ...
- OpenCV-Python 图像分割与Watershed算法 | 三十四
目标 在本章中, 我们将学习使用分水岭算法实现基于标记的图像分割 我们将看到:cv.watershed() 理论 任何灰度图像都可以看作是一个地形表面,其中高强度表示山峰,低强度表示山谷.你开始用不同 ...