Alink漫谈(二十一) :回归评估之源码分析
Alink漫谈(二十一) :回归评估之源码分析
0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文和将带领大家来分析Alink中 回归评估 的实现。
这是剖析Alink以来,最轻松的一次了。因为这里的概念和实现逻辑都非常清晰。
0x01 背景概念
1.1 功能介绍
回归评估是对回归算法的预测结果进行效果评估,支持下列评估指标。这些指标基本都是统计领域概念。
1.2 具体指标
Alink 提供如下指标:
count 行数
SST 总平方和(Sum of Squared for Total),度量了Y在样本中的分散程度。
\]
SSE 误差平方和(Sum of Squares for Error),度量了总样本变异。
\]
SSR 回归平方和(Sum of Squares for Regression),度量了残差的样本变异。
\]
R^2 判定系数(Coefficient of Determination),用于估计回归方程是否很好的拟合了样本的数据,判定系数为估计的回归方程提供了一个拟合优度的度量。
\]
R 多重相关系数(Multiple Correlation Coeffient),指一个随机变量与某一组随机变量间线性相依性的度量。
\]
MSE 均方误差(Mean Squared Error),均方差(标准差)、方差都是用来描述数据集的离散程度。
均方误差是衡量“平均误差”的一种较方便的方法,可以评价数据的变化程度。从类别来看属于预测评价与预测组合;从字面上看来,“均”指的是平均,即求其平均值,“方差”即是在概率论中用来衡量随机变量和其估计值(其平均值)之间的偏离程度的度量值,“误”可以理解为测定值与真实值之间的误差。
\]
RMSE 均方根误差(Root Mean Squared Error)
\]
SAE/SAD 绝对误差(Sum of Absolute Error/Difference)
\]
MAE/MAD 平均绝对误差(Mean Absolute Error/Difference)
\]
MAPE 平均绝对百分误差(Mean Absolute Percentage Error)
\]
explained variance 解释方差
\]
0x02 示例代码
直接拿出来Alink的示例代码。
public class EvalRegressionBatchOpExp {
public static void main(String[] args) throws Exception {
Row[] data =
new Row[] {
Row.of(0.4, 0.5),
Row.of(0.3, 0.5),
Row.of(0.2, 0.6),
Row.of(0.6, 0.7),
Row.of(0.1, 0.5)
};
MemSourceBatchOp input = new MemSourceBatchOp(data, new String[] {"label", "pred"});
RegressionMetrics metrics = new EvalRegressionBatchOp()
.setLabelCol("label")
.setPredictionCol("pred")
.linkFrom(input)
.collectMetrics();
System.out.println(metrics.getRmse());
System.out.println(metrics.getR2());
System.out.println(metrics.getSse());
System.out.println(metrics.getMape());
System.out.println(metrics.getMae());
System.out.println(metrics.getSsr());
System.out.println(metrics.getSst());
}
}
输出为:
0.27568097504180444
-1.5675675675675653
0.38
141.66666666666669
0.24
0.31999999999999973
0.14800000000000013
0x03 总体逻辑
总体逻辑是:
- 调用 CalcLocal 进行分区计算各种统计数值;
- reduce 调用 ReduceBaseMetrics 进行归并各种统计数值;
- 调用 SaveDataAsParams 存储;
getLabelCol 就是 y,getPredictionCol 就是 y_hat。
public EvalRegressionBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator in = checkAndGetFirst(inputs);
// 这里就是找到y, y_hat
TableUtil.findColIndexWithAssertAndHint(in.getColNames(), this.getLabelCol());
TableUtil.findColIndexWithAssertAndHint(in.getColNames(), this.getPredictionCol());
// 利用y, y_hat来构建Metrics
TableUtil.assertNumericalCols(in.getSchema(), this.getLabelCol(), this.getPredictionCol());
DataSet<Row> out = in.select(new String[] {this.getLabelCol(), this.getPredictionCol()})
.getDataSet()
.rebalance()
.mapPartition(new CalcLocal())
.reduce(new EvaluationUtil.ReduceBaseMetrics())
.flatMap(new EvaluationUtil.SaveDataAsParams());
this.setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(),
out, new TableSchema(new String[] {"regression_eval_result"}, new TypeInformation[] {Types.STRING})
));
return this;
}
0x04 分区计算统计数值
调用 CalcLocal 进行分区计算各种统计数值,间接调用getRegressionStatistics。
/**
* Get the label sum, predResult sum, SSE, MAE, MAPE of one partition.
*/
public static class CalcLocal implements MapPartitionFunction<Row, BaseMetricsSummary> {
@Override
public void mapPartition(Iterable<Row> rows, Collector<BaseMetricsSummary> collector)
throws Exception {
collector.collect(getRegressionStatistics(rows));
}
}
getRegressionStatistics作用是遍历输入数据,在本Partition内部计算各种累积数值,为后续做准备。
/**
* Calculate the RegressionMetrics from local data.
*
* @param rows Input rows, the first field is label value, the second field is prediction value.
* @return RegressionMetricsSummary.
*/
public static RegressionMetricsSummary getRegressionStatistics(Iterable<Row> rows) {
RegressionMetricsSummary regressionSummary = new RegressionMetricsSummary();
for (Row row : rows) {
if (checkRowFieldNotNull(row)) {
double yVal = ((Number)row.getField(0)).doubleValue();
double predictVal = ((Number)row.getField(1)).doubleValue();
double diff = Math.abs(yVal - predictVal);
regressionSummary.ySumLocal += yVal;
regressionSummary.ySum2Local += yVal * yVal;
regressionSummary.predSumLocal += predictVal;
regressionSummary.predSum2Local += predictVal * predictVal;
regressionSummary.maeLocal += diff;
regressionSummary.sseLocal += diff * diff;
regressionSummary.mapeLocal += Math.abs(diff / yVal);
regressionSummary.total++;
}
}
return regressionSummary.total == 0 ? null : regressionSummary;
}
0x05 归并统计数值
reduce 调用 ReduceBaseMetrics 进行归并各种统计数值:
/**
* Merge the BaseMetrics calculated locally.
*/
public static class ReduceBaseMetrics implements ReduceFunction<BaseMetricsSummary> {
@Override
public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception {
return null == t1 ? t2 : t1.merge(t2);
}
}
0x06 存储模型
这里调用SaveDataAsParams来存储模型。
/**
* After merging all the BaseMetrics, we get the total BaseMetrics. Calculate the indexes and save them into params.
*/
public static class SaveDataAsParams implements FlatMapFunction<BaseMetricsSummary, Row> {
@Override
public void flatMap(BaseMetricsSummary t, Collector<Row> collector) throws Exception {
collector.collect(t.toMetrics().serialize());
}
}
0x07 toMetrics
最后呈现出统计指标。
public RegressionMetrics toMetrics() {
Params params = new Params();
params.set(RegressionMetrics.SST, ySum2Local - ySumLocal * ySumLocal / total);
params.set(RegressionMetrics.SSE, sseLocal);
params.set(RegressionMetrics.SSR,
predSum2Local - 2 * ySumLocal * predSumLocal / total + ySumLocal * ySumLocal / total);
params.set(RegressionMetrics.R2, 1 - params.get(RegressionMetrics.SSE) / params.get(RegressionMetrics.SST));
params.set(RegressionMetrics.R, Math.sqrt(params.get(RegressionMetrics.R2)));
params.set(RegressionMetrics.MSE, params.get(RegressionMetrics.SSE) / total);
params.set(RegressionMetrics.RMSE, Math.sqrt(params.get(RegressionMetrics.MSE)));
params.set(RegressionMetrics.SAE, maeLocal);
params.set(RegressionMetrics.MAE, params.get(RegressionMetrics.SAE) / total);
params.set(RegressionMetrics.COUNT, (double)total);
params.set(RegressionMetrics.MAPE, mapeLocal * 100 / total);
params.set(RegressionMetrics.Y_MEAN, ySumLocal / total);
params.set(RegressionMetrics.PREDICTION_MEAN, predSumLocal / total);
params.set(RegressionMetrics.EXPLAINED_VARIANCE, params.get(RegressionMetrics.SSR) / total);
return new RegressionMetrics(params);
}
最后得到结果
params = {Params@9098} "Params {R2=-1.5675675675675693, predictionMean=0.5599999999999999, SSE=0.38, count=5.0, MAPE=141.66666666666666, RMSE=0.27568097504180444, MAE=0.24, R=NaN, SSR=0.3200000000000002, yMean=0.32, SST=0.1479999999999999, SAE=1.2, Explained Variance=0.06400000000000003, MSE=0.076}"
params = {HashMap@9101} size = 14
"R2" -> "-1.5675675675675693"
"predictionMean" -> "0.5599999999999999"
"SSE" -> "0.38"
"count" -> "5.0"
"MAPE" -> "141.66666666666666"
"RMSE" -> "0.27568097504180444"
"MAE" -> "0.24"
"R" -> "NaN"
"SSR" -> "0.3200000000000002"
"yMean" -> "0.32"
"SST" -> "0.1479999999999999"
"SAE" -> "1.2"
"Explained Variance" -> "0.06400000000000003"
"MSE" -> "0.076"
0xFF 参考
Alink漫谈(二十一) :回归评估之源码分析的更多相关文章
- TeamTalk源码分析(十一) —— pc客户端源码分析
--写在前面的话 在要不要写这篇文章的纠结中挣扎了好久,就我个人而已,我接触windows编程,已经六七个年头了,尤其是在我读研的三年内,基本心思都是花在学习和研究windows程序上 ...
- Java I/O系列(二)ByteArrayInputStream与ByteArrayOutputStream源码分析及理解
1. ByteArrayInputStream 定义 继承了InputStream,数据源是内置的byte数组buf,那read ()方法的使命(读取一个个字节出来),在ByteArrayInputS ...
- equals和==方法比较(二)--Long中equals源码分析
接上篇,分析equals方法在Long包装类中的重写,其他类及我们自定义的类,同样可以根据需要重新equals方法. equals方法定义 equals方法是Object类中的方法,java中所有的对 ...
- ABP源码分析一:整体项目结构及目录
ABP是一套非常优秀的web应用程序架构,适合用来搭建集中式架构的web应用程序. 整个Abp的Infrastructure是以Abp这个package为核心模块(core)+15个模块(module ...
- Solr4.8.0源码分析(19)之缓存机制(二)
Solr4.8.0源码分析(19)之缓存机制(二) 前文<Solr4.8.0源码分析(18)之缓存机制(一)>介绍了Solr缓存的生命周期,重点介绍了Solr缓存的warn过程.本节将更深 ...
- ConcurrenHashMap源码分析(二)
本篇博客的目录: 一:put方法源码 二:get方法源码 三:rehash的过程 四:总结 一:put方法的源码 首先,我们来看一下segment内部类中put方法的源码,这个方法它是segment片 ...
- 【集合框架】JDK1.8源码分析HashSet && LinkedHashSet(八)
一.前言 分析完了List的两个主要类之后,我们来分析Set接口下的类,HashSet和LinkedHashSet,其实,在分析完HashMap与LinkedHashMap之后,再来分析HashSet ...
- redis源码分析之发布订阅(pub/sub)
redis算是缓存界的老大哥了,最近做的事情对redis依赖较多,使用了里面的发布订阅功能,事务功能以及SortedSet等数据结构,后面准备好好学习总结一下redis的一些知识点. 原文地址:htt ...
- heapster源码分析——kubelet的api调用分析
一.heapster简介 什么是Heapster? Heapster是容器集群监控和性能分析工具,天然的支持Kubernetes和CoreOS.Kubernetes有个出名的监控agent---cAd ...
- redis源码分析之有序集SortedSet
有序集SortedSet算是redis中一个很有特色的数据结构,通过这篇文章来总结一下这块知识点. 原文地址:http://www.jianshu.com/p/75ca5a359f9f 一.有序集So ...
随机推荐
- 静态分析工具及使用总结(二)CheckStyle
这里主要介绍三种开源的工具,PMD.CheckStyle和FindBugs,着重是在Ant里的调用,据说商业软件JTest也是著名的代码分析工具,哈哈,要花钱的没有用过. Checkstyle (ht ...
- Xdebug+Phpstorm本地调试
很久不用php进行开发, debug插件的安装与配置都忘完了, 看了下自己之前记录的一篇文章, 有点太乱了, 这里简约介绍下,方便后面快捷使用 XDebug下载地址: https://xdebug.o ...
- js返回的字符串中添加空格
labelFormatter: function() { return `${this.name}\xa0\xa0\xa0${this.y}%`; } 使用"\xa0"
- Business Object 开发
一 什么是BO BO(Business Object),封装在数据库之上,用于直接操作数据(增.删.改.查) 针对不同的BO,在安装目录下有对应的DLL文件,其中封装了BO各式针对具体的业务的方法, ...
- 使用AES加密时,结果不一样
使用AES加密时,发现得到的结果不一致. python示例 from Crypto.Cipher import AES from Crypto.Util.Padding import pad from ...
- 从底层源码深入分析Spring的IoC容器初始化过程
IOC容器的初始化整体过程 Spring是如何实现将资源配置(以xml配置为例)通过加载,解析,生成BeanDefination并注册到IoC容器中的?这主要会经过以下 4 步: 从XML中读取配置文 ...
- nodejs koa2 ocr识别 身份证信息
1. 安装依赖 npm install baidu-aip-sdk 2.创建AipOcrClient 注:需要到百度api创建应用,拿到所需的APPID/AK/SK https://console.b ...
- 【javaweb】【Session】记录用户访问时间
效果 Servlet import jakarta.servlet.*; import jakarta.servlet.http.*; import jakarta.servlet.annotatio ...
- 【shell】远程执行shell|多节点并行执行shell|远程执行注意
目录 前提条件 shell远程执行 多节点上并行执行命令的三种方法 方法1 使用bash执行命令 方法2 使用clustershell执行命令--还能收集结果 方法3 使用pdsh 执行命令 远程执行 ...
- Qt编写的项目作品6-可视化大屏电子看板系统
一.功能特点 采用分层设计,整体总共分三级界面,一级界面是整体布局,二级界面是单个功能模块,三级界面是单个控件. 子控件包括饼图.圆环图.曲线图.柱状图.柱状分组图.横向柱状图.横向柱状分组图.合格率 ...