一.论文《QuickScorer:a Fast Algorithm to Rank Documents with Additive Ensembles of Regression Trees》是为了解决LTR模型的预测问题,如果LTR中的LambdaMart在生成模型时产生的树数和叶结点过多,在对样本打分预测时会遍历每棵树,这样在线上使用时效率较慢,这篇文章主要就是利用了bitvector方法加速打分预测。代码我找了很久没找到开源的,后来无意中在Solr ltr中看到被改动过了的源码,不过这个源码集成在solr中,这里暂时贴出来,后期再剥离出,集成到ranklib中,以便使用。

二.图片解说

1. Ensemble trees原始打分过程

像gbdt,lambdamart,xgboost或lightgbm等这样的集成树模型在打分预测阶段,比如来了一个样本,这个样本是vector形式输入到每一棵树中,然后在每棵树中像if else这样的过程走到或映射到每棵树的一个节点中,这个节点就是每棵树的打分,然后将每棵树的打分乘上学习率(shrinkage)加和就是此样本的预测分。

2.论文中提到的打分过程

A.为回归树中的每个分枝打上true和false标签

比如图中样本X=[0.2,1.1,0.2],在回归树的branch中判断X[0],X[1],X[2]的true和false,比如图中根结点X[1]<=1.0,但样本X[1]=1.1,所以是false(走左边是true,右边是false),这样将所有branch打上true和false标签(可以直接打上false标志,不用考虑true),后面需要用到所有的false branch。

B.为每个branch分配一个bitvector

这个bitvector中的"0"表示true leaves,比如"001111"表示6个叶结点中的最左边两个叶结点是候选节点。“110011”表示在右子树中true的结点只有中间两个,作为候选结点。

C.打分阶段

此阶段是最后的打分预测阶段,根据前几个图的过程,将所有branch为false的bitvector按位与操作,就会得出样本落在哪个叶结点上。比如图中的结果是"001101",最左边为1的便是最终的叶结点的编号,每个回归树都会这样操作得到预测值,乘上学习率(shrinkage)然后加和就会得到一个样本的预测值。

