用LSTM分类 MNIST
LSTM是RNN的一种算法, 在序列分类中比较有用。常用于语音识别,文字处理(NLP)等领域。
等同于VGG等CNN模型在在图像识别领域的位置。 本篇文章是叙述LSTM 在MNIST 手写图中的使用。
用来给初步学习RNN的一个范例,便于学习和理解LSTM .
先把工作流程图贴一下:
代码片段 :
数据准备
def makedata():
img_rows, img_cols = 28, 28 mnist = fetch_mldata("MNIST original")
# rescale the data, use the traditional train/test split
X_1D, y_int = mnist.data / 255., mnist.target
y = np_utils.to_categorical(y_int, num_classes=10) X = X_1D.reshape(X_1D.shape[0], img_rows, img_cols ) input_shape = (img_rows, img_cols, 1)
x_train, x_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:] return X, y
pass
下载 MNIST数据, 进行归一化 mnist.data / 255, 把数据[7000,784 ] 转成[ 70000,28,28]
构建模型:
def buildlstm(): import numpy as np data_dim = 28
timesteps = 28
num_classes = 10 # expected input data shape: (batch_size, timesteps, data_dim)
model = Sequential()
model.add(LSTM(32, return_sequences=True, input_shape=(timesteps, data_dim+14)))
model.add(LSTM(32, return_sequences=True))
model.add(LSTM(32))
model.add(Dense(10, activation='softmax')) model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
print model.summary()
return model
pass
基础参数: data_dim, timesteps, num_classes 分别为 28,28, 10
网络层级 : LSTM ----》LSTM ----》LSTM ----》Dense
注意点: input_shape=(timesteps, data_dim+14)) 此处 应该为 data_dim , data_dim+14是我做第二个试验使用。
网络理解: RNN是用前一部分数据对当前数据的影响,并共同作用于最后结果。 用基础的深度神经网络(只有Dense层),是把MNIST一个图形,
提取成784个像素数据,把784个数据扔给神经网络,784个数据是同等的概念。 训练出权重来确定最终的分类值。
RNN 之于MNIST, 是把MNIST 分成 28x28 数据。可以理解为用一个激光扫描一个图片,扫成28个(行)数据, 每行为28个像素。 站在时间序列
的角度,其实图片没有序列概念。但是我们可以这样理解, 每一行于下一行是有位置关系的,不能进行顺序变化。 比如一个手写 “7”字, 如果把28行
的上下行顺序打乱, 那么7 上面的一横就可能在中间位置,也可能在下面的位置。 这样,最终的结果就不应该是 7 .
所以MNIST 的 28x28可以理解为 有时序关系的数据。
训练预测:
def runTrain(model, x_train, x_test, y_train, y_test):
model.fit(x_train, y_train, batch_size= nbatch_size, epochs= nEpoches)
score = model.evaluate(x_test, y_test, batch_size=nbatch_size)
print 'evaluate score:', score
pass
这部分应该没什么好说的
主程序:
def test(): X,y = makedata2()
x_train, x_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
model = buildlstm()
runTrain(model, x_train, x_test, y_train, y_test )
pass
运行结果:
结构:
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (None, 28, 32) 7808
_________________________________________________________________
lstm_2 (LSTM) (None, 28, 32) 8320
_________________________________________________________________
lstm_3 (LSTM) (None, 32) 8320
_________________________________________________________________
dense_1 (Dense) (None, 10) 330
=================================================================
Total params: 24,778
Trainable params: 24,778
Non-trainable params: 0
_________________________________________________________________ 结果:
base lstm for mnist
acc : 98.56% 结果2:
把数据最后增加 50% 的 0 , (dim X 0.5)
acc : 98.39%
结果基本上 与原数据一致
该实验证明两个结论:
1. LSTM可用于图形识别
2. 在数据中 每行28个基础像素后面 + 14 个空白(0)的元素,不影分类识别。
写在最后: 本实验的目的是为了理解RNN(LSTM), 只有理解了才能很好的使用。 本文章的目的是为记录和分享。
再说下 RNN在其它领域的应用。 比如在语音识别领域,一个音谱,识别成一个单词(词语),可以理解成一个
竖向扫描的MNIST , 一个股票的K线图,也可以理解一个竖向扫描的MNIST。 还有其它领域,可以归纳递推。
入门之后, 如何在自己的领域,再深入(构建复杂模型,优化数据的处理),提高网络模型的识别准确,那需要
见仁见智的。
代码文件链接:
有对 金融程序化 和 深度学习结合有兴趣的可以加群 , 个人群: 杭州程序化交易群 375129936
用LSTM分类 MNIST的更多相关文章
- NLP用CNN分类Mnist,提取出来的特征训练SVM及Keras的使用(demo)
用CNN分类Mnist http://www.bubuko.com/infodetail-777299.html /DeepLearning Tutorials/keras_usage 提取出来的特征 ...
- tensorflow学习笔记————分类MNIST数据集
在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题. 一般是通过使用tensorflow内置的函数进行下载和加载, from tensorflow.examp ...
- 【转载】用Scikit-Learn构建K-近邻算法,分类MNIST数据集
原帖地址:https://www.jiqizhixin.com/articles/2018-04-03-5 K 近邻算法,简称 K-NN.在如今深度学习盛行的时代,这个经典的机器学习算法经常被轻视.本 ...
- 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...
- LSTM用于MNIST手写数字图片分类
按照惯例,先放代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 ...
- 检测用户命令序列异常——使用LSTM分类算法【使用朴素贝叶斯,类似垃圾邮件分类的做法也可以,将命令序列看成是垃圾邮件】
通过 搜集 Linux 服务器 的 bash 操作 日志, 通过 训练 识别 出 特定 用户 的 操作 习惯, 然后 进一步 识别 出 异常 操作 行为. 使用 SEA 数据 集 涵盖 70 多个 U ...
- 分类-MNIST(手写数字识别)
这是学习<Hands-On Machine Learning with Scikit-Learn and TensorFlow>的笔记,如果此笔记对该书有侵权内容,请联系我,将其删除. 这 ...
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- TensorFlow技术解析与实战学习笔记(15)-----MNIST识别(LSTM)
一.任务:采用基本的LSTM识别MNIST图片,将其分类成10个数字. 为了使用RNN来分类图片,将每张图片的行看成一个像素序列,因为MNIST图片的大小是28*28像素,所以我们把每一个图像样本看成 ...
随机推荐
- 在Office Add-in中实现单点登陆(SSO)
作者:陈希章 发表于 2017年12月27日 这篇文章经过多次修改,终于在今天晚上写完了,演示用的范例代码也终于跑通了.因为这个SSO的功能目前只是Preview的状态,所以本篇文章严格参考了官方的文 ...
- IDEA快速创建Maven+SpringBoot项目时:Cannot download https://start.spring.io;Status:403
先展示一下我遇到的问题: 用浏览器搜索是有页面的,但是但是但是呢,用IDEA快速构建的时候就报403 咳咳!巴格虐我万千遍,我待技术如初恋... 我看到的解决办法有以下两种,当然,我只想说:" ...
- 创建、设置和安装Windows服务
文章大部分内容转自:http://www.cnblogs.com/greatandforever/archive/2008/10/14/1310504.html:和:http://www.cnblog ...
- myecplise自带的tomcat问题
今天做一个项目时候,发现myecplise自带的tomcat上面部署了是可以运行的,可是当部署到自己下载的tomcat时候,就报错,tomcat可以启动,项目无法启动,查了问题,发现是web,xml中 ...
- 7.nginx伪静态规则
网上收集的一些常用的,要用的时候就仿照一下,或直接拿来用. WordPress伪静态规则 location / { index index.html index.php; if (-f $reques ...
- java juint框架的windows自动化-自动运行juint程序简述
在京东混了一个月,基本有点稳定了,觉得也有所余力了现在,继续写博客吧,不过以后更新也许不是那么频繁了 本人使用的是juint框架,对开发是一个单元测试的java框架,但是对测试而言是java的基石之一 ...
- 鸟哥的linux私房菜学习-(一)优缺点分析以及主机规划与磁盘分区
一.linux的优缺点 那干嘛要使用Linux做为我们的主机系统呢?这是因为Linux有底下这些优点: 稳定的系统:Linux本来就是基于Unix概念而发展出来的操作系统,因此,Linux具有与Uni ...
- python科学计算_scipy_常数与优化
scipy在numpy的基础上提供了众多的数学.科学以及工程计算中常用的模块:是强大的数值计算库: 1. 常数和特殊函数 scipy的constants模块包含了众多的物理常数: import sci ...
- Java求循环节长度
两个整数做除法,有时会产生循环小数,其循环部分称为:循环节.比如,11/13=6=>0.846153846153..... 其循环节为[846153] 共有6位.下面的方法,可以求出循环节的长 ...
- IndentationError: unexpected indent
都知道python是对格式要求很严格的,写了一些python但是也没发现他严格在哪里,今天遇到了IndentationError: unexpected indent错误我才知道他是多么的严格. ...