import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os mnist = input_data.read_data_sets('MNIST_data', one_hot=True) class MNISTModel(object):
def __init__(self, lr, batch_size, iter_num):
self.lr = lr
self.batch_size = batch_size
self.iter_num = iter_num
# 定义模型结构
# 输入张量,这里还没有数据,先占个地方,所以叫“placeholder”
self.x = tf.placeholder(tf.float32, [None, 784]) # 图像是28*28的大小
self.y = tf.placeholder(tf.float32, [None, 10]) # 输出是0-9的one-hot向量
self.h = tf.layers.dense(self.x, 100, activation=tf.nn.relu, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 一个全连接层
self.y_ = tf.layers.dense(self.h, 10, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 全连接层 # 使用交叉熵损失函数
self.loss = tf.losses.softmax_cross_entropy(self.y, self.y_)
self.optimizer = tf.train.AdamOptimizer()
self.train_step = self.optimizer.minimize(self.loss) # 用于模型训练
self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(self.y_, axis=1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) # 用于保存训练好的模型
self.saver = tf.train.Saver() def train(self):
with tf.Session() as sess: # 打开一个会话。可以想象成浏览器打开一个标签页一样,直观地理解一下
sess.run(tf.global_variables_initializer()) # 先初始化所有变量。
for i in range(self.iter_num):
batch_x, batch_y = mnist.train.next_batch(self.batch_size) # 读取一批数据
loss, _ = sess.run([self.loss, self.train_step], feed_dict={self.x: batch_x, self.y: batch_y}) # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。
if i%1000 == 0:
train_accuracy = sess.run(self.accuracy, feed_dict={self.x: batch_x, self.y: batch_y}) # 把训练集数据装填进去
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y}) # 把测试集数据装填进去
print( 'iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy,test_accuracy))
self.saver.save(sess, 'model/mnistModel') # 保存模型 def test(self):
with tf.Session() as sess:
self.saver.restore(sess, 'model/mnistModel')
Accuracy = []
for i in range(150):
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y})
Accuracy.append(test_accuracy)
print ('==' * 15)
print ('Test Accuracy: ', np.mean(np.array(Accuracy))) model = MNISTModel(0.001, 64, 40000) # 学习率为0.001,每批传入64张图,训练40000次
model.train() # 训练模型
model.test() #测试模型

基于多层感知机的手写数字识别(Tensorflow实现)的更多相关文章

  1. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  2. Mnist手写数字识别 Tensorflow

    Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...

  3. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  4. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  5. 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别

    from numpy import * def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i ...

  6. MNIST手写数字识别 Tensorflow实现

    def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 1. strides在官方定义中是一 ...

  7. Keras mlp 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...

  8. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  9. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

随机推荐

  1. 用Html创建简历

    用最基本的html实现简历的创建 代码 <!doctype html> <html> <head><meta charset="utf-8" ...

  2. allegro中如何添加安装孔(注:在PCB图纸中添加)

    最近再给外国一家公司做某一个小的系统模块的封装,其中这个模块中间是挖空的,这就比较难办,到现在为止我还没有找到如何在封装中添加自己绘制特定形状的过孔,(倒是可以添加软件自带的一些圆形的安装孔)最终解决 ...

  3. unigui作中间件使用

    unigui作中间件使用 可返回string或者tstream数据. 如果返回JSON字符,则UNIGUI就是REST 中间件. procedure TUniServerModule.UniGUISe ...

  4. WPF TreeView BringIntoViewBehavior

    由于项目需要,需要能够定位TreeView中的点,TreeView的节点数过多的情况下,即使找到了对应的节点并选中展示了,由于不在可视区域内,给用户的感觉还是不好,因此设计如下的Behavior,来实 ...

  5. qt在windows下编译遇到的一些问题

    软件是在linux上写的,然而搬到windows上来遇到了好多问题.... 想跪.... 首先就是压根编译不了的问题....这个问题困扰我好久了....一直报错undefined reference ...

  6. JQuery - on绑定多个事件

    一.jquery为多个选择器绑定同一个事件 $("#start,#end").on("click",function(){ alert("The pa ...

  7. 预处理命令使用详解----#if、#endif、#undef、#ifdef、#else、#elif

    预处理命令 在接触#if.#undef这类预处理指令前,大部分都都接触过#define.#include等预处理命令,通俗来讲预处理命令的作用就是在编译和链接之前,对源文件进行一些文本方面的操作,比如 ...

  8. java urlrewrite实现伪静态化

    1.示例 http://www.onlyfun.com/goods/company.jsp?companyId=455326 ==> http://www.onlyfun.com/company ...

  9. Swift5 语言参考(六) 声明

    一个声明引入了一个新的名称或构建到你的程序.例如,您使用声明来引入函数和方法,引入变量和常量,以及定义枚举,结构,类和协议类型.您还可以使用声明来扩展现有命名类型的行为,并将符号导入到其他地方声明的程 ...

  10. Redis---Hash(字典)

    1. 概述 Redis hash 是一个string类型的field和value的映射表,hash特别适合用于存储对象. Redis 中每个 hash 可以存储 232 - 1 键值对(40多亿). ...