一.论文《QuickScorer:a Fast Algorithm to Rank Documents with Additive Ensembles of Regression Trees》是为了解决LTR模型的预测问题,如果LTR中的LambdaMart在生成模型时产生的树数和叶结点过多,在对样本打分预测时会遍历每棵树,这样在线上使用时效率较慢,这篇文章主要就是利用了bitvector方法加速打分预测。代码我找了很久没找到开源的,后来无意中在Solr ltr中看到被改动过了的源码,不过这个源码集成在solr中,这里暂时贴出来,后期再剥离出,集成到ranklib中,以便使用。

二.图片解说

1. Ensemble trees原始打分过程

像gbdt,lambdamart,xgboost或lightgbm等这样的集成树模型在打分预测阶段,比如来了一个样本,这个样本是vector形式输入到每一棵树中,然后在每棵树中像if else这样的过程走到或映射到每棵树的一个节点中,这个节点就是每棵树的打分,然后将每棵树的打分乘上学习率(shrinkage)加和就是此样本的预测分。

2.论文中提到的打分过程

A.为回归树中的每个分枝打上true和false标签

比如图中样本X=[0.2,1.1,0.2],在回归树的branch中判断X[0],X[1],X[2]的true和false,比如图中根结点X[1]<=1.0,但样本X[1]=1.1,所以是false(走左边是true,右边是false),这样将所有branch打上true和false标签(可以直接打上false标志,不用考虑true),后面需要用到所有的false branch。

B.为每个branch分配一个bitvector

这个bitvector中的"0"表示true leaves,比如"001111"表示6个叶结点中的最左边两个叶结点是候选节点。“110011”表示在右子树中true的结点只有中间两个,作为候选结点。

C.打分阶段

此阶段是最后的打分预测阶段,根据前几个图的过程,将所有branch为false的bitvector按位与操作,就会得出样本落在哪个叶结点上。比如图中的结果是"001101",最左边为1的便是最终的叶结点的编号,每个回归树都会这样操作得到预测值,乘上学习率(shrinkage)然后加和就会得到一个样本的预测值。

