Alink漫谈(十三) :在线学习算法FTRL 之 具体实现

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文和上文一起介绍了在线学习算法 FTRL 在Alink中是如何实现的,希望对大家有所帮助。

0x01 回顾

书接上回 Alink漫谈(十二) :在线学习算法FTRL 之 整体设计 。到目前为止,已经处理完毕输入,接下来就是在线训练。训练优化的主要目标是找到一个方向,参数朝这个方向移动之后使得损失函数的值能够减小,这个方向往往由一阶偏导或者二阶偏导各种组合求得。

为了让大家更好理解,我们再次贴出整体流程图:

0x02 在线训练

在线训练主要逻辑是:

  • 1)加载初始化模型到 dataBridge;dataBridge = DirectReader.collect(model);
  • 2)获取相关参数。比如vectorSize默认是30000,是否 hasInterceptItem;
  • 3)获取切分信息。splitInfo = getSplitInfo(featureSize, hasInterceptItem, parallelism); 下面马上会用到。
  • 4)切分高维向量。初始化数据做了特征哈希,会产生高维向量,这里需要进行切割。 initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize,vectorTrainIdx, featureIdx, labelIdx));
  • 5)构建一个 IterativeStream.ConnectedIterativeStreams iteration,这样会构建(或者说连接)两个数据流:反馈流和训练流;
  • 6)用iteration来构建迭代体 iterativeBody,其包括两部分:CalcTask,ReduceTask;
    • 6.1)CalcTask分成两个部分。flatMap1 是分布计算FTRL迭代需要的predict,flatMap2 是FTRL的更新参数部分;
    • 6.2)ReduceTask分为两个功能:“归并这些predict计算结果“ / ”如果满足条件则归并模型 & 向下游算子输出模型“;
  • 7)result = iterativeBody.filter;基本是以时间间隔为标准来判断(也可以认为是时间驱动),"时间未过期&向量有意义" 的数据将被发送回反馈数据流,继续迭代,回到步骤 6),进入flatMap2
  • 8)output = iterativeBody.filter;符合标准(时间过期了)的数据将跳出迭代,然后算法会调用WriteModel将LineModelData转换为多条Row,转发给下游operator(也就是在线预测阶段);即定时把模型更新给在线预测阶段

2.1 预置模型

前面说到,FTRL先要训练出一个逻辑回归模型作为FTRL算法的初始模型,这是为了系统冷启动的需要。

2.1.1 训练模型

具体逻辑回归模型设定/训练是 :

// train initial batch model
LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setWithIntercept(true)
.setMaxIter(10);
BatchOperator<?> initModel = featurePipelineModel.transform(trainBatchData).link(lr);

训练好之后,模型信息是DataSet类型,位于变量 BatchOperator<?> initModel之中,这是一个批处理算子。

2.1.2 加载模型

FtrlTrainStreamOp将initModel作为初始化参数。

FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)

在FtrlTrainStreamOp构造函数中会加载这个模型;

dataBridge = DirectReader.collect(initModel);

具体加载时通过MemoryDataBridge直接获取初始化模型DataSet中的数据。

public MemoryDataBridge generate(BatchOperator batchOperator, Params globalParams) {
return new MemoryDataBridge(batchOperator.collect());
}

2.2 分割高维向量

从前文可知,Alink的FTRL算法设置的特征向量维度是30000。所以算法第一步就是切分高维度向量,以便分布式计算。

String vecColName = "vec";
int numHashFeatures = 30000;

首先要获取切分信息,代码如下,就是将特征数目featureSize 除以 并行度parallelism,然后得到了每个task对应系数的初始位置。

private static int[] getSplitInfo(int featureSize, boolean hasInterceptItem, int parallelism) {
int coefSize = (hasInterceptItem) ? featureSize + 1 : featureSize;
int subSize = coefSize / parallelism;
int[] poses = new int[parallelism + 1];
int offset = coefSize % parallelism;
for (int i = 0; i < offset; ++i) {
poses[i + 1] = poses[i] + subSize + 1;
}
for (int i = offset; i < parallelism; ++i) {
poses[i + 1] = poses[i] + subSize;
}
return poses;
}
//程序运行时变量如下
featureSize = 30000
hasInterceptItem = true
parallelism = 4
coefSize = 30001
subSize = 7500
poses = {int[5]@11660}
0 = 0
1 = 7501
2 = 15001
3 = 22501
4 = 30001
offset = 1

然后根据切分信息对高维向量进行切割。

