import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# number 1 to 10 data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) def compute_accuracy(v_xs, v_ys):
global prediction
y_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1})
correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1})
return result def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') xs = tf.placeholder(tf.float32, [None, 784]) # 28x28
ys = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
x_image = tf.reshape(xs, [-1, 28, 28, 1]) ## conv1 layer;
W_conv1 = weight_variable([5,5, 1,32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
W_conv2 = weight_variable([5,5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2) W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) # the error between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) sess = tf.Session()
# important step
sess.run(tf.global_variables_initializer()) for i in range(10):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})
if i % 50 == 0:
print(compute_accuracy(
mnist.test.images, mnist.test.labels))

TF:TF下CNN实现mnist数据集预测 96%采用placeholder用法+2层C及其max_pool法+隐藏层dropout法+输出层softmax法+目标函数cross_entropy法+AdamOptimizer算法的更多相关文章

  1. 卷积神经网络CNN识别MNIST数据集

    这次我们将建立一个卷积神经网络,它可以把MNIST手写字符的识别准确率提升到99%,读者可能需要一些卷积神经网络的基础知识才能更好的理解本节的内容. 程序的开头是导入TensorFlow: impor ...

  2. CNN完成mnist数据集手写数字识别

    # coding: utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data d ...

  3. 基于 tensorflow 的 mnist 数据集预测

    1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...

  4. python,tensorflow,CNN实现mnist数据集的训练与验证正确率

    1.工程目录 2.导入data和input_data.py 链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA 提取码:4nnl 3.CNN.py i ...

  5. Python实现bp神经网络识别MNIST数据集

    title: "Python实现bp神经网络识别MNIST数据集" date: 2018-06-18T14:01:49+08:00 tags: [""] cat ...

  6. TF之RNN:基于顺序的RNN分类案例对手写数字图片mnist数据集实现高精度预测—Jason niu

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  7. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  8. TF:TF分类问题之MNIST手写50000数据集实现87.4%准确率识别:SGD法+softmax法+cross_entropy法—Jason niu

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # number 1 to 10 ...

  9. TensorFlow——CNN卷积神经网络处理Mnist数据集

    CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...

随机推荐

  1. Confluence 6 发送 Confluence 通知到其他 Confluence 服务器

    你可以配置 Confluence 服务器向其他的 Confluence 服务器发送消息.在这种情况下,Confluence 服务器将不会显示 workbox. 希望发送消息到其他 Confluence ...

  2. Hive shell 基本命令

    首先连接 hive shell 直接输入 hive启动, 使用--开头的字符串来表示注释 hive>quit; --退出hive hive> exit; --exit会影响之前的使用,所以 ...

  3. Vue+restfulframework示例

    一.简单回顾vue 前不久我们已经了解了vue前端框架,所以现在强调几点: 修改源: npm config set registry https://registry.npm.taobao.org 创 ...

  4. 【python】threadpool的内存占用问题

    先说结论: 在使用多线程时,不要使用threadpool,应该使用threading, 尤其是数据量大的情况.因为threadpool会导致严重的内存占用问题! 对比threading和threadp ...

  5. 【python】gearman阻塞非阻塞,同步/异步,状态

    参考: http://pythonhosted.org/gearman/client.html?highlight=submit_multiple_jobs#gearman.client.Gearma ...

  6. python(6):Scipy之pandas

    pandas下面的数据结构是Series , DataFrame 在字典中, key 与 value对应, 但是key value 不是独立的, 但是在Series 中index 与value 是独立 ...

  7. hdu4612 卡cin e-DCC缩点

    /* 给定无向图,求加入一条边后最少剩下多少桥 */ #include<bits/stdc++.h> using namespace std; #define maxn 200005 #d ...

  8. Nginx详解二十三:Nginx深度学习篇之Nginx+Lua开发环境搭建

    Nginx+Lua开发环境 1.下载LuaJIT解释器wget http://luajit.org/download/LuaJIT-2.0.2.tar.gztar -zxvf LuaJIT-2.0.2 ...

  9. Mysql 5.7 CentOS 7 安装MHA

    Table of Contents 1. MHA简介 1.1. 功能 1.2. MHA切换逻辑 1.3. 工具 2. 环境 2.1. 软件 2.2. 环境 3. Mysql 主从复制 3.1. Mys ...

  10. Android 网络请求框架

    1.okHttp 特点 简单.灵活.无连接.无状态 优势: 谷歌官方API在6.0之后在Android SDK中移除了HttpClient,然后他火了起来, 他支持SPDY(谷歌开发的基于TCP应用层 ...