一:保存

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
#保存模型
saver.save(sess,'net/my_net.ckpt')

结果:

Iter 0,Testing Accuracy 0.8252
Iter 1,Testing Accuracy 0.8916
Iter 2,Testing Accuracy 0.9008
Iter 3,Testing Accuracy 0.906
Iter 4,Testing Accuracy 0.9091
Iter 5,Testing Accuracy 0.9104
Iter 6,Testing Accuracy 0.911
Iter 7,Testing Accuracy 0.9127
Iter 8,Testing Accuracy 0.9145
Iter 9,Testing Accuracy 0.9166
Iter 10,Testing Accuracy 0.9177

二:载入

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
saver.restore(sess,'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))

结果:

未载入识别率 0.098
INFO:tensorflow:Restoring parameters from net/my_net.ckpt
载入后识别率 0.9177

TensorFlow(十三):模型的保存与载入的更多相关文章

  1. TensorFlow 模型的保存与载入

    参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...

  2. TensorFlow笔记-模型的保存,恢复,实现线性回归

    模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...

  3. Python之TensorFlow的模型训练保存与加载-3

    一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...

  4. Tensorflow Learning1 模型的保存和恢复

    CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...

  5. tensorflow:模型的保存和训练过程可视化

    在使用tf来训练模型的时候,难免会出现中断的情况.这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始. 保存模型的方法: #之前是各种构建模型graph的操作(矩阵相乘,sig ...

  6. tensorflow 之模型的保存与加载(三)

    前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...

  7. tensorflow 之模型的保存与加载(二)

    上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...

  8. tensorflow 之模型的保存与加载(一)

    怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...

  9. 【TensorFlow】TensorFlow基础 —— 模型的保存读取与可视化方法总结

    TensorFlow提供了一个用于保存模型的工具以及一个可视化方案 这里使用的TensorFlow为1.3.0版本 一.保存模型数据 模型数据以文件的形式保存到本地: 使用神经网络模型进行大数据量和复 ...

随机推荐

  1. Sping Aop日志实现(六)--日志查询功能

    最终效果: 日志查询流程分析: Controller代码: Mapper:

  2. jvm--工具

    jps (java process status) == ps / top 作用:显示所有运行中的java进程. jstat 作用:查看类装载,内存,垃圾收集,jit编译的信息. jinfo 作用:实 ...

  3. Spring Cloud 基于Consul 实现配置服务

    Spring Cloud体系中提供了Config组件来进行配置服务管理.而Consul除了提供服务注册与发现功能外,同时也提供配置管理功能.本位将介绍如何结合Spring Cloud + Consul ...

  4. k8s 开源web操作平台

    https://kuboard.cn/install/install-dashboard.html kuborad

  5. 基于SCADA数据驱动的风电机组部件故障预警

    吴亚联 1 , 梁坤鑫 1 , 苏永新 1* , 詹 俊 2(1.湘潭大学 信息工程学院, 湖南 湘潭 411105: 2.湖南优利泰克自动化系统有限公司, 湖南 长沙 410205) 摘 要: 为提 ...

  6. Android NDK 学习之接受Java传入的Int数组

    本博客主要是在Ubuntu 下开发,且默认你已经安装了Eclipse,Android SDK, Android NDK, CDT插件. 在Eclipse中添加配置NDK,路径如下Eclipse-> ...

  7. 【完整篇】orangepi香橙派新手入门之被官方坑

    图片特意缩小,看不清请打开另一个窗口查看原图. 第一步:烧录系统,我烧录的是Ubuntu_Desktop[请注意!!!!用户名是错的!!用户名是错的!!用户名是错的!!] 正确的用户名是orangep ...

  8. OpenStack kilo版(5) Neutron部署

    neutron简介: Neutron 通过 plugin 和 agent 提供的网络服务. plugin 位于 Neutron server,包括 core plugin 和 service plug ...

  9. Nginx作为缓存服务

    缓存类型 (1) 服务器缓存 服务端缓存一般使用Memcache.Redis (2)代理缓存 (3)客户端缓存 代理缓存流程图 第一步:客户端第一次向Nginx请求数据a: 第二步:当Nginx发现缓 ...

  10. outlook配置其他邮箱登录如qq邮箱或登录无邮件信息记录

    今天加班想想自己outlook还没登登录过,于是想着登录一下outlook方便管理邮箱信息,才发现原来登录邮箱都要配置,感觉真是醉了.下面开始正式的配置流程. 选择添加账户 首先,点击文件选择账户设置 ...