本文细述上文引出的RAECost和SoftmaxCost两个类。

SoftmaxCost

我们已经知道。SoftmaxCost类在给定features和label的情况下(超參数给定),衡量给定权重(hidden×catSize)的误差值cost,并指出当前的权重梯度。看代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@Override
    public double valueAt(double[]
x)
    {
        if(
!requiresEvaluation(x) )
            return value;
        int numDataItems
= Features.columns;
         
        int[]
requiredRows = ArraysHelper.makeArray(
0,
CatSize-
2);
        ClassifierTheta
Theta =
new ClassifierTheta(x,FeatureLength,CatSize);
        DoubleMatrix
Prediction = getPredictions (Theta, Features);
         
        double MeanTerm
=
1.0 /
(
double)
numDataItems;
        double Cost
= getLoss (Prediction, Labels).sum() * MeanTerm;
        double RegularisationTerm
=
0.5 *
Lambda * DoubleMatrixFunctions.SquaredNorm(Theta.W);
         
        DoubleMatrix
Diff = Prediction.sub(Labels).muli(MeanTerm);
        DoubleMatrix
Delta = Features.mmul(Diff.transpose());
     
        DoubleMatrix
gradW = Delta.getColumns(requiredRows);
        DoubleMatrix
gradb = ((Diff.rowSums()).getRows(requiredRows));
         
        //Regularizing.
Bias does not have one.
        gradW
= gradW.addi(Theta.W.mul(Lambda));
         
        Gradient
=
new ClassifierTheta(gradW,gradb);
        value
= Cost + RegularisationTerm;
        gradient
= Gradient.Theta;
        return value;
    }<br><br>public DoubleMatrix
getPredictions (ClassifierTheta Theta, DoubleMatrix Features)<br>    {<br>        
int numDataItems
= Features.columns;<br>        DoubleMatrix Input = ((Theta.W.transpose()).mmul(Features)).addColumnVector(Theta.b);<br>        Input = DoubleMatrix.concatVertically(Input, DoubleMatrix.zeros(
1,numDataItems));<br>  
     
return Activation.valueAt(Input);
<br>    }

是个典型的2层神经网络,没有隐层,首先依据features预測labels,预測结果用softmax归一化,然后依据误差反向传播算出权重梯度。

此处添加200字。

这个典型的2层神经网络,label为一列向量,目标label置1,其余为0;转换函数为softmax函数,输出为每一个label的概率。

计算cost的函数为getLoss。如果目标label的预測输出为p∗,则每一个样本的cost也即误差函数为:

cost=E(p∗)=−log(p∗)

依据前述的神经网络后向传播算法,我们得到(j为目标label时,否则为0):

∂E∂wij=∂E∂pj∂hj∂netjxi=−1pjpj(1−pj)xi=−(1−pj)xi=−(labelj−pj)featurei

因此我们便理解了以下代码的含义:

1
DoubleMatrix
Delta = Features.mmul(Diff.transpose());

RAECost

先看实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@Override
    public double valueAt(double[]
x)
    {
        if(!requiresEvaluation(x))
            return value;
         
        Theta
Theta1 =
new Theta(x,hiddenSize,visibleSize,dictionaryLength);
        FineTunableTheta
Theta2 =
new FineTunableTheta(x,hiddenSize,visibleSize,catSize,dictionaryLength);
        Theta2.setWe(
Theta2.We.add(WeOrig) );
         
        final RAEClassificationCost
classificationCost =
new RAEClassificationCost(
                catSize,
AlphaCat, Beta, dictionaryLength, hiddenSize, Lambda, f, Theta2);
        final RAEFeatureCost
featureCost =
new RAEFeatureCost(
                AlphaCat,
Beta, dictionaryLength, hiddenSize, Lambda, f, WeOrig, Theta1);
     
        Parallel.For(DataCell,
            new Parallel.Operation<LabeledDatum<Integer,Integer>>()
{
                public void perform(int index,
LabeledDatum<Integer,Integer> Data)
                {
                    try {
                        LabeledRAETree
Tree = featureCost.Compute(Data);
                        classificationCost.Compute(Data,
Tree);                
                    }
catch (Exception
e) {
                        System.err.println(e.getMessage());
                    }
                }
        });
         
        double costRAE
= featureCost.getCost();
        double[]
gradRAE = featureCost.getGradient().clone();
             
        double costSUP
= classificationCost.getCost();
        gradient
= classificationCost.getGradient();
             
        value
= costRAE + costSUP;
        for(int i=0;
i<gradRAE.length; i++)
            gradient[i]
+= gradRAE[i];
         
        System.gc();   
System.gc();
        System.gc();   
System.gc();
        System.gc();   
System.gc();
        System.gc();   
System.gc();
         
        return value;
    }

cost由两部分组成,featureCost和classificationCost。程序遍历每一个样本,用featureCost.Compute(Data)生成一个递归树,同一时候累加cost和gradient。然后用classificationCost.Compute(Data, Tree)依据生成的树计算并累加cost和gradient。因此关键类为RAEFeatureCost和RAEClassificationCost。

