接上篇。

在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider。这篇重构这个体系。

Net

首先是Net,在上篇重新定义了激活函数和误差函数后,内容大致是这样的:

List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
List<DoubleMatrix> bs = new ArrayList<>();
List<ActivationFunction> activations = new ArrayList<>();
CostFunction costFunc;
CostFunction accuracyFunc;
int[] nodesNum;
int layersNum; public CompactDoubleMatrix getCompact(){
return new CompactDoubleMatrix(this.weights,this.bs);
}

函数getCompact()生成对应的超矩阵。

DataProvider

DataProvider是数据的提供者。

public interface DataProvider {
DoubleMatrix getInput();
DoubleMatrix getTarget();
}

如果输入为向量,还包含一个向量字典。

public interface DictDataProvider extends DataProvider {
public DoubleMatrix getIndexs();
public DoubleMatrix getDict();
}

每一列为一个样本。getIndexs()返回输入向量在字典中的索引。

我写了一个有用的类BatchDataProviderFactory来对样本进行批量分割,分割成minibatch。

int batchSize;
int dataLen;
DataProvider originalProvider;
List<Integer> endPositions;
List<DataProvider> providers; public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) {
super();
this.batchSize = batchSize;
this.originalProvider = originalProvider;
this.dataLen = this.originalProvider.getTarget().columns;
this.initEndPositions();
this.initProviders();
} public BatchDataProviderFactory(DataProvider originalProvider) {
this(4, originalProvider);
} public List<DataProvider> getProviders() {
return providers;
}

batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始数据为originalProvider。

Propagation

Propagation负责对神经网络的正向传播过程和反向传播过程。接口定义如下:

public interface Propagation {
public PropagationResult propagate(Net net,DataProvider provider);
}

传播函数propagate用指定数据对指定网络进行传播操作,返回执行结果。

BasePropagation实现了该接口,实现了简单的反向传播:

public class BasePropagation implements Propagation{

	// 多个样本。
protected ForwardResult forward(Net net,DoubleMatrix input) { ForwardResult result = new ForwardResult();
result.input = input;
DoubleMatrix currentResult = input;
int index = -1;
for (DoubleMatrix weight : net.weights) {
index++;
DoubleMatrix b = net.bs.get(index);
final ActivationFunction activation = net.activations
.get(index);
currentResult = weight.mmul(currentResult).addColumnVector(b);
result.netResult.add(currentResult); // 乘以导数
DoubleMatrix derivative = activation.derivativeAt(currentResult);
result.derivativeResult.add(derivative); currentResult = activation.valueAt(currentResult);
result.finalResult.add(currentResult); } result.netResult=null;// 不再需要。 return result;
} // 多个样本梯度平均值。
protected BackwardResult backward(Net net,DoubleMatrix target,
ForwardResult forwardResult) {
BackwardResult result = new BackwardResult(); DoubleMatrix output = forwardResult.getOutput();
DoubleMatrix outputDerivative = forwardResult.getOutputDerivative(); result.cost = net.costFunc.valueAt(output, target);
DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative);
if (net.accuracyFunc != null) {
result.accuracy=net.accuracyFunc.valueAt(output, target);
} result.deltas.add(outputDelta);
for (int i = net.layersNum - 1; i >= 0; i--) {
DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1); // 梯度计算,取所有样本平均
DoubleMatrix layerInput = i == 0 ? forwardResult.input
: forwardResult.finalResult.get(i - 1);
DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(
target.columns);
result.gradients.add(gradient);
// 偏置梯度
result.biasGradients.add(pdelta.rowMeans()); // 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。
DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);
if (i > 0)
delta = delta.muli(forwardResult.derivativeResult.get(i - 1));
result.deltas.add(delta);
}
Collections.reverse(result.gradients);
Collections.reverse(result.biasGradients); //其它的delta都不需要。
DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);
result.deltas.clear();
result.deltas.add(inputDeltas); return result;
} @Override
public PropagationResult propagate(Net net, DataProvider provider) {
ForwardResult forwardResult=this.forward(net, provider.getInput());
BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult);
PropagationResult result=new PropagationResult(backwardResult);
result.output=forwardResult.getOutput();
return result;
}

我们定义的PropagationResult略为:

