学习深度学习,首先从深度学习的入门MNIST入手。通过这个例子,了解Tensorflow的工作流程和机器学习的基本概念。

一  MNIST数据集

MNIST是入门级的计算机视觉数据集,包含了各种手写数字的图片。在这个例子中就是通过机器学习训练一个模型,以识别图片中的数字。

MNIST数据集来自 http://yann.lecun.com/exdb/mnist/

Tensorflow提供了一份python代码用于自动下载安装数据集。Tensorflow官方文档中的url打不开,在CSDN上找到了一个分享:http://download.csdn.net/detail/u010417185/9588647

和官方有点不同的是,我直接把四个数据集下载下来,放在/tmp/mnist下,在项目文件中使用以下代码导入:

import input_data
import tensorflow as tf mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)

这里的数据集分为两个部分:60000的训练数据集(mnist.train)和10000的测试数据集(mnist.test),测试集的作用是帮助模型泛化。数据对应包含图片和标签,分别用mnist.train.images,mnist.train.lables,mnist.test.images,mnist.test.lables来表示。每张图片有28×28=784个像素点,因此训练图片mnist.train.images的张量表示为 [60000, 784],第一个纬度用于索引图片,第二纬度用于索引像素点。由于判断10个数字,这里采用热独,即one-hot-vectors,除了一位数字为1外其他纬度数字为0。例如判断数字为0则其表示为[1,0,0,0,0,0,0,0,0,0]。因此训练标签表示为[10000,10],第一纬度索引图片,第二纬度判断数字。

二  softmax回归介绍

softmax模型可以给不同的对象分配概率。根据下图,对输入的x的加权求和,再分别加上一个偏置量,最后输入到softmax函数中:

具体转换为公式,即:

三  实现回归模型

首先进行模型的定义,如下:

x = tf.placeholder(tf.float32, [None, 784]) #使用占位符placeholder,第一维度可指定图片的数量是任意的
W = tf.Variable(tf.zeros([784,10])) #初始化权值
b = tf.Variable(tf.zeros([10])) #初始化偏置值
y = tf.nn.softmax(tf.matmul(x,W) + b) #根据公式计算

四  训练模型

选用的损失函数为交叉熵,其定义如下:

其中y为预测的概率分布,y'为实际分布。

代码如下:

y_ = tf.placeholder("float", [None,10])  #表示实际的分布
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #计算损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #以梯度下降算法最小化损失函数
init = tf.initialize_all_variables() #初始化所有变量
sess = tf.Session() #定义会话
sess.run(init) #初始化会话 for i in range(1000): #开始训练,循环训练1000次
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

五  评估模型

选用tf.argmax函数评估,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,用 tf.equal 来检测预测是否与真实标签匹配(索引位置一样表示匹配)。

代码如下:

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))  #评估
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float")) #将结果转换为浮点数
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) #输出

六  代码

import input_data
import tensorflow as tf mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True) x = tf.placeholder(tf.float32, [None, 784]) #使用占位符placeholder,第一维度可指定图片的数量是任意的
W = tf.Variable(tf.zeros([784,10])) #初始化权值
b = tf.Variable(tf.zeros([10])) #初始化偏置值
y = tf.nn.softmax(tf.matmul(x,W) + b) #根据公式计算
y_ = tf.placeholder("float", [None,10]) #表示实际的分布
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #计算损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #以梯度下降算法最小化损失函数
init = tf.initialize_all_variables() #初始化所有变量
sess = tf.Session() #定义会话
sess.run(init) #初始化会话 for i in range(1000): #开始训练,循环训练1000次
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) #评估
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float")) #将结果转换为浮点数
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) #输出

七  实验结果

最终测试结果精确度在91%左右。

