莫烦大大keras学习Mnist识别(4)-----RNN
一、步骤:
导入包以及读取数据
设置参数
数据预处理
构建模型
编译模型
训练以及测试模型
二、代码:
1、导入包以及读取数据
#导入包
import numpy as np
np.random.seed(1337) #设置之后每次执行代码,产生的随机数都一样 from tensorflow.examples.tutorials.mnist import input_data
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN , Activation , Dense
from keras.optimizers import Adam #读取数据
mnist = input_data.read_data_sets('E:\jupyter\TensorFlow\MNIST_data',one_hot = True)
X_train = mnist.train.images
Y_train = mnist.train.labels
X_test = mnist.test.images
Y_test = mnist.test.labels
2、设置参数
#设置参数
time_steps = 28 # same as the height of the image
input_size = 28 # same as the width of the image
batch_size = 50
batch_index = 0
output_size = 10
cell_size = 50
lr = 0.001
3、数据预处理
#数据预处理
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
4.1、构建RNN模型
#构建模型
model = Sequential() #RNN层
model.add(SimpleRNN(
batch_input_shape =(None,time_steps,input_size), # 输入维度
output_dim = cell_size, #输出维度
)) #输出层
model.add(Dense(output_size))
model.add(Activation('softmax'))
4.2、构建LSTM模型
def builtLSTMModel(time_steps,input_size,cell_size,output_size):
model = Sequential() #添加LSTM
model.add(LSTM(
batch_input_size = (None,time_steps,input_size),
output_dim = cell_size,
return_sequences = True, #要不要每个时间点的输出都输出
stateful = True, #batch和batch有联系,batch和batch之间的状态需要连接起来
)) #添加输出层
model.add(TimeDistributed(Dense(output_size))) #每一个时间点的输出都要加入全连接层。
return model
5、训练模型以及测试
#训练模型
for step in range(4001): X_batch = X_train[batch_index:batch_size + batch_index,:,:]
Y_batch = Y_train[batch_index:batch_size + batch_index,:]
cost = model.train_on_batch(X_batch,Y_batch) batch_index += batch_size
batch_index = 0 if batch_index >= X_train.shape[0] else batch_index if step % 500 ==0:
loss , acc = model.evaluate(X_test,Y_test,batch_size =Y_test.shape[0])
print(loss,',',acc)
莫烦大大keras学习Mnist识别(4)-----RNN的更多相关文章
- 莫烦大大keras学习Mnist识别(3)-----CNN
一.步骤: 导入模块以及读取数据 数据预处理 构建模型 编译模型 训练模型 测试 二.代码: 导入模块以及读取数据 #导包 import numpy as np np.random.seed(1337 ...
- 莫烦大大keras的Mnist手写识别(5)----自编码
一.步骤: 导入包和读取数据 数据预处理 编码层和解码层的建立 + 构建模型 编译模型 训练模型 测试模型[只用编码层来画图] 二.代码: 1.导入包和读取数据 #导入相关的包 import nump ...
- 莫烦大大TensorFlow学习笔记(9)----可视化
一.Matplotlib[结果可视化] #import os #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf i ...
- 莫烦python教程学习笔记——总结篇
一.机器学习算法分类: 监督学习:提供数据和数据分类标签.--分类.回归 非监督学习:只提供数据,不提供标签. 半监督学习 强化学习:尝试各种手段,自己去适应环境和规则.总结经验利用反馈,不断提高算法 ...
- 莫烦大大TensorFlow学习笔记(8)----优化器
一.TensorFlow中的优化器 tf.train.GradientDescentOptimizer:梯度下降算法 tf.train.AdadeltaOptimizer tf.train.Adagr ...
- 莫烦python教程学习笔记——保存模型、加载模型的两种方法
# View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...
- 莫烦python教程学习笔记——validation_curve用于调参
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
- 莫烦python教程学习笔记——learn_curve曲线用于过拟合问题
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
- 莫烦python教程学习笔记——利用交叉验证计算模型得分、选择模型参数
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
随机推荐
- MVC.Net:通过Global.asax捕捉错误
在MVC.Net中,如果我们想做一个统一的错误处理的模块,有几个选择,一种是通过一个Base Controller来实现,另外一种就是在Global.asax中实现.这里介绍后一种方法. 首先打开Gl ...
- 使用c3p0与DBCP连接池,造成的MySql 8小时问题解决方式
本文提供了对c3p0与DBCP连接池连接MySql数据库时. 8小时内无请求自己主动断开连接的解决方式.首先介绍一下我在项目(c3p0连接池)中遇到的问题,后面还提供了使用DBCP连接池的解决方式. ...
- 【Git使用具体解释】Egit使用过程中遇到的问题及解决的方法
1. Git错误non-fast-forward后的冲突解决 问题(Non-fast-forward)的出现原因在于:git仓库中已经有一部分代码,所以它不同意你直接把你的代码覆盖上去.于是你有2 ...
- 操作系统: 二级文件夹文件系统的实现(c/c++语言)
操作系统的一个课程设计,实现一个二级文件夹文件系统. 用disk.txt模拟磁盘,使用Help查看支持的命令及其操作方式,root为超级用户(写在disk.txt中) 文件的逻辑结构:流式文件. 物理 ...
- 【HDOJ 2255】奔小康赚大钱(KM算法)
[HDOJ 2255]奔小康赚大钱(KM算法) 奔小康赚大钱 Time Limit: 1000/1000 MS (Java/Others) Memory Limit: 32768/32768 K ...
- 音频格式opus
人耳能听到自然界的声音是20HZ-20KHZ,一般高保真音质采样率只有达到最高采样率的2倍以上即可,平时电话采样率8KHZ,CD音质的采样率44.1KHZ. IBM 的Watson的音频转文字接口支持 ...
- c++ string 解析ip
比如输入是192.168.80.12-15,解析成192.168.80.12.192.168.80.13.192.168.80.14.192.168.80.15. #include <iostr ...
- iOS手势识别
一.手势识别与触摸事件 1.如果想监听一个view上面的触摸事件,可选的做法是: (1)自定义一个view (2)实现view的touches方法,在方法内部实现具体处理代码 2.通过touches方 ...
- [Apple开发者帐户帮助]三、创建证书(8)撤销证书
您可以根据证书类型和角色撤消证书.有关详细信息,请转到撤消权限. 要了解撤销证书时会发生什么,请转到Apple Developer支持中的证书. 所需角色:帐户持有人或管理员. 在“ 证书”,“标识符 ...
- C++中const用法
1.const和指针: 如果const出现在星号左边,表示被指物是常量:如果出现在星号右边,表示指针自身是常量:如果出现在星号两边,表示被指物和指针两者都是常量. char greet[] = “He ...