三.代码

 import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils; import java.util.*; public class MultipleAdditiveTreesModel extends LTRScoringModel { // 特征名:索引(从0开始)
private final HashMap<String, Integer> fname2index = new HashMap();
private List<RegressionTree> trees; private MultipleAdditiveTreesModel.RegressionTree createRegressionTree(Map<String, Object> map) {
MultipleAdditiveTreesModel.RegressionTree rt = new MultipleAdditiveTreesModel.RegressionTree();
if(map != null) {
SolrPluginUtils.invokeSetters(rt, map.entrySet());
} return rt;
} private MultipleAdditiveTreesModel.RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
MultipleAdditiveTreesModel.RegressionTreeNode rtn = new MultipleAdditiveTreesModel.RegressionTreeNode();
if(map != null) {
SolrPluginUtils.invokeSetters(rtn, map.entrySet());
} return rtn;
} public void setTrees(Object trees) {
this.trees = new ArrayList();
Iterator var2 = ((List)trees).iterator(); while(var2.hasNext()) {
Object o = var2.next();
MultipleAdditiveTreesModel.RegressionTree rt = this.createRegressionTree((Map)o);
this.trees.add(rt);
}
} public void setTrees(List<RegressionTree> trees) {
this.trees = trees;
} public List<RegressionTree> getTrees() {
return this.trees;
} public MultipleAdditiveTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params); for(int i = 0; i < features.size(); ++i) {
String key = ((Feature)features.get(i)).getName();
this.fname2index.put(key, Integer.valueOf(i));//特征名:索引
} } public void validate() throws ModelException {
super.validate();
if(this.trees == null) {
throw new ModelException("no trees declared for model " + this.name);
} else {
Iterator var1 = this.trees.iterator(); while(var1.hasNext()) {
MultipleAdditiveTreesModel.RegressionTree tree = (MultipleAdditiveTreesModel.RegressionTree)var1.next();
tree.validate();
} }
} public float score(float[] modelFeatureValuesNormalized) {
float score = 0.0F; MultipleAdditiveTreesModel.RegressionTree t;
for(Iterator var3 = this.trees.iterator(); var3.hasNext(); score += t.score(modelFeatureValuesNormalized)) {
t = (MultipleAdditiveTreesModel.RegressionTree)var3.next();
} return score;
} public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
float[] fv = new float[featureExplanations.size()];
int index = 0; for(Iterator details = featureExplanations.iterator(); details.hasNext(); ++index) {
Explanation featureExplain = (Explanation)details.next();
fv[index] = featureExplain.getValue();
} ArrayList var12 = new ArrayList();
index = 0; for(Iterator var13 = this.trees.iterator(); var13.hasNext(); ++index) {
MultipleAdditiveTreesModel.RegressionTree t = (MultipleAdditiveTreesModel.RegressionTree)var13.next();
float score = t.score(fv);
Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv), new Explanation[0]);
var12.add(p);
} return Explanation.match(finalScore, this.toString() + " model applied to features, sum of:", var12);
} public String toString() {
StringBuilder sb = new StringBuilder(this.getClass().getSimpleName());
sb.append("(name=").append(this.getName());
sb.append(",trees=["); for(int ii = 0; ii < this.trees.size(); ++ii) {
if(ii > 0) {
sb.append(',');
} sb.append(this.trees.get(ii));
} sb.append("])");
return sb.toString();
} public class RegressionTree {
private Float weight;
private MultipleAdditiveTreesModel.RegressionTreeNode root; public void setWeight(float weight) {
this.weight = new Float(weight);
} public void setWeight(String weight) {
this.weight = new Float(weight);
} public float getWeight() {
return this.weight;
} public void setRoot(Object root) {
this.root = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)root);
} public RegressionTreeNode getRoot() {
return this.root;
} public float score(float[] featureVector) {
return this.weight.floatValue() * this.root.score(featureVector);
} public String explain(float[] featureVector) {
return this.root.explain(featureVector);
} public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("(weight=").append(this.weight);
sb.append(",root=").append(this.root);
sb.append(")");
return sb.toString();
} public RegressionTree() {
} public void validate() throws ModelException {
if(this.weight == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a weight");
} else if(this.root == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a tree");
} else {
this.root.validate();
}
}
} public class RegressionTreeNode {
private static final float NODE_SPLIT_SLACK = 1.0E-6F;
private float value = 0.0F;
private String feature;
private int featureIndex = -1;
private Float threshold;
private MultipleAdditiveTreesModel.RegressionTreeNode left;
private MultipleAdditiveTreesModel.RegressionTreeNode right; public void setValue(float value) {
this.value = value;
} public void setValue(String value) {
this.value = Float.parseFloat(value);
} public void setFeature(String feature) {
this.feature = feature;
Integer idx = (Integer)MultipleAdditiveTreesModel.this.fname2index.get(this.feature);
this.featureIndex = idx == null?-1:idx.intValue();
} public int getFeatureIndex() {
return this.featureIndex;
} public void setThreshold(float threshold) {
this.threshold = Float.valueOf(threshold + 1.0E-6F);
} public void setThreshold(String threshold) {
this.threshold = Float.valueOf(Float.parseFloat(threshold) + 1.0E-6F);
} public float getThreshold() {
return this.threshold;
} public void setLeft(Object left) {
this.left = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)left);
} public RegressionTreeNode getLeft() {
return this.left;
} public void setRight(Object right) {
this.right = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)right);
} public RegressionTreeNode getRight() {
return this.right;
} public boolean isLeaf() {
return this.feature == null;
} public float score(float[] featureVector) {
return this.isLeaf()?this.value:(this.featureIndex >= 0 && this.featureIndex < featureVector.length?(featureVector[this.featureIndex] <= this.threshold.floatValue()?this.left.score(featureVector):this.right.score(featureVector)):0.0F);
} public String explain(float[] featureVector) {
if(this.isLeaf()) {
return "val: " + this.value;
} else if(this.featureIndex >= 0 && this.featureIndex < featureVector.length) {
String rval;
if(featureVector[this.featureIndex] <= this.threshold.floatValue()) {
rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " <= " + this.threshold + ", Go Left | ";
return rval + this.left.explain(featureVector);
} else {
rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " > " + this.threshold + ", Go Right | ";
return rval + this.right.explain(featureVector);
}
} else {
return "\'" + this.feature + "\' does not exist in FV, Return Zero";
}
} public String toString() {
StringBuilder sb = new StringBuilder();
if(this.isLeaf()) {
sb.append(this.value);
} else {
sb.append("(feature=").append(this.feature);
sb.append(",threshold=").append(this.threshold.floatValue() - 1.0E-6F);
sb.append(",left=").append(this.left);
sb.append(",right=").append(this.right);
sb.append(')');
} return sb.toString();
} public RegressionTreeNode() {
} public void validate() throws ModelException {
if(this.isLeaf()) {
if(this.left != null || this.right != null) {
throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + this.left + " and right=" + this.right);
}
} else if(null == this.threshold) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
} else if(null == this.left) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
} else {
this.left.validate();
if(null == this.right) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
} else {
this.right.validate();
}
}
}
} }
 import org.apache.commons.lang.ArrayUtils;
