TensorFlow 入门之手写识别(MNIST) softmax算法

MNIST
flyu6
softmax回归

softmax回归算法

我们知道MNIST的每一张图片都表示一个数字,从0到9。我们希望得到给定图片代表每个数字的概率。比如说,我们的模型可能推测一张包含9的图片代表数字9的概率是80%但是判断它是8的概率是5%(因为8和9都有上半部分的小圆),然后给予它代表其他数字的概率更小的值。
这是一个使用softmax回归(softmax regression)模型的经典案例。 softmax 模型可以用来给不同的对象分配概率。即使在之后,我们训练更加精细的模型时,最后一步也需要用softmax来分配概率。
这是一个使用softmax回归(softmax regression)模型的经典案例。softmax模型可以用来给不同的对象分配概率。即使在之后,我们训练更加精细的模型时,最后一步也需要用softmax来分配概率。

softmax回归(softmax regression)分两步:第一步

为了得到一张给定图片属于某个特定数字类的证据(evidence),我们对图片像素值进行加权求和。如果这个像素具有很强的证据说明这张图片不属于该类,那么相应的权值为负数,相反如果这个像素拥有有利的证据支持这张图片属于这个类,那么权值是正数。

下面的图片显示了一个模型学习到的图片上每个像素对于特定数字类的权值。红色代表负数权值,蓝色代表正数权值。

数字的特征

我们也需要加入一个额外的偏置量(bias),因为输入往往会带有一些无关的干扰量。因此对于给定的输入图片x它代表的是数字i的证据可以表示为

求和

其中Wi代表权重,bi 代表数字 i 类的偏置量,j 代表给定图片 x 的像素索引用于像素求和。然后用softmax函数可以把这些证据转换成概率 y:

激励函数

这的softmax可是看做是一个sigmoid形式的函数。把我们定义的线性函数的输出转换成我们想要的格式,也就是关于10个数字类的概率分布。因此,给定一张图片,它对于每一个数字的吻合度可以被softmax函数转换成为一个概率值。

归一化处理

展开等式右边的子式,可以得到:

softmax使用的公式

对于softmax回归模型可以用下面的图解释,对于输入的xs加权求和,再分别加上一个偏置量,最后再输入到softmax函数中:

softmax运行方式

如果把它写成一个等式,我们可以得到:

softmax数学表达式

我们也可以用向量表示这个计算过程:用矩阵乘法和向量相加。这有助于提高计算效率。(也是一种更有效的思考方式):

softmax矩阵表现形式

更进一步,可以写成更加紧凑的方式:

最终会使用的表达式

TensorFlow实现softmax

  1. # create a softmax regression 


  2. import tensorflow as tf 

  3. from tensorflow.examples.tutorials.mnist import input_data 


  4. mnist = input_data.read_data_sets("/home/fly/TensorFlow/mnist", one_hot=True) 


  5. x = tf.placeholder(tf.float32,[None, 784]) 


  6. W = tf.Variable(tf.zeros([784, 10])) 


  7. b = tf.Variable(tf.zeros([10])) 


  8. y = tf.nn.softmax(tf.matmul(x,W)+b) 


  9. y_ = tf.placeholder(tf.float32,[None, 10]) 


  10. cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 


  11. train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 


  12. init = tf.initialize_all_variables() 

  13. sess = tf.Session() 

  14. sess.run(init) 


  15. for i in range(1000): 

  16. batch_xs, batch_ys = mnist.train.next_batch(100) 

  17. sess.run(train_step, feed_dict = {x: batch_xs, y_: batch_ys}) 


  18. correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) 

  19. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

  20. print(sess.run(accuracy, feed_dict={x:mnist.test.images, y_: mnist.test.labels})) 



Fly
2016.6

TensorFlow 入门之手写识别(MNIST) softmax算法的更多相关文章

  1. TensorFlow 入门之手写识别(MNIST) softmax算法 二

    TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  2. TensorFlow 入门之手写识别(MNIST) 数据处理 一

    TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...

  3. TensorFlow 入门之手写识别CNN 三

    TensorFlow 入门之手写识别CNN 三 MNIST 卷积神经网络 Fly 多层卷积网络 多层卷积网络的基本理论 构建一个多层卷积网络 权值初始化 卷积和池化 第一层卷积 第二层卷积 密集层连接 ...

  4. densenet tensorflow 中文汉字手写识别

    densenet 中文汉字手写识别,代码如下: import tensorflow as tf import os import random import math import tensorflo ...

  5. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  6. TensorFlow MNIST(手写识别 softmax)实例运行

    TensorFlow MNIST(手写识别 softmax)实例运行 首先要有编译环境,并且已经正确的编译安装,关于环境配置参考:http://www.cnblogs.com/dyufei/p/802 ...

  7. 使用tensorflow实现mnist手写识别(单层神经网络实现)

    import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data import n ...

  8. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  9. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

随机推荐

  1. <C++ 实现设计模式> 观察者模式

    观察者模式,又称公布--订阅,mvc模式等. 通俗点讲,比方股票来说,非常多人关注一支股票,派一个人去观察股票的情况,一有变化(观察),就通知全部的预定这个消息的人. 而我们常见的mvc模式,v是指v ...

  2. java抽象类和接口的区别(转载)

    1.Java接口和Java抽象类最大的一个区别,就在于Java抽象类可以提供某些方法的部分实现,而Java接口不可以,这大概就是Java抽象类唯一的优点吧,但这个优点非常有用. 如果向一个抽象类里加入 ...

  3. Android开发之Mediaplayer

    Android提供了常见的音频.视频的编码.解码机制.借助于多媒体类MediaPlayer的支持,开发者能够非常方便在在应用中播放音频.视频.本篇博客主要解说在Android平台下怎样播放一个音频文件 ...

  4. log4e插件的安装和使用

    1.首先下载log4e小工具.放入myeclipse10安装文件夹D:\Program Files (x86)\myEclipse10\MyEclipse Blue Edition 10\dropin ...

  5. jmeter参数化之CSV Data Set Config

    在jmeter中,可以用CSV Data Set Config实现参数化. 1.准备参数化数据

  6. ”Validation of viewstate MAC failed” 错误

    ”Validation of viewstate MAC failed” 错误 在ASP.NET里面,View State使用较为广泛.它作为一个隐藏字段,可以帮助服务端”记住“客户端的改变,这样客户 ...

  7. CentOS 7 结构体GCC 4.8.2 32位编译环境

    centos 7 结构体gcc 32位编译环境 1介绍 1.1背景 学习新 C++ 2011和C11标准. 1.2使用软件 CentOS 7(Linux version 3.10.0-123.el7. ...

  8. Java泛型和集合之泛型介绍

    在声明一个接口和类的时候可以使用尖括号带有一个或者多个参数但是当你在声明属于一个接口或者类的变量的时候或者你在创建一个类实例的时候需要提供他们的具体类型.我们来看下下面这个例子 List<Str ...

  9. 解决IIS7 HTTP/405 Method Not Allowed 问题的方法.

    1.处理程序映射 2.添加脚本映射 3.请求路径:*.html 4.可执行文件:C:/windows/system32/inetsrv/asp.dll 5.请求限制-谓词:输入需要允许请求的谓词(po ...

  10. js面向对象学习总结

    1.函数作为参数进行传递 function a(str,fun){ console.log(fun(str)) }; function up(str){ return str.toUpperCase( ...