本章已机器学习领域的Hello World任务----MNIST手写识别做为TensorFlow的开始。MNIST是一个非常简单的机器视觉数据集,是由几万张28像素*28像素的手写数字组成,这些图片只包含灰度值信息。

下面提取了784维的特征,也就是2828个点展开成一维的结果,所以训练数据是一个55000784的Tensor,label是一个55000*10的tensor。当我们处理多分类任务时,通常需要使用Softmax Regression模型。它的工作原理很简单,将可以判定为某类的特征相加,然后将这些特征转化为判定是这一类的概率。其本质就是多类别逻辑回归。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data",one_hot=True)#从TensorFlow读取数据 print (mnist.train.images.shape,mnist.train.labels.shape)
print (mnist.test.images.shape,mnist.test.labels.shape)
print (mnist.validation.images.shape,mnist.validation.labels.shape) import tensorflow as tf
sess = tf.InteractiveSession()#创建一个session,之后的运算都在这个session里,不同session的数据和运算是相互独立的
x = tf.placeholder(tf.float32,[None,784])#输入数据的地方,第一个参数是数据类型,第二个是tensor的shape W = tf.Variable(tf.zeros([784,10]))#Variable是存储模型参数的,不同于存储数据的tensor一旦使用掉就消失,Variable在模型训练迭代过程中是持久化的。
b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W)+b)#实现Softmax Regression算法 y_ = tf.placeholder(tf.float32,[None,10])#定义一个真实的label,与下面的结果做比较
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))#计算模型的loss train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#定义了损失函数之后,再定义一个优化算法,本代码使用SGD算法
tf.global_variables_initializer().run()#使用全局参数初始化器初始化参数 for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100)#每次选择100条数据
train_step.run({x:batch_xs,y_:batch_ys})#选择好数据之后用SGD算法做迭代 correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))#比较预测结果是否准确
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#转化成准确率
print (accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))#输出结果,正确率为91.57%

TensorFlow实现Softmax Regression识别手写数字的更多相关文章

  1. TensorFlow实现Softmax Regression识别手写数字中"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败”问题

    出现问题: 在使用TensorFlow实现MNIST手写数字识别时,出现"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应 ...

  2. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  3. 使用TensorFlow的卷积神经网络识别手写数字(2)-训练篇

    import numpy as np import tensorflow as tf import matplotlib import matplotlib.pyplot as plt import ...

  4. 【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)

    博文主要内容有: 1.softmax regression的TensorFlow实现代码(教科书级的代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3 ...

  5. 使用TensorFlow的卷积神经网络识别手写数字(3)-识别篇

    from PIL import Image import numpy as np import tensorflow as tf import time bShowAccuracy = True # ...

  6. 使用TensorFlow的卷积神经网络识别手写数字(1)-预处理篇

    功能: 将文件夹下的20*20像素黑白图片,根据重心位置绘制到28*28图片上,然后保存.经过预处理的图片有利于数字的准确识别.参见MNIST对图片的要求. 此处可下载已处理好的图片: https:/ ...

  7. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  8. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  9. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

随机推荐

  1. centos7设置rsyslog日志服务集中服务器

    centos7设置rsyslog日志服务集中服务器 环境:centos6.9_x86_64,自带的rsyslog版本是7.4.7,很多配置都不支持,于是进行升级后配置 # 安装新版本的rsyslog程 ...

  2. datax实例——全量、增量同步

    一.全量同步 本文以mysql -> mysql为示例: 本次测试的表为mysql的系统库-sakila中的actor表,由于不支持目的端自动建表,此处预先建立目的表: CREATE TABLE ...

  3. ROS学习笔记(一)

    运行ROS例程(turtlesim)1)安装turtlesim包sudo apt-get install ros-kinetic-turtlesim2)运行管理器节点roscore3)运行turtle ...

  4. 算法习题---5-2Ducci序列(UVa1594)

    一:题目 对于一个n元组(a1, a2, …, an),可以对于每个数求出它和下一个数的差的绝对值,得到一个新的n元组(|a1-a2|, |a2-a3|, …, |an-a1|).重复这个过程,得到的 ...

  5. k8s记录-kubeadm安装(二)(转载)

    kubeadm安装安装环境(vm6.5下虚拟机3台,centos 7.4):master:10.20.0.191Node1:10.20.0.192Node2:10.20.0.193 1.安装虚拟机,配 ...

  6. matlab基本函数randperm end数组索引

    一起来学演化计算-matlab基本函数randperm end数组索引 觉得有用的话,欢迎一起讨论相互学习~Follow Me 随机排列 语法 p = randperm(n) p = randperm ...

  7. CF1237D Balanced Playlist

    思路:假设从第i首歌开始听,结束位置为j,那么从第i+1首歌开始听,结束位置一定不早于j.可以用反证法证明.想到这一点,就不难解决了. 实现: #include <bits/stdc++.h&g ...

  8. C++STL位标志、智能指针与异常处理

    参考<21天学通C++>第25章节,对STL位标志进行介绍.就是当需要不是像char int long 等整个字节数的数据表示形式,而是使用二进制位表示的时候,通常使用这里讲到的位标志.从 ...

  9. composer安装扩展包异常

    我是tp5.1下,用composer安装扩展包,在命令行运行,无任何不反应,不下载也不报错,这时,我们先ctrl+c退出执行的命令,然后在tp5.1根目录下,找到composer.json文件,并用编 ...

  10. Vmware问题: 开机提示“虚拟机已被打开,是否获得所有权?”& Vmware检测不到USB

    "一只美丽的小鸟,在绿色的草坪上蹦来跳去,很是可爱"----清风徐来 问题1: Vmware开机提示"虚拟机已被打开,是否获得所有权?" 解决: 关闭虚拟机,用 ...