莫烦Python 4

新建模板小书匠

RNN Classifier 循环神经网络

问题描述

使用RNN对MNIST里面的图片进行分类

关键

SimpleRNN()参数

  • batch_input_shape

    使用状态RNN的注意事项

可以将RNN设置为‘stateful’,意味着由每个batch计算出的状态都会被重用于初始化下一个batch的初始状态。状态RNN假设连续的两个batch之中,相同下标的元素有一一映射关系。

要启用状态RNN,请在实例化层对象时指定参数stateful=True,并在Sequential模型使用固定大小的batch:通过在模型的第一层传入batch_size=(…)和input_shape来实现。在函数式模型中,对所有的输入都要指定相同的batch_size。

如果要将循环层的状态重置,请调用.reset_states(),对模型调用将重置模型中所有状态RNN的状态。对单个层调用则只重置该层的状态。

(samples,timesteps,input_dim)

代码

'''
RNN Classifier 循环神经网络
'''
import numpy as np
np.random.seed(1337) from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN, Activation, Dense
from keras.optimizers import Adam time_step = 28
input_size = 28
batch_size = 50
output_size = 10
cell_size = 50
LR = 0.001 (X_train, y_train), (X_test, y_test) = mnist.load_data() X_train = X_train.reshape(-1, 28, 28) / 255. # normalize
X_test = X_test.reshape(-1, 28, 28) / 255. # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10) model = Sequential()
model.add(
SimpleRNN(
batch_input_shape=(None, time_step, input_size),
units=cell_size
)
) model.add(
Dense(output_size)
) model.add(Activation('softmax')) adam = Adam(LR)
model.compile(
optimizer=adam,
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.summary() model.fit(X_train, y_train, batch_size=batch_size, epochs=2, verbose=2, validation_data=(X_test, y_test))

结果

Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_2 (SimpleRNN) (None, 50) 3950
_________________________________________________________________
dense_2 (Dense) (None, 10) 510
_________________________________________________________________
activation_2 (Activation) (None, 10) 0
=================================================================
Total params: 4,460
Trainable params: 4,460
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/2
- 12s - loss: 0.6643 - accuracy: 0.7966 - val_loss: 0.4501 - val_accuracy: 0.8550
Epoch 2/2
- 9s - loss: 0.3220 - accuracy: 0.9087 - val_loss: 0.2445 - val_accuracy: 0.9359

莫烦Python 4的更多相关文章

  1. 莫烦python教程学习笔记——保存模型、加载模型的两种方法

    # View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...

  2. 莫烦python教程学习笔记——validation_curve用于调参

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  3. 莫烦python教程学习笔记——learn_curve曲线用于过拟合问题

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  4. 莫烦python教程学习笔记——利用交叉验证计算模型得分、选择模型参数

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  5. 莫烦python教程学习笔记——数据预处理之normalization

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  6. 莫烦python教程学习笔记——线性回归模型的属性

    #调用查看线性回归的几个属性 # Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg # ...

  7. 莫烦python教程学习笔记——使用波士顿数据集、生成用于回归的数据集

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  8. 莫烦python教程学习笔记——使用鸢尾花数据集

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  9. 莫烦python课程里面的bug修复;课程爬虫小练习爬百度百科

    我今天弄了一下午修改这个代码,最后还是弄好了.原因是正则表达式的筛选不够准确,有时候是会带http:baidu这些东西的.所以需要一个正则表达式的断言,然后还有一点是如果his里面只有一个元素就不要再 ...

  10. Tensorflow 教程系列 | 莫烦Python

    Tensorflow 简介 1.1 科普: 人工神经网络 VS 生物神经网络 1.2 什么是神经网络 (Neural Network) 1.3 神经网络 梯度下降 1.4 科普: 神经网络的黑盒不黑 ...

随机推荐

  1. Android IO 框架 Okio 的实现原理,如何检测超时?

    本文已收录到  AndroidFamily,技术和职场问题,请关注公众号 [彭旭锐] 提问. 前言 大家好,我是小彭. 在上一篇文章里,我们聊到了 Square 开源的 I/O 框架 Okio 的三个 ...

  2. wsl 自动配置代理地址

  3. Git03 自建代码托管平台-GitLab

    1 GitLab 简介 GitLab 是由 GitLabInc.开发,使用 MIT 许可证的基于网络的 Git 仓库管理工具,且具有wiki 和 issue 跟踪功能.使用 Git 作为代码管理工具, ...

  4. HTTPS基础原理和配置-2

    〇.概述 作为概述,以下是本文要讲的内容.HTTPS 是什么? 每个人都可能从浏览器上认出 HTTPS,并对它有好感.然后再讲一遍基础知识,再详细讲一下协议版本,密码套件(Cipher Suites) ...

  5. Error: EPERM: operation not permitted, mkdir ‘C:\Program Files\nodejs‘TypeError: Cannot read proper

    出现问题: 问题如题,出现场景:vscode运行npm命令 解决办法: 有的友友说安装nodejs时用管理员身份安装,右键没找到最后删掉了此文件即可. 这个文件缓存了之前的配置与现在安装的nodejs ...

  6. 转载:屎人-->诗人系列--码农之歌

    转贴经常关注的一个博主的文,感觉还挺有趣: https://goofegg.github.io/content.html?id=141 ************************** 这个系列第 ...

  7. Vue3 企业级优雅实战 - 组件库框架 - 12 发布开源组件库

    前面使用了 11 篇文章分享基于 vue3 .Monorepo 的组件库工程完整四件套(组件库.文档.example.cli)的开发.构建及组件库的发布.本文属于这 11 篇文章的扩展 -- 如何发布 ...

  8. MySQL视图、存储过程、函数、触发器、定时任务、流程控制总结

    视图的增删改查 视图相当于一张只能读的表,不可以修改.当组成视图的表发生数据变化的时候,视图会相对应的进行改变. 存储过程的练习 创建存储过程: create [if not exists] proc ...

  9. Git多分支 远程仓库 协同开发以及解决冲突

    目录 一.Git多分支及远程仓库 1.Git多分支 2.正常密码链接远程仓库 3.ssh公钥私钥方式链接远程仓库 三.协同开发及解决冲突 1.协同开发 2.解决冲突 四.线上分支合并及远程仓库回滚 1 ...

  10. HTML+js页面横向分栏效果

    <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...