// Tuple5<SampleId, taskId, numSubVec, SubVec, label>
DataStream<Tuple5<Long, Integer, Integer, Vector, Object>> input
= initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize,
vectorTrainIdx, featureIdx, labelIdx))
.partitionCustom(new CustomBlockPartitioner(), 1);

具体切分在SplitVector.flatMap函数完成,结果就是把一个高维度向量分割给各个CalcTask

代码摘要如下:

public void flatMap(Row row, Collector<Tuple5<Long, Integer, Integer, Vector, Object>> collector) throws Exception {
long sampleId = counter;
counter += parallelism;
Vector vec;
if (vectorTrainIdx == -1) {
.....
} else {
// 输入row的第vectorTrainIdx个field就是那个30000大小的系数向量
vec = VectorUtil.getVector(row.getField(vectorTrainIdx));
} if (vec instanceof SparseVector) {
Map<Integer, Vector> tmpVec = new HashMap<>();
for (int i = 0; i < indices.length; ++i) {
.....
// 此处迭代完成后,tmpVec中就是task number个元素,每一个元素是分割好的系数向量。
}
for (Integer key : tmpVec.keySet()) {
//此处遍历,给后面所有CalcTask发送五元组数据。
collector.collect(Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx)));
}
} else {
......
}
}
}

这个Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx) )就是后面CalcTask的输入。

2.3 迭代训练

此处理论上有以下几个重点:

  • 预测方法:在每一轮t中,针对特征样本xt,以及迭代后(第一次则是给定初值)的模型参数wt,我们可以预测该样本的标记值:pt=σ(wt,xt),其中σ(a)=1/(1+exp(−a))是一个sigmoid函数。

  • 损失函数:对一个特征样本xt,其对应的标记为yt ∈ 0,1,则通过 logistic loss 来作为损失函数。

  • 迭代公式:我们的目的是使得损失函数尽可能的小,即可以采用极大似然估计来求解参数。首先求梯度,然后使用FTRL进行迭代。

伪代码思路大致如下

double p = learner.predict(x); //预测
learner.updateModel(x, p, y); //更新模型
double loss = LogLossEvalutor.calLogLoss(p, y); //计算损失
evalutor.addLogLoss(loss); //更新损失
totalLoss += loss;
trainedNum += 1;

具体实施上Alink有自己的特点和调整。

2.3.1 Flink Stream迭代功能

机器学习都需要迭代训练,Alink这里利用了Flink Stream的迭代功能

IterativeStream的实例是通过DataStream的iterate方法创建的˙。iterate方法存在两个重载形式:

  • 一种是无参的,表示不限定最大等待时间;
  • 一种提供一个长整型maxWaitTimeMillis参数,允许用户指定等待反馈边的下一个输入元素的最大时间间隔。

Alink选择了第二种。

在创建ConnectedIterativeStreams时候,用迭代流的初始输入作为第一个输入流,用反馈流作为第二个输入

每一种数据流(DataStream)都会有与之对应的流转换(StreamTransformation)。IterativeStream对应的转换是FeedbackTransformation。

迭代流(IterativeStream)对应的转换是反馈转换(FeedbackTransformation),它表示拓扑中的一个反馈点(也即迭代头)。一个反馈点包含一个输入边以及若干个反馈边,且Flink要求每个反馈边的并行度必须跟输入边的并行度一致,这一点在往该转换中加入反馈边时会进行校验。

当IterativeStream对象被构造时,FeedbackTransformation的实例会被创建并传递给DataStream的构造方法。

迭代的关闭是通过调用IterativeStream的实例方法closeWith来实现的。这个函数指定了某个流将成为迭代程序的结束,并且这个流将作为输入的第二部分(second input)被反馈回迭代。

2.3.2 迭代构建

对于Alink来说,迭代构建代码是:

// train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>
// feedback format = Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>
IterativeStream.ConnectedIterativeStreams<
Tuple5<Long, Integer, Integer, Vector, Object>,
Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
iteration = input.iterate(Long.MAX_VALUE)
.withFeedbackType(TypeInformation
.of(new TypeHint<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {})); // 即iteration是一个 IterativeStream.ConnectedIterativeStreams<...>
2.3.2.1 迭代的输入

从代码和注释可以看出,迭代的两种输入是:

  • train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>;这种其实是训练数据
  • Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>;这种其实是反馈数据,就是“迭代的反馈流”作为这个第二输入 (second input);
2.3.2.2 迭代的反馈

