写在前面的话

MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html

参考视频:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-01-classifier/

MNIST编程

代码全文

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data',one_hot = True) def add_layer(inputs, in_size, out_size, activation_function=None):
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs def compute_accuracy(v_xs,v_ys):
global prediction
y_pre = sess.run(prediction,{xs:v_xs})
correct_prediction = tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
result = sess.run(accuracy,{xs:v_xs,ys:v_ys})
return result xs = tf.placeholder(tf.float32,[None,784])
ys = tf.placeholder(tf.float32,[None,10]) # add hiden layer
prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax) # the error
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1])) # train
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100)
sess.run(train_step,{xs:batch_xs,ys:batch_ys})
if i % 50 == 0:
print(compute_accuracy(mnist.test.images,mnist.test.labels))

打印结果

解释

总的来说,整个程序分为两层,输入层和输出层,没有隐藏层。用到的激励函数为 softmax 函数。

与之前的 tensorflow之曲线拟合 相比,不同并值得记录的有以下几点:

  • 1 one_hot = True

表示使用热编码 ,什么是热编码呢?这里举个例子

如用[1,0,0,0,0,0,0]表示星期一,[0,1,0,0,0,0,0]表示星期二。这里是手写数字识别,识别0~9 共10个数字。所以这里用热编码来表示5的话就是: [0,0,0,0,0,1,0,0,0,0]。在实际预测过程中,预测值可能为 [0,0,0.1,0,0,0.6,0,0,0.3,0] 这样的形式,代表预测到0~9某一数字的概率。

  • 2 prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)

激励函数为 softmax ,这里为什么用这个函数,对于分类问题, “最后一层输出会使用Softmax函数进行概率化输出”

  • 3 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1]))

cross_entropy称为交叉熵 ,该误差模型被广泛应用于机器学习的分类问题。

  • 4 mnist.train.next_batch(100)

这里代表每次从训练样本中拿出100个训练样本样本数据来进行训练。进行1000次迭代循环,实际上使用了100 * 1000 个训练样本。

  • 5 compute_accuracy(mnist.test.images,mnist.test.labels)

函数本身的作用用来计算训练模型后用测试样本来检测准确率的。与上一条解释对比可以看到变量的不同。一个是 train ,一个是test 。这样做是用训练数据去训练模型,用测试数据来测试模型,更能测试模型的鲁棒性。

  • 6 tf.argmax()

tf.argmax(vector, 1):返回的是vector中的最大值的索引号,如果vector是一个向量,那就返回一个值,如果是一个矩阵,那就返回一个向量,这个向量的每一个维度都是相对应矩阵行的最大值元素的索引号。

import tensorflow as tf
import numpy as np A = [[1,3,4,5,6]]
B = [[1,3,4], [2,4,1]] with tf.Session() as sess:
print(sess.run(tf.argmax(A, 1)))
print(sess.run(tf.argmax(B, 1))) #输出:
#[4]
#[2 1]

这里返回最大数值的所引值,实际上就是返回了0~9概率最大的那个数值,然后识别的结果和测试样本的正确结果对比。

关于第二个参数,这涉及到axis ,对于axis取值0或1,关系的是在计算上矩阵时是 的方向还是 的方向。这个百度一下就知道。

  • 7 tf.cast()

cast(x, dtype, name=None)

将x的数据格式转化成dtype.例如,原来x的数据格式是bool, 那么将其转化成float以后,就能够将其转化成0和1的序列。反之也可以。例如

a = tf.Variable([1,0,0,1,1])
b = tf.cast(a,dtype=tf.bool)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
print(sess.run(b))
#输出[ True False False True True]

解析参考:

https://blog.csdn.net/uestc_c2_403/article/details/72232807

https://blog.csdn.net/luoganttcc/article/details/70315538

总结

  • 1 对于线性/非线性拟合

    ① 常用激励函数:relu

    ② 常用平方差,计算误差

  • 2 对于分类问题

    ① 最后一层常用Softmax 激励函数进行概率化输出

    ② 常用交叉熵 ,计算误差

  • 3 对于大批量数据,常用分批次训练,next_batch

