本文细述上文引出的RAECost和SoftmaxCost两个类。

SoftmaxCost

我们已经知道,SoftmaxCost类在给定features和label的情况下(超参数给定),衡量给定权重($hidden\times catSize$)的误差值$cost$,并指出当前的权重梯度。看代码。

@Override
public double valueAt(double[] x)
{
if( !requiresEvaluation(x) )
return value;
int numDataItems = Features.columns; int[] requiredRows = ArraysHelper.makeArray(0, CatSize-2);
ClassifierTheta Theta = new ClassifierTheta(x,FeatureLength,CatSize);
DoubleMatrix Prediction = getPredictions (Theta, Features); double MeanTerm = 1.0 / (double) numDataItems;
double Cost = getLoss (Prediction, Labels).sum() * MeanTerm;
double RegularisationTerm = 0.5 * Lambda * DoubleMatrixFunctions.SquaredNorm(Theta.W); DoubleMatrix Diff = Prediction.sub(Labels).muli(MeanTerm);
DoubleMatrix Delta = Features.mmul(Diff.transpose()); DoubleMatrix gradW = Delta.getColumns(requiredRows);
DoubleMatrix gradb = ((Diff.rowSums()).getRows(requiredRows)); //Regularizing. Bias does not have one.
gradW = gradW.addi(Theta.W.mul(Lambda)); Gradient = new ClassifierTheta(gradW,gradb);
value = Cost + RegularisationTerm;
gradient = Gradient.Theta;
return value;
} public DoubleMatrix getPredictions (ClassifierTheta Theta, DoubleMatrix Features)
    {
        int numDataItems = Features.columns;
        DoubleMatrix Input = ((Theta.W.transpose()).mmul(Features)).addColumnVector(Theta.b);
        Input = DoubleMatrix.concatVertically(Input, DoubleMatrix.zeros(1,numDataItems));
        return Activation.valueAt(Input);
    }

是个典型的2层神经网络,没有隐层,首先根据features预测labels,预测结果用softmax归一化,然后根据误差反向传播算出权重梯度。

此处增加200字。

这个典型的2层神经网络,label为一列向量,目标label置1,其余为0;转换函数为softmax函数,输出为每个label的概率。

计算cost的函数为getLoss,假设目标label的预测输出为$p^*$,则每个样本的cost也即误差函数为:

$$cost=E(p^*)=-\log(p^*)$$

根据前述的神经网络后向传播算法,我们得到($j$为目标label时,否则为0):

$$\frac{\partial E}{\partial w_{ij}}=\frac{\partial E}{\partial p_j}\frac{\partial h_j}{\partial net_j}x_i=-\frac{1}{p_j}p_j(1-p_j)x_i=-(1-p_j)x_i=-(label_j-p_j)feature_i$$

因此我们便理解了下面代码的含义:

DoubleMatrix Delta = Features.mmul(Diff.transpose());

RAECost

先看实现代码:

@Override
public double valueAt(double[] x)
{
if(!requiresEvaluation(x))
return value; Theta Theta1 = new Theta(x,hiddenSize,visibleSize,dictionaryLength);
FineTunableTheta Theta2 = new FineTunableTheta(x,hiddenSize,visibleSize,catSize,dictionaryLength);
Theta2.setWe( Theta2.We.add(WeOrig) ); final RAEClassificationCost classificationCost = new RAEClassificationCost(
catSize, AlphaCat, Beta, dictionaryLength, hiddenSize, Lambda, f, Theta2);
final RAEFeatureCost featureCost = new RAEFeatureCost(
AlphaCat, Beta, dictionaryLength, hiddenSize, Lambda, f, WeOrig, Theta1); Parallel.For(DataCell,
new Parallel.Operation<LabeledDatum<Integer,Integer>>() {
public void perform(int index, LabeledDatum<Integer,Integer> Data)
{
try {
LabeledRAETree Tree = featureCost.Compute(Data);
classificationCost.Compute(Data, Tree);
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}); double costRAE = featureCost.getCost();
double[] gradRAE = featureCost.getGradient().clone(); double costSUP = classificationCost.getCost();
gradient = classificationCost.getGradient(); value = costRAE + costSUP;
for(int i=0; i<gradRAE.length; i++)
gradient[i] += gradRAE[i]; System.gc(); System.gc();
System.gc(); System.gc();
System.gc(); System.gc();
System.gc(); System.gc(); return value;
}

cost由两部分组成,featureCost和classificationCost。程序遍历每个样本,用featureCost.Compute(Data)生成一个递归树,同时累加cost和gradient,然后用classificationCost.Compute(Data, Tree)根据生成的树计算并累加cost和gradient。因此关键类为RAEFeatureCost和RAEClassificationCost。

RAEFeatureCost类在Compute函数中调用RAEPropagation的ForwardPropagate函数生成一棵树,然后调用BackPropagate计算梯度并累加。具体的算法过程,下一章分解。

jrae源码解析(二)的更多相关文章

  1. Mybatis源码解析(二) —— 加载 Configuration

    Mybatis源码解析(二) -- 加载 Configuration    正如上文所看到的 Configuration 对象保存了所有Mybatis的配置信息,也就是说mybatis-config. ...

  2. RxJava2源码解析(二)

    title: RxJava2源码解析(二) categories: 源码解析 tags: 源码解析 rxJava2 前言 本篇主要解析RxJava的线程切换的原理实现 subscribeOn 首先, ...

  3. Sentinel源码解析二(Slot总览)

    写在前面 本文继续来分析Sentinel的源码,上篇文章对Sentinel的调用过程做了深入分析,主要涉及到了两个概念:插槽链和Node节点.那么接下来我们就根据插槽链的调用关系来依次分析每个插槽(s ...

  4. iOS即时通讯之CocoaAsyncSocket源码解析二

    原文 前言 本文承接上文:iOS即时通讯之CocoaAsyncSocket源码解析一 上文我们提到了GCDAsyncSocket的初始化,以及最终connect之前的准备工作,包括一些错误检查:本机地 ...

  5. jQuery 源码解析二:jQuery.fn.extend=jQuery.extend 方法探究

    终于动笔开始 jQuery 源码解析第二篇,写文章还真是有难度,要把自已懂的表述清楚,要让别人听懂真的不是一见易事. 在 jQuery 源码解析一:jQuery 类库整体架构设计解析 一文,大致描述了 ...

  6. Common.Logging源码解析二

    Common.Logging源码解析一分析了LogManager主入口的整个逻辑,其中第二步生成日志实例工厂类接口分析的很模糊,本随笔将会详细讲解整个日志实例工厂类接口的生成过程! (1).关于如何生 ...

  7. erlang下lists模块sort(排序)方法源码解析(二)

    上接erlang下lists模块sort(排序)方法源码解析(一),到目前为止,list列表已经被分割成N个列表,而且每个列表的元素是有序的(从大到小) 下面我们重点来看看mergel和rmergel ...

  8. element-ui 源码解析 二

    Carousel 走马灯源码解析 1. 基本原理:页面切换 页面切换使用的是 transform 2D 转换和 transition 过渡 可以看出是采用内联样式来实现的 举个栗子 <div : ...

  9. ArrayList源码解析(二)

    欢迎转载,转载烦请注明出处,谢谢. https://www.cnblogs.com/sx-wuyj/p/11177257.html 自己学习ArrayList源码的一些心得记录. 继续上一篇,Arra ...

  10. React的Component,PureComponent源码解析(二)

    1.什么是Component,PureComponent? 都是class方式定义的基类,两者没有什么大的区别,只是PureComponent内部使用shouldComponentUpdate(nex ...

随机推荐

  1. Jenkins 快速搭建持续集成环境

    持续集成概述 什么是持续集成 随着软件开发复杂度的不断提高,团队开发成员间如何更好地协同工作以确保软件开发的质量已经慢慢成为开发过程中不可回避的问题.尤其是近些年来,敏捷(Agile) 在软件工程领域 ...

  2. android-wear开发之定义布局

    Android Wear使用跟手机一样的布局技术,但需要对特定情况进行设计.不要把手机的UI直接照搬过来.更多可查看:Android Wear Design Guidelines 当创建android ...

  3. JDK安装配置与升级

    一.jdk1.4卸载 Redhat Enterprise 5 中自带安装了jdk1.4,在安装jdk1.6前,把jdk1.4卸载: 1. 首先查看系统自带的JDK版本: [root@linux ~]# ...

  4. soap协议

    定义: 简单对象访问协议是交换数据的一种协议规范,是一种轻量的.简单的.基于XML(标准通用标记语言下的一个子集)的协议,它被设计成在WEB上交换结构化的和固化的信息. 协议中的几个关键词术语: SO ...

  5. 根据标点符号分行,StringBuilder的使用;将字符串的每个字符颠倒输出,Reverse的使用

    一:根据标点符号分行,上图,代码很简单 二:代码 using System; using System.Collections.Generic; using System.ComponentModel ...

  6. 《algorithm puzzles》——概述

    这个专题我们开始对<algorithm puzzles>一书的学习,这本书是一本谜题集,包括一些数学与计算机起源性的古典命题和一些比较新颖的谜题,序章的几句话非常好,在这里做简单的摘录. ...

  7. SRM 404(1-250pt, 1-500pt)

    DIV1 250pt 题意:对于1-9数字三角形如下图,设其为a[i][j],则a[i][j] = (a[i-1][j] + a[i-1][j+1]) % 10.现在对于某个数字三角形, 每行告诉你某 ...

  8. J - Fire!

    题目大意: 这是一个放火逃生的游戏,就是给出来一个迷宫,迷宫里面有人‘J’和火焰‘F’当然这些火焰可能不止一处,然后问这个人最快从迷宫里面逃出来需要多久 /////////////////////// ...

  9. @property中有哪些属性关键字?/ @property 后面可以有哪些修饰符?

    出题者简介: 孙源(sunnyxx),目前就职于百度 整理者简介:陈奕龙(子循),目前就职于滴滴出行. 转载者:豆电雨(starain)微信:doudianyu 属性可以拥有的特质分为四类: 原子性- ...

  10. artTemplate的使用总结

    原生语法 使用原生语法,需要导入template-native.js文件. 在HTML中定义模板,注意模板的位置,不要放到被渲染区域,防止模板丢失. <script id="main_ ...