import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets("data/",one_hot = True)
#导入Tensorflwo和mnist数据集 #构建输入层
x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y') #隐藏层神经元数量
H1_NN = 256 #第一层神经元数量
W1 = tf.Variable(tf.random_normal([784,H1_NN])) #权重
b1 = tf.Variable(tf.zeros([H1_NN])) #偏置项
Y1 = tf.nn.relu(tf.matmul(x,W1)+b1) #第一层输出
W2 = tf.Variable(tf.random_normal([H1_NN,10]))#权重
b2 = tf.Variable(tf.zeros(10))#偏置项 forward = tf.matmul(Y1,W2)+b2 #定义前向传播
pred = tf.nn.softmax(forward) #激活函数输出 #损失函数
#loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
# reduction_indices=1))
#(log(0))超出范围报错 loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=forward,labels=y)) #训练参数
train_epochs = 50 #训练次数
batch_size = 50 #每次训练多少个样本
total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
display_step = 1 #训练情况输出
learning_rate = 0.01 #学习率 #优化器
opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function) #准确率函数
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #记录开始训练时间
from time import time
startTime = time()
#初始化变量
sess =tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
#训练
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练 #total_batch个批次训练完成后,使用验证数据计算误差与准确率
loss,acc=sess.run([loss_function,accuracy],
feed_dict={x:mnist.validation.images,
y:mnist.validation.labels})
#输出训练情况
if(epoch+1) % display_step == 0:
print("Train Epoch:",'%02d' % (epoch + 1),
"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
duration = time()-startTime
print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时 #由于pred预测结果是one_hot编码格式,所以需要转换0~9数字
prediction_resul = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) prediction_resul[0:10] #模型评估
accu_test = sess.run(accuracy,
feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Accuray:",accu_test) compare_lists = prediction_resul == np.argmax(mnist.test.labels,1)
print(compare_lists)
err_lists = [i for i in range(len(mnist.test.labels)) if compare_lists[i] == False]
print(err_lists,len(err_lists)) index_list = []
def print_predct_errs(labels,#标签列表
perdiction):#预测值列表
count = 0
compare_lists = (perdiction == np.argmax(labels,1))
err_lists = [i for i in range(len(labels)) if compare_lists[i] == False]
for x in err_lists:
index_list.append(x)
print("index="+str(x)+
"标签值=",np.argmax(labels[x]),
"预测值=",perdiction[x])
count = count+1
print("总计:",count)
return index_list print_predct_errs(mnist.test.labels,prediction_resul) def plot_images_labels_prediction(images,labels,prediction,index,num=25):
fig = plt.gcf() # 获取当前图片
fig.set_size_inches(10,12)
if num>=25:
num=25 #最多显示25张图片
for i in range(0,num):
ax = plt.subplot(5,5, i+1) #获取当前要处理的子图 ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')#显示第index个图像
title = 'label=' + str(np.argmax(labels[index]))#构建该图上要显示的title
if len(prediction)>0:
title += 'predict= '+str(prediction[index]) ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
index += 1
plt.show() plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_resul,index=index_list[100])

单纯记录一下个人代码,很基础的一个MNIST手写识别使用Tensorflwo实现,算是入门的Hello world 了,有些奇怪的问题暂时没有解决 训练次数调成40 在训练到第35次左右发生了梯度爆炸,原因未知,损失函数要使用带softmax那个,不然也会发生梯度爆炸

使用tensorflow实现mnist手写识别(单层神经网络实现)的更多相关文章

  1. 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...

  2. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

  3. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  4. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  5. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  6. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  9. Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

    好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前, ...

随机推荐

  1. windows的一些好用命令-自己总结:

    在win+R运行框中:     cmd:进入命令行界面     msconfig:可以查看“系统配置”     msinfo32:查看系统信息     services.msc打开"服务&q ...

  2. October 17th 2017 Week 42nd Tuesday

    We execuse our sloth under the pretext of difficulty. 我们常以困难为由,作为懒惰的借口. The process of my system-tra ...

  3. Git提交分支

    Git提交分支操作 1.git add 命令告诉 Git 开始对这些文件进行跟踪 git add . 2.然后提交 git commit -m'这是注释信息' 3.git pull命令用于从另一个存储 ...

  4. 协程运行原理猜测: async/await

    1.根据await调用链寻找最终的生产者或服务提供者: 2.请求服务: 3.进行执行环境切换,跳出顶层函数(第一个无await修饰的函数),执行后面的语句: 4.服务完成,将服务数据复制给最底层的aw ...

  5. Nowcoder 提高组练习赛-R3

    https://www.nowcoder.com/acm/contest/174#question 今天的题好难呀,只有94个人有分.然后我就爆零光荣 考到一半发现我们班要上物理课,还要去做物理实验( ...

  6. import org.apache.http.xxxxxx 爆红,包不存在之解决办法

    问题如下:import org.apache.http.HttpResponse;import org.apache.http.NameValuePair;import org.apache.http ...

  7. 【2018暑假集训模拟一】Day1题解

    T1准确率 [题目描述] 你是一个骁勇善战.日刷百题的OIer. 今天你已经在你OJ 上提交了y 次,其中x次是正确的,这时,你的准确率是x/y.然而,你最喜欢一个在[0; 1] 中的有理数p/q(是 ...

  8. 用BCP从SQL Server 数据库中导出Excel文件

    BCP(Bulk Copy Program)是一种简单高效的数据传输方式在SQL Server中,其他数据传输方式还有SSIS和DTS. 这个程序的主要功能是从数据库中查询Job中指定step的执行信 ...

  9. 纯css实现弹窗左右垂直居中效果

    1.HTML <div class="container"> <div class="dialog"> <div class=&q ...

  10. JavaScript 删除数组中的对象

    1.获得对象在数组中的下标 function (_arr,_obj) { var len = _arr.length; for(var i = 0; i < len; i++){ if(_arr ...