三.代码

 import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils; import java.util.*; public class MultipleAdditiveTreesModel extends LTRScoringModel { // 特征名:索引(从0开始)
private final HashMap<String, Integer> fname2index = new HashMap();
private List<RegressionTree> trees; private MultipleAdditiveTreesModel.RegressionTree createRegressionTree(Map<String, Object> map) {
MultipleAdditiveTreesModel.RegressionTree rt = new MultipleAdditiveTreesModel.RegressionTree();
if(map != null) {
SolrPluginUtils.invokeSetters(rt, map.entrySet());
} return rt;
} private MultipleAdditiveTreesModel.RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
MultipleAdditiveTreesModel.RegressionTreeNode rtn = new MultipleAdditiveTreesModel.RegressionTreeNode();
if(map != null) {
SolrPluginUtils.invokeSetters(rtn, map.entrySet());
} return rtn;
} public void setTrees(Object trees) {
this.trees = new ArrayList();
Iterator var2 = ((List)trees).iterator(); while(var2.hasNext()) {
Object o = var2.next();
MultipleAdditiveTreesModel.RegressionTree rt = this.createRegressionTree((Map)o);
this.trees.add(rt);
}
} public void setTrees(List<RegressionTree> trees) {
this.trees = trees;
} public List<RegressionTree> getTrees() {
return this.trees;
} public MultipleAdditiveTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params); for(int i = 0; i < features.size(); ++i) {
String key = ((Feature)features.get(i)).getName();
this.fname2index.put(key, Integer.valueOf(i));//特征名:索引
} } public void validate() throws ModelException {
super.validate();
if(this.trees == null) {
throw new ModelException("no trees declared for model " + this.name);
} else {
Iterator var1 = this.trees.iterator(); while(var1.hasNext()) {
MultipleAdditiveTreesModel.RegressionTree tree = (MultipleAdditiveTreesModel.RegressionTree)var1.next();
tree.validate();
} }
} public float score(float[] modelFeatureValuesNormalized) {
float score = 0.0F; MultipleAdditiveTreesModel.RegressionTree t;
for(Iterator var3 = this.trees.iterator(); var3.hasNext(); score += t.score(modelFeatureValuesNormalized)) {
t = (MultipleAdditiveTreesModel.RegressionTree)var3.next();
} return score;
} public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
float[] fv = new float[featureExplanations.size()];
int index = 0; for(Iterator details = featureExplanations.iterator(); details.hasNext(); ++index) {
Explanation featureExplain = (Explanation)details.next();
fv[index] = featureExplain.getValue();
} ArrayList var12 = new ArrayList();
index = 0; for(Iterator var13 = this.trees.iterator(); var13.hasNext(); ++index) {
MultipleAdditiveTreesModel.RegressionTree t = (MultipleAdditiveTreesModel.RegressionTree)var13.next();
float score = t.score(fv);
Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv), new Explanation[0]);
var12.add(p);
} return Explanation.match(finalScore, this.toString() + " model applied to features, sum of:", var12);
} public String toString() {
StringBuilder sb = new StringBuilder(this.getClass().getSimpleName());
sb.append("(name=").append(this.getName());
sb.append(",trees=["); for(int ii = 0; ii < this.trees.size(); ++ii) {
if(ii > 0) {
sb.append(',');
} sb.append(this.trees.get(ii));
} sb.append("])");
return sb.toString();
} public class RegressionTree {
private Float weight;
private MultipleAdditiveTreesModel.RegressionTreeNode root; public void setWeight(float weight) {
this.weight = new Float(weight);
} public void setWeight(String weight) {
this.weight = new Float(weight);
} public float getWeight() {
return this.weight;
} public void setRoot(Object root) {
this.root = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)root);
} public RegressionTreeNode getRoot() {
return this.root;
} public float score(float[] featureVector) {
return this.weight.floatValue() * this.root.score(featureVector);
} public String explain(float[] featureVector) {
return this.root.explain(featureVector);
} public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("(weight=").append(this.weight);
sb.append(",root=").append(this.root);
sb.append(")");
return sb.toString();
} public RegressionTree() {
} public void validate() throws ModelException {
if(this.weight == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a weight");
} else if(this.root == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a tree");
} else {
this.root.validate();
}
}
} public class RegressionTreeNode {
private static final float NODE_SPLIT_SLACK = 1.0E-6F;
private float value = 0.0F;
private String feature;
private int featureIndex = -1;
private Float threshold;
private MultipleAdditiveTreesModel.RegressionTreeNode left;
private MultipleAdditiveTreesModel.RegressionTreeNode right; public void setValue(float value) {
this.value = value;
} public void setValue(String value) {
this.value = Float.parseFloat(value);
} public void setFeature(String feature) {
this.feature = feature;
Integer idx = (Integer)MultipleAdditiveTreesModel.this.fname2index.get(this.feature);
this.featureIndex = idx == null?-1:idx.intValue();
} public int getFeatureIndex() {
return this.featureIndex;
} public void setThreshold(float threshold) {
this.threshold = Float.valueOf(threshold + 1.0E-6F);
} public void setThreshold(String threshold) {
this.threshold = Float.valueOf(Float.parseFloat(threshold) + 1.0E-6F);
} public float getThreshold() {
return this.threshold;
} public void setLeft(Object left) {
this.left = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)left);
} public RegressionTreeNode getLeft() {
return this.left;
} public void setRight(Object right) {
this.right = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)right);
} public RegressionTreeNode getRight() {
return this.right;
} public boolean isLeaf() {
return this.feature == null;
} public float score(float[] featureVector) {
return this.isLeaf()?this.value:(this.featureIndex >= 0 && this.featureIndex < featureVector.length?(featureVector[this.featureIndex] <= this.threshold.floatValue()?this.left.score(featureVector):this.right.score(featureVector)):0.0F);
} public String explain(float[] featureVector) {
if(this.isLeaf()) {
return "val: " + this.value;
} else if(this.featureIndex >= 0 && this.featureIndex < featureVector.length) {
String rval;
if(featureVector[this.featureIndex] <= this.threshold.floatValue()) {
rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " <= " + this.threshold + ", Go Left | ";
return rval + this.left.explain(featureVector);
} else {
rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " > " + this.threshold + ", Go Right | ";
return rval + this.right.explain(featureVector);
}
} else {
return "\'" + this.feature + "\' does not exist in FV, Return Zero";
}
} public String toString() {
StringBuilder sb = new StringBuilder();
if(this.isLeaf()) {
sb.append(this.value);
} else {
sb.append("(feature=").append(this.feature);
sb.append(",threshold=").append(this.threshold.floatValue() - 1.0E-6F);
sb.append(",left=").append(this.left);
sb.append(",right=").append(this.right);
sb.append(')');
} return sb.toString();
} public RegressionTreeNode() {
} public void validate() throws ModelException {
if(this.isLeaf()) {
if(this.left != null || this.right != null) {
throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + this.left + " and right=" + this.right);
}
} else if(null == this.threshold) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
} else if(null == this.left) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
} else {
this.left.validate();
if(null == this.right) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
} else {
this.right.validate();
}
}
}
} }
 import org.apache.commons.lang.ArrayUtils;
