TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例
TensorFlow之单层(全连接层)实现手写数字识别训练及测试实例:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('is_train',1,'指定程序是预测还是训练') def full_connected():
# 获取真实的数据
mnist = input_data.read_data_sets('./data/mnists/', one_hot=True) # 1.建立数据的占位符 X [None,784] y_true [None,10]
# 创建一个作用域
with tf.variable_scope('data'):
# 特征值
x = tf.placeholder(tf.float32, [None, 784]) # 目标值(真实值)
y_true = tf.placeholder(tf.int32, [None, 10]) # 2. 建立一个全连接层的神经网络 W [784,10] b [10]
with tf.variable_scope('fc_model'):
# 随机初始化权重和偏置
weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w') bias = tf.Variable(tf.constant(0.0, shape=[10])) # 预测None个样本的输出结果matrix [None,784]*[784,10]+[10] = [None,10]
y_predict = tf.matmul(x, weight) + bias # 3. 求出所有样本的损失,然后求平均值
with tf.variable_scope('soft_cross'):
# 求平均交叉熵损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 4. 梯度下降求出损失(优化)
with tf.variable_scope('optimizer'):
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 5. 计算准确率
with tf.variable_scope('acc'):
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1)) # equal_list None个样本 [1,0,1,0,1,1....]
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32)) # 收集变量,单个数字值收集
tf.summary.scalar('losses',loss)
tf.summary.scalar('acc',accuracy) # 高纬度变量收集
tf.summary.histogram('weightes',weight)
tf.summary.histogram('biases',bias) # 定义一个初始化变量的op
init_op = tf.global_variables_initializer() # 定义一个合并变量的op
merged = tf.summary.merge_all() # 创建一个saver
saver = tf.train.Saver() # 开启会话去训练
with tf.Session() as sess:
# 初始化变量
sess.run(init_op) # 建立events文件,然后写入
filewriter = tf.summary.FileWriter('./tmp/summary/test/',graph=sess.graph) if FLAGS.is_train == 1:
# 迭代步数去训练,更新参数预测
for i in range(2000):
# 取出真实存在的特征值和目标值
mnist_x, mnist_y = mnist.train.next_batch(50) # 运行train_op训练
sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y}) # 写入每步训练的值
summary = sess.run(merged,feed_dict={x: mnist_x, y_true: mnist_y}) filewriter.add_summary(summary,i) print('训练第%d步,准确率为:%f' % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y}))) # 保存模型
saver.save(sess,'./tmp/summary/model/fc_model')
else:
# 加载模型
saver.restore(sess,'./tmp/summary/model/fc_model') # 如果是0,做出预测
for i in range(100): # 每次测试一张图片,[0,0,0,0,0,1,0,0,0]
x_test,y_test = mnist.test.next_batch(1) print('第%d章图片,手写数字目标是:%d,预测结果是:%d' % (
i,
tf.argmax(y_test,1).eval(),
tf.argmax(sess.run(y_predict,feed_dict={x: x_test,y_true: y_test}),1).eval()
)) return None if __name__ == '__main__':
full_connected()
TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例的更多相关文章
- 5 TensorFlow入门笔记之RNN实现手写数字识别
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- Tensorflow项目实战一:MNIST手写数字识别
此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...
- TensorFlow(十):卷积神经网络实现手写数字识别以及可视化
上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- Tensorflow手写数字识别训练(梯度下降法)
# coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- Tensorflow2.0-mnist手写数字识别示例
Tensorflow2.0-mnist手写数字识别示例 读书不觉春已深,一寸光阴一寸金. 简介:通过CNN 卷积神经网络训练后识别出手写图片,测试图片mnist数据集中的0.1.2.4. ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
随机推荐
- LeetCode解题报告—— Sum Root to Leaf Numbers & Surrounded Regions & Single Number II
1. Sum Root to Leaf Numbers Given a binary tree containing digits from 0-9 only, each root-to-leaf p ...
- redis之(一)redis的简单介绍
[一]:概念 --->Redis是一个开源的,高性能的,基于键值对的缓存与存储系统 --->Redis数据库中的多有数据都存储在内存中,由于内存的读写速度远快于硬盘,一秒读写超过10万键值 ...
- Longest Valid Parentheses——仍然需要认真看看(动态规划)
Given a string containing just the characters '(' and ')', find the length of the longest valid (wel ...
- [vim]使用中问题
bug1: vim文档中文注释为乱码 step1: vim /var/lib/locales/supported.d/local 在其中添加下面的中文字符集 zh_CN.GBK GBK zh_CN.G ...
- 利用CSS3伪类做3D按钮
这是通过css3伪类实现的3d按钮,html代码为: <div id="container_buttons"> <p><a class="a ...
- bzoj 2938 AC自动机 + dfs判环
#include<bits/stdc++.h> #define LL long long #define ll long long #define fi first #define se ...
- ubuntu wine 使用
运行程序 wine xxx.exe 图形界面程序(普通程序):直接使用 wine 命令行的DOS程序:wineconsole 代替 wine.这才是正常的运行方式.不使用wineconsole运行命令 ...
- cocos-js Http方式网络请求
(转http://blog.csdn.net/sinat_28338727/article/details/52804167) 网络结构 网络结构是网络的构建方式,目前流行的有客户端服务器结构网络和点 ...
- thinkphp join 表前缀
public function get_user_group_title($uid){ $pre = C('DB_PREFIX'); $res = M('AuthGroupAccess aga')-& ...
- 用Logger来解释拦截
HZ 动态代理学了 不知道在工作中杂用哦 HE 现在一般不会直接用吧,一般都是用aspectJ这种完整aop的实现 STST 拦截方法调用 HZ 我见过把所有accessor方法放到切面的 还有tra ...