1.工程目录

2.导入data和input_data.py

链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA 
提取码:4nnl

3.CNN.py

import tensorflow as tf
import matplotlib.pyplot as plt
import input_data mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('MNIST ready') n_input = 784
n_output = 10 weights = {
'wc1': tf.Variable(tf.truncated_normal([3, 3, 1, 64], stddev=0.1)),
'wc2': tf.Variable(tf.truncated_normal([3, 3, 64, 128], stddev=0.1)),
'wd1': tf.Variable(tf.truncated_normal([7*7*128, 1024], stddev=0.1)),
'wd2': tf.Variable(tf.truncated_normal([1024, n_outpot], stddev=0.1)),
}
biases = {
'bc1': tf.Variable(tf.random_normal([64], stddev=0.1)),
'bc2': tf.Variable(tf.random_normal([128], stddev=0.1)),
'bd1': tf.Variable(tf.random_normal([1024], stddev=0.1)),
'bd2': tf.Variable(tf.random_normal([n_outpot], stddev=0.1)),
} def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].get_shape().as_list()[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out print('CNN READY') x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])
keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, weights, biases, keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y))
optm = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
_corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32))
init = tf.global_variables_initializer() print('GRAPH READY') sess = tf.Session()
sess.run(init)
training_epochs = 15
batch_size = 16
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys, keepratio: 0.7})
avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.0})/total_batch if epoch % display_step == 0:
print('Epoch: %03d/%03d cost: %.9f' % (epoch, training_epochs, avg_cost))
train_acc = sess.run(accr, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.})
print('Training accuracy: %.3f' % (train_acc)) res_dict = {'weight': sess.run(weights), 'biases': sess.run(biases)} import pickle
with open('res_dict.pkl', 'wb') as f:
pickle.dump(res_dict, f, pickle.HIGHEST_PROTOCOL)

4.test.py

import pickle
import numpy as np def load_file(path, name):
with open(path+''+name+'.pkl', 'rb') as f:
return pickle.load(f) res_dict = load_file('', 'res_dict')
print(res_dict['weight']['wc1']) index = 0 import input_data
mnist = input_data.read_data_sets('data/', one_hot=True) test_image = mnist.test.images
test_label = mnist.test.labels import tensorflow as tf def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].shape[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out n_input = 784
n_output = 10 x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output]) keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, res_dict['weight'], res_dict['biases'], keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y)) _corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32)) init = tf.global_variables_initializer() sess = tf.Session()
sess.run(init)
training_epochs = 1
batch_size = 1
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) if epoch % display_step == 0:
print('_pre:', np.argmax(sess.run(_pred, feed_dict={x: batch_xs, keepratio: 1. })))
print('answer:', np.argmax(batch_ys))

python,tensorflow,CNN实现mnist数据集的训练与验证正确率的更多相关文章

  1. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  2. tensorflow读取本地MNIST数据集

    tensorflow读取本地MNIST数据集 数据放入文件夹(不要解压gz): >>> import tensorflow as tf >>> from tenso ...

  3. Ubuntu14.04+caffe+cuda7.5 环境搭建以及MNIST数据集的训练与测试

    Ubuntu14.04+caffe+cuda 环境搭建以及MNIST数据集的训练与测试 一.ubuntu14.04的安装: ubuntu的安装是一件十分简单的事情,这里给出一个参考教程: http:/ ...

  4. 十折交叉验证10-fold cross validation, 数据集划分 训练集 验证集 测试集

    机器学习 数据挖掘 数据集划分 训练集 验证集 测试集 Q:如何将数据集划分为测试数据集和训练数据集? A:three ways: 1.像sklearn一样,提供一个将数据集切分成训练集和测试集的函数 ...

  5. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  6. 学习TensorFlow,邂逅MNIST数据集

    如果说"Hello Word!"是程序员的第一个程序,那么MNIST数据集,毫无疑问是机器学习者第一个训练的数据集,本文将使用Google公布的TensorFLow来学习训练MNI ...

  7. mnist的格式说明,以及在python3.x和python 2.x读取mnist数据集的不同

    有一个关于mnist的一个事例可以参考,我觉得写的很好:http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html #!/usr/bin/ ...

  8. Python Tensorflow CNN 识别验证码

    Python+Tensorflow的CNN技术快速识别验证码 文章来源于: https://www.jianshu.com/p/26ff7b9075a1 验证码处理的流程是:验证码分析和处理—— te ...

  9. TensorFlow CNN 测试CIFAR-10数据集

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...

随机推荐

  1. 理解 atime,ctime,mtime (下)

    话不多说,开始下篇. # 前言 通过 "理解 atime,ctime,mtime (上)" 我们已经知道了atime 是文件访问时间:ctime是文件权限改变时间:mtime是文件 ...

  2. leetcode-40-组合总和 II

    题目描述: 给定一个数组 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合. candidates 中的每个数字在每个组合中只 ...

  3. 【转载】MDX Step by Step 读书笔记(三) - Understanding Tuples (理解元组)

    1. 在 Analysis Service 分析服务中,Cube (多维数据集) 是以一个多维数据空间来呈现的.在Cube 中,每一个纬度的属性层次结构都形成了一个轴.沿着这个轴,在属性层次结构上的每 ...

  4. mfix输出自定义数据

    有时候需要输出一些自定义的网格或者DEM颗粒信息,比如输出颗粒的受力,这里举例颗粒自定义数据的输出.网格自定义输出方法类似. 首先用FileLocatorPro(网上很多绿色版),搜索一下代码里mod ...

  5. 「Neerc2016」Expect to Wait

    题目描述 ls最近开了一家图书馆,大家听说是ls开的,纷纷过来借书,自然就会出现供不应求的情况, 并且借书的过程类 似一个队列,每次有人来借书就将它加至队尾,每次有人来还书就把书借给队头的若干个人,定 ...

  6. Monkey and Banana

    Monkey and BananaTime Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others) ...

  7. Mac 10.12原生方法对NTFS分区进行读写的配置

    说明:不一定有效,最简单的方法就是不用NTFS,直接FAT32,对于大文件就用切割. 方法: 1.确定U盘名称 diskutil list ls /Volumes/ 2.比如我找到的U盘名称为Unti ...

  8. (转)OpenStack各服务所用端口号总结

    参考:Firewalls and default ports 注:可执行 sudo netstat -tnlp 查看 端口 服务描述 22 SSH 3306 MariaDB(MySQL) 27017 ...

  9. (转)Mysql占用过高CPU时的优化手段

    Mysql占用CPU过高的时候,该从哪些方面下手进行优化?占用CPU过高,可以做如下考虑:1)一般来讲,排除高并发的因素,还是要找到导致你CPU过高的哪几条在执行的SQL,show processli ...

  10. 利用setTimeout来实现setInterval

    在Js中,当我们要在一定间隔时间内不断执行同一函数,我们可以使用setInterval函数,但setInterval在某些情况下使用时也存在一定问题. 1.不去关心回调函数是否还在运行 在某些情况下, ...