import sys,os
sys.path.append(os.pardir)
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
import tensorflow as tf def predict():
meta_path = 'ckpt/mnist.ckpt.meta'
model_path = 'ckpt/mnist.ckpt'
sess = tf.InteractiveSession ()
saver = tf.train.import_meta_graph (meta_path)
saver.restore (sess, model_path)
graph = tf.get_default_graph ()
W = graph.get_tensor_by_name ("w:0")
b = graph.get_tensor_by_name ("b:0")
x = tf.placeholder (tf.float32, [None, 784])
y = tf.nn.softmax (tf.matmul (x, W) + b)
keep_prob = tf.placeholder (tf.float32)
batch_xs, batch_ys=mnist.train.next_batch (100)
one_img = batch_xs[0].reshape ((1, 784))
one_num = batch_ys[0].reshape ((1, 10))
temp = sess.run (y, feed_dict={x: one_img, keep_prob: 1.0})
b = sess.run (tf.argmax (temp, 1))
a = sess.run (tf.arg_max (one_num, 1))
print(temp)
print(one_num)
if b == a:
print ("success! the num is :", (b[0]))
showImgTest(one_img)
else:
print ("mistakes predict.") def trainNet():
x = tf.placeholder (tf.float32, [None, 784])
W = tf.Variable (tf.zeros ([784, 10]),name="w")
b = tf.Variable (tf.zeros ([10]),name="b")
y = tf.nn.softmax (tf.matmul (x, W) + b)
y_ = tf.placeholder (tf.float32, [None, 10])
keep_prob = tf.placeholder (tf.float32)
# 定义测试的准确率
correct_prediction = tf.equal (tf.argmax (y, 1), tf.argmax (y_, 1))
accuracy = tf.reduce_mean (tf.cast (correct_prediction, tf.float32))
#
saver = tf.train.Saver (max_to_keep=1)
max_acc = 0
train_accuracy = 0
#交叉熵
cross_entropy = tf.reduce_mean (-tf.reduce_sum (y_ * tf.log (y)))
# cross_error=cross_entropy_error_batch(y,y_)
train_step = tf.train.GradientDescentOptimizer (0.01).minimize (cross_entropy)
sess = tf.InteractiveSession ()
tf.global_variables_initializer ().run ()
for i in range (1000):
batch_xs, batch_ys = mnist.train.next_batch (100)
sess.run (train_step, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 1.0})
if i % 100 == 0:
train_accuracy = accuracy.eval (feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 1.0})
print ("step %d, training accuracy %g" % (i, train_accuracy))
if train_accuracy > max_acc:
max_acc = train_accuracy
saver.save (sess, 'ckpt/mnist.ckpt') if __name__ == '__main__':
mnist = input_data.read_data_sets ("MNIST_data/", one_hot=True)
choice=0
while choice == 0:
print ("------------------------tensorflow--------------------------")
print ("\t\t\t1\ttrain model..")
print("\t\t\t2\tpredict model")
print("\t\t\t3\tshow the first image")
print ("\t\t\t0\texit")
choice = input ("please input your choice!")
print(choice)
if choice == "1":
print("start train...")
trainNet()
if choice=="2":
predict()
if choice=="3":
showImg()

注:正在学习CNN,选项4还没有来的及做。后面补上

使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】的更多相关文章

  1. TensorFlow学习笔记(三)MNIST数字识别问题

    一.MNSIT数据处理 MNSIT是一个非常有名的手写体数字识别数据集.包含60000张训练图片,10000张测试图片.每张图片是28X28的数字. TonserFlow提供了一个类来处理 MNSIT ...

  2. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  3. MNIST数字识别问题

    摘自<Tensorflow:实战Google深度学习框架> import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

  4. Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)

    目录 1. 准备数据集 1.1 MNIST数据集获取: 1.2 程序部分 2. 设计网络结构 2.1 网络设计 2.2 程序部分 3. 迭代训练 4. 测试集预测部分 5. 全部代码 1. 准备数据集 ...

  5. Tensorflow手写数字识别训练(梯度下降法)

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...

  6. TensorFlow深度学习实战---MNIST数字识别问题

    1.滑动平均模型: 用途:用于控制变量的更新幅度,使得模型在训练初期参数更新较快,在接近最优值处参数更新较慢,幅度较小 方式:主要通过不断更新衰减率来控制变量的更新幅度. 衰减率计算公式 : deca ...

  7. 吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  8. Tensorflow手写数字识别---MNIST

    MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000

  9. Tensorflow手写数字识别(交叉熵)练习

    # coding: utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #pr ...

随机推荐

  1. MDRT_<>$表

    数据库中有好多的MDRT打头的表,而这些表的字段都是一样的,那这些表是做什么用呢? MDRT_<>$: 用来存储与空间索引相关的信息.这些表与常规表不一样,不能做复制,删除,新建等.如果对 ...

  2. 基于easyui开发Web版Activiti流程定制器详解(四)——页面结构(下)

    题外话: 这两天周末在家陪老婆和儿子没上来更新请大家见谅!上一篇介绍了调色板和画布区的页面结构,这篇讲解一下属性区的结构也是定制器最重要的一个页面. 属性区整体页面结构如图:  在这个区域可以定义工作 ...

  3. 20165318 2017-2018-2 《Java程序设计》第二周学习总结

    20165318 2017-2018-2 <Java程序设计>第二周学习总结 教材学习内容总结 本周学习了第二章和第三章的内容,掌握了Java中基本数据类型.数组.运算符.表达式和语句等方 ...

  4. macaca自动化初体验

     1.安装#cnpm i -g macaca-cli macaca-ios# Install Tools And Driver2.检查安装环境#macaca doctor ANT_HOME未设置,下载 ...

  5. Docker实战(一)之使用Docker镜像

    镜像是Docker三大核心概念中最为重要的,自Docker诞生之日起“镜像”就是相关社区最为热门的关键字. Docker运行容器前需要本地存在对应的镜像,如果镜像没有保存至本地,Docker会尝试先从 ...

  6. Kafka设计解析(六)Kafka高性能架构之道

    转载自 技术世界,原文链接 Kafka设计解析(六)- Kafka高性能架构之道 本文从宏观架构层面和微观实现层面分析了Kafka如何实现高性能.包含Kafka如何利用Partition实现并行处理和 ...

  7. docker - kubernetes 网络(转)+ 架构图

    1.host网络 连接到 host 网络的容器共享 Docker host 的网络栈,容器的网络配置与 host 完全一样.可以通过--network=host指定使用 host 网络.docker ...

  8. Apache24 + php5.6.31 +Sql server R2 环境搭建①

    win8(7)x64系统下 :PHP5.5.15 + Apache2.4.10 + SQL server 2008 R2  的配置方法分享给大家,32位的同理,不过下载的软件需要也是32位的. 好久未 ...

  9. TarsGo新版本发布,支持protobuf,zipkin和自定义插件

    本文作者:陈明杰(sandyskies) Tars是腾讯从2008年到今天一直在使用的后台逻辑层的统一应用框架,目前支持C++,Java,PHP,Nodejs,Golang语言.该框架为用户提供了涉及 ...

  10. 学习java前端 两种form表单提交方式

    第一种:原生方式 注意点:button标签的style为submit <form action="/trans/doTrans.do" method="post&q ...