基于Spark的GBDT + LR模型实现

测试数据来源http://archive.ics.uci.edu/ml/machine-learning-databases/adult/

该模型利用Spark mllib的GradientBoostedTrees作为GBDT部分,因为ml模块的GBTClassifier对所生成的模型做了相当严密的封装,导致难以获取某些类或方法。而GradientBoostedTrees所需的训练数据为mllib下的LabeledPoint,所以下面的数据预处理的目标是将cat数据进行编码并生成LabeledPoint。

数据预处理部分

import org.apache.spark.mllib.linalg.{SparseVector => OldSparseVector}
import org.apache.spark.sql.functions._
import spark.implicits._ val path = "" val manualSchema = StructType(Array(
StructField("age", IntegerType, true),
StructField("workclass", StringType, true),
StructField("fnlwgt", IntegerType, true),
StructField("education", StringType, true),
StructField("education-num", IntegerType, true),
StructField("marital-status", StringType, true),
StructField("occupation", StringType, true),
StructField("relationship", StringType, true),
StructField("race", StringType, true),
StructField("sex", StringType, true),
StructField("capital-gain", IntegerType, true),
StructField("capital-loss", IntegerType, true),
StructField("hours-per-week", IntegerType, true),
StructField("native-country", StringType, true),
StructField("label", StringType, true))) val df = spark.read
.option("header", false)
.option("delimiter", ",")
.option("nullValue", "?")
.schema(manualSchema)
.format("csv")
.load(path + "adult.data.txt")
// .limit(1000) // 去掉代表序列号的col
var df1 = df.drop("fnlwgt")
.na.drop() val allFeature = df1.columns.dropRight(1) // colName和index的映射
val colIdx = new util.HashMap[String, Int](allFeature.length)
var idx = 0
while (idx < allFeature.length){
colIdx.put(allFeature(idx), idx)
idx += 1
} val numCols = Array("age", "education-num", "capital-gain", "capital-loss", "hours-per-week")
val catCols = df1.columns.dropRight(1).diff(numCols)
val numLen = numCols.length
val catLen = catCols.length // 处理label
def labeludf(elem: String):Int = {
if (elem == "<=50K") 0
else 1
} val labelIndexer = udf(labeludf(_:String):Int) // 也可以用 when 函数
// val labelIndexer = when($"lable" === "<=50K", 0).otherwise(1) df1 = df1.withColumn("indexed_label", labelIndexer(col("label"))).drop("label") // 处理cat列
// 所有cat列统一编码,例如有两列cat,第一列为性别,第二列为早、午、晚,那么第一列的编码为0或1,而第二列的编码为2、3或4。下面实现仿照StringIndexer,可能更高效
val inderMap: util.HashMap[String, util.HashMap[String, Int]] = new util.HashMap(catCols.length)
var i = numCols.length
for (column <- catCols) {
val uniqueElem = df1.select(column)
.groupBy(column)
.agg(count(column))
.select(column)
.map(_.getAs[String](0))
.collect() val len = uniqueElem.length
var index = 0
val freqMap = new util.HashMap[String, Int](len) while (index < len) {
freqMap.put(uniqueElem(index), i)
index += 1
i += 1
}
inderMap.put(column, freqMap)
} val bcMap = spark.sparkContext.broadcast(inderMap) val d = i // 合并为LabeledPoint
val df2 = df1.rdd.map { row =>
val indics = new Array[Int](numLen + catLen)
val value = new Array[Double](numLen + catLen)
var i = 0
for (col <- numCols) {
indics(i) = i
value(i) = row.getAs[Int](colIdx.get(col)).toDouble
i += 1
} for (col <- catCols) {
indics(i) = bcMap.value.get(col).get(row.getAs[String](colIdx.get(col)))
value(i) = 1
i += 1
} new LabeledPoint(row.getAs[Int](numLen + catLen), new OldSparseVector(d, indics, value))
}
val ds = df2.toDF("label", "feature")
ds.write.save(path + "processed")

GBDT模型部分(省略调参部分)

val path = ""
val df = spark.read
.load(path)
.rdd
.map(row => LabeledPoint(row.getAs[Double](0), row.getAs[OldSparseVector](1))) // Train a GradientBoostedTrees model.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 10
boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 3
boostingStrategy.learningRate = 0.3
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() val model = GradientBoostedTrees.train(df, boostingStrategy) model.save(spark.sparkContext, path + "GBDTmodel")

GBDT与LR混合部分

object GBTLRTraining {

