jare用java实现了论文《Semi-Supervised Recursive Autoencoders for Predicting Sentiment Distributions》中提出的算法——基于半监督的递归自动编码机,用来预测情感分类。详情可查看论文内容,代码git地址为:https://github.com/sancha/jrae。

鸟瞰

主函数训练流程

FineTunableTheta tunedTheta = rae.train(params);// 根据参数和数据训练神经网络权重
tunedTheta.Dump(params.ModelFile); System.out.println("RAE trained. The model file is saved in "
+ params.ModelFile);
    // 特征抽取器
RAEFeatureExtractor fe = new RAEFeatureExtractor(params.EmbeddingSize,
tunedTheta, params.AlphaCat, params.Beta, params.CatSize,
params.Dataset.Vocab.size(), rae.f);
    // 获取训练数据
List<LabeledDatum<Double, Integer>> classifierTrainingData = fe
.extractFeaturesIntoArray(params.Dataset, params.Dataset.Data,
params.TreeDumpDir);
    // 测试精度
SoftmaxClassifier<Double, Integer> classifier = new SoftmaxClassifier<Double, Integer>();
Accuracy TrainAccuracy = classifier.train(classifierTrainingData);
System.out.println("Train Accuracy :" + TrainAccuracy.toString());

几个重要的接口以及实现类

1、Minimizer<T extends DifferentiableFunction>

public interface Minimizer<T extends DifferentiableFunction> {

  /**
* Attempts to find an unconstrained minimum of the objective
* <code>function</code> starting at <code>initial</code>, within
* <code>functionTolerance</code>.
*
* @param function the objective function
* @param functionTolerance a <code>double</code> value
* @param initial a initial feasible point
* @return Unconstrained minimum of function
*/
double[] minimize(T function, double functionTolerance, double[] initial);
double[] minimize(T function, double functionTolerance, double[] initial, int maxIterations); }

如其所述,该接口用来找到给定目标函数的最小化极值,目标函数必须是处处可微的,并实现DifferentiableFunction接口。functionTolerance是最小误差,initial是初始点,maxIterations是最大迭代次数。

public interface DifferentiableFunction extends Function {
double[] derivativeAt(double[] x);
} public interface Function {
int dimension();
double valueAt(double[] x);
}

QNMinimizer类实现了该接口,利用L-BFGS优化算法对目标函数进行优化,下面是算法的注释:

/**
* This code is part of the Stanford NLP Toolkit.
*
*
* An implementation of L-BFGS for Quasi Newton unconstrained minimization.
*
* The general outline of the algorithm is taken from: <blockquote> <i>Numerical
* Optimization</i> (second edition) 2006 Jorge Nocedal and Stephen J. Wright
* </blockquote> A variety of different options are available.
*
* <h3>LINESEARCHES</h3>
*
* BACKTRACKING: This routine simply starts with a guess for step size of 1. If
* the step size doesn't supply a sufficient decrease in the function value the
* step is updated through step = 0.1*step. This method is certainly simpler,
* but doesn't allow for an increase in step size, and isn't well suited for
* Quasi Newton methods.
*
* MINPACK: This routine is based off of the implementation used in MINPACK.
* This routine finds a point satisfying the Wolfe conditions, which state that
* a point must have a sufficiently smaller function value, and a gradient of
* smaller magnitude. This provides enough to prove theoretically quadratic
* convergence. In order to find such a point the linesearch first finds an
* interval which must contain a satisfying point, and then progressively
* reduces that interval all using cubic or quadratic interpolation.
*
*
* SCALING: L-BFGS allows the initial guess at the hessian to be updated at each
* step. Standard BFGS does this by approximating the hessian as a scaled
* identity matrix. To use this method set the scaleOpt to SCALAR. A better way
* of approximate the hessian is by using a scaling diagonal matrix. The
* diagonal can then be updated as more information comes in. This method can be
* used by setting scaleOpt to DIAGONAL.
*
*
* CONVERGENCE: Previously convergence was gauged by looking at the average
* decrease per step dividing that by the current value and terminating when
* that value because smaller than TOL. This method fails when the function
* value approaches zero, so two other convergence criteria are used. The first
* stores the initial gradient norm |g0|, then terminates when the new gradient
* norm, |g| is sufficiently smaller: i.e., |g| < eps*|g0| the second checks
* if |g| < eps*max( 1 , |x| ) which is essentially checking to see if the
* gradient is numerically zero.
*
* Each of these convergence criteria can be turned on or off by setting the
* flags: <blockquote><code>
* private boolean useAveImprovement = true;
* private boolean useRelativeNorm = true;
* private boolean useNumericalZero = true;
* </code></blockquote>
*
* To use the QNMinimizer first construct it using <blockquote><code>
* QNMinimizer qn = new QNMinimizer(mem, true)
* </code>
* </blockquote> mem - the number of previous estimate vector pairs to store,
* generally 15 is plenty. true - this tells the QN to use the MINPACK
* linesearch with DIAGONAL scaling. false would lead to the use of the criteria
* used in the old QNMinimizer class.
*/