反馈流的设置是通过调用IterativeStream的实例方法closeWith来实现的。Alink这里是

DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
result = iterativeBody.filter(
return (t3.f0 > 0 && t3.f2 > 0); // 这里是省略版本代码
); iteration.closeWith(result);

前面已经提到过,result filter 的判断是 return (t3.f0 > 0 && t3.f2 > 0)如果满足条件,则说明时间未过期&向量有意义,所以此时应该反馈回去,继续训练

反馈流的格式是:

  • Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>;

2.3.3 迭代体 CalcTask / ReduceTask

迭代体由两部分构成:CalcTask / ReduceTask。

CalcTask每一个实例都拥有初始化模型dataBridge

DataStream iterativeBody = iteration.flatMap(
new CalcTask(dataBridge, splitInfo, getParams()))
2.3.3.1 迭代初始化

迭代是由 CalcTask.open 函数开始,主要做如下几件事

  • 设定各种参数,比如

    • 工作task个数,numWorkers = getRuntimeContext().getNumberOfParallelSubtasks();
    • 本task的id,workerId = getRuntimeContext().getIndexOfThisSubtask();
  • 读取初始化模型
    • List modelRows = DirectReader.directRead(dataBridge);
    • 把Row类型数据转换为线性模型 LinearModelData model = new LinearModelDataConverter().load(modelRows);
  • 读取本task对应的系数 coef[i - startIdx],这里就是把整个模型切分到numWorkers这么多的Task中,并行更新
  • 指定本task的开始时间 startTime = System.currentTimeMillis();
2.3.3.2 处理输入数据

CalcTask.flatMap1主要实现的是FTRL算法中的predict部分(注意,不是FTRL预测)。

解释:pt=σ(Xt⋅w)是LR的预测函数,求出pt的唯一目的是为了求出目标函数(在LR中采用交叉熵损失函数作为目标函数)对参数w的一阶导数g,gi=(pt−yt)xi。此步骤同样适用于FTRL优化其他目标函数,唯一的不同就是求次梯度g(次梯度是左导和右导之间的集合,函数可导--左导等于右导时,次梯度就等于一阶梯度)的方法不同。

函数的输入是 "训练输入数据",即SplitVector.flatMap的输出 ----> CalcCalcTask的输入。输入数据是一个五元组,其格式为 train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>;

有三点需要注意:

  • 是如果是第一次进入,则需要savedFristModel;
  • 这里是有输入就处理,然后立即输出(和flatMap2不同,flatMap2有输入就处理,但不是立即输出,而是当时间到期了再输出);
  • predict的实现:((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];

大家会说,不对!predict函数应该是 sigmoid = 1.0 / (1.0 + np.exp(-w.dot(x)))。是的,这里还没有做 sigmoid 操作。当ReduceTask做了聚合之后,会把聚合好的 p 反馈回迭代体,然后在 CalcTask.flatMap2 中才会做 sigmoid 操作

public void flatMap1(Tuple5<Long, Integer, Integer, Vector, Object> value,
Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out) throws Exception {
if (!savedFristModel) { //第一次进入需要存模型
out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
new DenseVector(coef), labelValues, -1.0, modelId++));
savedFristModel = true;
}
Long timeStamps = System.currentTimeMillis();
double wx = 0.0;
Long sampleId = value.f0;
Vector vec = value.f3;
if (vec instanceof SparseVector) {
int[] indices = ((SparseVector)vec).getIndices();
// 这里就是具体的Predict
for (int i = 0; i < indices.length; ++i) {
wx += ((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];
}
} else {
......
}
//处理了就输出
out.collect(Tuple7.of(sampleId, value.f1, value.f2, value.f3, value.f4, wx, timeStamps));
}
2.3.3.3 归并数据

ReduceTask.flatMap 负责归并数据。

public static class ReduceTask extends
RichFlatMapFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>,
Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> {
private int parallelism;
private int[] poses;
private Map<Long, List<Object>> buffer;
private Map<Long, List<Tuple2<Integer, DenseVector>>> models = new HashMap<>();
}

flatMap函数大致完成如下功能,即两种归并:

  • 为了输出模型使用。判断是否时间过期 if (value.f0 < 0),如果过期,则归并模型

    • 生成一个List<Tuple2<Integer, DenseVector>> model = models.get(value.f6); 以value.f6,即时间戳为key,插入到HashMap中。
    • 如果全部收集完成,则向下游算子输出模型,并且从HashMap中删除暂存的模型。
  • 为了归并predict使用。归并每个CalcTask计算的predict,形成一个 lable y;
    • 用 label y 更新 Tuple7的f5,即Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps> 中的 label,也就是预测的 y。
    • 给每个下游算子(就是每个CalcTask了,不过是作为flatMap2的输入)发送这个新Tuple7;