public class PropagationResult{
DoubleMatrix output;// 输出结果矩阵:outputLen*sampleLength
DoubleMatrix cost;// 误差矩阵:1*sampleLength
DoubleMatrix accuracy;// 准确度矩阵:1*sampleLength
private List<DoubleMatrix> gradients;// 权重梯度矩阵
private List<DoubleMatrix> biasGradients;// 偏置梯度矩阵
DoubleMatrix inputDeltas;//输入层delta矩阵:inputLen*sampleLength public CompactDoubleMatrix getCompact(){
return new CompactDoubleMatrix(gradients,biasGradients);
} }

另一个实现了该接口的类为MiniBatchPropagation。他在内部用并行方式对样本进行传播,然后对每个minipatch结果进行综合,内部用到了BatchDataProviderFactory类和BasePropagation类。

Trainer

Trainer接口定义为:

public interface Trainer {
public void train(Net net,DataProvider provider);
}

简单的实现类为:

public class CommonTrainer implements Trainer {
int ecophs;
Learner learner;
Propagation propagation;
List<Double> costs = new ArrayList<>();
List<Double> accuracys = new ArrayList<>();
public void trainOne(Net net, DataProvider provider) {
PropagationResult propResult = this.propagation
.propagate(net, provider);
learner.learn(net, propResult, provider); Double cost = propResult.getMeanCost();
Double accuracy = propResult.getMeanAccuracy();
if (cost != null)
costs.add(cost);
if (accuracy != null)
accuracys.add(accuracy);
} @Override
public void train(Net net, DataProvider provider) {
for (int i = 0; i < this.ecophs; i++) {
System.out.println("echops:"+i);
this.trainOne(net, provider);
} }
}

简单的迭代echops此,没有智能停止功能,每次迭代用Learner调节权重。

Learner

Learner根据每次传播结果对网络权重进行调整,接口定义如下:

public interface Learner<N extends Net,P extends DataProvider> {
public void learn(N net,PropagationResult propResult,P provider);
}

一个简单的根据动量因子-自适应学习率进行调整的实现类为:

public class MomentAdaptLearner<N extends Net, P extends DataProvider>
implements Learner<N, P> {
double moment = 0.7;
double lmd = 1.05;
double preCost = 0;
double eta = 0.01;
double currentEta = eta;
double currentMoment = moment;
CompactDoubleMatrix preGradient; public MomentAdaptLearner(double moment, double eta) {
super();
this.moment = moment;
this.eta = eta;
this.currentEta = eta;
this.currentMoment = moment;
} public MomentAdaptLearner() { } @Override
public void learn(N net, PropagationResult propResult, P provider) {
if (this.preGradient == null)
init(net, propResult, provider); double cost = propResult.getMeanCost();
this.modifyParameter(cost);
System.out.println("current eta:" + this.currentEta);
System.out.println("current moment:" + this.currentMoment);
this.updateGradient(net, propResult, provider); } public void updateGradient(N net, PropagationResult propResult, P provider) {
CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult,
provider);
CompactDoubleMatrix gradCompact = this.getGradientCompact(net,
propResult, provider);
gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi(
preGradient.mul(currentMoment));
netCompact.subi(gradCompact);
this.preGradient = gradCompact;
} public CompactDoubleMatrix getNetCompact(N net,
PropagationResult propResult, P provider) {
return net.getCompact();
} public CompactDoubleMatrix getGradientCompact(N net,
PropagationResult propResult, P provider) {
return propResult.getCompact();
} public void modifyParameter(double cost) { if (this.currentEta > 10) {
this.currentEta = 10;
} else if (this.currentEta < 0.0001) {
this.currentEta = 0.0001;
} else if (cost < this.preCost) {
this.currentEta *= 1.05;
this.currentMoment = moment;
} else if (cost < 1.04 * this.preCost) {
this.currentEta *= 0.7;
this.currentMoment *= 0.7;
} else {
this.currentEta = eta;
this.currentMoment = 0.1;
}
this.preCost = cost;
} public void init(Net net, PropagationResult propResult, P provider) {
PropagationResult pResult = new PropagationResult(net);
preGradient = pResult.getCompact().dup();
} }

在上面的代码中,我们可以看到CompactDoubleMatrix类对权重自变量的封装,使代码更加简洁,它在此表现出来的就是一个超矩阵,超向量,完全忽略了内部的结构。

同时,其子类实现了同步更新字典的功能,代码也很简洁,只是简单的把需要调整的矩阵append到超矩阵中去即可,在父类中会统一对其进行调整:

public class DictMomentLearner extends
MomentAdaptLearner<Net, DictDataProvider> { public DictMomentLearner(double moment, double eta) {
super(moment, eta);
} public DictMomentLearner() {
super();
} @Override
public CompactDoubleMatrix getNetCompact(Net net,
PropagationResult propResult, DictDataProvider provider) {
CompactDoubleMatrix result = super.getNetCompact(net, propResult,
provider);
result.append(provider.getDict());
return result;
} @Override
public CompactDoubleMatrix getGradientCompact(Net net,
PropagationResult propResult, DictDataProvider provider) {
CompactDoubleMatrix result = super.getGradientCompact(net, propResult,
provider);
result.append(DictUtil.getDictGradient(provider, propResult));
return result;
} @Override
public void init(Net net, PropagationResult propResult,
DictDataProvider provider) {
DoubleMatrix preDictGradient = DoubleMatrix.zeros(
provider.getDict().rows, provider.getDict().columns);
super.init(net, propResult, provider);
this.preGradient.append(preDictGradient);
}
}

用java写bp神经网络(四)的更多相关文章

  1. 用java写bp神经网络(一)

    根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...

  2. 用java写bp神经网络(三)

    孔子曰,吾日三省吾身.我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码.看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象.代码在不断的重构过程中 ...

  3. 用java写bp神经网络(二)

    接上篇. Net和Propagation具备后,我们就可以训练了.训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量):什么时候调用学习师根据训练的结果进 ...

  4. python手写bp神经网络实现人脸性别识别1.0

    写在前面:本实验用到的图片均来自google图片,侵删! 实验介绍 用python手写一个简单bp神经网络,实现人脸的性别识别.由于本人的机器配置比较差,所以无法使用网上很红的人脸大数据数据集(如lf ...

  5. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  6. java写卷积神经网络---CupCnn简介

    https://blog.csdn.net/u011913612/article/details/79253450

  7. 【机器学习】BP神经网络实现手写数字识别

    最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...

  8. BP神经网络—java实现(转载)

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  9. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

随机推荐

  1. 14.5.5 Deadlocks in InnoDB

    14.5.5 Deadlocks in InnoDB 14.5.5.1 An InnoDB Deadlock Example 14.5.5.2 Deadlock Detection and Rollb ...

  2. ASP.NET MVC 后台接收集合参数和 jquery ajax 传值

    MVC 接收参数数组(集合)   示例样本:   public class Person {      public string FirstName { get; set; }      publi ...

  3. Linux Shell编程(17)——嵌套循环

    嵌套循环就是在一个循环中还有一个循环,内部循环在外部循环体中.在外部循环的每次执行过程中都会触发内部循环,直到内部循环执行结束.外部循环执行了多少次,内部循环就完成多少次.当然,不论是外部循环或内部循 ...

  4. (java) Merge k sorted linked lists and return it as one sorted list. Analyze and describe its complexity.

    /** * Definition for singly-linked list. * public class ListNode { * int val; * ListNode next; * Lis ...

  5. Android WebRTC 音视频开发总结

    www.cnblogs.com/lingyunhu/p/3621057.html 前面介绍了WebRTCDemo的基本结构,本节主要介绍WebRTC音视频服务端的处理,,转载请说明出处(博客园RTC. ...

  6. 分布式系统里session同步的那些事儿

    几周前,有个盆友问老王,说现在有多台服务器,怎么样来解决这些服务器间的session同步问题?老王一下就来精神了,因为在n年以前,老王还在学校和几个同学一起所谓创业的时候,也遇到了类似的问题.当时查了 ...

  7. 佛山Uber优步司机奖励政策(1月18日~1月24日)

    滴快车单单2.5倍,注册地址:http://www.udache.com/ 如何注册Uber司机(全国版最新最详细注册流程)/月入2万/不用抢单:http://www.cnblogs.com/mfry ...

  8. How to install Python 2.7 and Python 3.3 on CentOS 6

    原文地址:http://toomuchdata.com/2014/02/16/how-to-install-python-on-centos/

  9. Index of super-prime - SGU 116(素数+背包)

    题目大意:素数表2,3,5,7,11.....如果一个素数所在的位置还是素数,那么这个素数就是超级素数,比如3在第2位置,那么3就是超级素数.....现在给你一个数,求出来这个数由最少的超级素数的和组 ...

  10. The equation - SGU 106(扩展欧几里得)

    题目大意:有一个二元一次方程,给出系数值和x与y的取值范围,求出来总共有多少对整数解. 分析:有以下几点情况. 1,系数a=0, b=0, 当c != 0的时候结果很明显是无解,当c=0的时候x,y可 ...