# -*- coding: utf-8 -*-
import numpy as np
np.random.seed(1337) from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense
from keras.optimizers import Adam TIME_STEPS = 28 #图片的高
INPUT_SIZE = 28 #图片的行
BATCH_SIZE = 50 #每批训练多少图片
BATCH_INDEX = 0
OUTPUT_SIZE = 10
CELL_SIZE = 50
LR = 0.001 #下载mnist数据集
# X shape (60000,28*28) ,y shape (10000)
(X_train,y_train),(X_test,y_test) = mnist.load_data() # 数据预处理
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10) # 建模型
model = Sequential()
# RNN
model.add(SimpleRNN(
batch_input_shape=(None,TIME_STEPS,INPUT_SIZE),# 每次训练的量(None表示全部),图片大小
output_dim=CELL_SIZE,
))
# 输出层
model.add(Dense(OUTPUT_SIZE))
model.add(Activation('softmax')) # 优化器
adam = Adam(LR)
model.compile(optimizer=adam,
loss='categorical_crossentropy',
metrics=['accuracy']) # 训练
for step in range(4001):
X_batch=X_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:,:]
Y_batch=y_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:]
cost = model.train_on_batch(X_batch,Y_batch) BATCH_INDEX += BATCH_SIZE
BATCH_INDEX = 0 if BATCH_INDEX>=X_train.shape[0] else BATCH_INDEX if step % 500 == 0:
cost,accuracy = model.evaluate(X_test,y_test,batch_size=y_test.shape[0],verbose=False)
print('test cost: ',cost,'test accuracy: ',accuracy)

用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)的更多相关文章

  1. 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  2. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  3. 吴裕雄 python 神经网络——TensorFlow 卷积神经网络手写数字图片识别

    import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_N ...

  4. 用Keras搭建神经网络 简单模版(二)——Classifier分类(手写数字识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  5. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  6. RNN探索(2)之手写数字识别

    这篇博文不介绍基础的RNN理论知识,只是初步探索如何使用Tensorflow,之后会用笔推导RNN的公式和理论,现在时间紧迫所以先使用为主~~ import numpy as np import te ...

  7. 用tensorflow求手写数字的识别准确率 (简单版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = in ...

  8. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 100天搞定机器学习|day39 Tensorflow Keras手写数字识别

    提示:建议先看day36-38的内容 TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操作,图中的线(edge ...

随机推荐

  1. nginx精准反向代理

    1,完全反向代理,将请求10.130.111.110服务器的请求全部转发到10.130.111.111服务器 location / { proxy_pass http://10.130.111.111 ...

  2. php 生成gif 动图,可控制每张图时间

    <?php //namespace gifCreator; /** * Create an animated GIF from multiple images */ class gifcreat ...

  3. IIS搭建ASP站点

    1. 进入控制面板悬着打开或者关闭Windows功能. 2. 手工选择需要的功能进行安装. 3. 打开运行Internet信息服务(IIS)管理工具. 4. 展开左侧栏看到“Default Web S ...

  4. 多域名解析到同一网站C的php重定向代码

    在index.php最前面加上以下代码: <?php if(strpos($_SERVER['HTTP_HOST'],'afish.cnblogs.com')===false){ #header ...

  5. centos ntfs-3g not find

    1,CentOS默认源里没有ntfs3g,想要添加ntfs支持,需要自己下载编译安装或者加源yum安装.我这里使用的是添加aliyun的epel源来yum安装的方式. 2,添加epel yum源wge ...

  6. java判断文件是否为图片

    /** * 判断文件是否为图片<br> * <br> * @param pInput 文件名<br> * @param pImgeFlag 判断具体文件类型< ...

  7. JQuery实践--Why JQuery

    给页面增加动态功能的工作流模式:选择一个元素或一组元素,然后以某种方式对其进行操作. 利用原始的JavaScript完成这些任务中的任何一个,都会需要数十行代码,JQuery让这些常见的任务变得简单 ...

  8. 数值(Number,Math, 运算符)

    1.js中数字 1.数字存储 Javascript中所有数字的存储都是64位浮点数.整数也一样. 1 === 1.0 // true 2. 数字大小范围 可以表示的最大正数和最小负数 (-Math.p ...

  9. 参数类型 (Mapper.java)常用

    UserBaseInfo selectByMobile(@Param("mobile")String mobile,@Param("isDeleted")Int ...

  10. bzoj2725

    * 给出一张图 * 每次删掉一条边后求 the shortest path from S to T * 线段树维护最短路径树 * 具体维护从某点开始偏离最短路而到达 T 的最小距离 * 首先记录下最短 ...