import org.apache.lucene.util.CloseableThreadLocal;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.Normalizer; import java.util.*; public class QuickScorerTreesModel extends MultipleAdditiveTreesModel{ private static final long MAX_BITS = 0xFFFFFFFFFFFFFFFFL; // 64bits De Bruijn Sequence
// see: http://chessprogramming.wikispaces.com/DeBruijnsequence#Binary alphabet-B(2, 6)
private static final long HASH_BITS = 0x022fdd63cc95386dL;
private static final int[] hashTable = new int[64]; static {
long hash = HASH_BITS;
for (int i = 0; i < 64; ++i) {
hashTable[(int) (hash >>> 58)] = i;
hash <<= 1;
}
} /**
* Finds the index of rightmost bit with O(1) by using De Bruijn strategy.
*
* @param bits target bits (64bits)
* @see <a href="http://supertech.csail.mit.edu/papers/debruijn.pdf">http://supertech.csail.mit.edu/papers/debruijn.pdf</a>
*/
private static int findIndexOfRightMostBit(long bits) {
return hashTable[(int) (((bits & -bits) * HASH_BITS) >>> 58)];
} /**
* The number of trees of this model.
*/
private int treeNum; /**
* Weights of each tree.
*/
private float[] weights; /**
* List of all leaves of this model.
* We use tree instead of value to manage wide (i.e., more than 64 leaves) trees.
*/
private RegressionTreeNode[] leaves; /**
* Offsets of each leaf block correspond to each tree.
*/
private int[] leafOffsets; /**
* The number of conditions of this model.
*/
private int condNum; /**
* Thresholds of each condition.
* These thresholds are grouped by corresponding feature and each block is sorted by threshold values.
*/
private float[] thresholds; /**
* Corresponding featureIndex of each condition.
*/
private int[] featureIndexes; /**
* Offsets of each condition block correspond to each feature.
*/
private int[] condOffsets; /**
* Forward bitvectors of each condition which correspond to original additive trees.
*/
private long[] forwardBitVectors; /**
* Backward bitvectors of each condition which correspond to inverted additive trees.
*/
private long[] backwardBitVectors; /**
* Mappings from threasholdes index to tree indexes.
*/
private int[] treeIds; /**
* Bitvectors of each tree for calculating the score.
* We reuse bitvectors instance in each thread to prevent from re-allocating arrays.
*/
private CloseableThreadLocal<long[]> threadLocalTreeBitvectors = null; /**
* Boolean statistical tendency of this model.
* If conditions of the model tend to be false, we use inverted bitvectors for speeding up.
*/
private volatile float falseRatio = 0.5f; /**
* The decay factor for updating falseRatio in each evaluation step.
* This factor is used like "{@code ratio = preRatio * decay ratio * (1 - decay)}".
*/
private float falseRatioDecay = 0.99f; /**
* Comparable node cost for selecting leaf candidates.
*/
private static class NodeCost implements Comparable<NodeCost> {
private final int id;
private final int cost;
private final int depth;
private final int left;
private final int right; private NodeCost(int id, int cost, int depth, int left, int right) {
this.id = id;
this.cost = cost;
this.depth = depth;
this.left = left;
this.right = right;
} public int getId() {
return id;
} public int getLeft() {
return left;
} public int getRight() {
return right;
} /**
* Sorts by cost and depth.
* We prefer cheaper cost and deeper one.
*/
@Override
public int compareTo(NodeCost n) {
if (cost != n.cost) {
return Integer.compare(cost, n.cost);
} else if (depth != n.depth) {
return Integer.compare(n.depth, depth); // revere order
} else {
return Integer.compare(id, n.id);
}
}
} /**
* Comparable condition for constructing and sorting bitvectors.
*/
private static class Condition implements Comparable<Condition> {
private final int featureIndex;
private final float threshold;
private final int treeId;
private final long forwardBitvector;
private final long backwardBitvector; private Condition(int featureIndex, float threshold, int treeId, long forwardBitvector, long backwardBitvector) {
this.featureIndex = featureIndex;
this.threshold = threshold;
this.treeId = treeId;
this.forwardBitvector = forwardBitvector;
this.backwardBitvector = backwardBitvector;
} int getFeatureIndex() {
return featureIndex;
} float getThreshold() {
return threshold;
} int getTreeId() {
return treeId;
} long getForwardBitvector() {
return forwardBitvector;
} long getBackwardBitvector() {
return backwardBitvector;
} /*
* Sort by featureIndex and threshold with ascent order.
*/
@Override
public int compareTo(Condition c) {
if (featureIndex != c.featureIndex) {
return Integer.compare(featureIndex, c.featureIndex);
} else {
return Float.compare(threshold, c.threshold);
}
}
} /**
* Base class for traversing node with depth first order.
*/
private abstract static class Visitor {
private int nodeId = 0; int getNodeId() {
return nodeId;
} void visit(RegressionTree tree) {
nodeId = 0;
visit(tree.getRoot(), 0);
} private void visit(RegressionTreeNode node, int depth) {
if (node.isLeaf()) {
doVisitLeaf(node, depth);
} else {
// visit children first
visit(node.getLeft(), depth + 1);
visit(node.getRight(), depth + 1); doVisitBranch(node, depth);
}
++nodeId;
} protected abstract void doVisitLeaf(RegressionTreeNode node, int depth); protected abstract void doVisitBranch(RegressionTreeNode node, int depth);
} /**
* {@link Visitor} implementation for calculating the cost of each node.
*/
private static class NodeCostVisitor extends Visitor { private final Stack<AbstractMap.SimpleEntry<Integer, Integer>> idCostStack = new Stack<>();
private final PriorityQueue<NodeCost> nodeCostQueue = new PriorityQueue<>(); PriorityQueue<NodeCost> getNodeCostQueue() {
return nodeCostQueue;
} @Override
protected void doVisitLeaf(RegressionTreeNode node, int depth) {
nodeCostQueue.add(new NodeCost(getNodeId(), 0, depth, -1, -1));
idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), 1));
} @Override
protected void doVisitBranch(RegressionTreeNode node, int depth) {
// calculate the cost of this node from children costs
final AbstractMap.SimpleEntry<Integer, Integer> rightIdCost = idCostStack.pop();
final AbstractMap.SimpleEntry<Integer, Integer> leftIdCost = idCostStack.pop();
final int cost = Math.max(leftIdCost.getValue(), rightIdCost.getValue()); nodeCostQueue.add(new NodeCost(getNodeId(), cost, depth, leftIdCost.getKey(), rightIdCost.getKey()));
idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), cost + 1));
}
} /**
* {@link Visitor} implementation for extracting leaves and bitvectors.
*/
private static class QuickScorerVisitor extends Visitor { private final int treeId;
private final int leafNum;
private final Set<Integer> leafIdSet;
private final Set<Integer> skipIdSet; private final Stack<Long> bitsStack = new Stack<>();
private final List<RegressionTreeNode> leafList = new ArrayList<>();
private final List<Condition> conditionList = new ArrayList<>(); private QuickScorerVisitor(int treeId, int leafNum, Set<Integer> leafIdSet, Set<Integer> skipIdSet) {
this.treeId = treeId;
this.leafNum = leafNum;
this.leafIdSet = leafIdSet;
this.skipIdSet = skipIdSet;
} List<RegressionTreeNode> getLeafList() {
return leafList;
} List<Condition> getConditionList() {
return conditionList;
} private long reverseBits(long bits) {
long revBits = 0L;
long mask = (1L << (leafNum - 1));
for (int i = 0; i < leafNum; ++i) {
if ((bits & mask) != 0L) revBits |= (1L << i);
mask >>>= 1;
}
return revBits;
} @Override
protected void doVisitLeaf(RegressionTreeNode node, int depth) {
if (skipIdSet.contains(getNodeId())) return; bitsStack.add(1L << leafList.size()); // we use rightmost bit for detecting leaf
leafList.add(node);
} @Override
protected void doVisitBranch(RegressionTreeNode node, int depth) {
if (skipIdSet.contains(getNodeId())) return; if (leafIdSet.contains(getNodeId())) {
// an endpoint of QuickScorer
doVisitLeaf(node, depth);
return;
} final long rightBits = bitsStack.pop(); // bits of false branch
final long leftBits = bitsStack.pop(); // bits of true branch
/*
* NOTE:
* forwardBitvector = ~leftBits
* backwardBitvector = ~(reverse(rightBits))
*/
conditionList.add(
new Condition(node.getFeatureIndex(), node.getThreshold(), treeId, ~leftBits, ~reverseBits(rightBits)));
bitsStack.add(leftBits | rightBits);
}
} public QuickScorerTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
List<Feature> allFeatures, Map<String, Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
} /**
* Set falseRadioDecay parameter of this model.
*
* @param falseRatioDecay decay parameter for updating falseRatio
*/
public void setFalseRatioDecay(float falseRatioDecay) {
this.falseRatioDecay = falseRatioDecay;
} /**
* @see #setFalseRatioDecay(float)
*/
public void setFalseRatioDecay(String falseRatioDecay) {
this.falseRatioDecay = Float.parseFloat(falseRatioDecay);
} /**
* {@inheritDoc}
*/
@Override
public void validate() throws ModelException {
// validate trees before initializing QuickScorer
super.validate(); // initialize QuickScorer with validated trees
init(getTrees());
} /**
* Initializes quick scorer with given trees.
* 利用给定的树集初始化快速打分模型
*
* @param trees base additive trees model
*/
private void init(List<RegressionTree> trees) {
this.treeNum = trees.size();
this.weights = new float[trees.size()];
this.leafOffsets = new int[trees.size() + 1];
this.leafOffsets[0] = 0; // re-create tree bitvectors
if (this.threadLocalTreeBitvectors != null) this.threadLocalTreeBitvectors.close();
this.threadLocalTreeBitvectors = new CloseableThreadLocal<long[]>() {
@Override
protected long[] initialValue() {
return new long[treeNum];
}
}; int treeId = 0;
List<RegressionTreeNode> leafList = new ArrayList<>();
List<Condition> conditionList = new ArrayList<>();
for (RegressionTree tree : trees) {
// select up to 64 leaves from given tree
QuickScorerVisitor visitor = fitLeavesTo64bits(treeId, tree); // extract leaves and conditions with selected leaf candidates
visitor.visit(tree);
leafList.addAll(visitor.getLeafList());
conditionList.addAll(visitor.getConditionList()); // update weight, offset and treeId
this.weights[treeId] = tree.getWeight();
this.leafOffsets[treeId + 1] = this.leafOffsets[treeId] + visitor.getLeafList().size();
++treeId;
} // remap list to array for performance reason
this.leaves = leafList.toArray(new RegressionTreeNode[0]); // sort conditions by ascent order of featureIndex and threshold
Collections.sort(conditionList); // remap information of conditions
int idx = 0;
int preFeatureIndex = -1;
this.condNum = conditionList.size();
this.thresholds = new float[conditionList.size()];
this.forwardBitVectors = new long[conditionList.size()];
this.backwardBitVectors = new long[conditionList.size()];
this.treeIds = new int[conditionList.size()];
List<Integer> featureIndexList = new ArrayList<>();
List<Integer> condOffsetList = new ArrayList<>();
for (Condition condition : conditionList) {
this.thresholds[idx] = condition.threshold;
this.forwardBitVectors[idx] = condition.getForwardBitvector();
this.backwardBitVectors[idx] = condition.getBackwardBitvector();
this.treeIds[idx] = condition.getTreeId(); if (preFeatureIndex != condition.getFeatureIndex()) {
featureIndexList.add(condition.getFeatureIndex());
condOffsetList.add(idx);
preFeatureIndex = condition.getFeatureIndex();
} ++idx;
}
condOffsetList.add(conditionList.size()); // guard this.featureIndexes = ArrayUtils.toPrimitive(featureIndexList.toArray(new Integer[0]));
this.condOffsets = ArrayUtils.toPrimitive(condOffsetList.toArray(new Integer[0]));
} /**
* Checks costs of all nodes and select leaves up to 64.
*
* <p>NOTE:
* We can use {@link java.util.BitSet} instead of {@code long} to represent bitvectors longer than 64bits.
* However, this modification caused performance degradation in our experiments, and we decided to use this form.
*
* @param treeId index of given regression tree
* @param tree target regression tree
* @return QuickScorerVisitor with proper id sets
*/
private QuickScorerVisitor fitLeavesTo64bits(int treeId, RegressionTree tree) {
// calculate costs of all nodes
NodeCostVisitor nodeCostVisitor = new NodeCostVisitor();
nodeCostVisitor.visit(tree); // poll zero cost nodes (i.e., real leaves)
Set<Integer> leafIdSet = new HashSet<>();
Set<Integer> skipIdSet = new HashSet<>();
while (!nodeCostVisitor.getNodeCostQueue().isEmpty()) {
if (nodeCostVisitor.getNodeCostQueue().peek().cost > 0) break;
NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
leafIdSet.add(nodeCost.id);
} // merge leaves until the number of leaves reaches 64
while (leafIdSet.size() > 64) {
final NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
assert nodeCost.left >= 0 && nodeCost.right >= 0; // update leaves
leafIdSet.remove(nodeCost.left);
leafIdSet.remove(nodeCost.right);
leafIdSet.add(nodeCost.id); // register previous leaves to skip ids
skipIdSet.add(nodeCost.left);
skipIdSet.add(nodeCost.right);
} return new QuickScorerVisitor(treeId, leafIdSet.size(), leafIdSet, skipIdSet);
} /**
* {@inheritDoc}
*/
@Override
public float score(float[] modelFeatureValuesNormalized) {
assert threadLocalTreeBitvectors != null;
long[] treeBitvectors = threadLocalTreeBitvectors.get();
Arrays.fill(treeBitvectors, MAX_BITS); int falseNum = 0;
float score = 0.0f;
if (falseRatio <= 0.5) {
// use forward bitvectors
for (int i = 0; i < condOffsets.length - 1; ++i) {
final int featureIndex = featureIndexes[i];
for (int j = condOffsets[i]; j < condOffsets[i + 1]; ++j) {
if (modelFeatureValuesNormalized[featureIndex] <= thresholds[j]) break;
treeBitvectors[treeIds[j]] &= forwardBitVectors[j];
++falseNum;
}
} for (int i = 0; i < leafOffsets.length - 1; ++i) {
final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
score += weights[i] * leaves[leafOffsets[i] + leafIdx].score(modelFeatureValuesNormalized);
}
} else {
// use backward bitvectors
falseNum = condNum;
for (int i = 0; i < condOffsets.length - 1; ++i) {
final int featureIndex = featureIndexes[i];
for (int j = condOffsets[i + 1] - 1; j >= condOffsets[i]; --j) {
if (modelFeatureValuesNormalized[featureIndex] > thresholds[j]) break;
treeBitvectors[treeIds[j]] &= backwardBitVectors[j];
--falseNum;
}
} for (int i = 0; i < leafOffsets.length - 1; ++i) {
final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
score += weights[i] * leaves[leafOffsets[i + 1] - 1 - leafIdx].score(modelFeatureValuesNormalized);
}
} // update false ratio
falseRatio = falseRatio * falseRatioDecay + (falseNum * 1.0f / condNum) * (1.0f - falseRatioDecay);
return score;
} }
 import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.request.SolrQueryRequest;
