TensorFlow入门——MNIST深入
#load MNIST data
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) #start tensorflow interactiveSession
import tensorflow as tf
sess = tf.InteractiveSession() #weight initilization
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial) def bias_variable(shape):
initial = tf.constant(0.1, shape= shape)
return tf.Variable(initial) #convolution
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') #pooling
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME') #Create the model
#placeholder
x = tf.placeholder("float",[None, 784])
y_ = tf.placeholder("float", [None, 10]) #variable
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W) +b) #first convolutional layer
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32]) x_image = tf.reshape(x,[-1,28,28,1]) h_conv1 =tf.nn.relu(conv2d(x_image,w_conv1) + b_conv1)
h_pool1 =max_pool_2x2(h_conv1) #second convolutional layer
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64]) h_conv2 =tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 =max_pool_2x2(h_conv2) #densely connected layer
w_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1,7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) #dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) #readout layer
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10]) y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2) + b_fc2) #train and evaluate the model
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
#train_step = tf.train.AdagradOptimizer(1e-4).minimize(cross_entropy)
train_step = tf.train.GradientDescentOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(5000):
batch = mnist.train.next_batch(50)
if i%100 ==0:
train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1], keep_prob:1.0})
print "step %d, train accuracy %g " %(i,train_accuracy)
train_step.run(feed_dict={x:batch[0],y_:batch[1], keep_prob:0.5}) print "test accuracy %g" % accuracy.eval(fedd_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0})
同样是极客学院的课程,其实也是翻译的国外的robot-ai博客上的内容,但是这个博客,现在打不开了,可能是墙的问题?没有太深究。
按照作者的说法,是采用自适应下降的方式,在train阶段能达到99%的正确率,但是,我的结果只有93%左右,修改梯度步长到1e-4也只有94% 。因此尝试换用原来的梯度下降方式,反而能获得97.61%的正确率,在训练中还达到过98%,这个问题比较无奈,修改步长的结果提升也并不明显。有人在评论中说在不同的平台上测试的值不同,比如在纯CPU环境,和我的结果比较相似。在K20环境中能达到99%,这个问题留待以后探索。代码参考至:文章链接: http://blog.csdn.net/yhl_leo/article/details/50624471
TensorFlow入门——MNIST深入的更多相关文章
- TensorFlow入门——MNIST初探
import tensorflow.examples.tutorials.mnist.input_data as input_data import tensorflow as tf mnist = ...
- TensorFlow 入门之手写识别(MNIST) softmax算法
TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- TensorFlow入门之MNIST最佳实践
在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...
- TensorFlow入门之MNIST最佳实践-深度学习
在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...
- TensorFlow 入门之手写识别(MNIST) softmax算法 二
TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- TensorFlow 入门之手写识别(MNIST) 数据处理 一
TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...
- 统计学习方法:罗杰斯特回归及Tensorflow入门
作者:桂. 时间:2017-04-21 21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...
随机推荐
- js如何获取鼠标位置
获取鼠标位置,首先需要加载js文件: 然后设置一个div,给定大小: 最后进行具体操作: //首先要先设置一个div,给定大小 <div id="m"></div ...
- nginx目录及配置语法
一.Nginx安装目录 1.查看安装目录. 采用yum的方式安装,其实都是安装的一个一个的 pm 包,故可采用如下命令查看 rpm -ql nginx 遵循了 rpm 包管理规范. 2.安装目录详解 ...
- python内存泄露memory leak排查记录
问题描述 A服务,是一个检测MGR集群主节点是否发生变化的服务,使用python语言实现的. 针对每个集群,主线程会创建一个子线程,并由子线程去检测.子线程会频繁的创建和销毁. 上线以后,由于经常会有 ...
- flask_sqlalchemy的session线程安全源码解读
flask_sqlalchemy是如何在多线程中对数据库操作不相互影响 数据库操作隔离 结论:使用scoped_session实现数据库操作隔离 flask的api.route()接收一个请求,就会创 ...
- 【巨坑】springmvc 输出json格式数据的几种方式!
最近公司项目需要发布一些数据服务,从设计到实现两天就弄完了,心中窃喜之. 结果临近部署时突然发现..... 服务输出的JSON 数据中 date 类型数据输出格式要么是时间戳,要么是 {&quo ...
- python 连接oracle数据库:cx_Oracle
注意:64位操作系统必须安装64位oracle客户端,否则会连接报错 安装第三方库:cx_Oracle 一.连接数据库的三种方式: 1.cx_Oracle.connect('账号/密码@ip:端口/数 ...
- 阶段3 3.SpringMVC·_05.文件上传_1 文件上传之上传原理分析和搭建环境
分成几个部分 里面可能就包含文件上传的值 提交方式要改成post 第三个就是提供一个input file的文件选择域 新建项目 新建一个项目 当前项目没有父工程 跳过联网下载 改成02 构建 编译和目 ...
- SQL Server 等待统计信息基线收集
背景 我们随时监控每个服务器不同时间段的wait statistics ,可以根据监控信息大概判断什么时候开始出现异常,相当于一个wait statistics基线收集,还可以具体分析占比高的等待类型 ...
- style属性
style加样式是加在行间,取样式也是在行间取: 我们来看下面这段代码: <!DOCTYPE HTML> <html> <head> <meta charse ...
- 使用robotframework做接口测试之一——准备工作
最近发现做接口测试的朋友越来越多了,打算写一个系列的rf+requests做接口测试(主要是Http接口)的文档,可以帮助新入门的同学对接口测试有个大概的了解,同时也是敦促自己做总结的一种手段.希望经 ...