整体代码:

#数据读取
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) #定义待输入数据的占位符
#mnist中每张照片共有28*28=784个像素点
x = tf.placeholder(tf.float32,[None,784],name="X") #0-9一共10个数字=>10个类别
y = tf.placeholder(tf.float32,[None,10],name="Y") #定义模型变量
#以正态分布的随机数初始化权重W,以常数0初始化偏置b
#在神经网络中,权值W的初始值通常设为正态分布的随机数,偏置项b的初始值通常也设置为正态分布的随机数或常数。
W = tf.Variable(tf.random_normal([784,10],name="W"))
b = tf.Variable(tf.zeros([10]),name="b") #用单个神经元构建神经网络
forward=tf.matmul(x,W) + b #前向计算 #结果分类
#当我们处理多分类任务的时候,通常需要使用Softmax Regression模型。Softmax会对每一类别估算出一个概率。
#工作原理:将判定为某一类的特征相加,然后将这些特征转化为判定是这一类的概率
pred = tf.nn.softmax(forward) #Softmax分类 #设置训练参数
train_epochs = 120 #训练轮数
batch_size = 120 #单次训练样本数(批次大小)
total_batch = int(mnist.train.num_examples/batch_size) #一轮训练有多少批次
display_step = 1 #显示粒度
learning_rate = 0.01 #学习率 #概率估算值需要将预测输出值控制在[0,1]区间内。二元分类问题的目标是正确预测两个可能标签中的一个
#逻辑回归可以用于处理这类问题。二元逻辑回归的损失函数一般采用对数损失函数
#多元分类:逻辑回归可生成介于0到1.0之间的小数。Softmax将这一想法延伸到多类别领域。
#在多类别问题中,Softmax会为每个类别分配一个用小数表示的概率。这些用小数表示的概率相加之和必须是1.0 #交叉熵损失函数:交叉熵是一个信息论的概念,它原来是用来估算平均编码长度的。
#交叉熵刻画的是两个概率分布之间的距离,p代表正确答案,q代表的预测值,交叉熵越小,两个概率的分布越接近
#定义损失函数
loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) #交叉熵 #选择优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) #梯度下降优化器 #定义准确率
# 检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况
#argmax()将数组中最大值的下标取出来
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) #准确率,将布尔值转化为浮点数,并计算平均值 tf.cast()将布尔值投射成浮点数
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #声明会话,初始化变量
sess = tf.Session()
init = tf.global_variables_initializer() #变量初始化
sess.run(init) #训练模型
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size) #读取批次数据
sess.run(optimizer,feed_dict={x:xs,y:ys}) #执行批次训练 #total_batch个批次训练完成后,使用验证数据计算误差与准确率,验证集没有分批
loss,acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels}) #打印训练过程中的详细信息
if (epoch+1) % display_step == 0:
print("Train Epoch:",'%02d'%(epoch+1),"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc)) print("Train Finished!") #评估模型
#完成训练后,在测试集上评估模型的准确率
accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Test Accuracy:",accu_test)
#完成训练后,在验证集上评估模型的准确率
accu_validation = sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
print("Test Accuracy:",accu_validation)
#完成训练后,在训练集上评估模型的准确率
accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
print("Test Accuracy:",accu_train) #应用模型
#在建立模型并进行训练后,若认为准确率可以接受,则可以使用此模型进行预测
#由于pred预测结果是one_hot编码格式,所以需要转换成0~9数字
prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) #查看预测结果中的前10项
prediction_result[0:10] #定义可视化函数
def plot_images_labels_prediction(images,labels,prediction,index,num=10): #参数: 图形列表,标签列表,预测值列表,从第index个开始显示,缺省一次显示10幅
fig = plt.gcf() #获取当前图表,Get Current Figure
fig.set_size_inches(10,12) #1英寸等于2.45cm
if num > 25 : #最多显示25个子图
num = 25
for i in range(0,num):
ax = plt.subplot(5,5,i+1) #获取当前要处理的子图
ax.imshow(np.reshape(images[index],(28,28)), cmap = 'binary') #显示第index个图像
title = "labels="+str(np.argmax(labels[index])) #构建该图上要显示的title信息
if len(prediction)>0:
title += ",predict="+str(prediction[index]) ax.set_title(title,fontsize=10) #显示图上的title信息
ax.set_xticks([]) #不显示坐标轴
ax.set_yticks([])
index += 1
plt.show()
#可视化预测结果
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,10)

Tensorflow之MNIST手写数字识别:分类问题(2)的更多相关文章

  1. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  2. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  3. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  4. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  5. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  6. 基于TensorFlow的MNIST手写数字识别-深入

    构建多层卷积神经网络时需要多组W和偏移项b,我们封装2个方法来产生W和b 初级MNIST中用0初始化W和b,这里用噪声初始化进行对称打破,防止产生梯度0,同时用一个小的正值来初始化b避免dead ne ...

  7. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  8. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  9. mnist 手写数字识别

    mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...

随机推荐

  1. Linux:用户权限管理

    用户与用户组的概念 超级用户 拥有对系统的最高管理权限,默认是 root 用户 普通用户 只能对自己目录下的文件进行访问和修改,具有登录系统的权限. 虚拟用户 也叫"伪"用户,这类 ...

  2. LeetCode 200. 岛屿数量

    习题地址 https://leetcode-cn.com/problems/number-of-islands/ 给定一个由 '1'(陆地)和 '0'(水)组成的的二维网格,计算岛屿的数量.一个岛被水 ...

  3. Java Web 学习(9) —— EL 与 JSTL

    EL 与 JSTL EL与JSTL的作用是为了减少JSP页面中的代码. EL EL(Expression Language):表达式语言 常用于取值 语法 EL 表达式以${开头,以}结束. 多个表达 ...

  4. Python程序练习题(一)

    Python:程序练习题(一) 1.2 整数序列求和.用户输入一个正整数N,计算从1到N(包含1和N)相加之后的结果. 代码如下: n=input("请输入整数N:") sum=0 ...

  5. Socket抽象层

    目录 一.Socket抽象层 一.Socket抽象层 我们知道两个进程如果需要进行通讯最基本的一个前提是能够唯一标示一个进程,在本地进程通讯中我们可以使用PID来唯一标示一个进程,但PID只在本地唯一 ...

  6. Java连载48-final关键字

    一.final关键字 1.注意点: (1)final是一个关键字,表示最终的,不可变的. (2)final修饰的类无法被继承 (3)final修饰的方法无法被覆盖 (4)final修饰的变量一旦被赋值 ...

  7. python做中学(九)定时器函数的用法

    程序中,经常用到这种,就是需要固定时间执行的,或者需要每隔一段时间执行的.这里经常用的就是Timer定时器.Thread 类有一个 Timer子类,该子类可用于控制指定函数在特定时间内执行一次. 可以 ...

  8. python接口自动化8-unittest框架使用

    前言 unittest:Python单元测试框架,基于Erich Gamma的JUnit和Kent Beck的sSmalltalk测试框架. 一.unittest框架基本使用 unittest需要注意 ...

  9. pytorch_13-图像处理之skimage

    之前程序使用的是PIL(Python image library),今天遇到了另一种图像处理包--skimage. skimage即scikit-image,PIL和Pillow只提供最基础的数字图像 ...

  10. @PostConstruct - 静态方法调用IOC容器Bean对象

    需求:工具类里面引用IOC容器Bean,强迫症患者在调用工具类时喜欢用静态方法的方式而非注入的方式去调用,但是spring 不支持注解注入静态成员变量. 静态变量/类变量不是对象的属性,而是一个类的属 ...