import org.junit.Ignore;
import org.junit.Test; import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random; import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat; public class TestQuickScorerTreesModelBenchmark { /**
* 产生特征
* @param featureNum 特征个数
* @return
*/
private List<Feature> createDummyFeatures(int featureNum) {
List<Feature> features = new ArrayList<>();
for (int i = 0; i < featureNum; ++i) {
features.add(new Feature("fv_" + i, null) {
@Override
protected void validate() throws FeatureException { } @Override
public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, SolrQueryRequest request,
Query originalQuery, Map<String, String[]> efi) throws IOException {
return null;
} @Override
public LinkedHashMap<String, Object> paramsToMap() {
return null;
}
});
}
return features;
} private List<Normalizer> createDummyNormalizer(int featureNum) {
List<Normalizer> normalizers = new ArrayList<>();
for (int i = 0; i < featureNum; ++i) {
normalizers.add(new IdentityNormalizer());
}
return normalizers;
} /**
* 创建单棵树
* 递归调用自己
* @param leafNum 叶子个数
* @param features 特征
* @param rand 产生随机数
* @return
*/
private Map<String, Object> createRandomTree(int leafNum, List<Feature> features, Random rand) {
Map<String, Object> node = new HashMap<>();
if (leafNum == 1) {
// leaf
node.put("value", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
return node;
} // branch
node.put("feature", features.get(rand.nextInt(features.size())).getName());
node.put("threshold", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
node.put("left", createRandomTree(leafNum / 2, features, rand));
node.put("right", createRandomTree(leafNum - leafNum / 2, features, rand));
return node;
} /**
* 这里随机创建多棵树作为model测试
* @param treeNum 树的个数
* @param leafNum 叶子个数
* @param features 特征
* @param rand 产生随机数
* @return
*/
private List<Object> createRandomMultipleAdditiveTrees(int treeNum, int leafNum, List<Feature> features,
Random rand) {
List<Object> trees = new ArrayList<>();
for (int i = 0; i < treeNum; ++i) {
Map<String, Object> tree = new HashMap<>();
tree.put("weight", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 设置每棵树的学习率
tree.put("root", createRandomTree(leafNum, features, rand));
trees.add(tree);
}
return trees;
} /**
* 对比两个打分模型的分值是否一致
* @param featureNum 特征个数
* @param treeNum 树个数
* @param leafNum 叶子个数
* @param loopNum 样本个数
* @throws Exception
*/
private void compareScore(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
Random rand = new Random(0); List<Feature> features = createDummyFeatures(featureNum); //产生特征
List<Normalizer> norms = createDummyNormalizer(featureNum); //标准化 for (int i = 0; i < loopNum; ++i) {
List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand); MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
"dummy", features, null);
matModel.setTrees(trees);
matModel.validate(); QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
null);
qstModel.setTrees(trees);//设置提供的树模型
qstModel.validate();//对提供的树结构进行验证 float[] featureValues = new float[featureNum];
for (int j = 0; j < 100; ++j) {
for (int k = 0; k < featureNum; ++k) featureValues[k] = rand.nextFloat() - 0.5f; // [-0.5, 0.5) float expected = matModel.score(featureValues);
float actual = qstModel.score(featureValues);
assertThat(actual, is(expected));
//System.out.println("expected: " + expected + " actual: " + actual);
}
}
} /**
* 两个模型是否得分一致
*
* @throws Exception thrown if testcase failed to initialize models
*/
/*@Test
public void testAccuracy() throws Exception {
compareScore(25, 200, 32, 100);
//compareScore(19, 500, 31, 10000);
}*/ /**
* 对比两个打分模型打分的时间消耗
* @param featureNum 特征个数
* @param treeNum 树个数
* @param leafNum 叶子个数
* @param loopNum 样本个数
* @throws Exception
*/
private void compareTime(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
Random rand = new Random(0); //随机产生features
List<Feature> features = createDummyFeatures(featureNum);
//随机产生normalizer
List<Normalizer> norms = createDummyNormalizer(featureNum);
//随机创建trees
List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand); //初始化multiple additive trees model
MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
"dummy", features, null);
matModel.setTrees(trees);
matModel.validate(); //初始化quick scorer trees model
QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
null);
qstModel.setTrees(trees);
qstModel.validate(); //随机产生样本, loopNum * featureNum
float[][] featureValues = new float[loopNum][featureNum];
for (int i = 0; i < loopNum; ++i) {
for (int k = 0; k < featureNum; ++k) {
featureValues[i][k] = rand.nextFloat() * 2.0f - 1.0f; // [-1.0, 1.0)
}
} long start;
/*long matOpNsec = 0;
for (int i = 0; i < loopNum; ++i) {
start = System.nanoTime();
matModel.score(featureValues[i]);
matOpNsec += System.nanoTime() - start;
}
long qstOpNsec = 0;
for (int i = 0; i < loopNum; ++i) {
start = System.nanoTime();
qstModel.score(featureValues[i]);
qstOpNsec += System.nanoTime() - start;
}
System.out.println("MultipleAdditiveTreesModel : " + matOpNsec / 1000.0 / loopNum + " usec/op");
System.out.println("QuickScorerTreesModel : " + qstOpNsec / 1000.0 / loopNum + " usec/op");*/ long matOpNsec = 0;
start = System.currentTimeMillis();
for(int i = 0; i < loopNum; i++) {
matModel.score(featureValues[i]);
}
matOpNsec = System.currentTimeMillis() - start; long qstOpNsec = 0;
start = System.currentTimeMillis();
for(int i = 0; i < loopNum; i++) {
qstModel.score(featureValues[i]);
}
qstOpNsec = System.currentTimeMillis() - start; System.out.println("MultipleAdditiveTreesModel : " + matOpNsec); System.out.println("QuickScorerTreesModel : " + qstOpNsec); //assertThat(matOpNsec > qstOpNsec, is(true));
} /**
* 测试性能
* @throws Exception thrown if testcase failed to initialize models
*/ @Test
public void testPerformance() throws Exception {
//features,trees,leafs,samples
compareTime(20, 500, 61, 10000);
} }