tensorflow之分类学习的更多相关文章

  1. TensorFlow基础笔记(3) cifar10 分类学习

    TensorFlow基础笔记(3) cifar10 分类学习 CIFAR-10 is a common benchmark in machine learning for image recognit ...

  2. TensorFlow和深度学习入门教程(TensorFlow and deep learning without a PhD)【转】

    本文转载自:https://blog.csdn.net/xummgg/article/details/69214366 前言 上月导师在组会上交我们用tensorflow写深度学习和卷积神经网络,并把 ...

  3. TensorFlow和深度学习新手教程(TensorFlow and deep learning without a PhD)

    前言 上月导师在组会上交我们用tensorflow写深度学习和卷积神经网络.并把其PPT的參考学习资料给了我们, 这是codelabs上的教程:<TensorFlow and deep lear ...

  4. 文本分类学习 (五) 机器学习SVM的前奏-特征提取(卡方检验续集)

    前言: 上一篇比较详细的介绍了卡方检验和卡方分布.这篇我们就实际操刀,找到一些训练集,正所谓纸上得来终觉浅,绝知此事要躬行.然而我在躬行的时候,发现了卡方检验对于文本分类来说应该把公式再变形一般,那样 ...

  5. 文本分类学习 (七)支持向量机SVM 的前奏 结构风险最小化和VC维度理论

    前言: 经历过文本的特征提取,使用LibSvm工具包进行了测试,Svm算法的效果还是很好的.于是开始逐一的去了解SVM的原理. SVM 是在建立在结构风险最小化和VC维理论的基础上.所以这篇只介绍关于 ...

  6. 基于TensorFlow的深度学习系列教程 2——常量Constant

    前面介绍过了Tensorflow的基本概念,比如如何使用tensorboard查看计算图.本篇则着重介绍和整理下Constant相关的内容. 基于TensorFlow的深度学习系列教程 1--Hell ...

  7. 文本分类学习 (十)构造机器学习Libsvm 的C# wrapper(调用c/c++动态链接库)

    前言: 对于SVM的了解,看前辈写的博客加上读论文对于SVM的皮毛知识总算有点了解,比如线性分类器,和求凸二次规划中用到的高等数学知识.然而SVM最核心的地方应该在于核函数和求关于α函数的极值的方法: ...

  8. Reshape以及向量机分类学习和等高线绘制代码

    首先科普一下python里面对于数组的处理,就是如果获取数组大小,以及数组元素数量,这个概念是不一样的,就是一个size和len处理不用.老规矩,上代码: arr2 = np.array([-19.5 ...

  9. TensorFlow (RNN)深度学习 双向LSTM(BiLSTM)+CRF 实现 sequence labeling 序列标注问题 源码下载

    http://blog.csdn.net/scotfield_msn/article/details/60339415 在TensorFlow (RNN)深度学习下 双向LSTM(BiLSTM)+CR ...

随机推荐

  1. 前端使用mobx时,变量已经修改了,为什么组件还是没变化,map类型变量,对象类型变量的值获取问题(主要矛盾发生在组件使用时)

    前天我在使用一个前端多选框组件时遇到了一个问题,明明对象内的值已经修改了,但是组件显示的还是没有效果改变,以下是当时打出的log,我打印了这个对象的信息 对象内的值已经修改了但是组件还是不能及时更改, ...

  2. 基于 HTML5 Canvas 的 Web SCADA 组态电机控制面板

    前言 HT For Web 提供完整的基于 HTML5 图形界面组件库.您可以轻松构建现代化的,跨桌面和移动终端的企业应用,无需担忧跨平台兼容性,及触屏手势交互等棘手问题.也可用于快速创建和部署,高度 ...

  3. 获取当前目录下所有php文件内的函数名

    $dir = dirname(__FILE__); $files = scandir($dir); foreach($files as $name){ if($name == '.' || $name ...

  4. FROM_UNIXTIME/CONCAT

    将mysql查询结果中时间戳转化为时间格式 FROM_UNIXTIME( c.createtime, '%Y-%m-%d %H:%i:%S' ) 2个字段合并查询 CONCAT(d.`name`, ' ...

  5. 【Spark】编程实战之模拟SparkRPC原理实现自定义RPC

    1. 什么是RPC RPC(Remote Procedure Call)远程过程调用.在Hadoop和Spark中都使用了PRC,它是一种通过网络从远程计算机程序上请求服务,而不需要了解底层网络技术的 ...

  6. 20155210 潘滢昊 Java第一次实验---凯撒密码

    Java第一次实验---凯撒密码 实验内容 实现凯撒密码,并进行测试. 实验代码 import java.io.*; import java.util.Scanner; public class ks ...

  7. 20155310马英林 实验2 Windows口令破解

    实 验 报 告 实验名称: 实验二 口令破解 姓名:马英林 学号: 20155310 班级: 1553 日期: 2017.10.24 一. 实验环境 •系统环境:Windows •网络环境:交换网络结 ...

  8. jsp+servlet+javabean开发web项目

    一.介绍: 项目依赖包:jdbc数据库包 jsp+servlet+javabean开发web项目,是最接近web项目原生运行原理的. 但是,jsp内容混乱,项目结构复杂时,代码会混乱 二.运行原理: ...

  9. sqlserver安装遇到的问题——1

    SQL Server安装过无数次,今天第一次遇到这样的问题 一.问题消息复制出来是这样的 TITLE: Microsoft SQL Server 2008 R2 安装程序--------------- ...

  10. [NOIp2018]货币系统 背包

    LG传送门 完全背包板子题 显然就是判断有多少种面值的货币可以被其他面值的货币表示,完全背包搞一搞就好了. 考场代码(一看这两格缩进就知道是考场代码): #include<cstdio> ...