Tensorflow保存神经网络参数有妙招:Saver和Restore
摘要:这篇文章将讲解TensorFlow如何保存变量和神经网络参数,通过Saver保存神经网络,再通过Restore调用训练好的神经网络。
本文分享自华为云社区《[Python人工智能] 十一.Tensorflow如何保存神经网络参数 丨【百变AI秀】》,作者: eastmount。
一.保存变量
通过tf.Variable()定义权重和偏置变量,然后调用tf.train.Saver()存储变量,将数据保存至本地“my_net/save_net.ckpt”文件中。
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 2 20:04:57 2020
@author: xiuzhang Eastmount CSDN
"""
import tensorflow as tf
import numpy as np #---------------------------------------保存文件---------------------------------------
W = tf.Variable([[1,2,3], [3,4,5]], dtype=tf.float32, name='weights') #2行3列的数据
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases') # 初始化
init = tf.initialize_all_variables() # 定义saver 存储各种变量
saver = tf.train.Saver() # 使用Session运行初始化
with tf.Session() as sess:
sess.run(init)
# 保存 官方保存格式为ckpt
save_path = saver.save(sess, "my_net/save_net.ckpt")
print("Save to path:", save_path)
“Save to path: my_net/save_net.ckpt”保存成功如下图所示:

打开内容如下图所示:

接着定义标记变量train,通过Restore操作使用我们保存好的变量。注意,在Restore时需要定义相同的dtype和shape,不需要再定义init。最后直接通过 saver.restore(sess, “my_net/save_net.ckpt”) 提取保存的变量并输出即可。
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 2 20:04:57 2020
@author: xiuzhang Eastmount CSDN
"""
import tensorflow as tf
import numpy as np # 标记变量
train = False #---------------------------------------保存文件---------------------------------------
# Save
if train==True:
# 定义变量
W = tf.Variable([[1,2,3], [3,4,5]], dtype=tf.float32, name='weights') #2行3列的数据
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases') # 初始化
init = tf.global_variables_initializer() # 定义saver 存储各种变量
saver = tf.train.Saver() # 使用Session运行初始化
with tf.Session() as sess:
sess.run(init)
# 保存 官方保存格式为ckpt
save_path = saver.save(sess, "my_net/save_net.ckpt")
print("Save to path:", save_path)
#---------------------------------------Restore变量-------------------------------------
# Restore
if train==False:
# 记住在Restore时定义相同的dtype和shape
# redefine the same shape and same type for your variables
W = tf.Variable(np.arange(6).reshape((2,3)), dtype=tf.float32, name='weights') #空变量
b = tf.Variable(np.arange(3).reshape((1,3)), dtype=tf.float32, name='biases') #空变量 # Restore不需要定义init
saver = tf.train.Saver()
with tf.Session() as sess:
# 提取保存的变量
saver.restore(sess, "my_net/save_net.ckpt")
# 寻找相同名字和标识的变量并存储在W和b中
print("weights", sess.run(W))
print("biases", sess.run(b))
运行代码,如果报错“NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. ”,则需要重置Spyder即可。

最后输出之前所保存的变量,weights为 [[1,2,3], [3,4,5]],偏置为 [[1,2,3]]。

