使用tensorflow实现mnist手写识别(单层神经网络实现)
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets("data/",one_hot = True)
#导入Tensorflwo和mnist数据集 #构建输入层
x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y') #隐藏层神经元数量
H1_NN = 256 #第一层神经元数量
W1 = tf.Variable(tf.random_normal([784,H1_NN])) #权重
b1 = tf.Variable(tf.zeros([H1_NN])) #偏置项
Y1 = tf.nn.relu(tf.matmul(x,W1)+b1) #第一层输出
W2 = tf.Variable(tf.random_normal([H1_NN,10]))#权重
b2 = tf.Variable(tf.zeros(10))#偏置项 forward = tf.matmul(Y1,W2)+b2 #定义前向传播
pred = tf.nn.softmax(forward) #激活函数输出 #损失函数
#loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
# reduction_indices=1))
#(log(0))超出范围报错 loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=forward,labels=y)) #训练参数
train_epochs = 50 #训练次数
batch_size = 50 #每次训练多少个样本
total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
display_step = 1 #训练情况输出
learning_rate = 0.01 #学习率 #优化器
opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function) #准确率函数
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #记录开始训练时间
from time import time
startTime = time()
#初始化变量
sess =tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
#训练
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练 #total_batch个批次训练完成后,使用验证数据计算误差与准确率
loss,acc=sess.run([loss_function,accuracy],
feed_dict={x:mnist.validation.images,
y:mnist.validation.labels})
#输出训练情况
if(epoch+1) % display_step == 0:
print("Train Epoch:",'%02d' % (epoch + 1),
"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
duration = time()-startTime
print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时 #由于pred预测结果是one_hot编码格式,所以需要转换0~9数字
prediction_resul = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) prediction_resul[0:10] #模型评估
accu_test = sess.run(accuracy,
feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Accuray:",accu_test) compare_lists = prediction_resul == np.argmax(mnist.test.labels,1)
print(compare_lists)
err_lists = [i for i in range(len(mnist.test.labels)) if compare_lists[i] == False]
print(err_lists,len(err_lists)) index_list = []
def print_predct_errs(labels,#标签列表
perdiction):#预测值列表
count = 0
compare_lists = (perdiction == np.argmax(labels,1))
err_lists = [i for i in range(len(labels)) if compare_lists[i] == False]
for x in err_lists:
index_list.append(x)
print("index="+str(x)+
"标签值=",np.argmax(labels[x]),
"预测值=",perdiction[x])
count = count+1
print("总计:",count)
return index_list print_predct_errs(mnist.test.labels,prediction_resul) def plot_images_labels_prediction(images,labels,prediction,index,num=25):
fig = plt.gcf() # 获取当前图片
fig.set_size_inches(10,12)
if num>=25:
num=25 #最多显示25张图片
for i in range(0,num):
ax = plt.subplot(5,5, i+1) #获取当前要处理的子图 ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')#显示第index个图像
title = 'label=' + str(np.argmax(labels[index]))#构建该图上要显示的title
if len(prediction)>0:
title += 'predict= '+str(prediction[index]) ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
index += 1
plt.show() plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_resul,index=index_list[100])
单纯记录一下个人代码,很基础的一个MNIST手写识别使用Tensorflwo实现,算是入门的Hello world 了,有些奇怪的问题暂时没有解决 训练次数调成40 在训练到第35次左右发生了梯度爆炸,原因未知,损失函数要使用带softmax那个,不然也会发生梯度爆炸
使用tensorflow实现mnist手写识别(单层神经网络实现)的更多相关文章
- 基于tensorflow的MNIST手写识别
这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...
- 基于tensorflow实现mnist手写识别 (多层神经网络)
标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...
- Tensorflow之基于MNIST手写识别的入门介绍
Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...
- TensorFlow 入门之手写识别(MNIST) softmax算法
TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- tensorflow笔记(四)之MNIST手写识别系列一
tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...
- tensorflow笔记(五)之MNIST手写识别系列二
tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解
好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前, ...
随机推荐
- vcenter server appliance(vcsa) 配置IP的方法
方法一: vcenter server appliance 5.1 及以后版本包括5.5,在安装完毕后,console界面是没有网络配置项的,如果需要进行IP配置,可以login后,输入命令yast( ...
- OWASP TOP10(2017)
原文链接:https://www.t00ls.net/viewthread.php?from=notice&tid=39385
- SDN2017 第四次作业
1.阅读 了解SDN控制器的发展 http://www.sdnlab.com/13306.html http://www.docin.com/p-1536626509.html 了解ryu控制器 ht ...
- 乐视4.14硬件免费日de用户体验
此贴用于记录2016年4月14日乐视硬件免费日购买X65超级电视的用户体验.后续将动态更新 我是乐视电视的第一批用户,从乐视上市第一批超级电视,我先后帮助家人.同事.朋友买了6台乐视超级电视,也算是乐 ...
- apache出现You don’t have permission to access / on this server问题的解决
今天在部署一个系统时,在apache中新开了一个VirtualHost,然后设置了DocumentRoot,等访问时却提示“You don’t have permission to access / ...
- 网络对抗技术作业一 P201421410029
网络对抗技术作业一 14网安一区李政浩 201421410029 虚拟机 xp 虚拟机Windows xp的 ip地址 本机win10 IP xp虚拟机与主机ping Dir显示目录 Cd进入目录 A ...
- 测试udp服务的端口是否可用
测试tcp服务的端口是否可用,可以使用: telnet ip port 但是如果这个用在upd服务上,就会报错, 因为telnet走的是tcp协议, 比如说192.168.80.131在8888端 ...
- shiro实战系列(九)之Web
一.Configuration(配置) 将 Shiro 集成到任何 Web 应用程序的最简单的方法是在 web.xml 中配置 ContextListener 和 Filter,理解如何读取 Shir ...
- Hibernate Validator注解大全
hibernate Validator 是 Bean Validation 的参考实现 .Hibernate Validator 提供了 JSR 303 规范中所有内置 constraint 的实现, ...
- greys java在线诊断工具
greys是一个开源的github项目,用来分析运行中的java类.方法等信息. greys工具地址: https://github.com/oldmanpushcart/greys-anatomy/ ...