当具体用作输出模型使用时,其变量如下:

models = {HashMap@13258}  size = 1
{Long@13456} 1 -> {ArrayList@13678} size = 1
key = {Long@13456} 1
value = {ArrayList@13678} size = 1
0 = {Tuple2@13698} "(1,0.0 -8.244533295515879E-5 0.0 -1.103997743166529E-4 0.0 -3.336931546279811E-5....."
2.3.3.4 判断是否反馈

这个 filter result 是用来判断是否反馈的。这里t3.f0 是sampleId, t3.f2是subNum。

DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
result = iterativeBody.filter(
new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
@Override
public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> t3)
throws Exception {
// if t3.f0 > 0 && t3.f2 > 0 then feedback
return (t3.f0 > 0 && t3.f2 > 0);
}
});

对于 t3.f0,有两处代码会设置为负值。

  • 会在savedFirstModel 这里设置一次"-1";即

    if (!savedFristModel) {
    out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
    new DenseVector(coef), labelValues, -1.0, modelId++));
    savedFristModel = true;
    }
  • 也会在时间过期时候设置为 "-1"。

    if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) {
    startTime = System.currentTimeMillis();
    out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
    new DenseVector(coef), labelValues, -1.0, modelId++));
    }

对于 t3.f2,如果 subNum 大于零,说明在高维向量切分时候,是得到了有意义的数值。

因此 return (t3.f0 > 0 && t3.f2 > 0) 说明时间未过期&向量有意义,所以此时应该反馈回去,继续训练。

2.3.3.5 判断是否输出模型

这里是filter output。

value.f0 < 0 说明时间到期了,应该输出模型。

DataStream<Row> output = iterativeBody.filter(
new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
@Override
public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value)
{
/* if value.f0 small than 0, then output */
return value.f0 < 0;
}
}).flatMap(new WriteModel(labelType, getVectorCol(), featureCols, hasInterceptItem));
2.3.3.6 处理反馈数据/更新参数

CalcTask.flatMap2实际完成的是FTRL算法的其余部分,即更新参数部分。主要逻辑如下:

  • 计算时间间隔 timeInterval = System.currentTimeMillis() - value.f6;
  • 正式计算predict, p = 1 / (1 + Math.exp(-p)); 即sigmoid 操作;
  • 计算梯度 g = (p - label) * values[i] / Math.sqrt(timeInterval); 这里除以了时间间隔;
  • 更新参数;
  • 输入。注意,这里是有输入就处理,但 不是立即输出,而是累积参数,当时间到期了再输出,也就是做到了定期输出模型;

Logistic Regression 中,sigmoid函数是σ(a) = 1 / (1 + exp(-a)) ,预估 pt = σ(xt . wt), 则 LogLoss 函数是

\[l_t(w_t) = -y_t log(p_t) - (1-y_t)log(1-p_t)
\]

直接计算可以得到

\[∇l(w) = (σ(w.x_t) - y_t)x_t = (p_t - y_t)x_t
\]

具体 LR + FTRL 算法实现如下:

@Override
public void flatMap2(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value,
Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out)
throws Exception {
double p = value.f5;
// 计算时间间隔
long timeInterval = System.currentTimeMillis() - value.f6;
Vector vec = value.f3; /* eta */
// 正式计算predict,之前只是计算了一半,这里计算后半部,即
p = 1 / (1 + Math.exp(-p));
..... if (vec instanceof SparseVector) {
// 这里是更新参数
int[] indices = ((SparseVector)vec).getIndices();
double[] values = ((SparseVector)vec).getValues(); for (int i = 0; i < indices.length; ++i) {
// update zParam nParam
int id = indices[i] - startIdx;
// values[i]是xi
// 下面的计算基本和Google伪代码一致
double g = (p - label) * values[i] / Math.sqrt(timeInterval);
double sigma = (Math.sqrt(nParam[id] + g * g) - Math.sqrt(nParam[id])) / alpha;
zParam[id] += g - sigma * coef[id];
nParam[id] += g * g; // update model coefficient
if (Math.abs(zParam[id]) <= l1) {
coef[id] = 0.0;
} else {
coef[id] = ((zParam[id] < 0 ? -1 : 1) * l1 - zParam[id])
/ ((beta + Math.sqrt(nParam[id]) / alpha + l2));
}
}
} else {
......
} // 当时间到期了再输出,即做到了定期输出模型
if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) {
startTime = System.currentTimeMillis();
out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
new DenseVector(coef), labelValues, -1.0, modelId++));
}
}

