TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例
TensorFlow之单层(全连接层)实现手写数字识别训练及测试实例:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('is_train',1,'指定程序是预测还是训练') def full_connected():
# 获取真实的数据
mnist = input_data.read_data_sets('./data/mnists/', one_hot=True) # 1.建立数据的占位符 X [None,784] y_true [None,10]
# 创建一个作用域
with tf.variable_scope('data'):
# 特征值
x = tf.placeholder(tf.float32, [None, 784]) # 目标值(真实值)
y_true = tf.placeholder(tf.int32, [None, 10]) # 2. 建立一个全连接层的神经网络 W [784,10] b [10]
with tf.variable_scope('fc_model'):
# 随机初始化权重和偏置
weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w') bias = tf.Variable(tf.constant(0.0, shape=[10])) # 预测None个样本的输出结果matrix [None,784]*[784,10]+[10] = [None,10]
y_predict = tf.matmul(x, weight) + bias # 3. 求出所有样本的损失,然后求平均值
with tf.variable_scope('soft_cross'):
# 求平均交叉熵损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 4. 梯度下降求出损失(优化)
with tf.variable_scope('optimizer'):
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 5. 计算准确率
with tf.variable_scope('acc'):
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1)) # equal_list None个样本 [1,0,1,0,1,1....]
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32)) # 收集变量,单个数字值收集
tf.summary.scalar('losses',loss)
tf.summary.scalar('acc',accuracy) # 高纬度变量收集
tf.summary.histogram('weightes',weight)
tf.summary.histogram('biases',bias) # 定义一个初始化变量的op
init_op = tf.global_variables_initializer() # 定义一个合并变量的op
merged = tf.summary.merge_all() # 创建一个saver
saver = tf.train.Saver() # 开启会话去训练
with tf.Session() as sess:
# 初始化变量
sess.run(init_op) # 建立events文件,然后写入
filewriter = tf.summary.FileWriter('./tmp/summary/test/',graph=sess.graph) if FLAGS.is_train == 1:
# 迭代步数去训练,更新参数预测
for i in range(2000):
# 取出真实存在的特征值和目标值
mnist_x, mnist_y = mnist.train.next_batch(50) # 运行train_op训练
sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y}) # 写入每步训练的值
summary = sess.run(merged,feed_dict={x: mnist_x, y_true: mnist_y}) filewriter.add_summary(summary,i) print('训练第%d步,准确率为:%f' % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y}))) # 保存模型
saver.save(sess,'./tmp/summary/model/fc_model')
else:
# 加载模型
saver.restore(sess,'./tmp/summary/model/fc_model') # 如果是0,做出预测
for i in range(100): # 每次测试一张图片,[0,0,0,0,0,1,0,0,0]
x_test,y_test = mnist.test.next_batch(1) print('第%d章图片,手写数字目标是:%d,预测结果是:%d' % (
i,
tf.argmax(y_test,1).eval(),
tf.argmax(sess.run(y_predict,feed_dict={x: x_test,y_true: y_test}),1).eval()
)) return None if __name__ == '__main__':
full_connected()
TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例的更多相关文章
- 5 TensorFlow入门笔记之RNN实现手写数字识别
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- Tensorflow项目实战一:MNIST手写数字识别
此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...
- TensorFlow(十):卷积神经网络实现手写数字识别以及可视化
上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- Tensorflow手写数字识别训练(梯度下降法)
# coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- Tensorflow2.0-mnist手写数字识别示例
Tensorflow2.0-mnist手写数字识别示例 读书不觉春已深,一寸光阴一寸金. 简介:通过CNN 卷积神经网络训练后识别出手写图片,测试图片mnist数据集中的0.1.2.4. ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
随机推荐
- Java基础:类加载机制
之前的<java基础:内存模型>当中,我们大体了解了在java当中,不同类型的信息,都存放于java当中哪个部位当中,那么有了对于堆.栈.方法区.的基本理解以后,今天我们来好好剖析一下,j ...
- ubuntu 安装mysql及目录位置
安装 sudo apt-get install MySQL-server mysql-client 查看安装端口情况 sudo netstat -tap | grep mysql 配置文件位置 sud ...
- 视频H5のVideo标签在微信里的坑和技巧
随着 4G 的普遍以及 WiFi 的广泛使用,手机上的网速已经足够稳定和高速,以视频为主的 HTML5 也越来越普遍了,相比帧动画,视频的表现更加丰富,前段时间开发了一个以视频为主的移动端 HTML5 ...
- Apache配置文件相关命令
转:http://www.365mini.com/page/apache-options-directive.htm Options指令是Apache配置文件中一个比较常见也比较重要的指令,Optio ...
- 货币金额javascript正则表达式
最多保留两位小数,货币金额(不能为0): /^(([1-9]\d*)(\.\d{1,2})?)$|^(0\.0?([1-9]\d?))$/
- 【linux入门必备】小白需要掌握的基础知识_不定期更新
因为博主对linux掌握暂时不需要太精通,遇到一个记录一个. 零碎 知识点: 杂类常用命令: 模糊匹配补齐 TAB 系统相关命令: 查阅手册 man 更新软件 sudo apt-get update ...
- Maven实用总结
使用Maven还是推荐IDEA,以前用eclipse总是喜欢出现乱七八糟的问题,具体错误和解决方案也记不清楚了. 下面总结下IDEA中遇到的问题和解决方法: 与IDEA搭配的相关问题 如何根据模板快速 ...
- maven "Generating project in Batch mode"问题的解决
在maven的五分钟入门里面,有这样一个命令: mvn archetype:generate -DgroupId=com.mycompany.app -DartifactId=my-app -Darc ...
- 解决Linux用户模板文件被删除后显示不正常问题
缺失用户模板文件(用户骨架文件)会导致shell提示符不完整,可以到/etc/skel/目录下复制相关文件来恢复 (1).创建测试环境,删除模板文件 [root@xuexi ~]# useradd t ...
- NGUI的异步场景加载进度条
1.直接创建三个场景,其中第二个场景是用来显示进度条加载的界面,进度条用UISlider,不会的看我前面的博文就可以了. 2.这里提供两种方法,建议使用第一种,加载比较平缓 方法一: using Sy ...