TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例
TensorFlow之单层(全连接层)实现手写数字识别训练及测试实例:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('is_train',1,'指定程序是预测还是训练') def full_connected():
# 获取真实的数据
mnist = input_data.read_data_sets('./data/mnists/', one_hot=True) # 1.建立数据的占位符 X [None,784] y_true [None,10]
# 创建一个作用域
with tf.variable_scope('data'):
# 特征值
x = tf.placeholder(tf.float32, [None, 784]) # 目标值(真实值)
y_true = tf.placeholder(tf.int32, [None, 10]) # 2. 建立一个全连接层的神经网络 W [784,10] b [10]
with tf.variable_scope('fc_model'):
# 随机初始化权重和偏置
weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w') bias = tf.Variable(tf.constant(0.0, shape=[10])) # 预测None个样本的输出结果matrix [None,784]*[784,10]+[10] = [None,10]
y_predict = tf.matmul(x, weight) + bias # 3. 求出所有样本的损失,然后求平均值
with tf.variable_scope('soft_cross'):
# 求平均交叉熵损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 4. 梯度下降求出损失(优化)
with tf.variable_scope('optimizer'):
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 5. 计算准确率
with tf.variable_scope('acc'):
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1)) # equal_list None个样本 [1,0,1,0,1,1....]
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32)) # 收集变量,单个数字值收集
tf.summary.scalar('losses',loss)
tf.summary.scalar('acc',accuracy) # 高纬度变量收集
tf.summary.histogram('weightes',weight)
tf.summary.histogram('biases',bias) # 定义一个初始化变量的op
init_op = tf.global_variables_initializer() # 定义一个合并变量的op
merged = tf.summary.merge_all() # 创建一个saver
saver = tf.train.Saver() # 开启会话去训练
with tf.Session() as sess:
# 初始化变量
sess.run(init_op) # 建立events文件,然后写入
filewriter = tf.summary.FileWriter('./tmp/summary/test/',graph=sess.graph) if FLAGS.is_train == 1:
# 迭代步数去训练,更新参数预测
for i in range(2000):
# 取出真实存在的特征值和目标值
mnist_x, mnist_y = mnist.train.next_batch(50) # 运行train_op训练
sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y}) # 写入每步训练的值
summary = sess.run(merged,feed_dict={x: mnist_x, y_true: mnist_y}) filewriter.add_summary(summary,i) print('训练第%d步,准确率为:%f' % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y}))) # 保存模型
saver.save(sess,'./tmp/summary/model/fc_model')
else:
# 加载模型
saver.restore(sess,'./tmp/summary/model/fc_model') # 如果是0,做出预测
for i in range(100): # 每次测试一张图片,[0,0,0,0,0,1,0,0,0]
x_test,y_test = mnist.test.next_batch(1) print('第%d章图片,手写数字目标是:%d,预测结果是:%d' % (
i,
tf.argmax(y_test,1).eval(),
tf.argmax(sess.run(y_predict,feed_dict={x: x_test,y_true: y_test}),1).eval()
)) return None if __name__ == '__main__':
full_connected()
TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例的更多相关文章
- 5 TensorFlow入门笔记之RNN实现手写数字识别
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- Tensorflow项目实战一:MNIST手写数字识别
此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...
- TensorFlow(十):卷积神经网络实现手写数字识别以及可视化
上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- Tensorflow手写数字识别训练(梯度下降法)
# coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- Tensorflow2.0-mnist手写数字识别示例
Tensorflow2.0-mnist手写数字识别示例 读书不觉春已深,一寸光阴一寸金. 简介:通过CNN 卷积神经网络训练后识别出手写图片,测试图片mnist数据集中的0.1.2.4. ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
随机推荐
- python 几种循环性能测试: while, for, 列表生成式, map等
直接上代码: #!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2018/07/24 16:23 import itertools imp ...
- LoadRunner中的C Vuser函数
LoadRunner中的C Vuser函数 事务函数: lr_end_sub_transaction 标记子事务的结束以便进行性能分析. lr_end_transaction 标记事务的结束. ...
- Centos7 Elasticsearch部署
(1)ELKStack简介 1.elk介绍 ELK Stack包含:ElasticSearch.Logstash.Kibana ElasticSearch是一个搜索引擎,用来搜索.分析.存储日志.它是 ...
- Ubuntu 16.04 LTS安装Docker并使用加速器
参考优酷:http://v.youku.com/v_show/id_XMTkxOTYwODcxNg==.html?spm=a2h0k.8191407.0.0&from=s1.8-1-1.2 首 ...
- bootstrap插件学习-bootstrap.tab.js(读码)
先看bootstrap-tab.js的结构 var Tab = function ( element ) {} //构造器 Tab.prototype ={} //构造器的原型 $.fn.tab = ...
- thinkphp 检测上传的图片中是否含有木马脚本
1.检测原理 要想检测图片中是否含有木马脚本,首先从制作原理来分析这种木马程序.这种木马程序是十六进制编码写的,图片的十六进制代码中主要包含<% ( ) %>.<? ( ) ?> ...
- 经典算法-最长公共子序列(LCS)与最长公共子串(DP)
public static int lcs(String str1, String str2) { int len1 = str1.length(); int len2 = str2.length() ...
- Shiro的认证原理(Subject#login的背后故事)
登录操作一般都是我们触发的: Subject subject = SecurityUtils.getSubject(); AuthenticationToken authenticationToken ...
- shell 倒引号
`command` 倒引号 (backticks) 在前面的单双引号,括住的是字串,但如果该字串是一列命令列,会怎样?答案是不会执行.要处理这种情况,我们得用倒单引号来做. fdv=`date +%F ...
- 【BZOJ 4572】【SCOI 2016】围棋
http://www.lydsy.com/JudgeOnline/problem.php?id=4572 轮廓线DP:设\(f(i,j,S,x,y)\). \(S\)表示\((i,1)\)到\((i, ...