import org.apache.lucene.util.CloseableThreadLocal;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.Normalizer; import java.util.*; public class QuickScorerTreesModel extends MultipleAdditiveTreesModel{ private static final long MAX_BITS = 0xFFFFFFFFFFFFFFFFL; // 64bits De Bruijn Sequence
// see: http://chessprogramming.wikispaces.com/DeBruijnsequence#Binary alphabet-B(2, 6)
private static final long HASH_BITS = 0x022fdd63cc95386dL;
private static final int[] hashTable = new int[64]; static {
long hash = HASH_BITS;
for (int i = 0; i < 64; ++i) {
hashTable[(int) (hash >>> 58)] = i;
hash <<= 1;
}
} /**
* Finds the index of rightmost bit with O(1) by using De Bruijn strategy.
*
* @param bits target bits (64bits)
* @see <a href="http://supertech.csail.mit.edu/papers/debruijn.pdf">http://supertech.csail.mit.edu/papers/debruijn.pdf</a>
*/
private static int findIndexOfRightMostBit(long bits) {
return hashTable[(int) (((bits & -bits) * HASH_BITS) >>> 58)];
} /**
* The number of trees of this model.
*/
private int treeNum; /**
* Weights of each tree.
*/
private float[] weights; /**
* List of all leaves of this model.
* We use tree instead of value to manage wide (i.e., more than 64 leaves) trees.
*/
private RegressionTreeNode[] leaves; /**
* Offsets of each leaf block correspond to each tree.
*/
private int[] leafOffsets; /**
* The number of conditions of this model.
*/
private int condNum; /**
* Thresholds of each condition.
* These thresholds are grouped by corresponding feature and each block is sorted by threshold values.
*/
private float[] thresholds; /**
* Corresponding featureIndex of each condition.
*/
private int[] featureIndexes; /**
* Offsets of each condition block correspond to each feature.
*/
private int[] condOffsets; /**
* Forward bitvectors of each condition which correspond to original additive trees.
*/
private long[] forwardBitVectors; /**
* Backward bitvectors of each condition which correspond to inverted additive trees.
*/
private long[] backwardBitVectors; /**
* Mappings from threasholdes index to tree indexes.
*/
private int[] treeIds; /**
* Bitvectors of each tree for calculating the score.
* We reuse bitvectors instance in each thread to prevent from re-allocating arrays.
*/
private CloseableThreadLocal<long[]> threadLocalTreeBitvectors = null; /**
* Boolean statistical tendency of this model.
* If conditions of the model tend to be false, we use inverted bitvectors for speeding up.
*/
private volatile float falseRatio = 0.5f; /**
* The decay factor for updating falseRatio in each evaluation step.
* This factor is used like "{@code ratio = preRatio * decay ratio * (1 - decay)}".
*/
private float falseRatioDecay = 0.99f; /**
* Comparable node cost for selecting leaf candidates.
*/
private static class NodeCost implements Comparable<NodeCost> {
private final int id;
private final int cost;
private final int depth;
private final int left;
private final int right; private NodeCost(int id, int cost, int depth, int left, int right) {
this.id = id;
this.cost = cost;
this.depth = depth;
this.left = left;
this.right = right;
} public int getId() {
return id;
} public int getLeft() {
return left;
} public int getRight() {
return right;
} /**
* Sorts by cost and depth.
* We prefer cheaper cost and deeper one.
*/
@Override
public int compareTo(NodeCost n) {
if (cost != n.cost) {
return Integer.compare(cost, n.cost);
} else if (depth != n.depth) {
return Integer.compare(n.depth, depth); // revere order
} else {
return Integer.compare(id, n.id);
}
}
} /**
* Comparable condition for constructing and sorting bitvectors.
*/
private static class Condition implements Comparable<Condition> {
private final int featureIndex;
private final float threshold;
private final int treeId;
private final long forwardBitvector;
private final long backwardBitvector; private Condition(int featureIndex, float threshold, int treeId, long forwardBitvector, long backwardBitvector) {
this.featureIndex = featureIndex;
this.threshold = threshold;
this.treeId = treeId;
this.forwardBitvector = forwardBitvector;
this.backwardBitvector = backwardBitvector;
} int getFeatureIndex() {
return featureIndex;
} float getThreshold() {
return threshold;
} int getTreeId() {
return treeId;
} long getForwardBitvector() {
return forwardBitvector;
} long getBackwardBitvector() {
return backwardBitvector;
} /*
* Sort by featureIndex and threshold with ascent order.
*/
@Override
public int compareTo(Condition c) {
if (featureIndex != c.featureIndex) {
return Integer.compare(featureIndex, c.featureIndex);
} else {
return Float.compare(threshold, c.threshold);
}
}
} /**
* Base class for traversing node with depth first order.
*/
private abstract static class Visitor {
private int nodeId = 0; int getNodeId() {
return nodeId;
} void visit(RegressionTree tree) {
nodeId = 0;
visit(tree.getRoot(), 0);
} private void visit(RegressionTreeNode node, int depth) {
if (node.isLeaf()) {
doVisitLeaf(node, depth);
} else {
// visit children first
visit(node.getLeft(), depth + 1);
visit(node.getRight(), depth + 1); doVisitBranch(node, depth);
}
++nodeId;
} protected abstract void doVisitLeaf(RegressionTreeNode node, int depth); protected abstract void doVisitBranch(RegressionTreeNode node, int depth);
} /**
* {@link Visitor} implementation for calculating the cost of each node.
*/
private static class NodeCostVisitor extends Visitor { private final Stack<AbstractMap.SimpleEntry<Integer, Integer>> idCostStack = new Stack<>();
private final PriorityQueue<NodeCost> nodeCostQueue = new PriorityQueue<>(); PriorityQueue<NodeCost> getNodeCostQueue() {
return nodeCostQueue;
} @Override
protected void doVisitLeaf(RegressionTreeNode node, int depth) {
nodeCostQueue.add(new NodeCost(getNodeId(), 0, depth, -1, -1));
idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), 1));
} @Override
protected void doVisitBranch(RegressionTreeNode node, int depth) {
// calculate the cost of this node from children costs
final AbstractMap.SimpleEntry<Integer, Integer> rightIdCost = idCostStack.pop();
final AbstractMap.SimpleEntry<Integer, Integer> leftIdCost = idCostStack.pop();
final int cost = Math.max(leftIdCost.getValue(), rightIdCost.getValue()); nodeCostQueue.add(new NodeCost(getNodeId(), cost, depth, leftIdCost.getKey(), rightIdCost.getKey()));
idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), cost + 1));
}
} /**
* {@link Visitor} implementation for extracting leaves and bitvectors.
*/
private static class QuickScorerVisitor extends Visitor { private final int treeId;
private final int leafNum;
private final Set<Integer> leafIdSet;
private final Set<Integer> skipIdSet; private final Stack<Long> bitsStack = new Stack<>();
private final List<RegressionTreeNode> leafList = new ArrayList<>();
private final List<Condition> conditionList = new ArrayList<>(); private QuickScorerVisitor(int treeId, int leafNum, Set<Integer> leafIdSet, Set<Integer> skipIdSet) {
this.treeId = treeId;
this.leafNum = leafNum;
this.leafIdSet = leafIdSet;
this.skipIdSet = skipIdSet;
} List<RegressionTreeNode> getLeafList() {
return leafList;
} List<Condition> getConditionList() {
return conditionList;
} private long reverseBits(long bits) {
long revBits = 0L;
long mask = (1L << (leafNum - 1));
for (int i = 0; i < leafNum; ++i) {
if ((bits & mask) != 0L) revBits |= (1L << i);
mask >>>= 1;
}
return revBits;
} @Override
protected void doVisitLeaf(RegressionTreeNode node, int depth) {
if (skipIdSet.contains(getNodeId())) return; bitsStack.add(1L << leafList.size()); // we use rightmost bit for detecting leaf
leafList.add(node);
} @Override
protected void doVisitBranch(RegressionTreeNode node, int depth) {
if (skipIdSet.contains(getNodeId())) return; if (leafIdSet.contains(getNodeId())) {
// an endpoint of QuickScorer
doVisitLeaf(node, depth);
return;
} final long rightBits = bitsStack.pop(); // bits of false branch
final long leftBits = bitsStack.pop(); // bits of true branch
/*
* NOTE:
* forwardBitvector = ~leftBits
* backwardBitvector = ~(reverse(rightBits))
*/
conditionList.add(
new Condition(node.getFeatureIndex(), node.getThreshold(), treeId, ~leftBits, ~reverseBits(rightBits)));
bitsStack.add(leftBits | rightBits);
}
} public QuickScorerTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
List<Feature> allFeatures, Map<String, Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
} /**
* Set falseRadioDecay parameter of this model.
*
* @param falseRatioDecay decay parameter for updating falseRatio
*/
public void setFalseRatioDecay(float falseRatioDecay) {
this.falseRatioDecay = falseRatioDecay;
} /**
* @see #setFalseRatioDecay(float)
*/
public void setFalseRatioDecay(String falseRatioDecay) {
this.falseRatioDecay = Float.parseFloat(falseRatioDecay);
} /**
* {@inheritDoc}
*/
@Override
public void validate() throws ModelException {
// validate trees before initializing QuickScorer
super.validate(); // initialize QuickScorer with validated trees
init(getTrees());
} /**
* Initializes quick scorer with given trees.
* 利用给定的树集初始化快速打分模型
*
* @param trees base additive trees model
*/
private void init(List<RegressionTree> trees) {
this.treeNum = trees.size();
this.weights = new float[trees.size()];
this.leafOffsets = new int[trees.size() + 1];
this.leafOffsets[0] = 0; // re-create tree bitvectors
if (this.threadLocalTreeBitvectors != null) this.threadLocalTreeBitvectors.close();
this.threadLocalTreeBitvectors = new CloseableThreadLocal<long[]>() {
@Override
protected long[] initialValue() {
return new long[treeNum];
}
}; int treeId = 0;
List<RegressionTreeNode> leafList = new ArrayList<>();
List<Condition> conditionList = new ArrayList<>();
for (RegressionTree tree : trees) {
// select up to 64 leaves from given tree
QuickScorerVisitor visitor = fitLeavesTo64bits(treeId, tree); // extract leaves and conditions with selected leaf candidates
visitor.visit(tree);
leafList.addAll(visitor.getLeafList());
conditionList.addAll(visitor.getConditionList()); // update weight, offset and treeId
this.weights[treeId] = tree.getWeight();
this.leafOffsets[treeId + 1] = this.leafOffsets[treeId] + visitor.getLeafList().size();
++treeId;
} // remap list to array for performance reason
this.leaves = leafList.toArray(new RegressionTreeNode[0]); // sort conditions by ascent order of featureIndex and threshold
Collections.sort(conditionList); // remap information of conditions
int idx = 0;
int preFeatureIndex = -1;
this.condNum = conditionList.size();
this.thresholds = new float[conditionList.size()];
this.forwardBitVectors = new long[conditionList.size()];
this.backwardBitVectors = new long[conditionList.size()];
this.treeIds = new int[conditionList.size()];
List<Integer> featureIndexList = new ArrayList<>();
List<Integer> condOffsetList = new ArrayList<>();
for (Condition condition : conditionList) {
this.thresholds[idx] = condition.threshold;
this.forwardBitVectors[idx] = condition.getForwardBitvector();
this.backwardBitVectors[idx] = condition.getBackwardBitvector();
this.treeIds[idx] = condition.getTreeId(); if (preFeatureIndex != condition.getFeatureIndex()) {
featureIndexList.add(condition.getFeatureIndex());
condOffsetList.add(idx);
preFeatureIndex = condition.getFeatureIndex();
} ++idx;
}
condOffsetList.add(conditionList.size()); // guard this.featureIndexes = ArrayUtils.toPrimitive(featureIndexList.toArray(new Integer[0]));
this.condOffsets = ArrayUtils.toPrimitive(condOffsetList.toArray(new Integer[0]));
} /**
* Checks costs of all nodes and select leaves up to 64.
*
* <p>NOTE:
* We can use {@link java.util.BitSet} instead of {@code long} to represent bitvectors longer than 64bits.
* However, this modification caused performance degradation in our experiments, and we decided to use this form.
*
* @param treeId index of given regression tree
* @param tree target regression tree
* @return QuickScorerVisitor with proper id sets
*/
private QuickScorerVisitor fitLeavesTo64bits(int treeId, RegressionTree tree) {
// calculate costs of all nodes
NodeCostVisitor nodeCostVisitor = new NodeCostVisitor();
nodeCostVisitor.visit(tree); // poll zero cost nodes (i.e., real leaves)
Set<Integer> leafIdSet = new HashSet<>();
Set<Integer> skipIdSet = new HashSet<>();
while (!nodeCostVisitor.getNodeCostQueue().isEmpty()) {
if (nodeCostVisitor.getNodeCostQueue().peek().cost > 0) break;
NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
leafIdSet.add(nodeCost.id);
} // merge leaves until the number of leaves reaches 64
while (leafIdSet.size() > 64) {
final NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
assert nodeCost.left >= 0 && nodeCost.right >= 0; // update leaves
leafIdSet.remove(nodeCost.left);
leafIdSet.remove(nodeCost.right);
leafIdSet.add(nodeCost.id); // register previous leaves to skip ids
skipIdSet.add(nodeCost.left);
skipIdSet.add(nodeCost.right);
} return new QuickScorerVisitor(treeId, leafIdSet.size(), leafIdSet, skipIdSet);
} /**
* {@inheritDoc}
*/
@Override
public float score(float[] modelFeatureValuesNormalized) {
assert threadLocalTreeBitvectors != null;
long[] treeBitvectors = threadLocalTreeBitvectors.get();
Arrays.fill(treeBitvectors, MAX_BITS); int falseNum = 0;
float score = 0.0f;
if (falseRatio <= 0.5) {
// use forward bitvectors
for (int i = 0; i < condOffsets.length - 1; ++i) {
final int featureIndex = featureIndexes[i];
for (int j = condOffsets[i]; j < condOffsets[i + 1]; ++j) {
if (modelFeatureValuesNormalized[featureIndex] <= thresholds[j]) break;
treeBitvectors[treeIds[j]] &= forwardBitVectors[j];
++falseNum;
}
} for (int i = 0; i < leafOffsets.length - 1; ++i) {
final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
score += weights[i] * leaves[leafOffsets[i] + leafIdx].score(modelFeatureValuesNormalized);
}
} else {
// use backward bitvectors
falseNum = condNum;
for (int i = 0; i < condOffsets.length - 1; ++i) {
final int featureIndex = featureIndexes[i];
for (int j = condOffsets[i + 1] - 1; j >= condOffsets[i]; --j) {
if (modelFeatureValuesNormalized[featureIndex] > thresholds[j]) break;
treeBitvectors[treeIds[j]] &= backwardBitVectors[j];
--falseNum;
}
} for (int i = 0; i < leafOffsets.length - 1; ++i) {
final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
score += weights[i] * leaves[leafOffsets[i + 1] - 1 - leafIdx].score(modelFeatureValuesNormalized);
}
} // update false ratio
falseRatio = falseRatio * falseRatioDecay + (falseNum * 1.0f / condNum) * (1.0f - falseRatioDecay);
return score;
} }
 import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.request.SolrQueryRequest;