2.4 输出模型

WriteModel 类实现了输出模型功能,大致逻辑如下:

  • 生成一个LinearModelData,用训练好的Tuple7来填充这个 LinearModelData。其中两个重要点:

    • modelData.coefVector = (DenseVector)value.f3;
    • modelData.labelValues = (Object[])value.f4;
  • 把模型数据转换成List rows。LinearModelDataConverter().save(modelData, listCollector);
  • 序列化,发送给下游算子。因为模型可能会很大,所以这里打散之后分布发送给下游算子
public void flatMap(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value, Collector<Row> out){

//输入value变量打印如下:
value = {Tuple7@13296}
f0 = {Long@13306} -1
f1 = {Integer@13307} 0
f2 = {Integer@13308} 2
f3 = {DenseVector@13309} "-0.7383426732137565 0.0 0.0 0.0 1.5885293675862715E-4 -4.834608575902742E-5 0.0 0.0 -6.754208708318647E-5 ......"
data = {double[30001]@13314}
f4 = {Object[2]@13310}
f5 = {Double@13311} -1.0
f6 = {Long@13312} 0 //生成模型
LinearModelData modelData = new LinearModelData();
......
modelData.coefVector = (DenseVector)value.f3;
modelData.labelValues = (Object[])value.f4; //把模型数据转换成List<Row> rows
RowCollector listCollector = new RowCollector();
new LinearModelDataConverter().save(modelData, listCollector);
List<Row> rows = listCollector.getRows(); for (Row r : rows) {
int rowSize = r.getArity();
for (int j = 0; j < rowSize; ++j) {
.....
//序列化
}
out.collect(row);
} iter++;
}
}

0x03 在线预测

预测功能是在 FtrlPredictStreamOp 完成的。

// ftrl predict
FtrlPredictStreamOp predictResult = new FtrlPredictStreamOp(initModel)
.setVectorCol(vecColName)
.setPredictionCol("pred")
.setReservedCols(new String[]{labelColName})
.setPredictionDetailCol("details")
.linkFrom(model, featurePipelineModel.transform(splitter.getSideOutput(0)));

从上面代码我们可以看到

  • FtrlPredict 功能同样需要初始模型 initModel,我们也是把逻辑回归模型赋予它。这样也是为了冷启动,即当FTRL训练模块还没有产生模型之前,FTRL预测模块也是可以对其输入数据做预测的。
  • model 是 FtrlTrainStreamOp 的输出,即 FTRL 的训练输出。所以 WriteModel 就直接把输出传给了 FtrlPredict功能。
  • splitter.getSideOutput(0) 这里是前面提到的测试输入,就是测试数据集。

