1.工程目录

2.导入data和input_data.py

链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA 
提取码:4nnl

3.CNN.py

import tensorflow as tf
import matplotlib.pyplot as plt
import input_data mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('MNIST ready') n_input = 784
n_output = 10 weights = {
'wc1': tf.Variable(tf.truncated_normal([3, 3, 1, 64], stddev=0.1)),
'wc2': tf.Variable(tf.truncated_normal([3, 3, 64, 128], stddev=0.1)),
'wd1': tf.Variable(tf.truncated_normal([7*7*128, 1024], stddev=0.1)),
'wd2': tf.Variable(tf.truncated_normal([1024, n_outpot], stddev=0.1)),
}
biases = {
'bc1': tf.Variable(tf.random_normal([64], stddev=0.1)),
'bc2': tf.Variable(tf.random_normal([128], stddev=0.1)),
'bd1': tf.Variable(tf.random_normal([1024], stddev=0.1)),
'bd2': tf.Variable(tf.random_normal([n_outpot], stddev=0.1)),
} def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].get_shape().as_list()[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out print('CNN READY') x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])
keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, weights, biases, keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y))
optm = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
_corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32))
init = tf.global_variables_initializer() print('GRAPH READY') sess = tf.Session()
sess.run(init)
training_epochs = 15
batch_size = 16
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys, keepratio: 0.7})
avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.0})/total_batch if epoch % display_step == 0:
print('Epoch: %03d/%03d cost: %.9f' % (epoch, training_epochs, avg_cost))
train_acc = sess.run(accr, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.})
print('Training accuracy: %.3f' % (train_acc)) res_dict = {'weight': sess.run(weights), 'biases': sess.run(biases)} import pickle
with open('res_dict.pkl', 'wb') as f:
pickle.dump(res_dict, f, pickle.HIGHEST_PROTOCOL)

4.test.py

import pickle
import numpy as np def load_file(path, name):
with open(path+''+name+'.pkl', 'rb') as f:
return pickle.load(f) res_dict = load_file('', 'res_dict')
print(res_dict['weight']['wc1']) index = 0 import input_data
mnist = input_data.read_data_sets('data/', one_hot=True) test_image = mnist.test.images
test_label = mnist.test.labels import tensorflow as tf def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].shape[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out n_input = 784
n_output = 10 x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output]) keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, res_dict['weight'], res_dict['biases'], keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y)) _corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32)) init = tf.global_variables_initializer() sess = tf.Session()
sess.run(init)
training_epochs = 1
batch_size = 1
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) if epoch % display_step == 0:
print('_pre:', np.argmax(sess.run(_pred, feed_dict={x: batch_xs, keepratio: 1. })))
print('answer:', np.argmax(batch_ys))

python,tensorflow,CNN实现mnist数据集的训练与验证正确率的更多相关文章

  1. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  2. tensorflow读取本地MNIST数据集

    tensorflow读取本地MNIST数据集 数据放入文件夹(不要解压gz): >>> import tensorflow as tf >>> from tenso ...

  3. Ubuntu14.04+caffe+cuda7.5 环境搭建以及MNIST数据集的训练与测试

    Ubuntu14.04+caffe+cuda 环境搭建以及MNIST数据集的训练与测试 一.ubuntu14.04的安装: ubuntu的安装是一件十分简单的事情,这里给出一个参考教程: http:/ ...

  4. 十折交叉验证10-fold cross validation, 数据集划分 训练集 验证集 测试集

    机器学习 数据挖掘 数据集划分 训练集 验证集 测试集 Q:如何将数据集划分为测试数据集和训练数据集? A:three ways: 1.像sklearn一样,提供一个将数据集切分成训练集和测试集的函数 ...

  5. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  6. 学习TensorFlow,邂逅MNIST数据集

    如果说"Hello Word!"是程序员的第一个程序,那么MNIST数据集,毫无疑问是机器学习者第一个训练的数据集,本文将使用Google公布的TensorFLow来学习训练MNI ...

  7. mnist的格式说明,以及在python3.x和python 2.x读取mnist数据集的不同

    有一个关于mnist的一个事例可以参考,我觉得写的很好:http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html #!/usr/bin/ ...

  8. Python Tensorflow CNN 识别验证码

    Python+Tensorflow的CNN技术快速识别验证码 文章来源于: https://www.jianshu.com/p/26ff7b9075a1 验证码处理的流程是:验证码分析和处理—— te ...

  9. TensorFlow CNN 测试CIFAR-10数据集

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...

随机推荐

  1. How to manage local libraries in IntelliJ IDEA

    如何在 IntelliJ IDEA 中管理本地类库 一般来说,如果项目是基于 Maven 管理工具的,我们会在 pom.xml 中添加 dependency 来管理依赖.但有时也会遇到要用的类库不在 ...

  2. Visual Studio各个版本对应关系

  3. Django分页类的封装

    Django分页类的封装 Django ORM  封装 之前有提到(Django分页的实现)会多次用到分页,将分页功能封装起来能极大提高效率. 其实不是很难,就是将之前实现的代码全都放到类中,将需要用 ...

  4. testng XMl 参数化

    方法一: 方法二: 方法三: (1)如果测试的数据较多的情况下,很显然这种方式不适合,那么可以通过@DataProvider生成测试数据,通过@Test(dataProvider = "&q ...

  5. 修改Eclipse jdk环境

    原因:由于项目原因,要将原有的工程从jdk1.6迁移到jdk1.7 问题:Eclipse默认的jdk环境为jdk1.6 解决方法: 1)首先是安装jdk1.7,以及配置环境变量,在这里就不再说了 2) ...

  6. Webapps初步_认识HTTP例子程序读取

    package servlet_01; import java.io.BufferedReader; import java.io.InputStreamReader; import java.io. ...

  7. ES6-Async & 异步

    依赖文件地址 :https://github.com/chanceLe/ES6-Basic-Syntax/tree/master/js <!DOCTYPE html> <html&g ...

  8. vue实现非路由跳转以及数据传递

    定义一个父组件 <template> <v-layout> <v-card contextual-style="dark" v-if="sh ...

  9. CPU漏洞补丁修复导致KeServiceDescriptorTable获取变更

    一.前言 2018年元旦,出现的cpu的漏洞,可以在windows环三直接读取内核数据,windows对该漏洞提供补丁,补丁增加了一个页表,对应的内核处理也增加了,接下来我们看下补丁修复的表象以及对K ...

  10. code EINTEGRITY,npm安装时候报错

    解决方法: 1.如果有package-lock.json文件,就删掉 2.管理员权限进入cmd 3.执行npm cache clean --force 4.之后再npm install 有时候网不好也 ...