LSTM是RNN的一种算法, 在序列分类中比较有用。常用于语音识别,文字处理(NLP)等领域。

等同于VGG等CNN模型在在图像识别领域的位置。  本篇文章是叙述LSTM 在MNIST 手写图中的使用。

用来给初步学习RNN的一个范例,便于学习和理解LSTM .

先把工作流程图贴一下

代码片段

数据准备

def makedata():
img_rows, img_cols = 28, 28 mnist = fetch_mldata("MNIST original")
# rescale the data, use the traditional train/test split
X_1D, y_int = mnist.data / 255., mnist.target
y = np_utils.to_categorical(y_int, num_classes=10) X = X_1D.reshape(X_1D.shape[0], img_rows, img_cols ) input_shape = (img_rows, img_cols, 1)
x_train, x_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:] return X, y
pass

下载 MNIST数据, 进行归一化  mnist.data / 255, 把数据[7000,784 ] 转成[ 70000,28,28]

构建模型:

def buildlstm():

    import numpy as np

    data_dim = 28
timesteps = 28
num_classes = 10 # expected input data shape: (batch_size, timesteps, data_dim)
model = Sequential()
model.add(LSTM(32, return_sequences=True, input_shape=(timesteps, data_dim+14)))
model.add(LSTM(32, return_sequences=True))
model.add(LSTM(32))
model.add(Dense(10, activation='softmax')) model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
print model.summary()
return model
pass

基础参数: data_dim, timesteps, num_classes   分别为 28,28, 10
网络层级 :    LSTM ----》LSTM ----》LSTM ----》Dense
注意点: input_shape=(timesteps, data_dim+14))   此处 应该为  data_dim , data_dim+14是我做第二个试验使用。
网络理解: RNN是用前一部分数据对当前数据的影响,并共同作用于最后结果。 用基础的深度神经网络(只有Dense层),是把MNIST一个图形,
提取成784个像素数据,把784个数据扔给神经网络,784个数据是同等的概念。 训练出权重来确定最终的分类值。

RNN 之于MNIST, 是把MNIST 分成 28x28 数据。可以理解为用一个激光扫描一个图片,扫成28个(行)数据, 每行为28个像素。 站在时间序列
的角度,其实图片没有序列概念。但是我们可以这样理解, 每一行于下一行是有位置关系的,不能进行顺序变化。 比如一个手写 “7”字, 如果把28行
的上下行顺序打乱, 那么7 上面的一横就可能在中间位置,也可能在下面的位置。  这样,最终的结果就不应该是 7 .  
所以MNIST 的 28x28可以理解为 有时序关系的数据。

训练预测:

def runTrain(model, x_train, x_test, y_train, y_test):
model.fit(x_train, y_train, batch_size= nbatch_size, epochs= nEpoches)
score = model.evaluate(x_test, y_test, batch_size=nbatch_size)
print 'evaluate score:', score
pass

这部分应该没什么好说的

主程序:

def test():

    X,y = makedata2()
x_train, x_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
model = buildlstm()
runTrain(model, x_train, x_test, y_train, y_test )
pass

运行结果

结构:
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (None, 28, 32) 7808
_________________________________________________________________
lstm_2 (LSTM) (None, 28, 32) 8320
_________________________________________________________________
lstm_3 (LSTM) (None, 32) 8320
_________________________________________________________________
dense_1 (Dense) (None, 10) 330
=================================================================
Total params: 24,778
Trainable params: 24,778
Non-trainable params: 0
_________________________________________________________________ 结果:
base lstm for mnist
acc : 98.56% 结果2:
把数据最后增加 50% 的 0 , (dim X 0.5)
acc : 98.39%
结果基本上 与原数据一致

该实验证明两个结论:
1.  LSTM可用于图形识别
2.  在数据中 每行28个基础像素后面 + 14 个空白(0)的元素,不影分类识别。

写在最后:  本实验的目的是为了理解RNN(LSTM),  只有理解了才能很好的使用。 本文章的目的是为记录和分享。
再说下 RNN在其它领域的应用。  比如在语音识别领域,一个音谱,识别成一个单词(词语),可以理解成一个
竖向扫描的MNIST ,   一个股票的K线图,也可以理解一个竖向扫描的MNIST。  还有其它领域,可以归纳递推。 
入门之后, 如何在自己的领域,再深入(构建复杂模型,优化数据的处理),提高网络模型的识别准确,那需要
见仁见智的。

代码文件链接:

源码下载

有对 金融程序化 和 深度学习结合有兴趣的可以加群 , 个人群: 杭州程序化交易群  375129936