  // 遍历一颗决策树,找出其出口的叶子节点id
def predictModify(node: OldNode, features: OldSparseVector): Int = {
val split = node.split
if (node.isLeaf) {
node.id - 1 // 改为0-base
} else {
if (split.get.featureType == FeatureType.Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
predictModify(node.leftNode.get, features)
} else {
predictModify(node.rightNode.get, features)
}
} else {
if (split.get.categories.contains(features(split.get.feature))) {
predictModify(node.leftNode.get, features)
} else {
predictModify(node.rightNode.get, features)
}
}
}
} // 获取每棵树的出口叶子节点id数组
def getGBTFeatures(gbtModel: GradientBoostedTreesModel, oldFeatures: OldSparseVector): Array[Int] = {
val GBTMaxIter = gbtModel.trees.length
val leafIdArray = new Array[Int](GBTMaxIter)
for (i <- 0 until GBTMaxIter) {
val treePredict = predictModify(gbtModel.trees(i).topNode, oldFeatures)
leafIdArray(i) = treePredict
}
leafIdArray
} def main(args: Array[String]): Unit = { val spark = SparkSession
.builder()
.master("local[*]")
.appName("TEST")
// 本地配置
.config("spark.sql.shuffle.partitions", 12)
.config("spark.default.parallelism", 12)
.config("spark.memory.fraction", 0.75)
// .config("spark.memory.ofHeap.enabled", true)
// .config("spark.memory.ofHeapa.size", "2G")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
// .config("spark.executor.memory", "2G")
.getOrCreate() spark.sparkContext.setLogLevel("ERROR") import org.apache.spark.sql.functions._
import spark.implicits._ val path = ""
val df = spark.read
.load(path) val model = GradientBoostedTreesModel.load(spark.sparkContext, path + "GBDTmodel") val bcmodel = spark.sparkContext.broadcast(model) var treeNodeNum = 0
var treeDepth = 0
// 获取最大的树的数据
for (elem <- model.trees) {
if (treeNodeNum < elem.numNodes){
treeNodeNum = elem.numNodes
treeDepth = elem.depth
}
} val leafNum = math.pow(2, treeDepth).toInt
val nonLeafNum = treeNodeNum - leafNum
val totalColNum = leafNum * model.trees.length // print(leafNum + " " + nonLeafNum + " " + tree.numNodes + " " + totalColNum) // 利用之前训练好的GBT模型进行特征提取,并把原特征OldSparseVector转化为ml的SparseVector,让后续的LR使用
val addFeatureUDF = udf { features: OldSparseVector =>
val gbtFeatures = getGBTFeatures(bcmodel.value, features)
var i = 0
while (i < gbtFeatures.length){
val leafIdx = gbtFeatures(i) - nonLeafNum
// 有些树可能没有生长完全,leafIdx没有达到最大的树的最后一层,这里将这些情况默认为最大的树的最后一层的第一个叶子节点。
gbtFeatures(i) = (if (leafIdx < 0) 0 else leafIdx) + i * leafNum
i += 1
}
val idx = gbtFeatures
val values = Array.fill[Double](idx.length)(1.0)
Vectors.sparse(totalColNum, idx, values)
} val dsWithCombinedFeatures = df
.withColumn("lr_feature", addFeatureUDF(col("feature"))) // dsWithCombinedFeatures.show(false) val lr = new LogisticRegression()
.setMaxIter(500)
.setFeaturesCol("lr_feature")
.setLabelCol("label") val lrmodel = lr.fit(dsWithCombinedFeatures) val res = lrmodel.transform(dsWithCombinedFeatures) // res.show(false) val evaluator1 = new MulticlassClassificationEvaluator().setMetricName("accuracy")
.setLabelCol("label")
.setPredictionCol("prediction")
println("ACC:" + evaluator1.evaluate(res)) val evaluator2 = new BinaryClassificationEvaluator().setMetricName("areaUnderROC")
.setLabelCol("label")
.setRawPredictionCol("prediction")
println("AUC:" + evaluator2.evaluate(res)) }
}

参考资料:

https://github.com/wzhe06/CTRmodel

