TensorFlow(十三):模型的保存与载入
一:保存
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
#保存模型
saver.save(sess,'net/my_net.ckpt')
结果:
Iter 0,Testing Accuracy 0.8252
Iter 1,Testing Accuracy 0.8916
Iter 2,Testing Accuracy 0.9008
Iter 3,Testing Accuracy 0.906
Iter 4,Testing Accuracy 0.9091
Iter 5,Testing Accuracy 0.9104
Iter 6,Testing Accuracy 0.911
Iter 7,Testing Accuracy 0.9127
Iter 8,Testing Accuracy 0.9145
Iter 9,Testing Accuracy 0.9166
Iter 10,Testing Accuracy 0.9177
二:载入
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
saver.restore(sess,'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
结果:
未载入识别率 0.098
INFO:tensorflow:Restoring parameters from net/my_net.ckpt
载入后识别率 0.9177
TensorFlow(十三):模型的保存与载入的更多相关文章
- TensorFlow 模型的保存与载入
参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...
- TensorFlow笔记-模型的保存,恢复,实现线性回归
模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- Tensorflow Learning1 模型的保存和恢复
CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...
- tensorflow:模型的保存和训练过程可视化
在使用tf来训练模型的时候,难免会出现中断的情况.这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始. 保存模型的方法: #之前是各种构建模型graph的操作(矩阵相乘,sig ...
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- 【TensorFlow】TensorFlow基础 —— 模型的保存读取与可视化方法总结
TensorFlow提供了一个用于保存模型的工具以及一个可视化方案 这里使用的TensorFlow为1.3.0版本 一.保存模型数据 模型数据以文件的形式保存到本地: 使用神经网络模型进行大数据量和复 ...
随机推荐
- unittest之makeSuite\testload\discover及测试报告teseReport
转载:http://www.cnblogs.com/sunny0/p/7771089.html 测试套件suite除了使用addTest以外,还有使用操作起来更更简便的makeSuite\testlo ...
- Java基础IO类之数据流
DataInputStream: 数据输入流允许应用程序以与机器无关方式从底层输入流中读取基本java数据类型.应用程序可以使用数据输出流 写入稍后由数据输入流读取的数据.DataInputStrea ...
- webpack 3.1 升级webpack 4.0
webpack 3.1 升级webpack 4.0 为了提升打包速度以及跟上主流技术步伐,前段时间把项目的webpack 升级到4.0版本以上 webpack 官网:https://webpack.j ...
- OpenCvSharp 图像旋转
/// <summary> /// 图像旋转 /// </summary> private Mat MatRotate(Mat src, float angle) { Mat ...
- Spark机器学习基础-特征工程
对连续值处理 0.binarizer/二值化 from __future__ import print_function from pyspark.sql import SparkSession fr ...
- java第一次笔试+面试总结
今天是自己第一次java笔试和面试,总体感觉比预期好一点. 笔试题第一面是问答题,主要考查java基础,一共有18题,我有6道题没有写出来.第二面主要是算法题,一共有8道题,我大概写出来4道题,第三面 ...
- 安装HANA Rules Framework(HRF)
1. 收集文档 1.1 SAP HANA Rules Framework by the SAP HANA Academy link 1.2 HANA Rules Framework (HRF) b ...
- [LeetCode] 76. 最小覆盖子串 ☆☆☆☆☆(滑动窗口)
https://leetcode-cn.com/problems/minimum-window-substring/solution/hua-dong-chuang-kou-suan-fa-tong- ...
- 前端框架开始学习Vue(三)
初步安装.与搭建 https://www.cnblogs.com/yanxulan/p/8978732.html ----如何搭建一个vue项目 安装 nodejs,,, npm i == np ...
- C++——同名隐藏 和 赋值兼容规则
同名隐藏 一旦子类定义了与父类同名的方法,则父类里面该名字的所有方法都被隐藏了.必须显示指定是父类的方法才可以 #include<iostream> using namespace std ...