用LSTM分类 MNIST的更多相关文章

  1. NLP用CNN分类Mnist,提取出来的特征训练SVM及Keras的使用(demo)

    用CNN分类Mnist http://www.bubuko.com/infodetail-777299.html /DeepLearning Tutorials/keras_usage 提取出来的特征 ...

  2. tensorflow学习笔记————分类MNIST数据集

    在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题. 一般是通过使用tensorflow内置的函数进行下载和加载, from tensorflow.examp ...

  3. 【转载】用Scikit-Learn构建K-近邻算法,分类MNIST数据集

    原帖地址:https://www.jiqizhixin.com/articles/2018-04-03-5 K 近邻算法,简称 K-NN.在如今深度学习盛行的时代,这个经典的机器学习算法经常被轻视.本 ...

  4. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  5. LSTM用于MNIST手写数字图片分类

    按照惯例,先放代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 ...

  6. 检测用户命令序列异常——使用LSTM分类算法【使用朴素贝叶斯,类似垃圾邮件分类的做法也可以,将命令序列看成是垃圾邮件】

    通过 搜集 Linux 服务器 的 bash 操作 日志, 通过 训练 识别 出 特定 用户 的 操作 习惯, 然后 进一步 识别 出 异常 操作 行为. 使用 SEA 数据 集 涵盖 70 多个 U ...

  7. 分类-MNIST(手写数字识别)

    这是学习<Hands-On Machine Learning with Scikit-Learn and TensorFlow>的笔记,如果此笔记对该书有侵权内容,请联系我,将其删除. 这 ...

  8. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  9. TensorFlow技术解析与实战学习笔记(15)-----MNIST识别(LSTM)

    一.任务:采用基本的LSTM识别MNIST图片,将其分类成10个数字. 为了使用RNN来分类图片,将每张图片的行看成一个像素序列,因为MNIST图片的大小是28*28像素,所以我们把每一个图像样本看成 ...

随机推荐

  1. Material使用02 图标MdIconModule、矢量图作为图标使用及改进

    1 MdIconModule模块的使用 1.1 在需要用到的模块中引入Material图标模块 import { BrowserModule } from '@angular/platform-bro ...

  2. IntelliJ IDEA创建java项目

    IntelliJ IDEA创建java项目 进入到IntelliJ IDEA启动界面,点击Create New Project 2.这样就进入到了创建项目页面,这里可以创建好多项目,这里我们以java ...

  3. 小白的Python之路 day1 字符编码

    字符编码 python解释器在加载 .py 文件中的代码时,会对内容进行编码(默认ascill) ASCII(American Standard Code for Information Interc ...

  4. Tableau Desktop 10.4.2 的安装和激活

    在安装之前,首先我们要弄清楚Tableau是个什么鬼东西,我们为什么需要安装这款软件? Tableau将数据运算与美观的图表完美地嫁接在一起.它的程序很容易上手,各公司可以用它将大量数据拖放到数字&q ...

  5. 解决报错:IncompleteElementException: Could not find result map...

    今天遇到这样一个报错,记录一下: org.apache.ibatis.builder.IncompleteElementException: Could not find result map com ...

  6. splay小结—植树结

    我要把高级数据结构当爸爸了... ...弱到跪烂了. splay,二叉搜索树的一种,具有稳定变形功能. 二叉搜索树:对于一个节点,都只有不超过2个孩子.其左子树内的点的权值都比这个点小,右子树的点的权 ...

  7. 关于git的一些理论知识

    一.什么是版本控制器 好多刚用git的coder一说起git,就随口会说出版本控制器嘛,我问那是干嘛的,大部分人就回答上传代码的.然后会用,但是有些理论你问他们他们就不知道了,比如不是代码的文件就不能 ...

  8. ActiveMQ (一) 初识ActiveMQ

    了解JMS JMS即Java消息服务(Java Message Service)应用程序接口是一个Java平台中关于面向消息中间件(MOM)的API,用于在两个应用程序之间,或分布式系统中发送消息,进 ...

  9. calling c++ from golang with swig--windows dll (三)

    calling c++ from golang with swig--windows dll 三 使用动态链接库(DLL)主要有两种方式:一种通过链接导入库,在代码中直接调用DLL中的函数:另一种借助 ...

  10. JMeter测试HTTPS

    HTTP和HTTPS测试时稍有不同,HTTPS需要加载证书,端口也不一样,操作如下:  1)下载被测网站证书导入 见图为流程: 2)使用JMeter自带的证书 ApacheJMeterTemporar ...