import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets("data/",one_hot = True)
#导入Tensorflwo和mnist数据集 #构建输入层
x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y') #隐藏层神经元数量
H1_NN = 256 #第一层神经元数量
W1 = tf.Variable(tf.random_normal([784,H1_NN])) #权重
b1 = tf.Variable(tf.zeros([H1_NN])) #偏置项
Y1 = tf.nn.relu(tf.matmul(x,W1)+b1) #第一层输出
W2 = tf.Variable(tf.random_normal([H1_NN,10]))#权重
b2 = tf.Variable(tf.zeros(10))#偏置项 forward = tf.matmul(Y1,W2)+b2 #定义前向传播
pred = tf.nn.softmax(forward) #激活函数输出 #损失函数
#loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
# reduction_indices=1))
#(log(0))超出范围报错 loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=forward,labels=y)) #训练参数
train_epochs = 50 #训练次数
batch_size = 50 #每次训练多少个样本
total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
display_step = 1 #训练情况输出
learning_rate = 0.01 #学习率 #优化器
opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function) #准确率函数
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #记录开始训练时间
from time import time
startTime = time()
#初始化变量
sess =tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
#训练
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练 #total_batch个批次训练完成后,使用验证数据计算误差与准确率
loss,acc=sess.run([loss_function,accuracy],
feed_dict={x:mnist.validation.images,
y:mnist.validation.labels})
#输出训练情况
if(epoch+1) % display_step == 0:
print("Train Epoch:",'%02d' % (epoch + 1),
"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
duration = time()-startTime
print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时 #由于pred预测结果是one_hot编码格式,所以需要转换0~9数字
prediction_resul = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) prediction_resul[0:10] #模型评估
accu_test = sess.run(accuracy,
feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Accuray:",accu_test) compare_lists = prediction_resul == np.argmax(mnist.test.labels,1)
print(compare_lists)
err_lists = [i for i in range(len(mnist.test.labels)) if compare_lists[i] == False]
print(err_lists,len(err_lists)) index_list = []
def print_predct_errs(labels,#标签列表
perdiction):#预测值列表
count = 0
compare_lists = (perdiction == np.argmax(labels,1))
err_lists = [i for i in range(len(labels)) if compare_lists[i] == False]
for x in err_lists:
index_list.append(x)
print("index="+str(x)+
"标签值=",np.argmax(labels[x]),
"预测值=",perdiction[x])
count = count+1
print("总计:",count)
return index_list print_predct_errs(mnist.test.labels,prediction_resul) def plot_images_labels_prediction(images,labels,prediction,index,num=25):
fig = plt.gcf() # 获取当前图片
fig.set_size_inches(10,12)
if num>=25:
num=25 #最多显示25张图片
for i in range(0,num):
ax = plt.subplot(5,5, i+1) #获取当前要处理的子图 ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')#显示第index个图像
title = 'label=' + str(np.argmax(labels[index]))#构建该图上要显示的title
if len(prediction)>0:
title += 'predict= '+str(prediction[index]) ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
index += 1
plt.show() plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_resul,index=index_list[100])

单纯记录一下个人代码,很基础的一个MNIST手写识别使用Tensorflwo实现,算是入门的Hello world 了,有些奇怪的问题暂时没有解决 训练次数调成40 在训练到第35次左右发生了梯度爆炸,原因未知,损失函数要使用带softmax那个,不然也会发生梯度爆炸

使用tensorflow实现mnist手写识别(单层神经网络实现)的更多相关文章

  1. 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...

  2. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

  3. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  4. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  5. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  6. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  9. Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

    好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前, ...

随机推荐

  1. 【转】Nginx学习---深入浅出Nginx的介绍

    [原文]https://www.toutiao.com/i6595428119933354500/ Nginx是一款轻量级的Web服务器.反向代理服务器,由于它的内存占用少,启动极快,高并发能力强,在 ...

  2. 根据字体多少使UILabel自动调节尺寸

    原文:http://blog.csdn.net/enuola/article/details/8559588 在大多属性情况下,给UILabel进行动态数据绑定的时候,往往需要根据字符串的多少,动态调 ...

  3. 基于CNN网络的汉字图像字体识别及其原理

    现代办公要将纸质文档转换为电子文档的需求越来越多,目前针对这种应用场景的系统为OCR系统,也就是光学字符识别系统,例如对于古老出版物的数字化.但是目前OCR系统主要针对文字的识别上,对于出版物的版面以 ...

  4. 【Ansible 文档】【译文】主机清单文件

    Inventory 主机清单文件 Ansible 可以对你的基础设施中多个主机系统同时进行操作.通过选择在Ansible的inventory列出的一部分主机来实现.inventory默认保存在/etc ...

  5. mapreduce设置setMapOutputKeyClass与setMapOutputValueClass原因

    一般的mapreduce的wordcount程序如下: public class WcMapper extends Mapper<LongWritable, Text, Text, LongWr ...

  6. PHP foreach 循环使用"&$val" 地址符“&”

    在熟悉项目代码的时候 看到这样的foreach 循环: foreach($data as &$val){ .... } 第一次看到循环里面使用了地址符“&”,我印象中的这个符号 是直接 ...

  7. Hello Shader之Hello Trangle

    这两天配了一下现代OpenGL的开发环境,同时看了一下基础知识和编程规范 写了一个编译GLSL语言的前端程序和一个Hello trangle的程序 另外,推荐两个资源 1.学习网站Learn Open ...

  8. HDU1875+Prim模板

    https://cn.vjudge.net/problem/HDU-1875 相信大家都听说一个“百岛湖”的地方吧,百岛湖的居民生活在不同的小岛中,当他们想去其他的小岛时都要通过划小船来实现.现在政府 ...

  9. jq中each的中断

    最近在做项目中,遇到jq的each方法中的回调函数里面的break不生效,即通过 jquery 的循环方法进行数组遍历,但是当不符合条件时,怎么跳出当前循环,我们经常会习惯JS中的break和cont ...

  10. jqgrid 设置冻结列

    有时,jqgrid表格的列非常多,而表格的宽度值是固定的,我们需要在表格底部出现滚动条,并且固定前面几个列作为数据参照项,如何实现? 需要用的jqgrid冻结列,步骤如下: 1)设置需要冻结的列属性, ...