1.工程目录

2.导入data和input_data.py

链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA 
提取码:4nnl

3.CNN.py

import tensorflow as tf
import matplotlib.pyplot as plt
import input_data mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('MNIST ready') n_input = 784
n_output = 10 weights = {
'wc1': tf.Variable(tf.truncated_normal([3, 3, 1, 64], stddev=0.1)),
'wc2': tf.Variable(tf.truncated_normal([3, 3, 64, 128], stddev=0.1)),
'wd1': tf.Variable(tf.truncated_normal([7*7*128, 1024], stddev=0.1)),
'wd2': tf.Variable(tf.truncated_normal([1024, n_outpot], stddev=0.1)),
}
biases = {
'bc1': tf.Variable(tf.random_normal([64], stddev=0.1)),
'bc2': tf.Variable(tf.random_normal([128], stddev=0.1)),
'bd1': tf.Variable(tf.random_normal([1024], stddev=0.1)),
'bd2': tf.Variable(tf.random_normal([n_outpot], stddev=0.1)),
} def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].get_shape().as_list()[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out print('CNN READY') x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])
keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, weights, biases, keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y))
optm = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
_corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32))
init = tf.global_variables_initializer() print('GRAPH READY') sess = tf.Session()
sess.run(init)
training_epochs = 15
batch_size = 16
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys, keepratio: 0.7})
avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.0})/total_batch if epoch % display_step == 0:
print('Epoch: %03d/%03d cost: %.9f' % (epoch, training_epochs, avg_cost))
train_acc = sess.run(accr, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.})
print('Training accuracy: %.3f' % (train_acc)) res_dict = {'weight': sess.run(weights), 'biases': sess.run(biases)} import pickle
with open('res_dict.pkl', 'wb') as f:
pickle.dump(res_dict, f, pickle.HIGHEST_PROTOCOL)

4.test.py

import pickle
import numpy as np def load_file(path, name):
with open(path+''+name+'.pkl', 'rb') as f:
return pickle.load(f) res_dict = load_file('', 'res_dict')
print(res_dict['weight']['wc1']) index = 0 import input_data
mnist = input_data.read_data_sets('data/', one_hot=True) test_image = mnist.test.images
test_label = mnist.test.labels import tensorflow as tf def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].shape[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out n_input = 784
n_output = 10 x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output]) keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, res_dict['weight'], res_dict['biases'], keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y)) _corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32)) init = tf.global_variables_initializer() sess = tf.Session()
sess.run(init)
training_epochs = 1
batch_size = 1
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) if epoch % display_step == 0:
print('_pre:', np.argmax(sess.run(_pred, feed_dict={x: batch_xs, keepratio: 1. })))
print('answer:', np.argmax(batch_ys))

python,tensorflow,CNN实现mnist数据集的训练与验证正确率的更多相关文章

  1. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  2. tensorflow读取本地MNIST数据集

    tensorflow读取本地MNIST数据集 数据放入文件夹(不要解压gz): >>> import tensorflow as tf >>> from tenso ...

  3. Ubuntu14.04+caffe+cuda7.5 环境搭建以及MNIST数据集的训练与测试

    Ubuntu14.04+caffe+cuda 环境搭建以及MNIST数据集的训练与测试 一.ubuntu14.04的安装: ubuntu的安装是一件十分简单的事情,这里给出一个参考教程: http:/ ...

  4. 十折交叉验证10-fold cross validation, 数据集划分 训练集 验证集 测试集

    机器学习 数据挖掘 数据集划分 训练集 验证集 测试集 Q:如何将数据集划分为测试数据集和训练数据集? A:three ways: 1.像sklearn一样,提供一个将数据集切分成训练集和测试集的函数 ...

  5. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  6. 学习TensorFlow,邂逅MNIST数据集

    如果说"Hello Word!"是程序员的第一个程序,那么MNIST数据集,毫无疑问是机器学习者第一个训练的数据集,本文将使用Google公布的TensorFLow来学习训练MNI ...

  7. mnist的格式说明,以及在python3.x和python 2.x读取mnist数据集的不同

    有一个关于mnist的一个事例可以参考,我觉得写的很好:http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html #!/usr/bin/ ...

  8. Python Tensorflow CNN 识别验证码

    Python+Tensorflow的CNN技术快速识别验证码 文章来源于: https://www.jianshu.com/p/26ff7b9075a1 验证码处理的流程是:验证码分析和处理—— te ...

  9. TensorFlow CNN 测试CIFAR-10数据集

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...

随机推荐

  1. 为服务器设置固定IP地址

    为服务器设置固定IP地址 1.获取超级管理员权限 命令:$ su - 输入root密码 2.判断哪个网卡有流量,或者确定需要设置哪个网卡的固定ip 命令:# ifconfig PS:可以查询哪些网卡有 ...

  2. 使用HttpClient出现java.io.IOException: Attempted read from closed stream

    问题描述: 使用httpClient时候,出现java.io.IOException: Attempted read from closed stream. 原始代码: public static S ...

  3. SpringCloud学习笔记(一)——基础

    什么是微服务架构 简单地说,微服务是系统架构上的一种设计风格,它的主旨是将一个原本独立的系统拆分成多个小型服务,这些小型服务都在各自独立的进程中运行,服务之间通过基于HTTP的RESTful API进 ...

  4. 学习react教程

    网址收藏: React官网,React的Github,React的中文文档 1.react是什么? React起源于Facebook的内部项目,因为该公司对市场上所有的Javascript MVC框架 ...

  5. HDU – 1050 Moving Tables

    http://acm.hdu.edu.cn/showproblem.php?pid=1050 当时这道题被放在了贪心专题,我又刚刚做了今年暑假不AC所以一开始就在想这肯定是个变过型的复杂贪心,但是后来 ...

  6. po'j2559 Largest Rectangle in a Histogram 单调栈(递增)

    Largest Rectangle in a Histogram Time Limit: 1000MS   Memory Limit: 65536K Total Submissions: 29498 ...

  7. Cisco ISR4400 Netflow 配置模板

    flow exporter NAME destination 145.0.1.200 transport udp 9991 export-protocol netflow-v5 flow monito ...

  8. Ribbon是什么?

    学而时习之,不亦说乎!                              --<论语> Ribbon使用版本2.2.2 Ribbon是什么? 开始接触Ribbon的时候,网上以及很 ...

  9. [中英对照]The Art Of Reporting Bugs | 报bug的艺术

    前言:因为最近要给兄弟Team分享一下如何有效地报告bug, 故多做一做功课.下面给出一篇博客的中英文对照翻译. The Art Of Reporting Bugs | 报bug的艺术 My init ...

  10. chrome bookmarks location

    .config/google-chrome/Default file: Bookmarks