TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例
TensorFlow之单层(全连接层)实现手写数字识别训练及测试实例:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('is_train',1,'指定程序是预测还是训练') def full_connected():
# 获取真实的数据
mnist = input_data.read_data_sets('./data/mnists/', one_hot=True) # 1.建立数据的占位符 X [None,784] y_true [None,10]
# 创建一个作用域
with tf.variable_scope('data'):
# 特征值
x = tf.placeholder(tf.float32, [None, 784]) # 目标值(真实值)
y_true = tf.placeholder(tf.int32, [None, 10]) # 2. 建立一个全连接层的神经网络 W [784,10] b [10]
with tf.variable_scope('fc_model'):
# 随机初始化权重和偏置
weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w') bias = tf.Variable(tf.constant(0.0, shape=[10])) # 预测None个样本的输出结果matrix [None,784]*[784,10]+[10] = [None,10]
y_predict = tf.matmul(x, weight) + bias # 3. 求出所有样本的损失,然后求平均值
with tf.variable_scope('soft_cross'):
# 求平均交叉熵损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 4. 梯度下降求出损失(优化)
with tf.variable_scope('optimizer'):
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 5. 计算准确率
with tf.variable_scope('acc'):
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1)) # equal_list None个样本 [1,0,1,0,1,1....]
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32)) # 收集变量,单个数字值收集
tf.summary.scalar('losses',loss)
tf.summary.scalar('acc',accuracy) # 高纬度变量收集
tf.summary.histogram('weightes',weight)
tf.summary.histogram('biases',bias) # 定义一个初始化变量的op
init_op = tf.global_variables_initializer() # 定义一个合并变量的op
merged = tf.summary.merge_all() # 创建一个saver
saver = tf.train.Saver() # 开启会话去训练
with tf.Session() as sess:
# 初始化变量
sess.run(init_op) # 建立events文件,然后写入
filewriter = tf.summary.FileWriter('./tmp/summary/test/',graph=sess.graph) if FLAGS.is_train == 1:
# 迭代步数去训练,更新参数预测
for i in range(2000):
# 取出真实存在的特征值和目标值
mnist_x, mnist_y = mnist.train.next_batch(50) # 运行train_op训练
sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y}) # 写入每步训练的值
summary = sess.run(merged,feed_dict={x: mnist_x, y_true: mnist_y}) filewriter.add_summary(summary,i) print('训练第%d步,准确率为:%f' % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y}))) # 保存模型
saver.save(sess,'./tmp/summary/model/fc_model')
else:
# 加载模型
saver.restore(sess,'./tmp/summary/model/fc_model') # 如果是0,做出预测
for i in range(100): # 每次测试一张图片,[0,0,0,0,0,1,0,0,0]
x_test,y_test = mnist.test.next_batch(1) print('第%d章图片,手写数字目标是:%d,预测结果是:%d' % (
i,
tf.argmax(y_test,1).eval(),
tf.argmax(sess.run(y_predict,feed_dict={x: x_test,y_true: y_test}),1).eval()
)) return None if __name__ == '__main__':
full_connected()
TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例的更多相关文章
- 5 TensorFlow入门笔记之RNN实现手写数字识别
		
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
 - Tensorflow项目实战一:MNIST手写数字识别
		
此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...
 - TensorFlow(十):卷积神经网络实现手写数字识别以及可视化
		
上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
 - Tensorflow手写数字识别训练(梯度下降法)
		
# coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...
 - TensorFlow卷积神经网络实现手写数字识别以及可视化
		
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
 - [Python]基于CNN的MNIST手写数字识别
		
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
 - Tensorflow2.0-mnist手写数字识别示例
		
Tensorflow2.0-mnist手写数字识别示例 读书不觉春已深,一寸光阴一寸金. 简介:通过CNN 卷积神经网络训练后识别出手写图片,测试图片mnist数据集中的0.1.2.4. ...
 - 深度学习之PyTorch实战(3)——实战手写数字识别
		
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
 - 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
		
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
 
随机推荐
- 谈谈对final的理解
			
1.final修饰类 类不能被继承,类中的所有方法都是final的 2.final修饰方法 方法不能被覆盖,private修饰的方法隐性的添加了final 3.final修饰方法内参数 方法内的参数不 ...
 - mocha测试es6问题
			
平时在写完正常的逻辑代码后,需要使用单元测试去测试逻辑代码,现在比较流行的是使用mocha进行测试 现在都是使用es6的写法,如果直接使用mocha test\某个文件,会出现下面的错误,原因是因为m ...
 - YII2 源码阅读 综述
			
如何阅读源码呢? 我的方法是,打开xdebug的auto_trace [XDebug] ;xdebug.profiler_append = 0 ;xdebug.profiler_enable = 1 ...
 - CSS3主要的几个样式笔记
			
1.边框:border-color: 设置对象边框的颜色. 使用CSS3的border-radius属性,如果你设置了border的宽度是X px,那么你就可以在这个border上使用X ...
 - 深入解析php中的foreach问题
			
本篇文章是对php中的foreach问题进行了详细的分析介绍,需要的朋友参考下 前言:php4中引入了foreach结构,这是一种遍历数组的简单方式.相比传统的for循环,foreach能够更加便 ...
 - Flume学习应用:Java写日志数据到MongoDB
			
概述 Windows平台:Java写日志到Flume,Flume最终把日志写到MongoDB. 系统环境 操作系统:win7 64 JDK:1.6.0_43 资源下载 Maven:3.3.3下载.安装 ...
 - Python开发基础-Day7-闭包函数和装饰器基础
			
补充:全局变量声明及局部变量引用 python引用变量的顺序: 当前作用域局部变量->外层作用域变量->当前模块中的全局变量->python内置变量 global关键字用来在函数或其 ...
 - PHP单例类
			
单例模式按字面来看就是某一个类只有一个实例,这样做的好处还是很大的,比如说数据库的连接,我们只需要实例化一次,不需要每次都去new了,这样极大的降低了资源的耗费. 单例类至少拥有以下三种公共元素: 必 ...
 - 【BZOJ 3175】 3175: [Tjoi2013]攻击装置(二分图匹配)
			
3175: [Tjoi2013]攻击装置 Description 给定一个01矩阵,其中你可以在0的位置放置攻击装置.每一个攻击装置(x,y)都可以按照“日”字攻击其周围的 8个位置(x-1,y-2) ...
 - 初雪-Diary?
			
who care ------------2018 11 6-------------- 终于AK一场啦 ------------2018 10 18-------------- 嗯....今天T2多 ...