# -*- coding: utf-8 -*-
import numpy as np
np.random.seed(1337) from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense
from keras.optimizers import Adam TIME_STEPS = 28 #图片的高
INPUT_SIZE = 28 #图片的行
BATCH_SIZE = 50 #每批训练多少图片
BATCH_INDEX = 0
OUTPUT_SIZE = 10
CELL_SIZE = 50
LR = 0.001 #下载mnist数据集
# X shape (60000,28*28) ,y shape (10000)
(X_train,y_train),(X_test,y_test) = mnist.load_data() # 数据预处理
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10) # 建模型
model = Sequential()
# RNN
model.add(SimpleRNN(
batch_input_shape=(None,TIME_STEPS,INPUT_SIZE),# 每次训练的量(None表示全部),图片大小
output_dim=CELL_SIZE,
))
# 输出层
model.add(Dense(OUTPUT_SIZE))
model.add(Activation('softmax')) # 优化器
adam = Adam(LR)
model.compile(optimizer=adam,
loss='categorical_crossentropy',
metrics=['accuracy']) # 训练
for step in range(4001):
X_batch=X_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:,:]
Y_batch=y_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:]
cost = model.train_on_batch(X_batch,Y_batch) BATCH_INDEX += BATCH_SIZE
BATCH_INDEX = 0 if BATCH_INDEX>=X_train.shape[0] else BATCH_INDEX if step % 500 == 0:
cost,accuracy = model.evaluate(X_test,y_test,batch_size=y_test.shape[0],verbose=False)
print('test cost: ',cost,'test accuracy: ',accuracy)

用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)的更多相关文章

  1. 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  2. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  3. 吴裕雄 python 神经网络——TensorFlow 卷积神经网络手写数字图片识别

    import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_N ...

  4. 用Keras搭建神经网络 简单模版(二)——Classifier分类(手写数字识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  5. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  6. RNN探索(2)之手写数字识别

    这篇博文不介绍基础的RNN理论知识,只是初步探索如何使用Tensorflow,之后会用笔推导RNN的公式和理论,现在时间紧迫所以先使用为主~~ import numpy as np import te ...

  7. 用tensorflow求手写数字的识别准确率 (简单版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = in ...

  8. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 100天搞定机器学习|day39 Tensorflow Keras手写数字识别

    提示:建议先看day36-38的内容 TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操作,图中的线(edge ...

随机推荐

  1. python3 pyinstaller

    一.安装python.pywin32.pyinstaller库 二.官网:https://pyinstaller.readthedocs.io/en/v3.3.1/usage.html#general ...

  2. 由浅入深剖析 go channel

    原文:https://www.jianshu.com/p/24ede9e90490 ---------------------------------- 由浅入深剖析 go channel chann ...

  3. JS window对象详解

    window 是客户端浏览器对象模型的基类,window 对象是客户端 JavaScript 的全局对象.一个 window 对象实际上就是一个独立的窗口,对于框架页面来说,浏览器窗口每个框架都包含一 ...

  4. 月球-I型,月份日历生成器----基于PHP7.3

    生成月份周日的类 <?php class mycalendar { function __construct($year,$mon) { $'; $this->firstday=strto ...

  5. 为什么添加了@Aspect 还要加@Component(转)

    官方文档中有写: You may register aspect classes as regular beans in your Spring XML configuration, or autod ...

  6. 使用jdk自带的线程池。加载10个线程。

    在开发中使用线程,经常不经意间就new Thread()一个出来,然后发现,这样做不是很好,特别是很多线程同时处理的时候,会出现CPU被用光导致机器假死,线程运行完成自动销毁后,又复活的情况. 所以在 ...

  7. 【51nod 1340】地铁环线

    题目 有一个地铁环线,环线中有N个站台,标号为0,1,2,...,N-1.这个环线是单行线,一共由N条有向边构成,即从0到1,1到2,..k到k+1,...,N-2到N-1,N-1到0各有一条边.定义 ...

  8. ueditor粘贴word图片无法显示的问题

    图片的复制无非有两种方法,一种是图片直接上传到服务器,另外一种转换成二进制流的base64码目前限chrome浏览器使用首先以um-editor的二进制流保存为例:打开umeditor.js,找到UM ...

  9. Git Clone 的时候遇到 Filename too long 错误

    在对某些仓库进行 Git Clone 的时候遇到了 Filename too long 的错误提示. 错误提示如下图: 可以有下面的一些解决办法: 可以有下面的一些解决办法: 在 Git bash 中 ...

  10. leetcode解题报告(10):Merge Two Sorted Lists

    描述 Merge two sorted linked lists and return it as a new list. > The new list should be made by sp ...