BP神经网络的Java实现(转)
http://fantasticinblur.iteye.com/blog/1465497
课程作业要求实现一个BPNN。这次尝试使用Java实现了一个。现共享之。版权属于大家。关于BPNN的原理,就不赘述了。
下面是BPNN的实现代码。类名为BP。
- package ml;
- import java.util.Random;
- /**
- * BPNN.
- *
- * @author RenaQiu
- *
- */
- public class BP {
- /**
- * input vector.
- */
- private final double[] input;
- /**
- * hidden layer.
- */
- private final double[] hidden;
- /**
- * output layer.
- */
- private final double[] output;
- /**
- * target.
- */
- private final double[] target;
- /**
- * delta vector of the hidden layer .
- */
- private final double[] hidDelta;
- /**
- * output layer of the output layer.
- */
- private final double[] optDelta;
- /**
- * learning rate.
- */
- private final double eta;
- /**
- * momentum.
- */
- private final double momentum;
- /**
- * weight matrix from input layer to hidden layer.
- */
- private final double[][] iptHidWeights;
- /**
- * weight matrix from hidden layer to output layer.
- */
- private final double[][] hidOptWeights;
- /**
- * previous weight update.
- */
- private final double[][] iptHidPrevUptWeights;
- /**
- * previous weight update.
- */
- private final double[][] hidOptPrevUptWeights;
- public double optErrSum = 0d;
- public double hidErrSum = 0d;
- private final Random random;
- /**
- * Constructor.
- * <p>
- * <strong>Note:</strong> The capacity of each layer will be the parameter
- * plus 1. The additional unit is used for smoothness.
- * </p>
- *
- * @param inputSize
- * @param hiddenSize
- * @param outputSize
- * @param eta
- * @param momentum
- * @param epoch
- */
- public BP(int inputSize, int hiddenSize, int outputSize, double eta,
- double momentum) {
- input = new double[inputSize + 1];
- hidden = new double[hiddenSize + 1];
- output = new double[outputSize + 1];
- target = new double[outputSize + 1];
- hidDelta = new double[hiddenSize + 1];
- optDelta = new double[outputSize + 1];
- iptHidWeights = new double[inputSize + 1][hiddenSize + 1];
- hidOptWeights = new double[hiddenSize + 1][outputSize + 1];
- random = new Random(19881211);
- randomizeWeights(iptHidWeights);
- randomizeWeights(hidOptWeights);
- iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];
- hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];
- this.eta = eta;
- this.momentum = momentum;
- }
- private void randomizeWeights(double[][] matrix) {
- for (int i = 0, len = matrix.length; i != len; i++)
- for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
- double real = random.nextDouble();
- matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
- }
- }
- /**
- * Constructor with default eta = 0.25 and momentum = 0.3.
- *
- * @param inputSize
- * @param hiddenSize
- * @param outputSize
- * @param epoch
- */
- public BP(int inputSize, int hiddenSize, int outputSize) {
- this(inputSize, hiddenSize, outputSize, 0.25, 0.9);
- }
- /**
- * Entry method. The train data should be a one-dim vector.
- *
- * @param trainData
- * @param target
- */
- public void train(double[] trainData, double[] target) {
- loadInput(trainData);
- loadTarget(target);
- forward();
- calculateDelta();
- adjustWeight();
- }
- /**
- * Test the BPNN.
- *
- * @param inData
- * @return
- */
- public double[] test(double[] inData) {
- if (inData.length != input.length - 1) {
- throw new IllegalArgumentException("Size Do Not Match.");
- }
- System.arraycopy(inData, 0, input, 1, inData.length);
- forward();
- return getNetworkOutput();
- }
- /**
- * Return the output layer.
- *
- * @return
- */
- private double[] getNetworkOutput() {
- int len = output.length;
- double[] temp = new double[len - 1];
- for (int i = 1; i != len; i++)
- temp[i - 1] = output[i];
- return temp;
- }
- /**
- * Load the target data.
- *
- * @param arg
- */
- private void loadTarget(double[] arg) {
- if (arg.length != target.length - 1) {
- throw new IllegalArgumentException("Size Do Not Match.");
- }
- System.arraycopy(arg, 0, target, 1, arg.length);
- }
- /**
- * Load the training data.
- *
- * @param inData
- */
- private void loadInput(double[] inData) {
- if (inData.length != input.length - 1) {
- throw new IllegalArgumentException("Size Do Not Match.");
- }
- System.arraycopy(inData, 0, input, 1, inData.length);
- }
- /**
- * Forward.
- *
- * @param layer0
- * @param layer1
- * @param weight
- */
- private void forward(double[] layer0, double[] layer1, double[][] weight) {
- // threshold unit.
- layer0[0] = 1.0;
- for (int j = 1, len = layer1.length; j != len; ++j) {
- double sum = 0;
- for (int i = 0, len2 = layer0.length; i != len2; ++i)
- sum += weight[i][j] * layer0[i];
- layer1[j] = sigmoid(sum);
- }
- }
- /**
- * Forward.
- */
- private void forward() {
- forward(input, hidden, iptHidWeights);
- forward(hidden, output, hidOptWeights);
- }
- /**
- * Calculate output error.
- */
- private void outputErr() {
- double errSum = 0;
- for (int idx = 1, len = optDelta.length; idx != len; ++idx) {
- double o = output[idx];
- optDelta[idx] = o * (1d - o) * (target[idx] - o);
- errSum += Math.abs(optDelta[idx]);
- }
- optErrSum = errSum;
- }
- /**
- * Calculate hidden errors.
- */
- private void hiddenErr() {
- double errSum = 0;
- for (int j = 1, len = hidDelta.length; j != len; ++j) {
- double o = hidden[j];
- double sum = 0;
- for (int k = 1, len2 = optDelta.length; k != len2; ++k)
- sum += hidOptWeights[j][k] * optDelta[k];
- hidDelta[j] = o * (1d - o) * sum;
- errSum += Math.abs(hidDelta[j]);
- }
- hidErrSum = errSum;
- }
- /**
- * Calculate errors of all layers.
- */
- private void calculateDelta() {
- outputErr();
- hiddenErr();
- }
- /**
- * Adjust the weight matrix.
- *
- * @param delta
- * @param layer
- * @param weight
- * @param prevWeight
- */
- private void adjustWeight(double[] delta, double[] layer,
- double[][] weight, double[][] prevWeight) {
- layer[0] = 1;
- for (int i = 1, len = delta.length; i != len; ++i) {
- for (int j = 0, len2 = layer.length; j != len2; ++j) {
- double newVal = momentum * prevWeight[j][i] + eta * delta[i]
- * layer[j];
- weight[j][i] += newVal;
- prevWeight[j][i] = newVal;
- }
- }
- }
- /**
- * Adjust all weight matrices.
- */
- private void adjustWeight() {
- adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);
- adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);
- }
- /**
- * Sigmoid.
- *
- * @param val
- * @return
- */
- private double sigmoid(double val) {
- return 1d / (1d + Math.exp(-val));
- }
- }
为了验证正确性,我写了一个测试用例,目的是对于任意的整数(int型),BPNN在经过训练之后,能够准确地判断出它是奇数还是偶数,正数还是负数。首先对于训练的样本(是随机生成的数字),将它转化为一个32位的向量,向量的每个分量就是其二进制形式对应的位上的0或1。将目标输出视作一个4维的向量,[1,0,0,0]代表正奇数,[0,1,0,0]代表正偶数,[0,0,1,0]代表负奇数,[0,0,0,1]代表负偶数。
训练样本为1000个,学习200次。
- package ml;
- import java.io.IOException;
- import java.util.ArrayList;
- import java.util.List;
- import java.util.Random;
- public class Test {
- /**
- * @param args
- * @throws IOException
- */
- public static void main(String[] args) throws IOException {
- BP bp = new BP(32, 15, 4);
- Random random = new Random();
- List<Integer> list = new ArrayList<Integer>();
- for (int i = 0; i != 1000; i++) {
- int value = random.nextInt();
- list.add(value);
- }
- for (int i = 0; i != 200; i++) {
- for (int value : list) {
- double[] real = new double[4];
- if (value >= 0)
- if ((value & 1) == 1)
- real[0] = 1;
- else
- real[1] = 1;
- else if ((value & 1) == 1)
- real[2] = 1;
- else
- real[3] = 1;
- double[] binary = new double[32];
- int index = 31;
- do {
- binary[index--] = (value & 1);
- value >>>= 1;
- } while (value != 0);
- bp.train(binary, real);
- }
- }
- System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");
- while (true) {
- byte[] input = new byte[10];
- System.in.read(input);
- Integer value = Integer.parseInt(new String(input).trim());
- int rawVal = value;
- double[] binary = new double[32];
- int index = 31;
- do {
- binary[index--] = (value & 1);
- value >>>= 1;
- } while (value != 0);
- double[] result = bp.test(binary);
- double max = -Integer.MIN_VALUE;
- int idx = -1;
- for (int i = 0; i != result.length; i++) {
- if (result[i] > max) {
- max = result[i];
- idx = i;
- }
- }
- switch (idx) {
- case 0:
- System.out.format("%d是一个正奇数\n", rawVal);
- break;
- case 1:
- System.out.format("%d是一个正偶数\n", rawVal);
- break;
- case 2:
- System.out.format("%d是一个负奇数\n", rawVal);
- break;
- case 3:
- System.out.format("%d是一个负偶数\n", rawVal);
- break;
- }
- }
- }
- }
运行结果截图如下:

这个测试的例子非常简单。大家可以根据自己的需要去使用BP这个类。
BP神经网络的Java实现(转)的更多相关文章
- BP神经网络的Java实现(转载)
神经网络的计算过程 神经网络结构如下图所示,最左边的是输入层,最右边的是输出层,中间是多个隐含层,隐含层和输出层的每个神经节点,都是由上一层节点乘以其权重累加得到,标上“+1”的圆圈为截距项b,对输入 ...
- BP神经网络的Java实现
http://fantasticinblur.iteye.com/blog/1465497
- BP神经网络的直观推导与Java实现
人工神经网络模拟人体对于外界刺激的反应.某种刺激经过人体多层神经细胞传递后,可以触发人脑中特定的区域做出反应.人体神经网络的作用就是把某种刺激与大脑中的特定区域关联起来了,这样我们对于不同的刺激就可以 ...
- 用java写bp神经网络(一)
根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...
- BP神经网络—java实现(转载)
神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...
- BP神经网络—java实现
神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...
- JAVA实现BP神经网络算法
工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...
- 数据挖掘系列(9)——BP神经网络算法与实践
神经网络曾经很火,有过一段低迷期,现在因为深度学习的原因继续火起来了.神经网络有很多种:前向传输网络.反向传输网络.递归神经网络.卷积神经网络等.本文介绍基本的反向传输神经网络(Backpropaga ...
- BP神经网络的数学原理及其算法实现
什么是BP网络 BP网络的数学原理 BP网络算法实现 转载请声明出处http://blog.csdn.net/zhongkejingwang/article/details/44514073 上一篇 ...
随机推荐
- 2015.7.8js-05(简单日历)
今天做一个简单的小日历,12个月份,鼠标移动其中一个月份时添加高亮并显示本月的活动.其实同理与选项卡致.不过是内容存在js里 window.onload = function(){ var oMain ...
- 更新jenkins插件,报错 Perhaps you need to run your container with "-Djava.awt.headless=true"?
Configuring the Java environment variables vi ~/.bash_profile 在最后一行加入: export JAVA_OPTS=-Djava.awt.h ...
- 【Android】 导入项目报错的解决方案
1.打项目的properties -->android 为其指一个运版本, 2.修改default properties 文件 ,改相应版本等级 3.选中项目,单击右键,选中properties ...
- oracle简单存储过程以及如何查看编译错误
oracle简单存储过程以及如何查看编译错误; CREATE OR REPLACE PROCEDURE procedure_test ISval VARCHAR2(200);BEGIN /* val ...
- split陷阱
如果split最后一个为空,则要这么写 String[] lines=line.split(",",-1);
- thinkphp开启事物的简单方法
使用thinkphp开启事务,ThinkPHP 3.2.2实现事务操作的方法: 开启事务: $User->startTrans() 提交事务: $User->commit() 事务回滚: ...
- 关于51单片机使用printf串口调试
在51系列单片机上面使用串口的时候,有时候为了方便调试看一下输出结果,会用到printf函数输出到电脑终端,再用串口助手显示.但是单片机使用printf的时候有一点需要注意的地方. 1.首先添加头文件 ...
- C++和Java中枚举enum的用法
在C++和java中都有枚举enum这个关键字,但是它们之间又不太一样.对于C++来说,枚举是一系列命名了的整型常量,而且从枚举值转化为对应的整型值是在内部进行的.而对于Java来说,枚举更像一个类的 ...
- Windows下使用Gflags和UMDH查找内存泄漏
GFlags和UMDH与WinDbg一样,都是Debugging Tools for Windows里的工具. 1.设置符号路径 去微软官网下载对应的操作系统的符号安装文件,并安装到某个目录,如C:\ ...
- php代码不支持多维数组,注释和没有缓存功能。
php代码:simplet.class.php<?phpclass SimpleT {private $t_vars;private $templates_dir;private $templa ...