TensorFlow——深入MNIST
程序(有些不甚明白的地方改日修订):
# _*_coding:utf-8_*_ import inputdata
mnist = inputdata.read_data_sets('MNIST_data', one_hot=True) # mnist是一个以numpy数组形式存储训练、验证和测试数据的轻量级类 import tensorflow as tf
sess = tf.InteractiveSession() x = tf.placeholder("float",shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10])) sess.run(tf.initialize_all_variables()) y = tf.nn.softmax(tf.matmul(x,W)+b) # nn:neural network # 代价函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 最优化算法
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 会更新权值 for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x:batch[0], y_:batch[1]}) # 可以用feed_dict来替代任何张量,并不仅限于替换placeholder correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) print accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels}) # 构建多层卷积网络模型 # 初始化W,b的函数
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1) # truncated_normal表示的是截断的正态分布
return tf.Variable(initial) def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) # 卷积和池化
def conv2d(x, W): # 卷积用原版,1步长0边距
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') def max_pool_2x2(x): # 池化用传统的2*2模板做max polling
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') # 第一层卷积
W_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1,28,28,1]) h_conv1= tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1) # 第二层卷积
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2) # 密集连接层
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) # dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) # 输出层
W_fc2= weight_variable([1024, 10])
b_fc2 = bias_variable([10]) y_conv= tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) # 训练和评估模型
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(20000):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0})
print "step %d, training accuracy %g" %(i, train_accuracy)
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5}) print "test accuracy %g" % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
运行结果:
0.9092
step 0, training accuracy 0.08
step 100, training accuracy 0.9
step 200, training accuracy 0.94
step 300, training accuracy 0.98
step 400, training accuracy 0.98
step 500, training accuracy 0.9
step 600, training accuracy 0.96
step 700, training accuracy 0.96
step 800, training accuracy 0.96
step 900, training accuracy 0.94
step 1000, training accuracy 0.98
step 1100, training accuracy 1
step 1200, training accuracy 0.92
step 1300, training accuracy 0.96
step 1400, training accuracy 0.92
step 1500, training accuracy 0.98
...明天早上跑出来再贴
TensorFlow——深入MNIST的更多相关文章
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- Ubuntu16.04安装TensorFlow及Mnist训练
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com TensorFlow是Google开发的开源的深度学习框架,也是当前使用最广泛的深度学习框架. 一.安 ...
- 一个简单的TensorFlow可视化MNIST数据集识别程序
下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 使用Tensorflow操作MNIST数据
MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...
- TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架
TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html
- 2、TensorFlow训练MNIST
装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...
- Tensorflow之MNIST的最佳实践思路总结
Tensorflow之MNIST的最佳实践思路总结 在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...
- TensorFlow训练MNIST报错ResourceExhaustedError
title: TensorFlow训练MNIST报错ResourceExhaustedError date: 2018-04-01 12:35:44 categories: deep learning ...
随机推荐
- ionic 2 起航 控件的使用 客户列表场景(四)
接下来,我们的客户列表要怎么刷新数据呢? 我们不会安卓开发,不会ios开发,没关系,我们还有ionic 2.ionic 2的控件 Ion-refresher 轻松帮我们搞掂. <!--下拉刷新- ...
- log4cpp安装使用
1. 主页:http://log4cpp.sourceforge.net“Log4cpp is library of C++ classes for flexible logging to files ...
- sql问题:备份集中的数据库备份与现有的 '办公系统' 数据库不同
解决方法:把备份的数据库从原有的地方先分离,再拷贝一份,在需要还原的服务器上附加到数据库中,在根数据库上点击“还原数据库”,选择需要还原的数据库名称,以及还原的bak备份文件,在选择“选项”,勾选上“ ...
- CentOS-语言设置
查看所有的locale语言 # locale -a # locale -a|grep en 查看当前操作系统使用的语言 # echo $LANG 设置系统locale语言为中文环境(永久生效) # v ...
- HDU - 5491 The Next 2015 ACM/ICPC Asia Regional Hefei Online
从D+1开始,对于一个数x,区间[x,x+lowbit(x))内的数字的二进制位上1的数量整体来说是单调不减的,因此可快速得出1在这个区间的取值范围. 每次判断一下有没有和[s1,s2]有没有交集,一 ...
- 用dfs求解八皇后问题
相信大家都已经很熟悉八皇后问题了,就是指:在8X8格的国际象棋上摆放八个皇后,使其不能互相攻击,即任意两个皇后都不能处于同一行.同一列或同一斜线上,问有多少种摆法.主要思路:按行进行深度优先搜索,在该 ...
- 01_11_Strtus2简单数据验证
01_11_Strtus2简单数据验证 1. 引入struts标签 <%@taglib uri="/struts-tags" prefix="s" %&g ...
- I/O理解
I/O是什么 我的理解I/O就是用于读写的一个流 官方解释:I/O(英语:Input/Output),即输入/输出,通常指数据在内部存储器和外部存储器或其他周边设备之间的输入和输出. node中的io ...
- C++内存管理(effective c++ 04)
阅读effective c++ 04 (30页) 提到的static对象和堆与栈对象.看了看侯老师的内存管理视频1~3.有点深. 了解一下. 目录 1 内存管理 1.1 C++内存管理详解 1.1.1 ...
- 二叉搜索树详解(Java实现)
1.二叉搜索树定义 二叉搜索树,是指一棵空树或者具有下列性质的二叉树: 若任意节点的左子树不空,则左子树上所有节点的值均小于它的根节点的值: 若任意节点的右子树不空,则右子树上所有节点的值均大于它的根 ...