import org.junit.Ignore;
import org.junit.Test; import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random; import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat; public class TestQuickScorerTreesModelBenchmark { /**
* 产生特征
* @param featureNum 特征个数
* @return
*/
private List<Feature> createDummyFeatures(int featureNum) {
List<Feature> features = new ArrayList<>();
for (int i = 0; i < featureNum; ++i) {
features.add(new Feature("fv_" + i, null) {
@Override
protected void validate() throws FeatureException { } @Override
public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, SolrQueryRequest request,
Query originalQuery, Map<String, String[]> efi) throws IOException {
return null;
} @Override
public LinkedHashMap<String, Object> paramsToMap() {
return null;
}
});
}
return features;
} private List<Normalizer> createDummyNormalizer(int featureNum) {
List<Normalizer> normalizers = new ArrayList<>();
for (int i = 0; i < featureNum; ++i) {
normalizers.add(new IdentityNormalizer());
}
return normalizers;
} /**
* 创建单棵树
* 递归调用自己
* @param leafNum 叶子个数
* @param features 特征
* @param rand 产生随机数
* @return
*/
private Map<String, Object> createRandomTree(int leafNum, List<Feature> features, Random rand) {
Map<String, Object> node = new HashMap<>();
if (leafNum == 1) {
// leaf
node.put("value", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
return node;
} // branch
node.put("feature", features.get(rand.nextInt(features.size())).getName());
node.put("threshold", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
node.put("left", createRandomTree(leafNum / 2, features, rand));
node.put("right", createRandomTree(leafNum - leafNum / 2, features, rand));
return node;
} /**
* 这里随机创建多棵树作为model测试
* @param treeNum 树的个数
* @param leafNum 叶子个数
* @param features 特征
* @param rand 产生随机数
* @return
*/
private List<Object> createRandomMultipleAdditiveTrees(int treeNum, int leafNum, List<Feature> features,
Random rand) {
List<Object> trees = new ArrayList<>();
for (int i = 0; i < treeNum; ++i) {
Map<String, Object> tree = new HashMap<>();
tree.put("weight", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 设置每棵树的学习率
tree.put("root", createRandomTree(leafNum, features, rand));
trees.add(tree);
}
return trees;
} /**
* 对比两个打分模型的分值是否一致
* @param featureNum 特征个数
* @param treeNum 树个数
* @param leafNum 叶子个数
* @param loopNum 样本个数
* @throws Exception
*/
private void compareScore(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
Random rand = new Random(0); List<Feature> features = createDummyFeatures(featureNum); //产生特征
List<Normalizer> norms = createDummyNormalizer(featureNum); //标准化 for (int i = 0; i < loopNum; ++i) {
List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand); MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
"dummy", features, null);
matModel.setTrees(trees);
matModel.validate(); QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
null);
qstModel.setTrees(trees);//设置提供的树模型
qstModel.validate();//对提供的树结构进行验证 float[] featureValues = new float[featureNum];
for (int j = 0; j < 100; ++j) {
for (int k = 0; k < featureNum; ++k) featureValues[k] = rand.nextFloat() - 0.5f; // [-0.5, 0.5) float expected = matModel.score(featureValues);
float actual = qstModel.score(featureValues);
assertThat(actual, is(expected));
//System.out.println("expected: " + expected + " actual: " + actual);
}
}
} /**
* 两个模型是否得分一致
*
* @throws Exception thrown if testcase failed to initialize models
*/
/*@Test
public void testAccuracy() throws Exception {
compareScore(25, 200, 32, 100);
//compareScore(19, 500, 31, 10000);
}*/ /**
* 对比两个打分模型打分的时间消耗
* @param featureNum 特征个数
* @param treeNum 树个数
* @param leafNum 叶子个数
* @param loopNum 样本个数
* @throws Exception
*/
private void compareTime(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
Random rand = new Random(0); //随机产生features
List<Feature> features = createDummyFeatures(featureNum);
//随机产生normalizer
List<Normalizer> norms = createDummyNormalizer(featureNum);
//随机创建trees
List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand); //初始化multiple additive trees model
MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
"dummy", features, null);
matModel.setTrees(trees);
matModel.validate(); //初始化quick scorer trees model
QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
null);
qstModel.setTrees(trees);
qstModel.validate(); //随机产生样本, loopNum * featureNum
float[][] featureValues = new float[loopNum][featureNum];
for (int i = 0; i < loopNum; ++i) {
for (int k = 0; k < featureNum; ++k) {
featureValues[i][k] = rand.nextFloat() * 2.0f - 1.0f; // [-1.0, 1.0)
}
} long start;
/*long matOpNsec = 0;
for (int i = 0; i < loopNum; ++i) {
start = System.nanoTime();
matModel.score(featureValues[i]);
matOpNsec += System.nanoTime() - start;
}
long qstOpNsec = 0;
for (int i = 0; i < loopNum; ++i) {
start = System.nanoTime();
qstModel.score(featureValues[i]);
qstOpNsec += System.nanoTime() - start;
}
System.out.println("MultipleAdditiveTreesModel : " + matOpNsec / 1000.0 / loopNum + " usec/op");
System.out.println("QuickScorerTreesModel : " + qstOpNsec / 1000.0 / loopNum + " usec/op");*/ long matOpNsec = 0;
start = System.currentTimeMillis();
for(int i = 0; i < loopNum; i++) {
matModel.score(featureValues[i]);
}
matOpNsec = System.currentTimeMillis() - start; long qstOpNsec = 0;
start = System.currentTimeMillis();
for(int i = 0; i < loopNum; i++) {
qstModel.score(featureValues[i]);
}
qstOpNsec = System.currentTimeMillis() - start; System.out.println("MultipleAdditiveTreesModel : " + matOpNsec); System.out.println("QuickScorerTreesModel : " + qstOpNsec); //assertThat(matOpNsec > qstOpNsec, is(true));
} /**
* 测试性能
* @throws Exception thrown if testcase failed to initialize models
*/ @Test
public void testPerformance() throws Exception {
//features,trees,leafs,samples
compareTime(20, 500, 61, 10000);
} }