linkFrom函数完成了业务逻辑,大致功能如下:

  • 使用 inputs[0].getDataStream().flatMap ------> partition ----> map ----> flatMap(new CollectModel()) 得到了模型 LinearModelData modelstr;
  • 使用 DataStream.connect 把输入的测试数据集 和 模型 LinearModelData modelstr关联起来,这样每个task都拥有了在线模型 modelstr,就可以通过 flatMap(new PredictProcess(...) 进行分布式预测;
  • 使用 setOutputTable 和 LinearModelMapper 把预测结果输出;

FTRL的预测功能有三个输入

  • 初始模型 initModel ----->  最后被 PredictProcess.open 加载,作为冷启动的预测模型;
  • 测试数据流 -----> 被 PredictProcess.flatMap1处理,进行预测;
  • FTRL训练阶段产生的模型数据流 ----> 被 PredictProcess.flatMap2 处理,进行在线模型更新;

3.1 初始化

构造函数中完成了初始化,即获取事先训练好的逻辑回归模型。

public FtrlPredictStreamOp(BatchOperator model) {
super(new Params());
if (model != null) {
dataBridge = DirectReader.collect(model);
} else {
throw new IllegalArgumentException("Ftrl algo: initial model is null. Please set a valid initial model.");
}
}

3.2 获取在线训练模型

CollectModel完成了 获取在线训练模型 功能。

其逻辑主要是:模型被分成若干块,其中 (long)inRow.getField(1) 这里记录了具体有多少块。所以 flatMap 函数会把这些块累积起来,最后组装成模型,统一发送给下游算子

具体是通过一个 HashMap<> buffers 来完成临时拼装/最后组装的。

public static class CollectModel implements FlatMapFunction<Row, LinearModelData> {

    private Map<Long, List<Row>> buffers = new HashMap<>(0);

    @Override
public void flatMap(Row inRow, Collector<LinearModelData> out) throws Exception { // 输入参数如下
inRow = {Row@13389} "0,19,0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null"
fields = {Object[5]@13405}
0 = {Long@13406} 0
1 = {Long@13403} 19
2 = {Long@13406} 0
3 = "{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"}" long id = (long)inRow.getField(0);
Long nTab = (long)inRow.getField(1); Row row = new Row(inRow.getArity() - 2); for (int i = 0; i < row.getArity(); ++i) {
row.setField(i, inRow.getField(i + 2));
} if (buffers.containsKey(id) && buffers.get(id).size() == nTab.intValue() - 1) {
buffers.get(id).add(row);
// 如果累积完成,则组装成模型
LinearModelData ret = new LinearModelDataConverter().load(buffers.get(id));
buffers.get(id).clear();
// 发送给下游算子。
out.collect(ret);
} else {
if (buffers.containsKey(id)) {
//如果有key。则往list添加。
buffers.get(id).add(row);
} else {
// 如果没有key,则添加list
List<Row> buffer = new ArrayList<>(0);
buffer.add(row);
buffers.put(id, buffer);
}
}
}
} //变量类似这种
this = {FtrlPredictStreamOp$CollectModel@13388}
buffers = {HashMap@13393} size = 1
{Long@13406} 0 -> {ArrayList@13431} size = 2
key = {Long@13406} 0
value = 0
value = {ArrayList@13431} size = 2
0 = {Row@13409} "0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null"
1 = {Row@13471} "1048576,{"featureColNames":null,"featureColTypes":null,"coefVector":{"data":[-0.7383426732137549,0.0,0.0,0.0,1.5885293675862704E-4,-4.834608575902738E-5,0.0,0.0,-6.754208708318643E-5,-1.5904172331763155E-4,0.0,-1.315219790338925E-4,0.0,-4.994749246390495E-4,0.0,2.755456604395511E-4,-9.616429481614131E-4,-9.601054004112163E-5,0.0,-1.6679174640370486E-4,0.0,......"

3.3 在线预测

PredictProcess 完成了在线预测功能,LinearModelMapper 是具体预测实现。

public static class PredictProcess extends RichCoFlatMapFunction<Row, LinearModelData, Row> {
private LinearModelMapper predictor = null;
private String modelSchemaJson;
private String dataSchemaJson;
private Params params;
private int iter = 0;
private DataBridge dataBridge;
}

3.3.1 加载预设置模型

其构造函数获得了 FtrlPredictStreamOp 类的 dataBridge,即事先训练好的逻辑回归模型。每一个Task都拥有完整的模型。

open函数会加载逻辑回归模型。

public void open(Configuration parameters) throws Exception {
this.predictor = new LinearModelMapper(TableUtil.fromSchemaJson(modelSchemaJson),
TableUtil.fromSchemaJson(dataSchemaJson), this.params);
if (dataBridge != null) {
// read init model
List<Row> modelRows = DirectReader.directRead(dataBridge);
LinearModelData model = new LinearModelDataConverter().load(modelRows);
this.predictor.loadModel(model);
}
}

3.3.2 在线预测

FtrlPredictStreamOp.flatMap1 函数完成了在线预测。

public void flatMap1(Row row, Collector<Row> collector) throws Exception {
collector.collect(this.predictor.map(row));
}

调用栈如下:

predictWithProb:157, LinearModelMapper (com.alibaba.alink.operator.common.linear)
predictResultDetail:114, LinearModelMapper (com.alibaba.alink.operator.common.linear)
map:90, RichModelMapper (com.alibaba.alink.common.mapper)
flatMap1:174, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning)
flatMap1:143, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning)
processElement1:53, CoStreamFlatMap (org.apache.flink.streaming.api.operators.co)
processRecord1:135, StreamTwoInputProcessor (org.apache.flink.streaming.runtime.io)

具体是通过 LinearModelMapper 完成。

