# coding: utf-8

 # In[1]:

 import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data # In[2]: mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # In[3]: print('train',mnist.train.num_examples,
',validation',mnist.validation.num_examples,
',test',mnist.test.num_examples) # In[4]: print('train images :', mnist.train.images.shape,
'labels:' , mnist.train.labels.shape) # In[5]: import matplotlib.pyplot as plt
def plot_image(image):
plt.imshow(image.reshape(28,28),cmap='binary')
plt.gcf().set_size_inches(2, 4)
plt.show() # In[6]: plot_image(mnist.train.images[0]) # In[7]: import numpy as np
np.argmax(mnist.train.labels[0]) # In[8]: import matplotlib.pyplot as plt
def plot_images_labels_prediction(images,labels,
prediction,idx,num=10):
fig = plt.gcf()
fig.set_size_inches(12, 14)
if num>25: num=25
for i in range(0, num):
ax=plt.subplot(5,5, 1+i) ax.imshow(np.reshape(images[idx],(28, 28)),
cmap='binary') title= "label=" +str(np.argmax(labels[idx]))
if len(prediction)>0:
title+=",predict="+str(prediction[idx]) ax.set_title(title,fontsize=10)
ax.set_xticks([]);ax.set_yticks([])
idx+=1
plt.show() # In[9]: plot_images_labels_prediction(mnist.train.images,
mnist.train.labels,[],0) # In[10]: def layer(output_dim,input_dim,inputs, activation=None):
W = tf.Variable(tf.random_normal([input_dim, output_dim]))
b = tf.Variable(tf.random_normal([1, output_dim]))
XWb = tf.matmul(inputs, W) + b
if activation is None:
outputs = XWb
else:
outputs = activation(XWb)
return outputs # In[11]: x = tf.placeholder("float", [None, 784])
h1=layer(output_dim=256,input_dim=784,
inputs=x ,activation=tf.nn.relu)
y_predict=layer(output_dim=10,input_dim=256,
inputs=h1,activation=None)
y_label = tf.placeholder("float", [None, 10]) # In[12]: loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits
(logits=y_predict ,
labels=y_label))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function) # In[13]: correct_prediction = tf.equal(tf.argmax(y_label , 1),
tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # In[14]: trainEpochs = 20
batchSize = 100
totalBatchs = int(mnist.train.num_examples/batchSize)
epoch_list=[];loss_list=[];accuracy_list=[]
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer()) # In[15]: for epoch in range(trainEpochs):
for i in range(totalBatchs):
batch_x, batch_y = mnist.train.next_batch(batchSize)
sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y}) loss,acc = sess.run([loss_function,accuracy],
feed_dict={x: mnist.validation.images,
y_label: mnist.validation.labels}) epoch_list.append(epoch);
loss_list.append(loss)
accuracy_list.append(acc)
print("Train Epoch:", '%02d' % (epoch+1), "Loss=", "{:.9f}".format(loss)," Accuracy=",acc) duration =time()-startTime
print("Train Finished takes:",duration) # In[16]: get_ipython().magic('matplotlib inline')
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left') # In[17]: plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show() # In[18]: print("Accuracy:", sess.run(accuracy,
feed_dict={x: mnist.test.images,
y_label: mnist.test.labels})) # In[19]: prediction_result=sess.run(tf.argmax(y_predict,1),
feed_dict={x: mnist.test.images })
prediction_result[:10] # In[20]: plot_images_labels_prediction(mnist.test.images,
mnist.test.labels,
prediction_result,0) # In[21]: y_predict_Onehot=sess.run(y_predict,
feed_dict={x: mnist.test.images })
y_predict_Onehot[8] # In[22]: for i in range(400):
if prediction_result[i]!=np.argmax(mnist.test.labels[i]):
print("i="+str(i)+
" label=",np.argmax(mnist.test.labels[i]),
"predict=",prediction_result[i]) # In[ ]:

代码如上。

手动建立好输入层,隐层,输出层。

设置损失函数,优化器:

评估方式与准确率:

开始分批次训练:

训练完成后的准确率:

查看某项中的预测概率:

筛选出预测失败的数据:

可以通过:

tf.summary.merge_all()
train_writer = tf.summary.FileWriter('log/area',sess.graph)

保存图。

通过tensorboard --logdir="路径",打开服务,通过输入localhost:6006之类打开网站。

查看生成的图:

tensorflow下识别手写数字基于MLP网络的更多相关文章

  1. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  2. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  3. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  4. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  5. TensorFlow实现Softmax Regression识别手写数字中"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败”问题

    出现问题: 在使用TensorFlow实现MNIST手写数字识别时,出现"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应 ...

  6. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  7. 6 TensorFlow实现cnn识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  8. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  9. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

随机推荐

  1. K-邻近算法简单例子

    from numpy import * import operator import matplotlib.pyplot as plt def creatDataset(): group = arra ...

  2. HDU-1087.SuperJUmpingJUmpingJumping.(DP and LISPP)

    本题大意:给定一个长度为n的序列a,让你输出这个序列子序列中元素和最大的最大上升子序列. 本题思路:一开始肯定可以想到用LIS实现,我们用LIS实现的时候可以发现这个问题并不满足LIS问题的最优子结构 ...

  3. TOJ2811: Bessie's Weight Problem(完全背包)

    传送门(<---可以点的) 描述 Bessie, like so many of her sisters, has put on a few too many pounds enjoying t ...

  4. Shell教程 之字符串

    1.Shell字符串 字符串是shell编程中最常用最有用的数据类型(除了数字和字符串,也没啥其它类型好用了),字符串可以用单引号,也可以用双引号,也可以不用引号. 1.1 单引号 str='I am ...

  5. tcp/ip通信第5期之服务器端程序

    /* 此程序是tcp/ip通信服务器端程序,测试运行在redhat5上 重构readline函数,解决粘包问题——利用“\n”识别一个消息边界 */ #include<stdio.h> # ...

  6. 利用jenkins+saltstack+sh 修改nginx配置文件并重新加载

    jenkins的配置(这里作用只是当做界面使用,利用它来管理执行salt命令) 1.构建操作来执行shell脚本 (pillar可以配置灵活的参数) saltstack 的 sls文件编写 nginx ...

  7. sourceTree git的一些命令

    经常使用的三个命令 1.添加修改过的文件到缓冲区 git add. 2.commit到本地 git commit -am ' 更改描述' 3.如果是多人开发的话,中间可能会有别人先提交的这是就需要先把 ...

  8. gearman中worker常驻后台,导致MySQL server has gone away

    产生这个原因主要有如下几点: 1.mysql服务宕机了 2.长时间没有操作,超过了wait_timeout的设置,mysql自动断开 3.mysql请求链接被主动kill 4.发送的请求或返回结果过大 ...

  9. webpack搭建自己的项目

    使用代理的方式参考地址:http://www.cnblogs.com/wangyongcun/p/7665687.html npm 常用模块 npm init 创建package.json文件 一值回 ...

  10. c#Loading 页SplashScreenManager的使用

    一.新建一个加载界面: SplashScreenManager控件只是作为加载界面的统一管理器,我们要使用加载界面,需要自行创建加载界面,两种方法如下: 1.点击SplashScreenManager ...