TensorFlow(十三):模型的保存与载入
一:保存
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
#保存模型
saver.save(sess,'net/my_net.ckpt')
结果:
Iter 0,Testing Accuracy 0.8252
Iter 1,Testing Accuracy 0.8916
Iter 2,Testing Accuracy 0.9008
Iter 3,Testing Accuracy 0.906
Iter 4,Testing Accuracy 0.9091
Iter 5,Testing Accuracy 0.9104
Iter 6,Testing Accuracy 0.911
Iter 7,Testing Accuracy 0.9127
Iter 8,Testing Accuracy 0.9145
Iter 9,Testing Accuracy 0.9166
Iter 10,Testing Accuracy 0.9177
二:载入
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
saver.restore(sess,'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
结果:
未载入识别率 0.098
INFO:tensorflow:Restoring parameters from net/my_net.ckpt
载入后识别率 0.9177
TensorFlow(十三):模型的保存与载入的更多相关文章
- TensorFlow 模型的保存与载入
参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...
- TensorFlow笔记-模型的保存,恢复,实现线性回归
模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- Tensorflow Learning1 模型的保存和恢复
CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...
- tensorflow:模型的保存和训练过程可视化
在使用tf来训练模型的时候,难免会出现中断的情况.这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始. 保存模型的方法: #之前是各种构建模型graph的操作(矩阵相乘,sig ...
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- 【TensorFlow】TensorFlow基础 —— 模型的保存读取与可视化方法总结
TensorFlow提供了一个用于保存模型的工具以及一个可视化方案 这里使用的TensorFlow为1.3.0版本 一.保存模型数据 模型数据以文件的形式保存到本地: 使用神经网络模型进行大数据量和复 ...
随机推荐
- Linux安装JDK,Tomcat,Mysql+部署项目
安装VMWare虚拟机 下载地址(http://www.onlinedown.net/soft/2062.htm) 安装步骤很简单(除了选择安装路径),傻瓜式安装 同意协议 选择安装路径 安装 完成 ...
- PMBOK项目管理的五大过程组和十大知识领域
PMBOK五大过程组是:启动过程.规划过程.执行过程.监控过程.收尾过程. 各用一句话概括项目管理知识体系五大过程组: 1.启动过程组:作用是设定项目目标,让项目团队有事可做: 2.规划过程组:作用是 ...
- C# 使用Berkeley DB
Berkeley DB是一个开源的文件数据库,介于关系数据库与内存数据库之间.简称BDB Berkeley DB是嵌入式键值数据库库,为应用程序提供可扩展的高性能数据管理服务. Berkeley DB ...
- C#使用消息队列(MSMQ)
最近项目用到消息队列,找资料学习了下.把学习的结果 分享出来 首先说一下,消息队列 (MSMQ Microsoft Message Queuing)是MS提供的服务,也就是Windows操作系统的功能 ...
- shell 实战 -- 基于一个服务启动,关闭,状态检查的脚本
功能说明: check:检查服务状态,在开启,关闭,状态检查时都会用到这个函数,所以封装起来放到最前面 start:开启服务 stop:关闭服务 fstop:强制关闭 status:检查服务状态 ru ...
- eval函数和isNaN函数
(一)eval函数定义:eval() 函数可计算某个字符串,并执行其中的的 JavaScript 代码. (二)语法:eval(string)string必需. (三)返回值:通过计算 string ...
- webstorm编写react native,代码修改后,重新编译运行没有变化的问题
w我是拷贝一份react native代码到另一台电脑,发现修改代码运行之后不显示修改后的效果,即仍然与原来的效果一样,暂时不知道什么原因, 后来我运行了npm install 就可以了,不知道是不是 ...
- Linux下安装php 扩展fileinfo
在项目初始部署环境的时候,可能考虑的并不全面,就会少装一些扩展,这里讲解如何添加fileinfo扩展 1.找到php安装的压缩包 2.将压缩包cp到 /data目录下,并解压 cp php-7.0.3 ...
- HTTP 协议部分常识简介
1.状态码 具体的状态码可以百度查找,但是对于状态码的大致分类有一个清楚的了解 1XX ----信息状态码------接受的请求正在处理 2XX ------成功状态码 ------请求正常处理完 ...
- windows10下安装tensorflow2.0-GPU和Cupy(不用搞CUDA+cudnn)
0.前言 今年暑假买了个1660ti的游戏本学python,后来发现跑一些数据量比较大的代码和深度学习的时候太慢了,遂想装一下GPU版本,看了网上的资料搞了好几天,又是CUDA又是cudnn的,网速慢 ...