关于Additive Ensembles of Regression Trees模型的快速打分预测的更多相关文章

  1. Gradient Boosted Regression Trees 2

    Gradient Boosted Regression Trees 2   Regularization GBRT provide three knobs to control overfitting ...

  2. 基于XGBoost模型的幸福度预测——阿里天池学习赛

    加载数据 加载的是完整版的数据 happiness_train_complete.csv . import numpy as np import pandas as pd import matplot ...

  3. Zbrush 4R7 P3中各类模型怎么快速隐藏

    在ZBrush®软件中除了遮罩功能可以对模型局部进行编辑外,我们还可以通过显示和隐藏来对模型的局部进行控制. 查看更多内容请直接前往:http://www.zbrushcn.com/jichu/xia ...

  4. css3弹性盒模型flex快速入门与上手(align-content与align-items)

    接着上文css3弹性盒模型flex快速入门与上手1继续,上文还剩下两个父容器的属性align-items和align-content. 一.align-content:多行的副轴对齐方式 含义 多行的 ...

  5. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  6. 机器学习——Java调用sklearn生成好的Logistic模型进行鸢尾花的预测

    机器学习是python语言的长处,而Java在web开发方面更具有优势,如何通过java来调用python中训练好的模型进行在线的预测呢?在java语言中去调用python构建好的模型主要有三种方法: ...

  7. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人。

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  8. Regression trees树回归 以及其他

    https://www.cnblogs.com/wuliytTaotao/p/10724118.html 选 weighted variance 最小的 但是weighted variance是怎么计 ...

  9. 关于PCB 3D 模型的快速导入方法

    altium designer中创建的3D library 只能查看3D效果,并没有其他功能,经测试在原理图编辑界面通过给元件添加 PCB 3D 并不能真正添加3D模型,这样添加根本没有效果(显示不出 ...

随机推荐

  1. 怎样在 Vue 里面使用自定义事件将子组件的数据传回给父组件?

    首先, Vue 里面的组件之间的数据流动是 单向 的, 数据可以从父组件传递给子组件, 但不能从子组件传递给父组件, 因为组件和组件之间是 隔离 的. 就像两个嵌套的 黑盒子 . 能通过 props ...

  2. [js]$.ajax标准写法

    $.ajax({     url:"http://www.microsoft.com",    //请求的url地址     dataType:"json",  ...

  3. 在Mysql中使用索引

    MySQL查询的优化是个老生常谈的问题,方法更是多种多样,其中最直接的就是创建索引. 这里通过一个简单的demo来实际用一下索引,看看索引在百万级别查询中速率的提升效果如何 所需数据可以从我前面的一篇 ...

  4. Web自动化测试中的接口测试

    1.2.3 接口可测性分析 接口显而易见要比UI简单的都,只需要知道协议和参数即可完成一次请求,从自动化测试实施难易程度来看,有以下几个特征: 1)驱动执行接口的自动化成本不高:HTTP,RPC,SO ...

  5. Delphi 建立ODBC数据源

    樊伟胜

  6. laravel5.8 IoC 容器

    网上 对容器的解释有很多,这里只是记录,搬运! 1.简单理解: 2019-10-10 11:24:09 解析 lavarel 容器 IoC 容器 作用 就是 “解耦” .“依赖注入(DI) IoC 容 ...

  7. visio连接线随形状移动自动伸缩

    粘附可保持形状和连接线彼此依附.粘附打开时,可在移动形状时保持连接线跟着一起移动.粘附关闭时,移动形状时连接线将不会跟着移动. 1.在“视图”选项卡上的“视觉帮助​​”组中,单击对话框启动器 . 2. ...

  8. SVN搭建以及客户端使用

    第1章 CentOS下搭建SVN服务器 1.1 SVN简介 SVN是Subversion的简称,是一个开放源代码的版本控制系统,相较于RCS.CVS,它采用了分支管理系统,它的设计目标就是取代CVS. ...

  9. LNMP安装与配置之MySQL

    MySQL 是最流行的关系型数据库管理系统之一,今天的安装是在CentOS7环境下进行安装,安装的版本是MySQL5.7,有需要别的版本可点击  官网. 一.安装 1.配置YUM源 # 下载mysql ...

  10. Ubuntu18.04安装破解版MATLAB2018b

    参考链接: Ubuntu 18.04安装破解Matlab 2018b及创建快捷方式的详细教程(附图) Linux下安装JDK遇到的问题之cp: 无法获取"jdk-8u191-linux-i5 ...