TensorFlow入门——MNIST深入
#load MNIST data
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) #start tensorflow interactiveSession
import tensorflow as tf
sess = tf.InteractiveSession() #weight initilization
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial) def bias_variable(shape):
initial = tf.constant(0.1, shape= shape)
return tf.Variable(initial) #convolution
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') #pooling
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME') #Create the model
#placeholder
x = tf.placeholder("float",[None, 784])
y_ = tf.placeholder("float", [None, 10]) #variable
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W) +b) #first convolutional layer
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32]) x_image = tf.reshape(x,[-1,28,28,1]) h_conv1 =tf.nn.relu(conv2d(x_image,w_conv1) + b_conv1)
h_pool1 =max_pool_2x2(h_conv1) #second convolutional layer
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64]) h_conv2 =tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 =max_pool_2x2(h_conv2) #densely connected layer
w_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1,7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) #dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) #readout layer
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10]) y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2) + b_fc2) #train and evaluate the model
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
#train_step = tf.train.AdagradOptimizer(1e-4).minimize(cross_entropy)
train_step = tf.train.GradientDescentOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(5000):
batch = mnist.train.next_batch(50)
if i%100 ==0:
train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1], keep_prob:1.0})
print "step %d, train accuracy %g " %(i,train_accuracy)
train_step.run(feed_dict={x:batch[0],y_:batch[1], keep_prob:0.5}) print "test accuracy %g" % accuracy.eval(fedd_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0})
同样是极客学院的课程,其实也是翻译的国外的robot-ai博客上的内容,但是这个博客,现在打不开了,可能是墙的问题?没有太深究。
按照作者的说法,是采用自适应下降的方式,在train阶段能达到99%的正确率,但是,我的结果只有93%左右,修改梯度步长到1e-4也只有94% 。因此尝试换用原来的梯度下降方式,反而能获得97.61%的正确率,在训练中还达到过98%,这个问题比较无奈,修改步长的结果提升也并不明显。有人在评论中说在不同的平台上测试的值不同,比如在纯CPU环境,和我的结果比较相似。在K20环境中能达到99%,这个问题留待以后探索。代码参考至:文章链接: http://blog.csdn.net/yhl_leo/article/details/50624471
TensorFlow入门——MNIST深入的更多相关文章
- TensorFlow入门——MNIST初探
import tensorflow.examples.tutorials.mnist.input_data as input_data import tensorflow as tf mnist = ...
- TensorFlow 入门之手写识别(MNIST) softmax算法
TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- TensorFlow入门之MNIST最佳实践
在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...
- TensorFlow入门之MNIST最佳实践-深度学习
在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...
- TensorFlow 入门之手写识别(MNIST) softmax算法 二
TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- TensorFlow 入门之手写识别(MNIST) 数据处理 一
TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...
- 统计学习方法:罗杰斯特回归及Tensorflow入门
作者:桂. 时间:2017-04-21 21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...
随机推荐
- linux用户和组账户管理
linux操作系统是一个多用户操作系统,它允许多用户同时登录到系统上并使用资源.系统会根据账户来区分每个用户的文件,进程,任务和工作环境,使得每个用户工作都不受干扰. 用户账户 A.普通用户账户:普通 ...
- ElementTree 无法处理中文
ElementTree.fromstring() 导入xml格式时,是可以选择解析parser的编码的,并且 它解析出来的node类型是 严谨且严格的,不会 自己内部全部转换成str,比如 9.87 ...
- 关于springmvc的包含list提交的格式
<%-- Created by IntelliJ IDEA. User: jh Date: 2017/7/12 Time: 14:31 To change this template use F ...
- pyqt5-实时刷新页面(QApplication.processEvents())
对于执行很耗时的程序来说,由于PyQt需要等待程序执行完毕才能进行下一步,这个过程表现在界面上就是卡顿,而如果需要执行这个耗时程序时不断的刷新界面.那么就可以使用QApplication.proces ...
- 如何从项目中移除CocoaPods
一.项目Show in Finder: 删除本地文件(Podfile.Podfile.lock.Pods文件夹) 删除本地生成的xcworkspace文件 打开项目,在Frameworks文件夹下,删 ...
- 【转】Java从hdfs上读取文件中的某一行
[From]https://blog.csdn.net/u010989078/article/details/51790166 package test; import java.io.Buffere ...
- java:面向对象(多态,final,抽象方法,(简单工厂模式即静态方法模式),接口)
* 生活中的多态:同一种物质,因环境不同而表现不同的形态. * 程序中多态:同一个"接口",因不同的实现而执行不同的操作. * 多态和方法的重写经常结合使用,子类重写父类的方法,将 ...
- 【css】纯css实现文字循环滚动效果
不用js来实现. html: <div class="box"> <p class="animate"> 文字滚动的内容文字滚动的内容文 ...
- PJzhang:python基础入门的7个疗程-three
猫宁!!! 参考链接:易灵微课-21天轻松掌握零基础python入门必修课-售价29元人民币 https://www.liaoxuefeng.com/wiki/1016959663602400 第七天 ...
- session到底是何时何地生成的
关于session,之前只是在用,从没考虑到底怎么生成的 今天有空我做了个实验,把监控了一下访问某网站第一二次的请求响应详细信息,终于搞明白了,好了,开始放图 这里发起一个请求,然后我们看下第一次请 ...