基于Spark的GBDT + LR模型实现的更多相关文章

  1. GBDT+LR算法解析及Python实现

    1. GBDT + LR 是什么 本质上GBDT+LR是一种具有stacking思想的二分类器模型,所以可以用来解决二分类问题.这个方法出自于Facebook 2014年的论文 Practical L ...

  2. ctr中的GBDT+LR的优点

    1 为什么gbdt+lr优于gbdt? 其实gbdt+lr类似于做了一个stacking.gbdt+lr模型中,把gbdt的叶子节点作为lr的输入,而gbdt的叶子节点相当于它的输出y',用这个y'作 ...

  3. 客户流失?来看看大厂如何基于spark+机器学习构建千万数据规模上的用户留存模型 ⛵

    作者:韩信子@ShowMeAI 大数据技术 ◉ 技能提升系列:https://www.showmeai.tech/tutorials/84 行业名企应用系列:https://www.showmeai. ...

  4. 在Java Web中使用Spark MLlib训练的模型

    PMML是一种通用的配置文件,只要遵循标准的配置文件,就可以在Spark中训练机器学习模型,然后再web接口端去使用.目前应用最广的就是基于Jpmml来加载模型在javaweb中应用,这样就可以实现跨 ...

  5. 基于spark邮件自动分类

    代码放在github上:click me 一.数据说明 数据集为英文语料集,一共包含20种类别的邮件,除了类别soc.religion.christian的邮件数为997以外每个类别的邮件数都是100 ...

  6. 基于Spark ALS构建商品推荐引擎

    基于Spark ALS构建商品推荐引擎   一般来讲,推荐引擎试图对用户与某类物品之间的联系建模,其想法是预测人们可能喜好的物品并通过探索物品之间的联系来辅助这个过程,让用户能更快速.更准确的获得所需 ...

  7. 大数据实时处理-基于Spark的大数据实时处理及应用技术培训

    随着互联网.移动互联网和物联网的发展,我们已经切实地迎来了一个大数据 的时代.大数据是指无法在一定时间内用常规软件工具对其内容进行抓取.管理和处理的数据集合,对大数据的分析已经成为一个非常重要且紧迫的 ...

  8. spark概念、编程模型和模块概述

    http://blog.csdn.net/pipisorry/article/details/50931274 spark基本概念 Spark一种与 Hadoop 相似的通用的集群计算框架,通过将大量 ...

  9. 基于Spark自动扩展scikit-learn (spark-sklearn)(转载)

    转载自:https://blog.csdn.net/sunbow0/article/details/50848719 1.基于Spark自动扩展scikit-learn(spark-sklearn)1 ...

随机推荐

  1. [luoguP1489] 猫狗大战(DP)

    传送门 类似背包的做法. f[i][j]表示是否能放i个物品,价格为j #include <cstdio> #include <iostream> #define N 8001 ...

  2. Codeforces Round #264 (Div. 2) D

    题意: 给出最多5个序列,问这几个序列的最长公共子序列的长度是多少. solution : 脑抽级别我是,第一个序列每个数字位置固定,这样只要维护一个k-1维的偏序集就好了.然后在保证每个位置合法的情 ...

  3. 【HDOJ6146】Pokémon GO(DP,计数)

    题意:一个2*n的矩阵,从任意一格出发,不重复且不遗漏地走遍所有格子,问方案数 mo 10^9+7 n<=10000 思路:因为OEIS搜出来的两个数列都是错误的,所以考虑DP 设B[i]为2* ...

  4. tyvj1045 最大的算式

    描述 题目很简单,给出N个数字,不改变它们的相对位置,在中间加入K个乘号和N-K-1个加号,(括号随便加)使最终结果尽量大.因为乘号和加号一共就是N-1个了,所以恰好每两个相邻数字之间都有一个符号.例 ...

  5. 【Tomcat】tomcat启动后查看运行时JVM参数

    Tomcat优化配置参考http://www.cnblogs.com/qlqwjy/p/8007490.html 1.启动服务后访问localhost,点击Server Status

  6. linux 用户管理、权限管理

    1.useradd -[ugGdsce]2.passwd 用户名 ================================================ 1.chmod 2.chown 3. ...

  7. noip 2011

    铺地毯 题目描述 为了准备一个独特的颁奖典礼,组织者在会场的一片矩形区域(可看做是平面直角坐标系的第一象限)铺上一些矩形地毯.一共有 n 张地毯,编号从 1 到n .现在将这些地毯按照编号从小到大的顺 ...

  8. ZOJ3953 ZJU2017校赛(贪心)

    题意:给出n个区间,求至少删掉多少个区间使得不存在区间a, b, c 两两相交    (定义两个区间相交是,区间[l1, r1]和区间[l2, r2]相交,当且仅当存在一个数x,l1<=x< ...

  9. javaweb开发页面数字过长显示科学计数法的问题

    1. 检查该字段是否为double类型,如果是,请改成BigDecimal 2.如果是导出excel里面为科学计数法,原页面正常,是因为excel设置的原因,请参考https://jingyan.ba ...

  10. mysql性能调优——锁优化

    影响mysql server性能的相关因素 需求和架构及业务实现优化:55% Query语句优化:30% 数据库自身优化:15% 很多时候大家看到数据库应用系统中性能瓶颈出现在数据库方面,就希望通过数 ...