OK,可以结合我前面文章,了解L-BFGS算法的原理,然后该类实现了这个算法,并且在某些细节上做了一些修改。具体的实现算法先略去不议,日后再说。

2、DifferentiableFunction

DifferentiableFunction定义上面已经给出,对应一个可微的函数。抽象类MemoizedDifferentiableFunction实现了这个接口,封装了一些通用的代码:

public abstract class MemoizedDifferentiableFunction implements DifferentiableFunction {
protected double[] prevQuery, gradient;
protected double value;
protected int evalCount; protected void initPrevQuery()
{
prevQuery = new double[ dimension() ];
} protected boolean requiresEvaluation(double[] x)
{
if(DoubleArrays.equals(x,prevQuery))
return false; System.arraycopy(x, 0, prevQuery, 0, x.length);
evalCount++;
return true;
} @Override
public double[] derivativeAt(double[] x){
if(DoubleArrays.equals(x,prevQuery))
return gradient;
valueAt(x);
return gradient;
}
}

封装的通用方法为,保存了上次请求的参数,如果传入参数已经被请求过,直接返回结果即可;保存了执行请求的次数;实现了求导流程,首先调用valueAt求得当前值$f(x)$,然后返回梯度(导数),valueAt由子类实现,即约定子类在计算$f(x)$的时候顺便计算好了$f'(x)$,然后保存到gradient变量中。

两个子类分别为RAECost和SoftmaxCost。

SoftmaxCost类表示,在给定样本的情况下,计算出给定权重的误差,导数指明减小误差的梯度。对应的是一个2层的网络,输入层为features(特征),输出层为label,并且转换函数为softmax(能量函数)。

RAECost类表示,在给定样本的情况下,计算出给定权重的误差,误差包括生成递归树的误差与label分类的误差只和,导数指明梯度,也是两者梯度之和。

在调用Minimizer接口进行优化时,传入的第一个参数即是RAECost对象,优化完毕时即是训练完毕时。

参考文献:

http://www.socher.org/index.php/Main/Semi-SupervisedRecursiveAutoencodersForPredictingSentimentDistributions