二.保存神经网络
那么,TensorFlow如何保存我们的神经网络框架呢?我们需要把整个网络训练好再进行保存,其方法和上面类似,完整代码如下:
"""
Created on Sun Dec 29 19:21:08 2019
@author: xiuzhang Eastmount CSDN
"""
import os
import glob
import cv2
import numpy as np
import tensorflow as tf # 定义图片路径
path = 'photo/' #---------------------------------第一步 读取图像-----------------------------------
def read_img(path):
cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
imgs = []
labels = []
fpath = []
for idx, folder in enumerate(cate):
# 遍历整个目录判断每个文件是不是符合
for im in glob.glob(folder + '/*.jpg'):
#print('reading the images:%s' % (im))
img = cv2.imread(im) #调用opencv库读取像素点
img = cv2.resize(img, (32, 32)) #图像像素大小一致
imgs.append(img) #图像数据
labels.append(idx) #图像类标
fpath.append(path+im) #图像路径名
#print(path+im, idx) return np.asarray(fpath, np.string_), np.asarray(imgs, np.float32), np.asarray(labels, np.int32) # 读取图像
fpaths, data, label = read_img(path)
print(data.shape) # (1000, 256, 256, 3)
# 计算有多少类图片
num_classes = len(set(label))
print(num_classes) # 生成等差数列随机调整图像顺序
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
fpaths = fpaths[arr] # 拆分训练集和测试集 80%训练集 20%测试集
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
fpaths_train = fpaths[:s]
x_val = data[s:]
y_val = label[s:]
fpaths_test = fpaths[s:]
print(len(x_train),len(y_train),len(x_val),len(y_val)) #800 800 200 200
print(y_val)
#---------------------------------第二步 建立神经网络-----------------------------------
# 定义Placeholder
xs = tf.placeholder(tf.float32, [None, 32, 32, 3]) #每张图片32*32*3个点
ys = tf.placeholder(tf.int32, [None]) #每个样本有1个输出
# 存放DropOut参数的容器
drop = tf.placeholder(tf.float32) #训练时为0.25 测试时为0 # 定义卷积层 conv0
conv0 = tf.layers.conv2d(xs, 20, 5, activation=tf.nn.relu) #20个卷积核 卷积核大小为5 Relu激活
# 定义max-pooling层 pool0
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2]) #pooling窗口为2x2 步长为2x2
print("Layer0:\n", conv0, pool0) # 定义卷积层 conv1
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu) #40个卷积核 卷积核大小为4 Relu激活
# 定义max-pooling层 pool1
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2]) #pooling窗口为2x2 步长为2x2
print("Layer1:\n", conv1, pool1) # 将3维特征转换为1维向量
flatten = tf.layers.flatten(pool1) # 全连接层 转换为长度为400的特征向量
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
print("Layer2:\n", fc) # 加上DropOut防止过拟合
dropout_fc = tf.layers.dropout(fc, drop) # 未激活的输出层
logits = tf.layers.dense(dropout_fc, num_classes)
print("Output:\n", logits) # 定义输出结果
predicted_labels = tf.arg_max(logits, 1)
#---------------------------------第三步 定义损失函数和优化器--------------------------------- # 利用交叉熵定义损失
losses = tf.nn.softmax_cross_entropy_with_logits(
labels = tf.one_hot(ys, num_classes), #将input转化为one-hot类型数据输出
logits = logits) # 平均损失
mean_loss = tf.reduce_mean(losses) # 定义优化器 学习效率设置为0.0001
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(losses)
#------------------------------------第四步 模型训练和预测-----------------------------------
# 用于保存和载入模型
saver = tf.train.Saver()
# 训练或预测
train = False
# 模型文件路径
model_path = "model/image_model" with tf.Session() as sess:
if train:
print("训练模式")
# 训练初始化参数
sess.run(tf.global_variables_initializer())
# 定义输入和Label以填充容器 训练时dropout为0.25
train_feed_dict = {
xs: x_train,
ys: y_train,
drop: 0.25
}
# 训练学习1000次
for step in range(1000):
_, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
if step % 50 == 0: #每隔50次输出一次结果
print("step = {}\t mean loss = {}".format(step, mean_loss_val))
# 保存模型
saver.save(sess, model_path)
print("训练结束,保存模型到{}".format(model_path))
else:
print("测试模式")
# 测试载入参数
saver.restore(sess, model_path)
print("从{}载入模型".format(model_path))
# label和名称的对照关系
label_name_dict = {
0: "人类",
1: "沙滩",
2: "建筑",
3: "公交",
4: "恐龙",
5: "大象",
6: "花朵",
7: "野马",
8: "雪山",
9: "美食"
}
# 定义输入和Label以填充容器 测试时dropout为0
test_feed_dict = {
xs: x_val,
ys: y_val,
drop: 0
} # 真实label与模型预测label
predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):
# 将label id转换为label名
real_label_name = label_name_dict[real_label]
predicted_label_name = label_name_dict[predicted_label]
print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))
# 评价结果
print("正确预测个数:", sum(y_val==predicted_labels_val))
print("准确度为:", 1.0*sum(y_val==predicted_labels_val) / len(y_val))
核心步骤为:
saver = tf.train.Saver()
model_path = "model/image_model"
with tf.Session() as sess:
if train:
#保存神经网络
sess.run(tf.global_variables_initializer())
for step in range(1000):
_, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
if step % 50 == 0:
print("step = {}\t mean loss = {}".format(step, mean_loss_val))
saver.save(sess, model_path)
else:
#载入神经网络
saver.restore(sess, model_path)
predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):
real_label_name = label_name_dict[real_label]
predicted_label_name = label_name_dict[predicted_label]
print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))
预测输出结果如下图所示,最终预测正确181张图片,准确度为0.905。相比之前机器学习KNN的0.500有非常高的提升。

