用java写bp神经网络(一)
根据前篇博文《神经网络之后向传播算法》,现在用java实现一个bp神经网络。矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法。
上帝说,要有神经网络,于是,便有了一个神经网络。上帝还说,神经网络要有节点,权重,激活函数,输出函数,目标函数,然后也许还要有一个准确率函数,于是,神经网络完成了:
public class Net {
List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
List<DoubleMatrix> bs = new ArrayList<>();
List<ScalarDifferentiableFunction> activations = new ArrayList<>();
CostFunctionFactory costFunc;
CostFunctionFactory accuracyFunc;
int[] nodesNum;
int layersNum;
public Net(int[] nodesNum, ScalarDifferentiableFunction[] activations,CostFunctionFactory costFunc) {
super();
this.initNet(nodesNum, activations);
this.costFunc=costFunc;
this.layersNum=nodesNum.length-1;
}
public Net(int[] nodesNum, ScalarDifferentiableFunction[] activations,CostFunctionFactory costFunc,CostFunctionFactory accuracyFunc) {
this(nodesNum,activations,costFunc);
this.accuracyFunc=accuracyFunc;
}
public void resetNet() {
this.initNet(nodesNum, (ScalarDifferentiableFunction[]) activations.toArray());
}
private void initNet(int[] nodesNum, ScalarDifferentiableFunction[] activations) {
assert (nodesNum != null && activations != null
&& nodesNum.length == activations.length + 1 && nodesNum.length > 1);
this.nodesNum = nodesNum;
this.weights.clear();
this.bs.clear();
this.activations.clear();
for (int i = 0; i < nodesNum.length - 1; i++) {
// 列数==输入;行数==输出。
int columns = nodesNum[i];
int rows = nodesNum[i + 1];
double r1 = Math.sqrt(6) / Math.sqrt(rows + columns + 1);
//r1=0.001;
// W
DoubleMatrix weight = DoubleMatrix.rand(rows, columns).muli(2*r1).subi(r1);
//weight=DoubleMatrix.ones(rows, columns);
weights.add(weight);
// b
DoubleMatrix b = DoubleMatrix.zeros(rows, 1);
bs.add(b);
// activations
this.activations.add(activations[i]);
}
}
}
上帝造完了神经网络,去休息了。人说,我要使用神经网络,我要利用正向传播计算各层的结果,然后利用反向传播调整网络的状态,最后,我要让它能告诉我猎物在什么方向,花儿为什么这样香。
public class Propagation {
Net net;
public Propagation(Net net) {
super();
this.net = net;
}
// 多个样本。
public ForwardResult forward(DoubleMatrix input) {
ForwardResult result = new ForwardResult();
result.input = input;
DoubleMatrix currentResult = input;
int index = -1;
for (DoubleMatrix weight : net.weights) {
index++;
DoubleMatrix b = net.bs.get(index);
final ScalarDifferentiableFunction activation = net.activations
.get(index);
currentResult = weight.mmul(currentResult).addColumnVector(b);
result.netResult.add(currentResult);
// 乘以导数
DoubleMatrix derivative = MatrixUtil.applyNewElements(
new ScalarFunction() {
@Override
public double valueAt(double x) {
return activation.derivativeAt(x);
}
}, currentResult);
currentResult = MatrixUtil.applyNewElements(activation,
currentResult);
result.finalResult.add(currentResult);
result.derivativeResult.add(derivative);
}
result.netResult=null;// 不再需要。
return result;
}
// 多个样本梯度平均值。
public BackwardResult backward(DoubleMatrix target,
ForwardResult forwardResult) {
BackwardResult result = new BackwardResult();
DoubleMatrix cost = DoubleMatrix.zeros(1,target.columns);
DoubleMatrix output = forwardResult.finalResult
.get(forwardResult.finalResult.size() - 1);
DoubleMatrix outputDelta = DoubleMatrix.zeros(output.rows,
output.columns);
DoubleMatrix outputDerivative = forwardResult.derivativeResult
.get(forwardResult.derivativeResult.size() - 1);
DoubleMatrix accuracy = null;
if (net.accuracyFunc != null) {
accuracy = DoubleMatrix.zeros(1,target.columns);
}
for (int i = 0; i < target.columns; i++) {
CostFunction costFunc = net.costFunc.create(target.getColumn(i)
.toArray());
cost.put(i, costFunc.valueAt(output.getColumn(i).toArray()));
// System.out.println(i);
DoubleMatrix column1 = new DoubleMatrix(
costFunc.derivativeAt(output.getColumn(i).toArray()));
DoubleMatrix column2 = outputDerivative.getColumn(i);
outputDelta.putColumn(i, column1.muli(column2));
if (net.accuracyFunc != null) {
CostFunction accuracyFunc = net.accuracyFunc.create(target
.getColumn(i).toArray());
accuracy.put(i,
accuracyFunc.valueAt(output.getColumn(i).toArray()));
}
}
result.deltas.add(outputDelta);
result.cost = cost;
result.accuracy = accuracy;
for (int i = net.layersNum - 1; i >= 0; i--) {
DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1);
// 梯度计算,取所有样本平均
DoubleMatrix layerInput = i == 0 ? forwardResult.input
: forwardResult.finalResult.get(i - 1);
DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(
target.columns);
result.gradients.add(gradient);
// 偏置梯度
result.biasGradients.add(pdelta.rowMeans());
// 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。
DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);
if (i > 0)
delta = delta.muli(forwardResult.derivativeResult.get(i - 1));
result.deltas.add(delta);
}
Collections.reverse(result.gradients);
Collections.reverse(result.biasGradients);
//其它的delta都不需要。
DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);
result.deltas.clear();
result.deltas.add(inputDeltas);
return result;
}
public Net getNet() {
return net;
}
}
上面是一次正向/反向传播的具体代码。训练方式为批量训练,即所有样本一起训练。然而我们可以传入只有一列的input/target样本实现adapt方式的串行训练,也可以把样本分成很多批传入实现mini-batch方式的训练,这,不是Propagation要考虑的事情,它只是忠实的把传入的数据正向过一遍,反向过一遍,然后把过后的数据原封不动的返回给你。至于传入什么,以及结果怎么运用,是Trainer和Learner要做的事情。下回分解。
用java写bp神经网络(一)的更多相关文章
- 用java写bp神经网络(四)
接上篇. 在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider.这篇重构这个体系. Net 首先是Net,在上篇重新定义了激活函数和 ...
- 用java写bp神经网络(三)
孔子曰,吾日三省吾身.我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码.看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象.代码在不断的重构过程中 ...
- 用java写bp神经网络(二)
接上篇. Net和Propagation具备后,我们就可以训练了.训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量):什么时候调用学习师根据训练的结果进 ...
- python手写bp神经网络实现人脸性别识别1.0
写在前面:本实验用到的图片均来自google图片,侵删! 实验介绍 用python手写一个简单bp神经网络,实现人脸的性别识别.由于本人的机器配置比较差,所以无法使用网上很红的人脸大数据数据集(如lf ...
- JAVA实现BP神经网络算法
工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...
- java写卷积神经网络---CupCnn简介
https://blog.csdn.net/u011913612/article/details/79253450
- BP神经网络的手写数字识别
BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...
- 【机器学习】BP神经网络实现手写数字识别
最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...
- BP神经网络的直观推导与Java实现
人工神经网络模拟人体对于外界刺激的反应.某种刺激经过人体多层神经细胞传递后,可以触发人脑中特定的区域做出反应.人体神经网络的作用就是把某种刺激与大脑中的特定区域关联起来了,这样我们对于不同的刺激就可以 ...
随机推荐
- maven jetty plugin
转载:http://blog.163.com/xueling1231989@126/blog/static/1026408072013101311395492/ 前言: 在 maven 下测试调试时, ...
- 怎样在delphi中实现控件和窗体的拖拽
下面这2种方法都能实现对控件和窗体的拖拽 方法1 procedure TForm1.FormMouseDown(Sender: TObject; Button: TMouseButton; Shift ...
- Struts2 权限验证
之前的Struts2项目通过再Sitemesh的母版页中使用Struts的if标签进行了session判断,使得未登录的用户不能看到页面,但是这 种现仅仅在view层进行,如果未登录用户直接在地址栏输 ...
- 【Linux】鸟哥的Linux私房菜基础学习篇整理(七)
1. test命令的测试功能.测试的标志:(1)关于文件类型的检测 test [-efdbcSpL] filename-e:该文件名是否存在:-f:该文件名是否为文件:-d:该文件名是否为目录:-b: ...
- 自己动手实现Queue
前言: 看到许多面经说,有时候面试官要你自己当场用模板写出自己的vector容器.于是,我也琢磨着怎么自己动手写一个,可是本人才刚刚学C++模板编程不久,会的不多.不过,我恰好在C++ Primer上 ...
- spoj-694-Distinct Substrings(后缀数组)
题意: 给定一个字符串,求不相同的子串的个数 分析: 每个子串一定是某个后缀的前缀,那么原问题等价于求所有后缀之间的不相同 的 前 缀 的 个 数 . 如 果 所 有 的 后 缀 按 照 suffix ...
- Palindrome - POJ 3974 (最长回文子串,Manacher模板)
题意:就是求一个串的最长回文子串....输出长度. 直接上代码吧,没什么好分析的了. 代码如下: ================================================= ...
- 谈谈C#中的接口
接口的相关陈述 1.一个接口定义了一个契约. 2.接口可以包容方法.C#属性.事件.以及索引器. 3.在一个接口声明中,我们可以声明零个或者多个成员. 4.所有接口成员的默认访问类型都是public. ...
- win10在安装Oracle11g时出现了:[INS-13001]环境不满足最低要求,及未找到文件 E:\app\xxj\product\11.2.0\dbhome_1\owb\external\oc4j_applications\applications\WFMLRSVCApp.ear
win10安装Oracle11g碰到的3个问题: 1.win10在安装Oracle11g时出现了:[INS-13001]环境不满足最低要求 2.未找到文件 E:\app\xxj\product\11. ...
- 【算法与数据结构】在n个数中取第k大的数(基础篇)
(转载请注明出处:http://blog.csdn.net/buptgshengod) 题目介绍 在n个数中取第k大的数(基础篇),之所以叫基础篇是因为还有很多更高级的算法,这些 ...