import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os mnist = input_data.read_data_sets('MNIST_data', one_hot=True) class MNISTModel(object):
def __init__(self, lr, batch_size, iter_num):
self.lr = lr
self.batch_size = batch_size
self.iter_num = iter_num
# 定义模型结构
# 输入张量,这里还没有数据,先占个地方,所以叫“placeholder”
self.x = tf.placeholder(tf.float32, [None, 784]) # 图像是28*28的大小
self.y = tf.placeholder(tf.float32, [None, 10]) # 输出是0-9的one-hot向量
self.h = tf.layers.dense(self.x, 100, activation=tf.nn.relu, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 一个全连接层
self.y_ = tf.layers.dense(self.h, 10, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 全连接层 # 使用交叉熵损失函数
self.loss = tf.losses.softmax_cross_entropy(self.y, self.y_)
self.optimizer = tf.train.AdamOptimizer()
self.train_step = self.optimizer.minimize(self.loss) # 用于模型训练
self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(self.y_, axis=1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) # 用于保存训练好的模型
self.saver = tf.train.Saver() def train(self):
with tf.Session() as sess: # 打开一个会话。可以想象成浏览器打开一个标签页一样,直观地理解一下
sess.run(tf.global_variables_initializer()) # 先初始化所有变量。
for i in range(self.iter_num):
batch_x, batch_y = mnist.train.next_batch(self.batch_size) # 读取一批数据
loss, _ = sess.run([self.loss, self.train_step], feed_dict={self.x: batch_x, self.y: batch_y}) # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。
if i%1000 == 0:
train_accuracy = sess.run(self.accuracy, feed_dict={self.x: batch_x, self.y: batch_y}) # 把训练集数据装填进去
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y}) # 把测试集数据装填进去
print( 'iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy,test_accuracy))
self.saver.save(sess, 'model/mnistModel') # 保存模型 def test(self):
with tf.Session() as sess:
self.saver.restore(sess, 'model/mnistModel')
Accuracy = []
for i in range(150):
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y})
Accuracy.append(test_accuracy)
print ('==' * 15)
print ('Test Accuracy: ', np.mean(np.array(Accuracy))) model = MNISTModel(0.001, 64, 40000) # 学习率为0.001,每批传入64张图,训练40000次
model.train() # 训练模型
model.test() #测试模型

基于多层感知机的手写数字识别(Tensorflow实现)的更多相关文章

  1. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  2. Mnist手写数字识别 Tensorflow

    Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...

  3. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  4. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  5. 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别

    from numpy import * def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i ...

  6. MNIST手写数字识别 Tensorflow实现

    def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 1. strides在官方定义中是一 ...

  7. Keras mlp 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...

  8. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  9. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

随机推荐

  1. pycharm的基本使用

    1.设置注释字体的颜色,如下图 2.常用快捷键 Ctrl + Space 基本的代码完成(类.方法.属性) Ctrl + Alt + Space 类名完成 Ctrl + Shift + Enter 语 ...

  2. 【转】Windows Server 2016 安装 IIS 服务时提示指定备用源路径

    原文地址:http://www.codingwhy.com/view/973.html 在Windows Serever 2016中安装IIS的时候,遇到以下提示 是否需要指定备用源路径?一个或多个安 ...

  3. 关于ubuntu软件卸载的问题

    ...... 起因很菜....就是手贱把/usr/lib/下的R的目录给rm -rf掉了, 然后在R的软件包里用./configure也生不成, R RHOME还在, 然而因为lib下的R删掉了所以R ...

  4. MariaDB 数据类型与运算符(4)

    MariaDB数据库管理系统是MySQL的一个分支,主要由开源社区在维护,采用GPL授权许可MariaDB的目的是完全兼容MySQL,包括API和命令行,MySQL由于现在闭源了,而能轻松成为MySQ ...

  5. idea 映射文件同class文件一起打包安装

    经过几天的摸索,已经能够用idea做日常的Demo了,在复习的过程中,又在学的知识,所以进度有点慢,但自己好像有点着急,为自己的效率 但是自己也是知道的,只顾速度,最后的学完的效果也不是自己想要的,所 ...

  6. 计算机中的K、M、G、T到底指的是

    计算机语言是二进制数字01组成. 在计算机工作中,2的10次方用K(kilo)表示,2的20次方用M(mega)表示,2的30次方用G(giga)表示,2的40次方用T(tera)表示.因此,4K=2 ...

  7. 上下文相关的GMM-HMM声学模型续:参数共享

    一.三音素建模存在的问题 问题一:很多三音素在训练数据中没有出现(尤其跨词三音素) 问题二:在训练数据中出现过的三音素有相当一部分出现的频次较少 因此,三音素模型训练时存在较严重的数据不足问题 二.参 ...

  8. Docker 镜像安装 GitLab 中文社区版

    docker run \ --detach \ --publish : \ --publish : \ --name gitlab \ --restart unless-stopped \ --vol ...

  9. Identity Server4学习系列二之令牌(Token)的概念

    1.简介 通过前文知道了Identity Server4的基本用途,现在必须了解一些实现它的基本细节. 2.关于服务端生成Token令牌 头部(Header): { “typ”: “JWT”, //t ...

  10. android的电话监听

    android的电话监听 新建一个项目,结构图如下: PhoneService: package com.demo.tingdianhua; import android.app.Service; i ...