测试模式
INFO:tensorflow:Restoring parameters from model/image_model
从model/image_model载入模型
b'photo/photo/3\\335.jpg' 公交 => 公交
b'photo/photo/1\\129.jpg' 沙滩 => 沙滩
b'photo/photo/7\\740.jpg' 野马 => 野马
b'photo/photo/5\\564.jpg' 大象 => 大象
...
b'photo/photo/9\\974.jpg' 美食 => 美食
b'photo/photo/2\\220.jpg' 建筑 => 公交
b'photo/photo/9\\912.jpg' 美食 => 美食
b'photo/photo/4\\459.jpg' 恐龙 => 恐龙
b'photo/photo/5\\525.jpg' 大象 => 大象
b'photo/photo/0\\44.jpg' 人类 => 人类 正确预测个数: 181
准确度为: 0.905
Tensorflow保存神经网络参数有妙招:Saver和Restore的更多相关文章
- 神经网络参数与TensorFlow变量
在TensorFlow中变量的作用是保存和更新神经网络中的参数,需要给变量指定初始值,如下声明一个2x3矩阵变量 weights =tf.Variable(tf.random_normal([2,3] ...
- Tensorflow学习教程------参数保存和提取重利用
#coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mni ...
- (转)一文学会用 Tensorflow 搭建神经网络
一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...
- 用Tensorflow搭建神经网络的一般步骤
用Tensorflow搭建神经网络的一般步骤如下: ① 导入模块 ② 创建模型变量和占位符 ③ 建立模型 ④ 定义loss函数 ⑤ 定义优化器(optimizer), 使 loss 达到最小 ⑥ 引入 ...
- 一文学会用 Tensorflow 搭建神经网络
http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...
- tensorflow 保存训练模型ckpt 查看ckpt文件中的变量名和对应值
TensorFlow 模型保存与恢复 一个快速完整的教程,以保存和恢复Tensorflow模型. 在本教程中,我将会解释: TensorFlow模型是什么样的? 如何保存TensorFlow模型? 如 ...
- TensorFlow实现超参数调整
TensorFlow实现超参数调整 正如你目前所看到的,神经网络的性能非常依赖超参数.因此,了解这些参数如何影响网络变得至关重要. 常见的超参数是学习率.正则化器.正则化系数.隐藏层的维数.初始权重值 ...
- tensorflow之神经网络实现流程总结
tensorflow之神经网络实现流程总结 1.数据预处理preprocess 2.前向传播的神经网络搭建(包括activation_function和层数) 3.指数下降的learning_rate ...
- Tensorflow搭建神经网络及使用Tensorboard进行可视化
创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...
随机推荐
- K8S系列第九篇(持久化存储,emptyDir、hostPath、PV/PVC)
更多k8s内容,请关注威信公众好:新猿技术生态圈 一.数据持久化 Pod是由容器组成的,而容器宕机或停止之后,数据就随之丢了,那么这也就意味着我们在做Kubernetes集群的时候就不得不考虑存储的问 ...
- Vulhub-Mysql 身份认证绕过漏洞(CVE-2012-2122)
前言 当连接MariaDB/MySQL时,输入的密码会与期望的正确密码比较,由于不正确的处理,会导致即便是memcmp()返回一个非零值,也会使MySQL认为两个密码是相同的.也就是说只要知道用户名, ...
- Upload-labs 文件上传靶场通关攻略(上)
Upload-labs 文件上传靶场通关攻略(上) 文件上传是Web网页中常见的功能之一,通常情况下恶意的文件上传,会形成漏洞. 逻辑是这样的:用户通过上传点上传了恶意文件,通过服务器的校验后保存到指 ...
- 十分钟带你了解CANN应用开发全流程
摘要:CANN作为昇腾AI处理器的发动机,支持业界多种主流的AI框架,包括MindSpore.TensorFlow.Pytorch.Caffe等,并提供1200多个基础算子. 2021年7月8日,第四 ...
- 快速理解VLAN与三层交换机
一.VLAN 1.1.VLAN的概述与优势 VLAN是逻辑隔离的虚拟局域网,作用是分割广播域(分为物理分割和逻辑分割) VLAN的优势:控制广播.增强网络安全性.简化网络管理 1.2.VLAN的种类 ...
- moco模拟接口具体操作
1.get请求 [ { "description": "模拟一个没有参数的get请求", "request": { "uri&qu ...
- 30 个极大提高开发效率超级实用的 VSCode 插件
Visual Studio Code 的插件对于在提升编程效率和加快工作速度非常重要.这里有 30 个最受欢迎的 VSCode 插件,它们将使你成为更高效的搬砖摸鱼大师.这些插件主要适用于前端开发人员 ...
- axios 请求数据跳转页面报'$router' of undefined问题
代码: this.$axios.post("/auth", { 'username': this.username, 'password': this.password }).th ...
- springboot如何使用事物注解方式
1.在启动类Application中添加注解@EnableTransactionManagement import tk.mybatis.spring.annotation.MapperScan; i ...
- vue去掉一些烦人的校验规则
例如:括号前没有加空格报错,很难受 如何处理呢,故意犯错,然后打开页面出现错误信息,如下图复制错误 space-before-function-paren 找到项目中的.eslintrc.js 添加一 ...