jrae源码解析(一)的更多相关文章

  1. jrae源码解析(二)

    本文细述上文引出的RAECost和SoftmaxCost两个类. SoftmaxCost 我们已经知道,SoftmaxCost类在给定features和label的情况下(超参数给定),衡量给定权重( ...

  2. 【原】Android热更新开源项目Tinker源码解析系列之三:so热更新

    本系列将从以下三个方面对Tinker进行源码解析: Android热更新开源项目Tinker源码解析系列之一:Dex热更新 Android热更新开源项目Tinker源码解析系列之二:资源文件热更新 A ...

  3. 【原】Android热更新开源项目Tinker源码解析系列之一:Dex热更新

    [原]Android热更新开源项目Tinker源码解析系列之一:Dex热更新 Tinker是微信的第一个开源项目,主要用于安卓应用bug的热修复和功能的迭代. Tinker github地址:http ...

  4. 【原】Android热更新开源项目Tinker源码解析系列之二:资源文件热更新

    上一篇文章介绍了Dex文件的热更新流程,本文将会分析Tinker中对资源文件的热更新流程. 同Dex,资源文件的热更新同样包括三个部分:资源补丁生成,资源补丁合成及资源补丁加载. 本系列将从以下三个方 ...

  5. 多线程爬坑之路-Thread和Runable源码解析之基本方法的运用实例

    前面的文章:多线程爬坑之路-学习多线程需要来了解哪些东西?(concurrent并发包的数据结构和线程池,Locks锁,Atomic原子类) 多线程爬坑之路-Thread和Runable源码解析 前面 ...

  6. jQuery2.x源码解析(缓存篇)

    jQuery2.x源码解析(构建篇) jQuery2.x源码解析(设计篇) jQuery2.x源码解析(回调篇) jQuery2.x源码解析(缓存篇) 缓存是jQuery中的又一核心设计,jQuery ...

  7. Spring IoC源码解析——Bean的创建和初始化

    Spring介绍 Spring(http://spring.io/)是一个轻量级的Java 开发框架,同时也是轻量级的IoC和AOP的容器框架,主要是针对JavaBean的生命周期进行管理的轻量级容器 ...

  8. jQuery2.x源码解析(构建篇)

    jQuery2.x源码解析(构建篇) jQuery2.x源码解析(设计篇) jQuery2.x源码解析(回调篇) jQuery2.x源码解析(缓存篇) 笔者阅读了园友艾伦 Aaron的系列博客< ...

  9. jQuery2.x源码解析(设计篇)

    jQuery2.x源码解析(构建篇) jQuery2.x源码解析(设计篇) jQuery2.x源码解析(回调篇) jQuery2.x源码解析(缓存篇) 这一篇笔者主要以设计的角度探索jQuery的源代 ...

随机推荐

  1. 【Java】在Eclipse中使用JUnit4进行单元测试(初级篇)

    本文绝大部分内容引自这篇文章: http://www.devx.com/Java/Article/31983/0/page/1 我们在编写大型程序的时候,需要写成千上万个方法或函数,这些函数的功能可能 ...

  2. DB2中SQLSTATE=57016 SQLCODE=-668

    执行 alter table DW_RPT.TRPT_JV_COGNOS_RPT add CENTER_CD varchar(10) ALTER TABLE DW_RPT.TRPT_JV_COGNOS ...

  3. Word Pattern

    ​package cn.edu.xidian.sselab.hashtable; import java.util.HashMap;import java.util.Map; /** *  * @au ...

  4. Spring MVC学习总结。

    公司项目用的Spring MVC.顺便学习学习. 其实框架并没有想象中的复杂.尤其对于初学者,总觉得SSH是一些很复杂的东西似的.其实对初学者来说能够用这些框架就足够了.在公司里也是,基本功能会用了就 ...

  5. OC中的野指针(僵尸指针)

    涉及到内存管理问题的都是类类型的变量,而在OC中我们操纵这些对象都是通过操纵指向他们的指针来完成的,一致很多时候会忽略指针存在.比如定义UIView * view = [[UIView alloc]i ...

  6. 开源库CImg 数据格式存储之二(RGB 顺序)

    在上一篇博客中已经初步说明了GDI和CImg数据的存储格式感谢博友 Imageshop 评论说明 CImg的说明文档中已有详细说明(详见上篇博客说明) CImg的数据格式确实是RRRGGGBBB顺序存 ...

  7. 高效算法——M 扫描法

    In an open credit system, the students can choose any course they like, but there is a problem. Some ...

  8. Android吧数据保存成xml文件

    public class MainActivity extends Activity { private List<Person> persons; @Override protected ...

  9. 软工UML学习札记

    UML模型由:事物.关系和图组成 (1)类(class)── 类用带有类名.属性和操作的矩形框来表示. (2)主动类(active class)── 主动类的实例应具有一个或多个进程或线程,能够启动控 ...

  10. DFS序 参考许昊然《数据结构漫谈》

    网上特别讲DFS序的东西好像很少 太简单了? 实用性不大? 看了论文中 7个经典问题, 觉得挺有用的 原文 "所谓DFS序, 就是DFS整棵树依次访问到的结点组成的序列" &quo ...