import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets("data/",one_hot = True)
#导入Tensorflwo和mnist数据集 #构建输入层
x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y') #隐藏层神经元数量
H1_NN = 256 #第一层神经元数量
W1 = tf.Variable(tf.random_normal([784,H1_NN])) #权重
b1 = tf.Variable(tf.zeros([H1_NN])) #偏置项
Y1 = tf.nn.relu(tf.matmul(x,W1)+b1) #第一层输出
W2 = tf.Variable(tf.random_normal([H1_NN,10]))#权重
b2 = tf.Variable(tf.zeros(10))#偏置项 forward = tf.matmul(Y1,W2)+b2 #定义前向传播
pred = tf.nn.softmax(forward) #激活函数输出 #损失函数
#loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
# reduction_indices=1))
#(log(0))超出范围报错 loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=forward,labels=y)) #训练参数
train_epochs = 50 #训练次数
batch_size = 50 #每次训练多少个样本
total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
display_step = 1 #训练情况输出
learning_rate = 0.01 #学习率 #优化器
opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function) #准确率函数
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #记录开始训练时间
from time import time
startTime = time()
#初始化变量
sess =tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
#训练
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练 #total_batch个批次训练完成后,使用验证数据计算误差与准确率
loss,acc=sess.run([loss_function,accuracy],
feed_dict={x:mnist.validation.images,
y:mnist.validation.labels})
#输出训练情况
if(epoch+1) % display_step == 0:
print("Train Epoch:",'%02d' % (epoch + 1),
"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
duration = time()-startTime
print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时 #由于pred预测结果是one_hot编码格式,所以需要转换0~9数字
prediction_resul = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) prediction_resul[0:10] #模型评估
accu_test = sess.run(accuracy,
feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Accuray:",accu_test) compare_lists = prediction_resul == np.argmax(mnist.test.labels,1)
print(compare_lists)
err_lists = [i for i in range(len(mnist.test.labels)) if compare_lists[i] == False]
print(err_lists,len(err_lists)) index_list = []
def print_predct_errs(labels,#标签列表
perdiction):#预测值列表
count = 0
compare_lists = (perdiction == np.argmax(labels,1))
err_lists = [i for i in range(len(labels)) if compare_lists[i] == False]
for x in err_lists:
index_list.append(x)
print("index="+str(x)+
"标签值=",np.argmax(labels[x]),
"预测值=",perdiction[x])
count = count+1
print("总计:",count)
return index_list print_predct_errs(mnist.test.labels,prediction_resul) def plot_images_labels_prediction(images,labels,prediction,index,num=25):
fig = plt.gcf() # 获取当前图片
fig.set_size_inches(10,12)
if num>=25:
num=25 #最多显示25张图片
for i in range(0,num):
ax = plt.subplot(5,5, i+1) #获取当前要处理的子图 ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')#显示第index个图像
title = 'label=' + str(np.argmax(labels[index]))#构建该图上要显示的title
if len(prediction)>0:
title += 'predict= '+str(prediction[index]) ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
index += 1
plt.show() plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_resul,index=index_list[100])

单纯记录一下个人代码,很基础的一个MNIST手写识别使用Tensorflwo实现,算是入门的Hello world 了,有些奇怪的问题暂时没有解决 训练次数调成40 在训练到第35次左右发生了梯度爆炸,原因未知,损失函数要使用带softmax那个,不然也会发生梯度爆炸

使用tensorflow实现mnist手写识别(单层神经网络实现)的更多相关文章

  1. 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...

  2. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

  3. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  4. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  5. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  6. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  9. Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

    好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前, ...

随机推荐

  1. react的新手基础知识笔记

    <!DOCTYPE html> <html> <head> <script src="../build/react.js">< ...

  2. 团队作业——Beta冲刺2

    团队作业--Beta冲刺 冲刺任务安排 杨光海天 今日任务:根据冲刺内容,具体分配个人任务,对于冲刺内容做准备 明日任务:图片详情界面的开发 吴松青 今日任务:学习熟悉安卓开发,跟随组员快速了解其代码 ...

  3. sql !=与null

    在写SQL 条件语句是经常用到 不等于‘<>’的筛选条件,此时要注意此条件会将字段为null的数据也当做满足不等于的条件而将数据筛选掉. 例:表A A1  B1 1 0 2 1 3 NUL ...

  4. BZOJ3251:树上三角形(乱搞)

    Description 给定一大小为n的有点权树,每次询问一对点(u,v),问是否能在u到v的简单路径上取三个点权,以这三个权值为边长构成一个三角形.同时还支持单点修改. Input 第一行两个整数n ...

  5. Cobalt Strike深入使用

    System Profiler使用 System Profiler 模块,搜集目标的各类机器信息(操作系统版本,浏览器版本等) Attacks->web drive-by->System ...

  6. linux IP 命令使用举例(转)

    ip 1.作用ip是iproute2软件包里面的一个强大的网络配置工具,它能够替代一些传统的网络管理工具,例如ifconfig.route等,使用权限为超级用户.几乎所有的Linux发行版本都支持该命 ...

  7. 学习Kali Linux必须知道的几点

    Kali Linux 在渗透测试和白帽子方面是业界领先的 Linux 发行版.默认情况下,该发行版附带了大量入侵和渗透的工具和软件,并且在全世界都得到了广泛认可.即使在那些甚至可能不知道 Linux ...

  8. Oracle 表删除操作

    删除表内容(dml):delete from 删除表结构(ddl):drop table xx 清空表(ddl):truncate table 清空整张表,不能回滚,不会产生大量日志文件: 表空间会得 ...

  9. debian jessie 网络设置

    从stable更换到testing后,更新系统apt-get dist-upgrade,然后是等待, 然后不耐烦了不等了,关机! 第二天早上开机apt-get update,找不到源! 用ifconf ...

  10. 数字IC设计入门书单

    首发于观芯志 写文章     数字IC设计入门书单 Forever snow   1 年前 作者:Forever snow链接:你所在领域的入门书单? - 知乎用户的回答来源:知乎著作权归作者所有,转 ...