tensorflow下识别手写数字基于MLP网络
# coding: utf-8 # In[1]: import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data # In[2]: mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # In[3]: print('train',mnist.train.num_examples,
',validation',mnist.validation.num_examples,
',test',mnist.test.num_examples) # In[4]: print('train images :', mnist.train.images.shape,
'labels:' , mnist.train.labels.shape) # In[5]: import matplotlib.pyplot as plt
def plot_image(image):
plt.imshow(image.reshape(28,28),cmap='binary')
plt.gcf().set_size_inches(2, 4)
plt.show() # In[6]: plot_image(mnist.train.images[0]) # In[7]: import numpy as np
np.argmax(mnist.train.labels[0]) # In[8]: import matplotlib.pyplot as plt
def plot_images_labels_prediction(images,labels,
prediction,idx,num=10):
fig = plt.gcf()
fig.set_size_inches(12, 14)
if num>25: num=25
for i in range(0, num):
ax=plt.subplot(5,5, 1+i) ax.imshow(np.reshape(images[idx],(28, 28)),
cmap='binary') title= "label=" +str(np.argmax(labels[idx]))
if len(prediction)>0:
title+=",predict="+str(prediction[idx]) ax.set_title(title,fontsize=10)
ax.set_xticks([]);ax.set_yticks([])
idx+=1
plt.show() # In[9]: plot_images_labels_prediction(mnist.train.images,
mnist.train.labels,[],0) # In[10]: def layer(output_dim,input_dim,inputs, activation=None):
W = tf.Variable(tf.random_normal([input_dim, output_dim]))
b = tf.Variable(tf.random_normal([1, output_dim]))
XWb = tf.matmul(inputs, W) + b
if activation is None:
outputs = XWb
else:
outputs = activation(XWb)
return outputs # In[11]: x = tf.placeholder("float", [None, 784])
h1=layer(output_dim=256,input_dim=784,
inputs=x ,activation=tf.nn.relu)
y_predict=layer(output_dim=10,input_dim=256,
inputs=h1,activation=None)
y_label = tf.placeholder("float", [None, 10]) # In[12]: loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits
(logits=y_predict ,
labels=y_label))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function) # In[13]: correct_prediction = tf.equal(tf.argmax(y_label , 1),
tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # In[14]: trainEpochs = 20
batchSize = 100
totalBatchs = int(mnist.train.num_examples/batchSize)
epoch_list=[];loss_list=[];accuracy_list=[]
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer()) # In[15]: for epoch in range(trainEpochs):
for i in range(totalBatchs):
batch_x, batch_y = mnist.train.next_batch(batchSize)
sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y}) loss,acc = sess.run([loss_function,accuracy],
feed_dict={x: mnist.validation.images,
y_label: mnist.validation.labels}) epoch_list.append(epoch);
loss_list.append(loss)
accuracy_list.append(acc)
print("Train Epoch:", '%02d' % (epoch+1), "Loss=", "{:.9f}".format(loss)," Accuracy=",acc) duration =time()-startTime
print("Train Finished takes:",duration) # In[16]: get_ipython().magic('matplotlib inline')
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left') # In[17]: plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show() # In[18]: print("Accuracy:", sess.run(accuracy,
feed_dict={x: mnist.test.images,
y_label: mnist.test.labels})) # In[19]: prediction_result=sess.run(tf.argmax(y_predict,1),
feed_dict={x: mnist.test.images })
prediction_result[:10] # In[20]: plot_images_labels_prediction(mnist.test.images,
mnist.test.labels,
prediction_result,0) # In[21]: y_predict_Onehot=sess.run(y_predict,
feed_dict={x: mnist.test.images })
y_predict_Onehot[8] # In[22]: for i in range(400):
if prediction_result[i]!=np.argmax(mnist.test.labels[i]):
print("i="+str(i)+
" label=",np.argmax(mnist.test.labels[i]),
"predict=",prediction_result[i]) # In[ ]:
代码如上。
手动建立好输入层,隐层,输出层。

设置损失函数,优化器:

评估方式与准确率:

开始分批次训练:

训练完成后的准确率:

查看某项中的预测概率:

筛选出预测失败的数据:

可以通过:
tf.summary.merge_all()
train_writer = tf.summary.FileWriter('log/area',sess.graph)
保存图。
通过tensorboard --logdir="路径",打开服务,通过输入localhost:6006之类打开网站。
查看生成的图:

tensorflow下识别手写数字基于MLP网络的更多相关文章
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 基于TensorFlow的MNIST手写数字识别-初级
一:MNIST数据集 下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...
- 3 TensorFlow入门之识别手写数字
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- TensorFlow实现Softmax Regression识别手写数字中"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败”问题
出现问题: 在使用TensorFlow实现MNIST手写数字识别时,出现"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应 ...
- 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)
笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...
- 6 TensorFlow实现cnn识别手写数字
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字
TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...
- TensorFlow实战之Softmax Regression识别手写数字
关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...
随机推荐
- MongoDB之Limit选取Skip跳过Sort排序
1.Limit选取 我要从Document中取出多少个 只要2条Document db.Wjs.find().limit(2) 2.Skip跳过 我要跳过多少个Document 我要跳过前两个Docu ...
- http://www.bugku.com:Bugku——SQL注入1(http://103.238.227.13:10087/)
Bugku——SQL注入1(http://103.238.227.13:10087/) 过滤了几乎所有的关键字,尝试绕过无果之后发现,下面有个xss过滤代码.经搜索得该函数会去掉所有的html标签,所 ...
- 转移动APP测试实践
http://blog.csdn.net/hgstclyh/article/details/53115325
- 第一个springMVC小程序
1首先配置一个前端控制器,在web.xml文件中配置(dispatcherservlet) <!-- 前端控制器配置 --> <servlet> <servlet-nam ...
- TZOJ 3665 方格取数(2)(最大点权独立集)
描述 给你一个m*n的格子的棋盘,每个格子里面有一个非负数. 从中取出若干个数,使得任意的两个数所在的格子没有公共边,就是说所取数所在的2个格子不能相邻,并且取出的数的和最大. 输入 包括多个测试实例 ...
- java 空格替换%20
public String replaceSpace(StringBuffer str2) { StringBuffer str4 = new StringBuffer(); int length=s ...
- 【centos】centos命令总结(持续更新)
1.查看系统版本命令 转自:https://blog.csdn.net/networken/article/details/79771212 .查看内核版本 [root@localhost ~]# u ...
- webpack.dev.conf.js
var utils = require('./utils')var webpack = require('webpack')var config = require('../config') // 一 ...
- git查看历史操作
在提交了若干更新,又或者克隆了某个项目之后,偶尔想回顾下过往提交历史.可以使用git log命令来实现. 最简单的查看提交历史命令如下: $ git log $ git log --oneline $ ...
- 队列 和 线程 之GCD dispatch
1.dispatch_queue_create 创建队列开启异步线程(1,4,2,3) // 创建一个队列 dispatch_queue_t queue = dispatch_queue_creat ...