1 import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
import numpy as np
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #下载据数
print('train images:',mnist.train.images.shape, #查看数据
'labels:',mnist.train.labels.shape)
print('validation images:',mnist.validation.images.shape,
'labels:',mnist.validation.labels.shape)
print('test images:',mnist.test.images.shape,
'labels:',mnist.test.labels.shape
#定义显示多项图像的函数
def plot_images_labels_prediction_3(images,labels,prediction,idx,num=):
fig=plt.gcf()
fig.set_size_inches(,)
if num>:num=
for i in range(,num):
ax=plt.subplot(,,i+)
ax.imshow(np.reshape(images[idx],(,)),cmap='binary')
title='lable='+str(np.argmax(labels[idx]))
if len(prediction)>:
title+=",prediction="+str(prediction[idx])
ax.set_title(title,fontsize=)
ax.set_xticks([]);ax.set_yticks([])
idx+=
plt.show() plot_images_labels_prediction_3(mnist.train.images,mnist.train.labels,[],)
#定义layer函数,构建多层感知器模型
def layer(output_dim,input_dim,inputs,activation=None):
W=tf.Variable(tf.random_normal([input_dim,output_dim]))
b=tf.Variable(tf.random_normal([,output_dim]))
XWb=tf.matmul(inputs,W)+b
if activation is None:
outputs=XWb
else:
outputs=activation(XWb)
return outputs
#建立输入层
x=tf.placeholder("float",[None,])
#建立隐藏层
h1=layer(output_dim=,input_dim=,inputs=x,
activation=tf.nn.relu)
#建立输出层
y_predict=layer(output_dim=,input_dim=,inputs=h1,
activation=None)
y_label=tf.placeholder("float",[None,])
#定义损失函数
loss_function=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits
(logits=y_predict,
labels=y_label))
#定义优化器
optimizer=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)
#计算每一项数据是否预测正确
correct_prediction=tf.equal(tf.argmax(y_label,),
tf.argmax(y_predict,))
#计算预测正确结果的平均值
accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float"))
#、定义训练参数
trainEpochs= #设置执行15个训练周期
batchSize= #每一批次项数为100
totalBatchs=int(mnist.train.num_examples/batchSize) #计算每个训练周期
loss_list=[];epoch_list=[];accuracy_list=[] #初始化训练周期、误差、准确率
from time import time #导入时间模块
startTime=time() #开始计算时间
sess=tf.Session() #建立Session
sess.run(tf.global_variables_initializer()) #初始化TensorFlow global 变量
#、进行训练
for epoch in range(trainEpochs):
for i in range(totalBatchs):
batch_x,batch_y=mnist.train.next_batch(batchSize) #使用mnist.train.next_batch方法读取批次数据,传入参数batchSize是100
sess.run(optimizer,feed_dict={x:batch_x,
y_label:batch_y}) #执行批次训练
loss,acc=sess.run([loss_function,accuracy], #使用验证数据计算准确率
feed_dict={x:mnist.validation.images,
y_label:mnist.validation.labels})
epoch_list.append(epoch); #加入训练周期列表
loss_list.append(loss) #加入误差列表
accuracy_list.append(acc) #加入准确率列表
print("Train Epoch:",'%02d' % (epoch+),"Loss=",\
"{:.9f}".format(loss),"Accuracy=",acc)
duration=time()-startTime
print("Train Finished takes:",duration) #计算并显示全部训练所需时间
#画出误差执行结果 fig=plt.gcf()
fig.set_size_inches(,)
plt.plot(epoch_list,loss_list,label='loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'],loc='upper left')
#画出准确率执行结果
plt.plot(epoch_list,accuracy_list,label="accuracy")
fig=plt.gcf()
fig.set_size_inches(,)
plt.ylim(0.8,)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()
#评估模型准确率
print("accuracy:",sess.run(accuracy,
feed_dict={x:mnist.test.images,
y_label:mnist.test.labels}))
#进行预测
#.执行预测
prediction_result=sess.run(tf.argmax(y_predict,),
feed_dict={x:mnist.test.images})
#.预测结果
print(prediction_result[:])
#.显示前10项预测结果
plot_images_labels_prediction_3(mnist.test.images,
mnist.test.labels,
prediction_result,)

运行结果:

TensorFlow—多层感知器—MNIST手写数字识别的更多相关文章

  1. 【TensorFlow-windows】(三) 多层感知器进行手写数字识别(mnist)

    主要内容: 1.基于多层感知器的mnist手写数字识别(代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  2. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  3. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  4. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  5. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  6. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  7. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  8. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  9. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

随机推荐

  1. HDU 2063 过山车(匈牙利算法)

    过山车 Time Limit : 1000/1000ms (Java/Other)   Memory Limit : 32768/32768K (Java/Other) Total Submissio ...

  2. 术语-服务:IaaS

    ylbtech-术语-服务:IaaS IaaS(Infrastructure as a Service),即基础设施即服务.消费者通过Internet 可以从完善的计算机基础设施获得服务.这类服务称为 ...

  3. [转]身份证从 15 >> 18

    身份证号码的结构和表达形式 1.号码的结构 由十七位数字本体码和一位效验码组成.排列顺序从左至右依次为:六位数字地址码,八位数字出生日期码,三位数字顺序码和一位数字效验码.2.地址码 表示编码对象常住 ...

  4. idea配置(卡顿、开发环境等配置),code style template

    Tomcat配置VM Options:    -XX:PermSize=512m -XX:MaxPermSize=1024m 1.IDEA卡顿,修改IDEA使用内存 修改idea配置文件 在IDEA的 ...

  5. Python - Django - App 的概念

    App 方便我们在一个大的项目中,管理实现不同的业务功能 创建 App: 命令行: python manage.py startapp app名 使用 Pycharm 创建: 文件 -> 新建项 ...

  6. SPM——Using Maven+Junit to test Hello Wudi

    Last week, ours teacher taught us 'Software Delivery and Build Management'. And in this class, our t ...

  7. 经典算法 Morris遍历

    内容: 1.什么是morris遍历 2.morris遍历规则与过程 3.先序及中序 4.后序 5.morris遍历时间复杂度分析 1.什么是morris遍历 关于二叉树先序.中序.后序遍历的递归和非递 ...

  8. ansible的安装与使用

    ansible的特点: 1. 基于ssh运行 2. 无需客户端 安装ansible 这里提供四种安装方式,根据自己的需要任选一种即可 1.1使用yum安装 yum install epel-relea ...

  9. shell语法(二)

    Shell脚本语法 条件测试:test. [ ] 命令test或[可以测试一个条件是否成立,如果测试结果为真,则该命令的Exit Status为0,如果测试结果为假,则命令的Exit Status为1 ...

  10. 0_Simple__simplePitchLinearTexture

    对比设备线性二维数组和 CUDA 二维数组在纹理引用中的效率 ▶ 源代码.分别绑定相同大小的设备线性二维数组和 CUDA 二维数组为纹理引用,做简单的平移操作,重复若干次计算带宽和访问速度. #inc ...