Tensorflow学习笔记(一):MNIST机器学习入门的更多相关文章

  1. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  2. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  3. TensorFlow框架(3)之MNIST机器学习入门

    1. MNIST数据集 1.1 概述 Tensorflow框架载tensorflow.contrib.learn.python.learn.datasets包中提供多个机器学习的数据集.本节介绍的是M ...

  4. TensorFlow学习笔记(MNIST报错修正 适用Tensorflow1.3)

    在Tensorflow实战Google框架下的深度学习这本书的MNIST的图像识别例子中,每次都要报错   错误如下: Only call `sparse_softmax_cross_entropy_ ...

  5. tensorflow学习笔记————分类MNIST数据集

    在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题. 一般是通过使用tensorflow内置的函数进行下载和加载, from tensorflow.examp ...

  6. tensorflow学习笔记(10) mnist格式数据转换为TFrecords

    本程序 (1)mnist的图片转换成TFrecords格式 (2) 读取TFrecords格式 # coding:utf-8 # 将MNIST输入数据转化为TFRecord的格式 # http://b ...

  7. MNIST机器学习入门【学习笔记】

    平台信息:PC:ubuntu18.04.i5.anaconda2.cuda9.0.cudnn7.0.5.tensorflow1.10.GTX1060 作者:庄泽彬(欢迎转载,请注明作者) 说明:本文是 ...

  8. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  9. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  10. tensorflow学习笔记二:入门基础 好教程 可用

    http://www.cnblogs.com/denny402/p/5852083.html tensorflow学习笔记二:入门基础   TensorFlow用张量这种数据结构来表示所有的数据.用一 ...

随机推荐

  1. 警惕javascript变量的全局污染问题

    作用域的概念总是和变量形影不离,它不是javascript语言独有的概念,只是其运用上与其他大型语言略有不同,JavaScript语言中采用的是弱类型的变量类型,对使用的数据类型未做出严格的要求,是基 ...

  2. NOIP2009普及组细胞分裂(数论)——yhx

    题目描述 Hanks 博士是 BT (Bio-Tech,生物技术) 领域的知名专家.现在,他正在为一个细胞实 验做准备工作:培养细胞样本. Hanks 博士手里现在有 N 种细胞,编号从 1~N,一个 ...

  3. Web Storage中的sessionStorage和localStorage

    html5中的Web Storage包括了两种存储方式:sessionStorage和localStorage. sessionStorage用于本地存储一个会话(session)中的数据,这些数据只 ...

  4. 页的lock

    文件为什么要加锁? 页的操作为什么要加锁? http://linux.chinaunix.net/techdoc/system/2007/06/11/959844.shtml 上面一个页面有简单介绍什 ...

  5. 【转】【C#】C#性能优化总结

    1.  C#语言方面         1.1 垃圾回收    垃圾回收解放了手工管理对象的工作,提高了程序的健壮性,但副作用就是程序代码可能对于对象创建变得随意.    1.1.1 避免不必要的对象创 ...

  6. C# 通过消息捕获处理窗体最大化/最小化

    通过以下的一些代码可以实现捕获相关的一些消息事件; 以及可以通过调用 SetCloseMenu();实现关闭一些按钮功能如屏蔽关闭按钮功能等; 需要添加命名空间:using System.Runtim ...

  7. 有一家做BPM的公司叫K2,Gartner和IDC都说好!

    有一家公司被Gartner称为成长最快速的BPMS厂商,被IDC称为破坏性创新者… IDC及Gartner均称K2为成长最快速的商务流程管理套装平台(BPMS)厂商.IDC称K2为“破坏性创新者,在关 ...

  8. 记录使用gogs,drone搭建自动部署测试环境

    使用gogs,drone,docker搭建自动部署测试环境 Gogs是一个使用go语言开发的自助git服务,支持所有平台 Docker是使用go开发的开源容器引擎 Drone是一个基于容器技术的持续集 ...

  9. sed 4个功能

    [root@lanny test]# cat test.txt test liyao lanny 经典博文: http://oldboy.blog.51cto.com/2561410/949365 h ...

  10. Linux Shell编程一

    交互模式 --当Shell收到用户输入命令后,就开始执行这项命令,并把结果显示到屏幕上,结束后Shell又会显示系统提示符,等待用户输入下一条命令. 后台运行 --后台运行的符号为"& ...