TensorFlow学习笔记(8)--网络模型的保存和读取【转】
转自:http://blog.csdn.net/lwplwf/article/details/62419087
之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西。
TensorFlow提供了一个非常简单的API,即tf.train.Saver
类来保存和还原一个神经网络模型。
下面代码给出了保存TensorFlow模型的方法:
import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
sess.run(init_op)
print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
print("v2:", sess.run(v2))
saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
print("Model saved in file:", saver_path)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
这段代码中,通过saver.save
函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt"
,也就是保存到了当前程序所在文件夹里面的save
文件夹中。
TensorFlow模型会保存在后缀为.ckpt
的文件中。保存后在save这个文件夹中实际会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。
model.ckpt.meta
文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构model.ckpt
文件保存了TensorFlow程序中每一个变量的取值checkpoint
文件保存了一个目录下所有的模型文件列表
下面代码给出了加载TensorFlow模型的方法:
可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?
import tensorflow as tf
# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来
print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
print("v2:", sess.run(v2))
print("Model Restored")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
运行结果:
v1: [[ 0.76705766 1.82217288]]
v2: [[-0.98012197 1.2369734 0.5797025 ]
[ 2.50458145 0.81897354 0.07858191]]
Model Restored
- 1
- 2
- 3
- 4
- 1
- 2
- 3
- 4
这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver
类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
也就是说使用TensorFlow完成了一次模型的保存和读取的操作。
如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:
import tensorflow as tf
# 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量
# 直接加载持久化的图
saver = tf.train.import_meta_graph("save/model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt")
# 通过张量的名称来获取张量
print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
运行程序,输出:
[[ 0.76705766 1.82217288]]
- 1
- 1
有时可能只需要保存或者加载部分变量。
比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。
为了保存或者加载部分变量,在声明tf.train.Saver
类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])
命令来构建tf.train.Saver
类,那么只有变量v1会被加载进来。
…未完待续
TensorFlow学习笔记(8)--网络模型的保存和读取【转】的更多相关文章
- matlab学习笔记4--MAT文件的保存和读取
一起来学matlab-matlab学习笔记4 数据导入和导出_1 MAT文件的保存和读取 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考书籍 <matlab 程序设计与综合应用&g ...
- TensorFlow基础笔记(14) 网络模型的保存与恢复_mnist数据实例
http://blog.csdn.net/huachao1001/article/details/78502910 http://blog.csdn.net/u014432647/article/de ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- Tensorflow学习笔记No.7
tf.data与自定义训练综合实例 使用tf.data自定义猫狗数据集,并使用自定义训练实现猫狗数据集的分类. 1.使用tf.data创建自定义数据集 我们使用kaggle上的猫狗数据以及tf.dat ...
- Tensorflow学习笔记No.8
使用VGG16网络进行迁移学习 使用在ImageNet数据上预训练的VGG16网络模型对猫狗数据集进行分类识别. 1.预训练网络 预训练网络是一个保存好的,已经在大型数据集上训练好的卷积神经网络. 如 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
- Tensorflow学习笔记2019.01.03
tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...
随机推荐
- keras callback中的stop_training
keras这个框架简洁优美,设计上堪称典范.而tensorflow就显得臃肿庞杂,混乱不清.当然,keras的周边部件比如callbacks.datasets.preprocessing有许多过度设计 ...
- 一个worker thread服务一个客户端
服务器端对一个客户端来了就开启一个工作线程,最多可接受64个. 具体看代码: #pragma once #include <winsock.h> #include <stdio.h& ...
- 消息队列状态:struct msqid_ds
Linux的消息队列(queue)实质上是一个链表, 它有消息队列标识符(queue ID). msgget创建一个新队列或打开一个存在的队列; msgsnd向队列末端添加一条新消息; msgrcv从 ...
- Socket模型(二):完成端口(IOCP)
为什么要采用Socket模型,而不直接使用Socket? 原因源于recv()方法是堵塞式的,当多个客户端连接服务器时,其中一个socket的recv调用时,会产生堵塞,使其他链接不能继续.这样我们又 ...
- 【jQuery】利用jQuery实现“记住我”的功能
[1]先下载jQuery.cookie插件:使用帮助请参考链接(https://github.com/carhartl/jquery-cookie). [2]安装插件: <script type ...
- 【Algorithm】快速排序(续)
前面在常用的排序算法中,已经写过一篇关于快速排序算法的博客,但是最近看到<The C Programming Language>这本书中的快速排序算法写的不错,所以就拿过来分享一下,下面我 ...
- Redis踩过的坑
现象:在使用redis云提供的redis服务后,经常出现connect timeout: redis.clients.jedis.exceptions.JedisConnectionException ...
- java反射之获取枚举对象
项目中导入大量枚举对象,用来定义常量.随着带来一个问题,就是每个枚举类都需要通过key来获取对应枚举的需求. public enum ExamType { CRAFT(1, "草稿" ...
- Android 热修复 Tinker接入及源码浅析
一.概述 放了一个大长假,happy,先祝大家2017年笑口常开. 假期中一行代码没写,但是想着马上要上班了,赶紧写篇博客回顾下技能,于是便有了本文. 热修复这项技术,基本上已经成为项目比较重要的模块 ...
- DVWA安装——一个菜鸟的入门教程
DVWA的安装非常简单: 1.更改config/config.inc.php文件中的数据库配置信息 2.访问setup.php,点击create/reset database即可 3.默认用户名/密码 ...