TensorFlow——深入MNIST
程序(有些不甚明白的地方改日修订):
# _*_coding:utf-8_*_ import inputdata
mnist = inputdata.read_data_sets('MNIST_data', one_hot=True) # mnist是一个以numpy数组形式存储训练、验证和测试数据的轻量级类 import tensorflow as tf
sess = tf.InteractiveSession() x = tf.placeholder("float",shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10])) sess.run(tf.initialize_all_variables()) y = tf.nn.softmax(tf.matmul(x,W)+b) # nn:neural network # 代价函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 最优化算法
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 会更新权值 for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x:batch[0], y_:batch[1]}) # 可以用feed_dict来替代任何张量,并不仅限于替换placeholder correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) print accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels}) # 构建多层卷积网络模型 # 初始化W,b的函数
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1) # truncated_normal表示的是截断的正态分布
return tf.Variable(initial) def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) # 卷积和池化
def conv2d(x, W): # 卷积用原版,1步长0边距
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') def max_pool_2x2(x): # 池化用传统的2*2模板做max polling
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') # 第一层卷积
W_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1,28,28,1]) h_conv1= tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1) # 第二层卷积
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2) # 密集连接层
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) # dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) # 输出层
W_fc2= weight_variable([1024, 10])
b_fc2 = bias_variable([10]) y_conv= tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) # 训练和评估模型
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(20000):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0})
print "step %d, training accuracy %g" %(i, train_accuracy)
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5}) print "test accuracy %g" % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
运行结果:
0.9092
step 0, training accuracy 0.08
step 100, training accuracy 0.9
step 200, training accuracy 0.94
step 300, training accuracy 0.98
step 400, training accuracy 0.98
step 500, training accuracy 0.9
step 600, training accuracy 0.96
step 700, training accuracy 0.96
step 800, training accuracy 0.96
step 900, training accuracy 0.94
step 1000, training accuracy 0.98
step 1100, training accuracy 1
step 1200, training accuracy 0.92
step 1300, training accuracy 0.96
step 1400, training accuracy 0.92
step 1500, training accuracy 0.98
...明天早上跑出来再贴
TensorFlow——深入MNIST的更多相关文章
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- Ubuntu16.04安装TensorFlow及Mnist训练
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com TensorFlow是Google开发的开源的深度学习框架,也是当前使用最广泛的深度学习框架. 一.安 ...
- 一个简单的TensorFlow可视化MNIST数据集识别程序
下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 使用Tensorflow操作MNIST数据
MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...
- TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架
TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html
- 2、TensorFlow训练MNIST
装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...
- Tensorflow之MNIST的最佳实践思路总结
Tensorflow之MNIST的最佳实践思路总结 在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...
- TensorFlow训练MNIST报错ResourceExhaustedError
title: TensorFlow训练MNIST报错ResourceExhaustedError date: 2018-04-01 12:35:44 categories: deep learning ...
随机推荐
- Android Doze模式源码分析
科技的仿生学无处不在,给予我们启发.为了延长电池是使用寿命,google从蛇的冬眠中得到体会,那就是在某种情况下也让手机进入类冬眠的情况,从而引入了今天的主题,Doze模式,Doze中文是打盹儿,打盹 ...
- [make error ]ubuntu显示不全
make时候,输出到文件里 make >&makelog 就会自动出现一个makelog 会慢一些,不要急.
- UVA 177 PaperFolding 折纸痕 (分形,递归)
著名的折纸问题:给你一张很大的纸,对折以后再对折,再对折……每次对折都是从右往左折,因此在折了很多次以后,原先的大纸会变成一个窄窄的纸条.现在把这个纸条沿着折纸的痕迹打开,每次都只打开“一半”,即把每 ...
- Android(java)学习笔记131:关于构造代码块,构造函数的一道面试题(华为面试题)
1. 代码实例: package text; public class TestStaticCon { public static int a = 0; static { a = 10; System ...
- HTML之元素分类
一.元素展示类型 在HTML本身定义了很多元素,这些元素在网页上展示的时候都会有自己的默认状态,例如有些元素在默认状态下对高宽的属性设置不起作用,有些元素都默认情况下都独立一行显示,这种现象我们称之为 ...
- 【转】树莓派3代3.5寸触摸屏驱动的安装(通过ssh安装)
这是用到的配件的树莓派3代 烧录好系统后,启动的树莓派,我的树莓派已经在一开始通过路由器和局域网,登陆了ssh,设置好了开机就能自动连接到电脑的360wifi,所以无论到哪 里,只要自己的笔记本电脑还 ...
- Luogu P2073 送花
权值线段树的模板题 然而AC后才发现,可以用\(\tt{set}\)水过-- 权值线段树类似于用线段树来实现平衡树的一些操作,代码实现还是比较方便的 #include<iostream> ...
- 作业题:输出单个字符 输入单个字符 scanf printf
输出单个字符用putchar() #include <iostream> using namespace std; int main(){ char x='B'; char y='O'; ...
- cocos2d-x中解决暂停并保存画面和开始的功能
1.调用所有对象的pauseSchedulerAndActions().太麻烦,不太现实,而且有很多对象不易获取. 2.CCDirector::sharedirector()->pause(). ...
- centos7重启后/etc/resolv.conf 被还原解决办法
每次重启服务器后,/etc/resolv.conf文件就被自动还原了,最后发现是被Network Manager修改了. 查看Network Manager服务状态 systemctl status ...