RAEFeatureCost类在Compute函数中调用RAEPropagation的ForwardPropagate函数生成一棵树。然后调用BackPropagate计算梯度并累加。详细的算法过程。下一章分解。

jrae源代码解析(二)的更多相关文章

  1. Spring源代码解析

    Spring源代码解析(一):IOC容器:http://www.iteye.com/topic/86339 Spring源代码解析(二):IoC容器在Web容器中的启动:http://www.itey ...

  2. Spring源代码解析(收藏)

    Spring源代码解析(收藏)   Spring源代码解析(一):IOC容器:http://www.iteye.com/topic/86339 Spring源代码解析(二):IoC容器在Web容器中的 ...

  3. C#使用zxing,zbar,thoughtworkQRcode解析二维码,附源代码

    最近做项目需要解析二维码图片,找了一大圈,发现没有人去整理下开源的几个库案例,花了点时间 做了zxing,zbar和thoughtworkqrcode解析二维码案例,希望大家有帮助. zxing是谷歌 ...

  4. NIO框架之MINA源代码解析(二):mina核心引擎

    NIO框架之MINA源代码解析(一):背景 MINA的底层还是利用了jdk提供了nio功能,mina仅仅是对nio进行封装.包含MINA用的线程池都是jdk直接提供的. MINA的server端主要有 ...

  5. SDWebImage源代码解析(二)

    上一篇:SDWebImage源代码解析(一) 2.缓存 为了降低网络流量的消耗.我们都希望下载下来的图片缓存到本地.下次再去获取同一张图片时.能够直接从本地获取,而不再从远程server获取.这样做的 ...

  6. redis之字符串命令源代码解析(二)

    形象化设计模式实战             HELLO!架构                     redis命令源代码解析 在redis之字符串命令源代码解析(一)中讲了get的简单实现,并没有对 ...

  7. asp.net C#生成和解析二维码代码

    类库文件我们在文件最后面下载 [ThoughtWorks.QRCode.dll 就是类库] 使用时需要增加: using ThoughtWorks.QRCode.Codec;using Thought ...

  8. Fixflow引擎解析(二)(模型) - BPMN2.0读写

    Fixflow引擎解析(四)(模型) - 通过EMF扩展BPMN2.0元素 Fixflow引擎解析(三)(模型) - 创建EMF模型来读写XML文件 Fixflow引擎解析(二)(模型) - BPMN ...

  9. Arrays.sort源代码解析

    Java Arrays.sort源代码解析 Java Arrays中提供了对所有类型的排序.其中主要分为Primitive(8种基本类型)和Object两大类. 基本类型:采用调优的快速排序: 对象类 ...

随机推荐

  1. [unity菜鸟] 笔记1 —— 函数篇

    SendMessage() 调用其他物体中的指令,先在脚本中编写一个自定义的函数,然后使用SendMessage()命令来调用那个物体上的命令 //①将以下函数附给target对象 void Rena ...

  2. 第九章 Mass Storage设备

    9.1 Mass Storage设备介绍 USB的Mass Storage类是USB大容量储存设备类(Mass Storage Device Class).专门用于大容量存储设备,比如U盘.移动硬盘. ...

  3. 170. Two Sum III - Data structure design

    题目: Design and implement a TwoSum class. It should support the following operations: add and find. a ...

  4. Linux 线程优先级

    http://www.cnblogs.com/imapla/p/4234258.html http://blog.csdn.net/lanseshenhua/article/details/55247 ...

  5. Android开发UI之自定义动画

    自定义动画,需要新建一个类,继承Animation类. 重写applyTransformation()方法和initialize()方法. applyTransformation(float inte ...

  6. 【转】【iOS知识学习】_视图控制对象生命周期-init、viewDidLoad、viewWillAppear、viewDidAppear、viewWillDisappear等的区别及用途

    原文网址:http://blog.csdn.net/weasleyqi/article/details/8090373 iOS视图控制对象生命周期-init.viewDidLoad.viewWillA ...

  7. 【转】Android 4.3源码的下载和编译环境的安装及编译

    原文网址:http://jingyan.baidu.com/article/c85b7a641200e0003bac95a3.html  告诉windows用户一个不好的消息,windows环境下没法 ...

  8. uboot环境变量(设置bootargs向linux内核传递正确的参数)

    这是我uboot的环境变量设置,在该设置下可以运行initram内核(从内存下载到nandflash再运行),但是运行nfs根文件系统的时候一直出错,各种错误.查看了很多资料后猜想应该是uboot传递 ...

  9. WebService优点和缺点小结

    最近做的几个项目都用到了webservice,通过自己的实践和网上资料的汇总,现在做个小结:        当前WebService是一个热门话题.但是,WebService究竟是什么?,WebSer ...

  10. zookeeper 客户端编程

    zookeeper是一个分布式的开源的分布式协调服务,用它可以来现同步服务,配置维护.zookeeper的稳定性也是可以保证的,笔者曾参与过的使用zookeeper的两个应用,一个是用zookeepe ...