LSTM java 实现
由于实验室事情缘故,需要将Python写的神经网络转成Java版本的,但是python中的numpy等啥包也不知道在Java里面对应的是什么工具,所以索性直接寻找一个现成可用的Java神经网络框架,于是就找到了JOONE,JOONE是一个神经网络的开源框架,使用的是BP算法进行迭代计算参数,使用起来比较方便也比较实用,下面介绍一下JOONE的一些使用方法。
JOONE需要使用一些外部的依赖包,这在官方网站上有,也可以在这里下载。将所需的包引入工程之后,就可以进行编码实现了。
首先看下完整的程序,这个是上面那个超链接给出的程序,应该是官方给出的一个示例吧,因为好多文章都用这个,这其实是神经网络训练一个异或计算器:
- import org.joone.engine.*;
- import org.joone.engine.learning.*;
- import org.joone.io.*;
- import org.joone.net.*;
- /*
- *
- * JOONE实现
- *
- * */
- public class XOR_using_NeuralNet implements NeuralNetListener
- {
- private NeuralNet nnet = null;
- private MemoryInputSynapse inputSynapse, desiredOutputSynapse;
- LinearLayer input;
- SigmoidLayer hidden, output;
- boolean singleThreadMode = true;
- // XOR input
- private double[][] inputArray = new double[][]
- {
- { 0.0, 0.0 },
- { 0.0, 1.0 },
- { 1.0, 0.0 },
- { 1.0, 1.0 } };
- // XOR desired output
- private double[][] desiredOutputArray = new double[][]
- {
- { 0.0 },
- { 1.0 },
- { 1.0 },
- { 0.0 } };
- /**
- * @param args
- * the command line arguments
- */
- public static void main(String args[])
- {
- XOR_using_NeuralNet xor = new XOR_using_NeuralNet();
- xor.initNeuralNet();
- xor.train();
- xor.interrogate();
- }
- /**
- * Method declaration
- */
- public void train()
- {
- // set the inputs
- inputSynapse.setInputArray(inputArray);
- inputSynapse.setAdvancedColumnSelector(" 1,2 ");
- // set the desired outputs
- desiredOutputSynapse.setInputArray(desiredOutputArray);
- desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");
- // get the monitor object to train or feed forward
- Monitor monitor = nnet.getMonitor();
- // set the monitor parameters
- monitor.setLearningRate(0.8);
- monitor.setMomentum(0.3);
- monitor.setTrainingPatterns(inputArray.length);
- monitor.setTotCicles(5000);
- monitor.setLearning(true);
- long initms = System.currentTimeMillis();
- // Run the network in single-thread, synchronized mode
- nnet.getMonitor().setSingleThreadMode(singleThreadMode);
- nnet.go(true);
- System.out.println(" Total time= "
- + (System.currentTimeMillis() - initms) + " ms ");
- }
- private void interrogate()
- {
- double[][] inputArray = new double[][]
- {
- { 1.0, 1.0 } };
- // set the inputs
- inputSynapse.setInputArray(inputArray);
- inputSynapse.setAdvancedColumnSelector(" 1,2 ");
- Monitor monitor = nnet.getMonitor();
- monitor.setTrainingPatterns(4);
- monitor.setTotCicles(1);
- monitor.setLearning(false);
- MemoryOutputSynapse memOut = new MemoryOutputSynapse();
- // set the output synapse to write the output of the net
- if (nnet != null)
- {
- nnet.addOutputSynapse(memOut);
- System.out.println(nnet.check());
- nnet.getMonitor().setSingleThreadMode(singleThreadMode);
- nnet.go();
- for (int i = 0; i < 4; i++)
- {
- double[] pattern = memOut.getNextPattern();
- System.out.println(" Output pattern # " + (i + 1) + " = "
- + pattern[0]);
- }
- System.out.println(" Interrogating Finished ");
- }
- }
- /**
- * Method declaration
- */
- protected void initNeuralNet()
- {
- // First create the three layers
- input = new LinearLayer();
- hidden = new SigmoidLayer();
- output = new SigmoidLayer();
- // set the dimensions of the layers
- input.setRows(2);
- hidden.setRows(3);
- output.setRows(1);
- input.setLayerName(" L.input ");
- hidden.setLayerName(" L.hidden ");
- output.setLayerName(" L.output ");
- // Now create the two Synapses
- FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */
- FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */
- // Connect the input layer whit the hidden layer
- input.addOutputSynapse(synapse_IH);
- hidden.addInputSynapse(synapse_IH);
- // Connect the hidden layer whit the output layer
- hidden.addOutputSynapse(synapse_HO);
- output.addInputSynapse(synapse_HO);
- // the input to the neural net
- inputSynapse = new MemoryInputSynapse();
- input.addInputSynapse(inputSynapse);
- // The Trainer and its desired output
- desiredOutputSynapse = new MemoryInputSynapse();
- TeachingSynapse trainer = new TeachingSynapse();
- trainer.setDesired(desiredOutputSynapse);
- // Now we add this structure to a NeuralNet object
- nnet = new NeuralNet();
- nnet.addLayer(input, NeuralNet.INPUT_LAYER);
- nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
- nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);
- nnet.setTeacher(trainer);
- output.addOutputSynapse(trainer);
- nnet.addNeuralNetListener(this);
- }
- public void cicleTerminated(NeuralNetEvent e)
- {
- }
- public void errorChanged(NeuralNetEvent e)
- {
- Monitor mon = (Monitor) e.getSource();
- if (mon.getCurrentCicle() % 100 == 0)
- System.out.println(" Epoch: "
- + (mon.getTotCicles() - mon.getCurrentCicle()) + " RMSE: "
- + mon.getGlobalError());
- }
- public void netStarted(NeuralNetEvent e)
- {
- Monitor mon = (Monitor) e.getSource();
- System.out.print(" Network started for ");
- if (mon.isLearning())
- System.out.println(" training. ");
- else
- System.out.println(" interrogation. ");
- }
- public void netStopped(NeuralNetEvent e)
- {
- Monitor mon = (Monitor) e.getSource();
- System.out.println(" Network stopped. Last RMSE= "
- + mon.getGlobalError());
- }
- public void netStoppedError(NeuralNetEvent e, String error)
- {
- System.out.println(" Network stopped due the following error: "
- + error);
- }
- }
现在我会逐步解释上面的程序。
【1】 从main方法开始说起,首先第一步新建一个对象:
- XOR_using_NeuralNet xor = new XOR_using_NeuralNet();
【2】然后初始化神经网络:
- xor.initNeuralNet();
初始化神经网络的方法中:
- // First create the three layers
- input = new LinearLayer();
- hidden = new SigmoidLayer();
- output = new SigmoidLayer();
- // set the dimensions of the layers
- input.setRows(2);
- hidden.setRows(3);
- output.setRows(1);
- input.setLayerName(" L.input ");
- hidden.setLayerName(" L.hidden ");
- output.setLayerName(" L.output ");
上面代码解释:
input=new LinearLayer()是新建一个输入层,因为神经网络的输入层并没有训练参数,所以使用的是线性层;
hidden = new SigmoidLayer();这里是新建一个隐含层,使用sigmoid函数作为激励函数,当然你也可以选择其他的激励函数,如softmax激励函数
output则是新建一个输出层
之后的三行代码是建立输入层、隐含层、输出层的神经元个数,这里表示输入层为2个神经元,隐含层是3个神经元,输出层是1个神经元
最后的三行代码是给每个输出层取一个名字。
- // Now create the two Synapses
- FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */
- FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */
- // Connect the input layer whit the hidden layer
- input.addOutputSynapse(synapse_IH);
- hidden.addInputSynapse(synapse_IH);
- // Connect the hidden layer whit the output layer
- hidden.addOutputSynapse(synapse_HO);
- output.addInputSynapse(synapse_HO);
上面代码解释:
上面代码的主要作用是将三个层连接起来,synapse_IH用来连接输入层和隐含层,synapse_HO用来连接隐含层和输出层
- // the input to the neural net
- inputSynapse = new MemoryInputSynapse();
- input.addInputSynapse(inputSynapse);
- // The Trainer and its desired output
- desiredOutputSynapse = new MemoryInputSynapse();
- TeachingSynapse trainer = new TeachingSynapse();
- trainer.setDesired(desiredOutputSynapse);
上面代码解释:
上面的代码是在训练的时候指定输入层的数据和目的输出的数据,
inputSynapse = new MemoryInputSynapse();这里指的是使用了从内存中输入数据的方法,指的是输入层输入数据,当然还有从文件输入的方法,这点在文章后面再谈。同理,desiredOutputSynapse = new MemoryInputSynapse();也是从内存中输入数据,指的是从输入层应该输出的数据
- // Now we add this structure to a NeuralNet object
- nnet = new NeuralNet();
- nnet.addLayer(input, NeuralNet.INPUT_LAYER);
- nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
- nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);
- nnet.setTeacher(trainer);
- output.addOutputSynapse(trainer);
- nnet.addNeuralNetListener(this);
上面代码解释:
这段代码指的是将之前初始化的构件连接成一个神经网络,NeuralNet是JOONE提供的类,主要是连接各个神经层,最后一个nnet.addNeuralNetListener(this);这个作用是对神经网络的训练过程进行监听,因为这个类实现了NeuralNetListener这个接口,这个接口有一些方法,可以实现观察神经网络训练过程,有助于参数调整。
【3】然后我们来看一下train这个方法:
- inputSynapse.setInputArray(inputArray);
- inputSynapse.setAdvancedColumnSelector(" 1,2 ");
- // set the desired outputs
- desiredOutputSynapse.setInputArray(desiredOutputArray);
- desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");
上面代码解释:
inputSynapse.setInputArray(inputArray);这个方法是初始化输入层数据,也就是指定输入层数据的内容,inputArray是程序中给定的二维数组,这也就是为什么之前初始化神经网络的时候使用的是MemoryInputSynapse,表示从内存中读取数据
inputSynapse.setAdvancedColumnSelector(" 1,2 ");这个表示的是输入层数据使用的是inputArray的前两列数据。
desiredOutputSynapse这个也同理
- Monitor monitor = nnet.getMonitor();
- // set the monitor parameters
- monitor.setLearningRate(0.8);
- monitor.setMomentum(0.3);
- monitor.setTrainingPatterns(inputArray.length);
- monitor.setTotCicles(5000);
- <span style="line-height: 1.5;">monitor.setLearning(true);
上面代码解释:
这个monitor类也是JOONE框架提供的,主要是用来调节神经网络的参数,monitor.setLearningRate(0.8);是用来设置神经网络训练的步长参数,步长越大,神经网络梯度下降的速度越快,monitor.setTrainingPatterns(inputArray.length);这个是设置神经网络的输入层的训练数据大小size,这里使用的是数组的长度;monitor.setTotCicles(5000);这个指的是设置迭代数目;monitor.setLearning(true);这个true表示是在训练过程。
- nnet.getMonitor().setSingleThreadMode(singleThreadMode);
- nnet.go(true);
上面代码解释:
nnet.getMonitor().setSingleThreadMode(singleThreadMode);这个指的是是不是使用多线程,但是我不太清楚这里的多线程指的是什么意思
nnet.go(true)表示的是开始训练。
【4】最后来看一下interrogate方法
- double[][] inputArray = new double[][]
- {
- { 1.0, 1.0 } };
- // set the inputs
- inputSynapse.setInputArray(inputArray);
- inputSynapse.setAdvancedColumnSelector(" 1,2 ");
- Monitor monitor = nnet.getMonitor();
- monitor.setTrainingPatterns(4);
- monitor.setTotCicles(1);
- monitor.setLearning(false);
- MemoryOutputSynapse memOut = new MemoryOutputSynapse();
- // set the output synapse to write the output of the net
- if (nnet != null)
- {
- nnet.addOutputSynapse(memOut);
- System.out.println(nnet.check());
- nnet.getMonitor().setSingleThreadMode(singleThreadMode);
- nnet.go();
- for (int i = 0; i < 4; i++)
- {
- double[] pattern = memOut.getNextPattern();
- System.out.println(" Output pattern # " + (i + 1) + " = "
- + pattern[0]);
- }
- System.out.println(" Interrogating Finished ");
- }
这个方法相当于测试方法,这里的inputArray是测试数据, 注意这里需要设置monitor.setLearning(false);,因为这不是训练过程,并不需要学习,monitor.setTrainingPatterns(4);这个是指测试的数量,4表示有4个测试数据(虽然这里只有一个)。这里还给nnet添加了一个输出层数据对象,这个对象mmOut是初始测试结果,注意到之前我们初始化神经网络的时候并没有给输出层指定数据对象,因为那个时候我们在训练,而且指定了trainer作为目的输出。
接下来就是输出结果数据了,pattern的个数和输出层的神经元个数一样大,这里输出层神经元的个数是1,所以pattern大小为1.
【5】我们看一下测试结果:
- Output pattern # 1 = 0.018303527517809233
表示输出结果为0.01,根据sigmoid函数特性,我们得到的输出是0,和预期结果一致。如果输出层神经元个数大于1,那么输出值将会有多个,因为输出层结果是0|1离散值,所以我们取输出最大的那个神经元的输出值取为1,其他为0
【6】最后我们来看一下神经网络训练过程中的一些监听函数:
cicleTerminated:每个循环结束后输出的信息
errorChanged:神经网络错误率变化时候输出的信息
netStarted:神经网络开始运行的时候输出的信息
netStopped:神经网络停止的时候输出的信息
【7】好了,JOONE基本上内容就是这些。还有一些额外东西需要说明:
1,从文件中读取数据构建神经网络
2.如何保存训练好的神经网络到文件夹中,只要测试的时候直接load到内存中就行,而不用每次都需要训练。
【8】先看第一个问题:
从文件中读取数据:
文件的格式:
0;0;0
1;0;1
1;1;0
0;1;1
中间使用分号隔开,使用方法如下,也就是把上文的MemoryInputSynapse换成FileInputSynapse即可。
- fileInputSynapse = new FileInputSynapse();
- input.addInputSynapse(fileInputSynapse);
- fileDisireOutputSynapse = new FileInputSynapse();
- TeachingSynapse trainer = new TeachingSynapse();
- trainer.setDesired(fileDisireOutputSynapse);
我们看下文件是如何输出数据的:
- private File inputFile = new File(Constants.TRAIN_WORD_VEC_PATH);
- fileInputSynapse.setInputFile(inputFile);
- fileInputSynapse.setFirstCol(2);//使用文件的第2列到第3列作为输出层输入
- fileInputSynapse.setLastCol(3);
- fileDisireOutputSynapse.setInputFile(inputFile);
- fileDisireOutputSynapse.setFirstCol(1);//使用文件的第1列作为输出数据
- fileDisireOutputSynapse.setLastCol(1);
其余的代码和上文的是一样的。
【9】然后看第二个问题:
如何保存神经网络
其实很简单,直接序列化nnet对象就行了,然后读取该对象就是java的反序列化,这个就不多做介绍了,比较简单。但是需要说明的是,保存神经网络的时机一定是在神经网络训练完毕后,可以使用下面代码:
- public void netStopped(NeuralNetEvent e) {
- Monitor mon = (Monitor) e.getSource();
- try {
- if (mon.isLearning()) {
- saveModel(nnet); //序列化对象
- }
- } catch (IOException ee) {
- // TODO Auto-generated catch block
- ee.printStackTrace();
- }
LSTM java 实现的更多相关文章
- Spark案例分析
一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...
- Python中利用LSTM模型进行时间序列预测分析
时间序列模型 时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺 ...
- 新手教程之:循环网络和LSTM指南 (A Beginner’s Guide to Recurrent Networks and LSTMs)
新手教程之:循环网络和LSTM指南 (A Beginner’s Guide to Recurrent Networks and LSTMs) 本文翻译自:http://deeplearning4j.o ...
- [转] 图 + 文 + 公式 理解LSTM
转自公号“机器之心” LSTM入门必读:从入门基础到工作方式详解 长短期记忆(LSTM)是一种非常重要的神经网络技术,其在语音识别和自然语言处理等许多领域都得到了广泛的应用..在这篇文章中,Edwin ...
- 机器学习与Tensorflow(6)——LSTM的Tensorflow实现、Tensorboard简单实现、CNN应用
最近写的一些程序以及做的一个关于轴承故障诊断的程序 最近学习进度有些慢 而且马上假期 要去补习班 去赚下学期生活费 额.... 抓紧时间再多学习点 1.RNN递归神经网络Tensorflow实现程序 ...
- Tesseract:简单的Java光学字符识别
1.1 介绍 开发具有一定价值的符号是人类特有的特征.对于人们来说识别这些符号和理解图片上的文字是非常正常的事情.与计算机那样去抓取文字不同,我们完全是基于视觉的本能去阅读它们. 另一方面,计算机的工 ...
- 尚学堂JAVA基础学习笔记
目录 尚学堂JAVA基础学习笔记 写在前面 第1章 JAVA入门 第2章 数据类型和运算符 第3章 控制语句 第4章 Java面向对象基础 1. 面向对象基础 2. 面向对象的内存分析 3. 构造方法 ...
- java 读取CSV数据并写入txt文本
java 读取CSV数据并写入txt文本 package com.vfsd; import java.io.BufferedWriter; import java.io.File; import ja ...
- Tika结合Tesseract-OCR 实现光学汉字识别(简体、宋体的识别率百分之百)—附Java源码、测试数据和训练集下载地址
OCR(Optical character recognition) —— 光学字符识别,是图像处理的一个重要分支,中文的识别具有一定挑战性,特别是手写体和草书的识别,是重要和热门的科学研究方向.可 ...
随机推荐
- Linux内核态、用户态简介与IntelCPU特权级别--Ring0-3
一.现代操作系统的权限分离: 现代操作系统一般都至少分为内核态和用户态.一般应用程序通常运行于用户态,而当应用程序调用系统调用时候会执行内核代码,此时会处于内核态.一般的,应用程序是不能随便进入内核态 ...
- Kafka+SpringMVC+Maven应用示例
本文借助主流SpringMVC框架向大家介绍如何在具体应用中简单快捷的使用kafka.kafka.maven以及SpringMVC在现在的企业级应用中都占据着非常重要的地位,所以本文将三者结合起来也可 ...
- myeclipse 代码提示(alt+/)
windows -->preference-->general-->keys找到 alt+/ 解除绑定 windows -->preference-->general-- ...
- Mysql 多主一从数据备份
Mysql 多主一从数据备份 概述 对任何一个数据库的操作都自动应用到另外一个数据库,始终保持两个数据库中的数据一致. 这样做有如下几点好处: 可以做灾备,其中一个坏了可以切换到另一个. 可以做负载均 ...
- ubuntu 打开 gbk编码的txt乱码
iconv -f gbk -t utf8 filename.txt > filename.txt.utf8
- Python面向对象高级
一 反射 反射也可以说是python的自省机制 反射就是通过字符串的形式,导入模块,然后以字符串的形式去查找指定函数并执行.利用字符串的形式去模块(对象)中操作(查找/获取/添加/删除)属性,是一种 ...
- Python标准库 之 turtle(海龟绘图)
turtle库介绍 首先,turtle库是一个点线面的简单图像库(也被人们成为海龟绘图),在Python2.6之后被引入进来,能够完成一些比较简单的几何图像可视化.它就像一个小乌龟,在一个横轴为x.纵 ...
- Shiro框架简介
Apache Shiro是Java的一个安全框架.对比另一个安全框架Spring Sercurity,它更简单和灵活. Shiro可以帮助我们完成:认证.授权.加密.会话管理.Web集成.缓存等. A ...
- bug-sqlite3
[root@izj6c6b4i40od17ev77lhez Python-3.7.0]# python Python 3.7.0 (default, Sep 5 2018, 00:40:27) [GC ...
- 【开发者笔记】按List中存放对象的某一字段计数的问题
如题,假设有如下表t_info: name date info a 20127-12-20 xxxx描述 b 20127-12-20 yyyyy描述 c 20127-12-21 zzz描述 d 201 ...