关于Additive Ensembles of Regression Trees模型的快速打分预测的更多相关文章

  1. Gradient Boosted Regression Trees 2

    Gradient Boosted Regression Trees 2   Regularization GBRT provide three knobs to control overfitting ...

  2. 基于XGBoost模型的幸福度预测——阿里天池学习赛

    加载数据 加载的是完整版的数据 happiness_train_complete.csv . import numpy as np import pandas as pd import matplot ...

  3. Zbrush 4R7 P3中各类模型怎么快速隐藏

    在ZBrush®软件中除了遮罩功能可以对模型局部进行编辑外,我们还可以通过显示和隐藏来对模型的局部进行控制. 查看更多内容请直接前往:http://www.zbrushcn.com/jichu/xia ...

  4. css3弹性盒模型flex快速入门与上手(align-content与align-items)

    接着上文css3弹性盒模型flex快速入门与上手1继续,上文还剩下两个父容器的属性align-items和align-content. 一.align-content:多行的副轴对齐方式 含义 多行的 ...

  5. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  6. 机器学习——Java调用sklearn生成好的Logistic模型进行鸢尾花的预测

    机器学习是python语言的长处,而Java在web开发方面更具有优势,如何通过java来调用python中训练好的模型进行在线的预测呢?在java语言中去调用python构建好的模型主要有三种方法: ...

  7. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人。

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  8. Regression trees树回归 以及其他

    https://www.cnblogs.com/wuliytTaotao/p/10724118.html 选 weighted variance 最小的 但是weighted variance是怎么计 ...

  9. 关于PCB 3D 模型的快速导入方法

    altium designer中创建的3D library 只能查看3D效果,并没有其他功能,经测试在原理图编辑界面通过给元件添加 PCB 3D 并不能真正添加3D模型,这样添加根本没有效果(显示不出 ...

随机推荐

  1. EF 查询扩展

    using Microsoft.EntityFrameworkCore; using System; using System.Collections.Generic; using System.Da ...

  2. Sql Server--如何自动备份数据

    下面我来讲一下如何通过维护计划来实现完整备份+差异备份: (1)在SSMS的对象资源管理器中右击“维护计划”,选择“维护计划向导”,系统将弹出向导窗口,如图: 这里向导已经告诉我们维护计划到底能够干什 ...

  3. 弹窗插件 layer

    官方网站 http://layer.layui.com/ Github 地址 https://github.com/sentsin/layer //在这里面输入任何合法的js语句 layer.open ...

  4. kali破解ssh

    hydra,是一个非常好用的暴力破解工具,而且名字也很cool. 下面是官网上的介绍: AFP, Cisco AAA, Cisco auth, Cisco enable, CVS, Firebird, ...

  5. Laravel 查询数据按照时间分组

    首先取消严格模式: // config/database.php // 'strict' => true, // 严谨模式注释掉 查询构造器代码: //查询构造器部分代码 })->with ...

  6. 深入Java虚拟机之内存区域与内存溢出

    一.内存区域 Java虚拟机在执行Java程序的过程中会把他所管理的内存划分为若干个不同的数据区域.Java虚拟机规范将JVM所管理的内存分为以下几个运行时数据区:程序计数器.Java虚拟机栈.本地方 ...

  7. ASE19 团队项目 模型组 scrum report集合

    scrum report 链接 scrum1 report scrum2 report scrum3 report scrum4 report scrum5 report scrum6 report ...

  8. Linux中的sudoer详解

    目录 Linux中的sudo详解 一.引言 二.格式 三./etc/sudoers文件 四.sudoers文件讲解 五.其他 Linux中的sudo详解 一.引言 Liunx用户只有两类: 管理员用户 ...

  9. 注解【Annotation】、反射

    注解:Annotation是从JDK5.0开始引入的新技术.Annotation的作用:如果没有注解信息处理流程,则注解毫无意义)- 不是程序本身,可以对程序作出解释.(这一点,跟注释没什么区别)- ...

  10. ble ic

    ti cc25xxnordic nrf24xx nrf51xx nrf52xx Beken bk34xx