public abstract class RichModelMapper extends ModelMapper {
public Row map(Row row) throws Exception {
if (isPredDetail) {
// 我们的示例代码在这里
Tuple2<Object, String> t2 = predictResultDetail(row);
return this.outputColsHelper.getResultRow(row, Row.of(t2.f0, t2.f1));
} else {
return this.outputColsHelper.getResultRow(row, Row.of(predictResult(row)));
}
}
}

预测代码如下,可以看出来使用了sigmoid。

/**
* Predict the label information with the probability of each label.
*/
public Tuple2 <Object, Double[]> predictWithProb(Vector vector) {
double dotValue = MatVecOp.dot(vector, model.coefVector);
switch (model.linearModelType) {
case LR:
case SVM:
double prob = sigmoid(dotValue);
return new Tuple2 <>(dotValue >= 0 ? model.labelValues[0] : model.labelValues[1],
new Double[] {prob, 1 - prob});
}
}

3.3.3 在线更新模型

FtrlPredictStreamOp.flatMap2 函数完成了处理在线训练输出的模型数据流,在线更新模型。

LinearModelData参数是由CollectModel完成加载并且传输出来的。

在模型加载过程中,是不能预测的,没有看到相关保护机制。如果我疏漏请大家指出。

public void flatMap2(LinearModelData linearModel, Collector<Row> collector) throws Exception {
this.predictor.loadModel(linearModel);
}

0x04 问题解答

针对之前我们提出的问题,现在总结归纳如下:

  • 训练阶段和预测阶段都有预制模型以应对"冷启动"嘛?都有预制模型
  • 训练阶段和预测阶段是如何关联起来的?用 linkFrom 直接把训练阶段和预测阶段的算子连在一起
  • 如何把训练出来的模型传给预测阶段?训练阶段用 Flink collector.collect 把模型发给下游算子
  • 输出模型时候,模型过大怎么处理?在线训练会 模型打散 之后分布发送给下游算子
  • 在线训练的模型通过什么机制实现更新?是定时驱动更新嘛?定时更新
  • 预测阶段加载模型过程中,还可以预测嘛?有没有机制保证这段时间内也能预测?目前没有发现类似保护机制
  • 训练阶段中,有哪些阶段用到了并行处理?训练过程中主要是FTRL算法的"预测predict" 和 "更新参数"两个部分,以及发送模型
  • 预测阶段中,有哪些阶段用到了并行处理?预测过程中主要是分布式接受模型和分布式预测
  • 遇到高维向量如何处理?切分开嘛?切分处理

0xFF 参考

【机器学习】逻辑回归(非常详细)

逻辑回归(logistics regression)

【机器学习】LR的分布式(并行化)实现

并行逻辑回归

机器学习算法及其并行化讨论

Online LR—— FTRL 算法理解

在线优化算法 FTRL 的原理与实现

LR+FTRL算法原理以及工程化实现

Flink流处理之迭代API分析

FTRL公式推导

FTRL论文笔记

在线机器学习FTRL(Follow-the-regularized-Leader)算法介绍

FTRL代码实现

FTRL实战之LR+FTRL(代码采用的稠密数据)

在线学习算法FTRL-Proximal原理

基于FTRL的在线CTR预测算法

CTR预测算法之FTRL-Proximal

各大公司广泛使用的在线学习算法FTRL详解

在线最优化求解(Online Optimization)之五:FTRL

FOLLOW THE REGULARIZED LEADER (FTRL) 算法总结

Alink漫谈(十三) :在线学习算法FTRL 之 具体实现的更多相关文章

  1. Alink漫谈(十二) :在线学习算法FTRL 之 整体设计

    Alink漫谈(十二) :在线学习算法FTRL 之 整体设计 目录 Alink漫谈(十二) :在线学习算法FTRL 之 整体设计 0x00 摘要 0x01概念 1.1 逻辑回归 1.1.1 推导过程 ...

  2. 各大公司广泛使用的在线学习算法FTRL详解

    各大公司广泛使用的在线学习算法FTRL详解 现在做在线学习和CTR常常会用到逻辑回归( Logistic Regression),而传统的批量(batch)算法无法有效地处理超大规模的数据集和在线数据 ...

  3. 各大公司广泛使用的在线学习算法FTRL详解 - EE_NovRain

    转载请注明本文链接:http://www.cnblogs.com/EE-NovRain/p/3810737.html 现在做在线学习和CTR常常会用到逻辑回归( Logistic Regression ...

  4. 广告点击率预测(CTR) —— 在线学习算法FTRL的应用

    FTRL由google工程师提出,在13的paper中给出了伪代码和实现细节,paper地址:http://www.eecs.tufts.edu/~dsculley/papers/ad-click-p ...

  5. 在线优化算法 FTRL 的原理与实现

    在线学习想要解决的问题 在线学习 ( \(\it{Online \;Learning}\) ) 代表了一系列机器学习算法,特点是每来一个样本就能训练,能够根据线上反馈数据,实时快速地进行模型调整,使得 ...

  6. Alink漫谈(一) : 从KMeans算法实现不同看Alink设计思想

    Alink漫谈(一) : 从KMeans算法实现不同看Alink设计思想 目录 Alink漫谈(一) : 从KMeans算法实现不同看Alink设计思想 0x00 摘要 0x01 Flink 是什么 ...

  7. Bandit:一种简单而强大的在线学习算法

    假设我有5枚硬币,都是正反面不均匀的.我们玩一个游戏,每次你可以选择其中一枚硬币掷出,如果掷出正面,你将得到一百块奖励.掷硬币的次数有限(比如10000次),显然,如果要拿到最多的利益,你要做的就是尽 ...

  8. Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构

    Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构 目录 Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构 0x00 摘要 0x01 Alink设计原则 0x02 A ...

  9. Alink漫谈(十一) :线性回归 之 L-BFGS优化

    Alink漫谈(十一) :线性回归 之 L-BFGS优化 目录 Alink漫谈(十一) :线性回归 之 L-BFGS优化 0x00 摘要 0x01 回顾 1.1 优化基本思路 1.2 各类优化方法 0 ...

随机推荐

  1. myeclipse集成jad反编译步骤

    (1) 将jad.exe放到java的jdk\bin目录下(2) 将jadeclipse插件net.sf.jadclipse_3.3.0.jar 拷贝到myeclipse安装目\Common\plug ...

  2. STL初步学习(map)

    3.map map作为一个映射,有两个参数,第一个参数作为关键值,第二个参数为对应的值,关键值是唯一的 在平时使用的数组中,也有点类似于映射的方法,例如a[10]=1,但其实我们的关键值和对应的值只能 ...

  3. css 分割线样式_css实现文章分割线的多种方法总结

    这篇文章整理css如何实现文章分割线的多种方式,分割线在页面中可以起到美化作用,那么就来看看使用css实现分割线样式的多种方法.效果如下: 方式一:单个标签实现分隔线: html: <div c ...

  4. scrapy框架携带cookie访问淘宝购物车

    我们知道,有的网页必须要登录才能访问其内容.scrapy登录的实现一般就三种方式. 1.在第一次请求中直接携带用户名和密码. 2.必须要访问一次目标地址,服务器返回一些参数,例如验证码,一些特定的加密 ...

  5. Django---进阶1

    目录 静态文件配置 request对象方法初识 pycharm链接数据库(MySQL) django链接数据库(MySQL) Django ORM 字段的增删改查 数据的增删改查 今日作业 静态文件配 ...

  6. Django setting设置 常用设置

    目录 Django配置文件基本设置 前言 setting配置汇总 一.APP路径 二.数据库配置 三.sql语句展示 四.静态文件目录 五.media文件配置 六.数据库中的UserInfo(用户表) ...

  7. pigctf期末测评

    pigctf期末测评 MISC 1 拿到图片,先binwalk一下,如下图 果然发现png图片后面跟了个ZIP,然后提取出来打开发现了一个flag.png,然后查看16进制文件没有发现什么问题,之后查 ...

  8. Scala 基础(十四):Scala 模式匹配(二)

    1 匹配数组 1)Array(0) 匹配只有一个元素且为0的数组. 2)Array(x,y) 匹配数组有两个元素,并将两个元素赋值为x和y.当然可以依次类推Array(x,y,z) 匹配数组有3个元素 ...

  9. Kubernetes部署通用手册 (支持版本1.19,1.18,1.17,1.16)

    Kubernetes平台环境规划 操作环境 rbac 划分(HA高可用双master部署实例) 本文穿插了ha 高可用部署的实例,当前章节设计的是ha部署双master 部署 内网ip 角色 安装软件 ...

  10. 深度剖析分布式单点登录框架XXL-SSO

    于2018年初,在github上创建XXL-SSO项目仓库并提交第一个commit,随之进行系统结构设计,UI选型,交互设计-- 于2018年初,在github上创建XXL-SSO项目仓库并提交第一个 ...