下面代码按照之前参加Kaggle的python代码改写,只完成了模型的训练过程,还需要对test集的数据进行转换和对test集进行预测。

scala 2.11.12

spark 2.2.2

package ML.Titanic

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.Bucketizer
import org.apache.spark.ml.feature.QuantileDiscretizer
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.tuning.{TrainValidationSplit, TrainValidationSplitModel}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.types._ /**
* GBTClassifier for predicting survival in the Titanic ship
*/
object TitanicChallenge { def main(args: Array[String]) { val spark = SparkSession.builder.
master("local[*]")
.appName("example")
.config("spark.sql.shuffle.partitions", 20)
.config("spark.default.parallelism", 20)
.config("spark.driver.memory", "4G")
.config("spark.memory.fraction", 0.75)
.getOrCreate()
val sc = spark.sparkContext spark.sparkContext.setLogLevel("ERROR") val schemaArray = StructType(Array(
StructField("PassengerId", IntegerType, true),
StructField("Survived", IntegerType, true),
StructField("Pclass", IntegerType, true),
StructField("Name", StringType, true),
StructField("Sex", StringType, true),
StructField("Age", FloatType, true),
StructField("SibSp", IntegerType, true),
StructField("Parch", IntegerType, true),
StructField("Ticket", StringType, true),
StructField("Fare", FloatType, true),
StructField("Cabin", StringType, true),
StructField("Embarked", StringType, true)
)) val path = "Titanic/"
val df = spark.read
.option("header", "true")
.schema(schemaArray)
.csv(path + "train.csv")
.drop("PassengerId")
// df.cache() val utils = new TitanicChallenge(spark)
val df2 = utils.transCabin(df)
val df3 = utils.transTicket(sc, df2)
val df4 = utils.transEmbarked(df3)
val df5 = utils.extractTitle(sc, df4)
val df6 = utils.transAge(sc, df5)
val df7 = utils.categorizeAge(df6)
val df8 = utils.createFellow(df7)
val df9 = utils.categorizeFellow(df8)
val df10 = utils.extractFName(df9)
val df11 = utils.transFare(df10) val prePipelineDF = df11.select("Survived", "Pclass", "Sex",
"Age_categorized", "fellow_type", "Fare_categorized",
"Embarked", "Cabin", "Ticket",
"Title", "family_type") // prePipelineDF.show(1)
// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+
// |Survived|Pclass| Sex|Age_categorized|fellow_type|Fare_categorized|Embarked|Cabin|Ticket|Title|family_type|
// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+
// | 0| 3|male| 3.0| Small| 0.0| S| U| 0| Mr| 0|
// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+ val (df_indexed, colsTrain) = utils.index_onehot(prePipelineDF)
df_indexed.cache() //训练模型
val validatorModel = utils.trainData(df_indexed, colsTrain) //打印最优模型的参数
val bestModel = validatorModel.bestModel
println(bestModel.asInstanceOf[PipelineModel].stages.last.extractParamMap) //打印各模型的成绩和参数
val paramsAndMetrics = validatorModel.validationMetrics
.zip(validatorModel.getEstimatorParamMaps)
.sortBy(-_._1)
paramsAndMetrics.foreach { case (metric, params) =>
println(metric)
println(params)
println()
} validatorModel.write.overwrite().save(path + "Titanic_gbtc") spark.stop()
}
} class TitanicChallenge(private val spark: SparkSession) extends Serializable { import spark.implicits._ //Cabin,用“U”填充null,并提取Cabin的首字母
def transCabin(df: Dataset[Row]): Dataset[Row] = {
df.na.fill("U", Seq("Cabin"))
.withColumn("Cabin", substring($"Cabin", 0, 1))
} //
def transTicket(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = { ////提取船票的号码,如“A/5 21171”中的21171
val medDF1 = df.withColumn("Ticket", split($"Ticket", " "))
.withColumn("Ticket", $"Ticket"(size($"Ticket").minus(1)))
.filter($"Ticket" =!= "LINE")//去掉某种特殊的船票 //对船票号进行分类,小于四位号码的为“1”,四位号码的以第一个数字开头,后面接上“0”,大于4位号码的,取前三个数字开头。如21171变为211
val ticketTransUdf = udf((ticket: String) => {
if (ticket.length < 4) {
"1"
} else if (ticket.length == 4){
ticket(0)+"0"
} else {
ticket.slice(0, 3)
}
})
val medDF2 = medDF1.withColumn("Ticket", ticketTransUdf($"Ticket")) //将数量小于等于5的类别统一归为“0”。先统计小于5的名单,然后用udf进行转换。
val filterList = medDF2.groupBy($"Ticket").count()
.filter($"count" <= 5)
.map(row => row.getString(0))
.collect.toList val filterList_bc = sc.broadcast(filterList) val ticketTransAdjustUdf = udf((subticket: String) => {
if (filterList_bc.value.contains(subticket)) "0"
else subticket
}) medDF2.withColumn("Ticket", ticketTransAdjustUdf($"Ticket"))
} //用“S”填充null
def transEmbarked(df: Dataset[Row]): Dataset[Row] = {
df.na.fill("S", Seq("Embarked"))
} def extractTitle(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = {
val regex = ".*, (.*?)\\..*" //对头衔进行归类
val titlesMap = Map(
"Capt"-> "Officer",
"Col"-> "Officer",
"Major"-> "Officer",
"Jonkheer"-> "Royalty",
"Don"-> "Royalty",
"Sir" -> "Royalty",
"Dr"-> "Officer",
"Rev"-> "Officer",
"the Countess"->"Royalty",
"Mme"-> "Mrs",
"Mlle"-> "Miss",
"Ms"-> "Mrs",
"Mr" -> "Mr",
"Mrs" -> "Mrs",
"Miss" -> "Miss",
"Master" -> "Master",
"Lady" -> "Royalty"
) val titlesMap_bc = sc.broadcast(titlesMap) df.withColumn("Title", regexp_extract(($"Name"), regex, 1))
.na.replace("Title", titlesMap_bc.value)
} //根据null age的records对应的Pclass和Name_final分组后的平均来填充缺失age。
// 首先,生成分组key,并获取分组后的平均年龄map。然后广播map,当Age为null时,用udf返回需要填充的值。
def transAge(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = {
val medDF = df.withColumn("Pclass_Title_key", concat($"Title", $"Pclass"))
val meanAgeMap = medDF.groupBy("Pclass_Title_key")
.mean("Age")
.map(row => (row.getString(0), row.getDouble(1)))
.collect().toMap val meanAgeMap_bc = sc.broadcast(meanAgeMap) val fillAgeUdf = udf((comb_key: String) => meanAgeMap_bc.value.getOrElse(comb_key, 0.0)) medDF.withColumn("Age", when($"Age".isNull, fillAgeUdf($"Pclass_Title_key")).otherwise($"Age"))
} //对Age进行分类
def categorizeAge(df: Dataset[Row]): Dataset[Row] = {
val ageBucketBorders = 0.0 +: (10.0 to 60.0 by 5.0).toArray :+ 150.0
val ageBucketer = new Bucketizer().setSplits(ageBucketBorders).setInputCol("Age").setOutputCol("Age_categorized")
ageBucketer.transform(df).drop("Pclass_Title_key")
} //将SibSp和Parch相加,得出同行人数
def createFellow(df: Dataset[Row]): Dataset[Row] = {
df.withColumn("fellow", $"SibSp" + $"Parch")
} //fellow_type, 对fellow进行分类。此处其实可以留到pipeline部分一次性完成。
def categorizeFellow(df: Dataset[Row]): Dataset[Row] = {
df.withColumn("fellow_type", when($"fellow" === 0, "Alone")
.when($"fellow" <= 3, "Small")
.otherwise("Large"))
} def extractFName(df: Dataset[Row]): Dataset[Row] = { //检查df是否有Survived和fellow列
if (!df.columns.contains("Survived") || !df.columns.contains("fellow")){
throw new IllegalArgumentException(
"""
|Check if the argument is a training set or if this training set contains column named \"fellow\"
""".stripMargin)
} //FName,提取家庭名称。例如:"Johnston, Miss. Catherine Helen ""Carrie""" 提取出Johnston
// 由于spark的读取csv时,如果有引号,读取就会出现多余的引号,所以除了split逗号,还要再split一次引号。
val medDF = df
.withColumn("FArray", split($"Name", ","))
.withColumn("FName", expr("FArray[0]"))
.withColumn("FArray", split($"FName", "\""))
.withColumn("FName", $"FArray"(size($"FArray").minus(1))) //family_type,分为三类,第一类是60岁以下女性遇难的家庭,第二类是18岁以上男性存活的家庭,第三类其他。
val femaleDiedFamily_filter = $"Sex" === "female" and $"Age" < 60 and $"Survived" === 0 and $"fellow" > 0 val maleSurvivedFamily_filter = $"Sex" === "male" and $"Age" >= 18 and $"Survived" === 1 and $"fellow" > 1 val resDF = medDF.withColumn("family_type", when(femaleDiedFamily_filter, 1)
.when(maleSurvivedFamily_filter, 2).otherwise(0)) //familyTable,家庭分类名单,用于后续test集的转化。此处用${FName}_${family_type}的形式保存。
resDF.filter($"family_type".isin(1,2))
.select(concat($"FName", lit("_"), $"family_type"))
.dropDuplicates()
.write.format("text").mode("overwrite").save("familyTable") //如果需要直接收集成Map的话,可用下面代码。
// 此代码先利用mapPartitions对各分块的数据进行聚合,降低直接调用count而使driver挂掉的风险。
//另外新建一个默认Set是为了防止某个partition并没有数据的情况(出现概率可能比较少),
// 从而使得Set的类型变为Set[_>:Tuple]而不能直接flatten // val familyMap = df10
// .filter($"family_type" === 1 || $"family_type" === 2)
// .select("FName", "family_type")
// .rdd
// .mapPartitions{iter => {
// if (!iter.isEmpty) {
// Iterator(iter.map(row => (row.getString(0), row.getInt(1))).toSet)}
// else Iterator(Set(("defualt", 9)))}
// }
// .collect()
// .flatten
// .toMap resDF
} //Fare。首先去掉缺失的(test集合中有一个,如果量多的话,也可以像Age那样通过头衔,年龄等因数来推断)
//然后对Fare进行分类
def transFare(df: Dataset[Row]): Dataset[Row] = { val medDF = df.na.drop("any", Seq("Fare"))
val fareBucketer = new QuantileDiscretizer()
.setInputCol("Fare")
.setOutputCol("Fare_categorized")
.setNumBuckets(4) fareBucketer.fit(medDF).transform(medDF)
} def index_onehot(df: Dataset[Row]): Tuple2[Dataset[Row], Array[String]] = {
val stringCols = Array("Sex","fellow_type", "Embarked", "Cabin", "Ticket", "Title")
val subOneHotCols = stringCols.map(cname => s"${cname}_index")
val index_transformers: Array[org.apache.spark.ml.PipelineStage] = stringCols.map(
cname => new StringIndexer()
.setInputCol(cname)
.setOutputCol(s"${cname}_index")
.setHandleInvalid("skip")
) val oneHotCols = subOneHotCols ++ Array("Pclass", "Age_categorized", "Fare_categorized", "family_type")
val vectorCols = oneHotCols.map(cname => s"${cname}_encoded")
val encode_transformers: Array[org.apache.spark.ml.PipelineStage] = oneHotCols.map(
cname => new OneHotEncoder()
.setInputCol(cname)
.setOutputCol(s"${cname}_encoded")
) val pipelineStage = index_transformers ++ encode_transformers
val index_onehot_pipeline = new Pipeline().setStages(pipelineStage)
val index_onehot_pipelineModel = index_onehot_pipeline.fit(df) val resDF = index_onehot_pipelineModel.transform(df).drop(stringCols:_*).drop(subOneHotCols:_*)
println(resDF.columns.size)
(resDF, vectorCols)
} def trainData(df: Dataset[Row], vectorCols: Array[String]): TrainValidationSplitModel = {
//separate and model pipeline,包含划分label和features,机器学习模型的pipeline
val vectorAssembler = new VectorAssembler()
.setInputCols(vectorCols)
.setOutputCol("features") val gbtc = new GBTClassifier()
.setLabelCol("Survived")
.setFeaturesCol("features")
.setPredictionCol("prediction") val pipeline = new Pipeline().setStages(Array(vectorAssembler, gbtc)) val paramGrid = new ParamGridBuilder()
.addGrid(gbtc.stepSize, Seq(0.1))
.addGrid(gbtc.maxDepth, Seq(5))
.addGrid(gbtc.maxIter, Seq(20))
.build() val multiclassEval = new MulticlassClassificationEvaluator()
.setLabelCol("Survived")
.setPredictionCol("prediction")
.setMetricName("accuracy") val tvs = new TrainValidationSplit()
.setTrainRatio(0.75)
.setEstimatorParamMaps(paramGrid)
.setEstimator(pipeline)
.setEvaluator(multiclassEval) tvs.fit(df)
}
}

基于Spark ML的Titanic Challenge (Top 6%)的更多相关文章

  1. 使用spark ml pipeline进行机器学习

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  2. spark ml 的例子

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  3. spark ml pipeline构建机器学习任务

    一.关于spark ml pipeline与机器学习一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的流 ...

  4. Spark ML Pipeline简介

    Spark ML Pipeline基于DataFrame构建了一套High-level API,我们可以使用MLPipeline构建机器学习应用,它能够将一个机器学习应用的多个处理过程组织起来,通过在 ...

  5. 基于Spark的电影推荐系统(推荐系统~4)

    第四部分-推荐系统-模型训练 本模块基于第3节 数据加工得到的训练集和测试集数据 做模型训练,最后得到一系列的模型,进而做 预测. 训练多个模型,取其中最好,即取RMSE(均方根误差)值最小的模型 说 ...

  6. 基于Spark ALS构建商品推荐引擎

    基于Spark ALS构建商品推荐引擎   一般来讲,推荐引擎试图对用户与某类物品之间的联系建模,其想法是预测人们可能喜好的物品并通过探索物品之间的联系来辅助这个过程,让用户能更快速.更准确的获得所需 ...

  7. 推荐系统那点事 —— 基于Spark MLlib的特征选择

    在机器学习中,一般都会按照下面几个步骤:特征提取.数据预处理.特征选择.模型训练.检验优化.那么特征的选择就很关键了,一般模型最后效果的好坏往往都是跟特征的选择有关系的,因为模型本身的参数并没有太多优 ...

  8. Spark ML下实现的多分类adaboost+naivebayes算法在文本分类上的应用

    1. Naive Bayes算法 朴素贝叶斯算法算是生成模型中一个最经典的分类算法之一了,常用的有Bernoulli和Multinomial两种.在文本分类上经常会用到这两种方法.在词袋模型中,对于一 ...

  9. Spark ML源码分析之一 设计框架解读

    本博客为作者原创,如需转载请注明参考           在深入理解Spark ML中的各类算法之前,先理一下整个库的设计框架,是非常有必要的,优秀的框架是对复杂问题的抽象和解剖,对这种抽象的学习本身 ...

随机推荐

  1. node里读取命令行参数

    一.process.env process.env属性返回一个包含用户环境信息的对象. 最常见的需求,前端需要根据不同的环境(dev,prd),来调用不同的后端接口.如果用webpack,是这么做的: ...

  2. Spring MVC起步(一)

    下图展示了请求使用Spring MVC所经历的所有站点. 在请求离开浏览器时1,会带有用户请求内容的信息,至少会包含请求的URL.但是还可能包含其他的信息,如用户提交的表单. DispatcherSe ...

  3. 针对mdadm的RAID1失效测试

    背景 对软RAID(mdadm)方式进行各个场景失效测试. 一.初始信息 内核版本: root@omv30:~# uname -a Linux omv30 4.18.0-0.bpo.1-amd64 # ...

  4. Linux - VMware和Centos安装

    目录 Linux - VMware和Centos安装 选择性 下载centos系统ISO镜像 安装虚拟机VMware虚拟机 1. 准备vmware软件 2. 解压软件包, 当前选择vm12 3. vm ...

  5. SLF4J和Logback和Log4j和Logging的区别与联系

    本文转载自:一个著名的日志系统是怎么设计出来的?(作者:刘欣) 前言 Java帝国在诞生之初就提供了集合.线程.IO.网络等常用功能,从C和C++领地那里吸引了大量程序员过来加盟,但是却有意无意地忽略 ...

  6. Redis 原子操作INCR

    The content below come from http://try.redis.io/ There is something special about INCR. Why do we pr ...

  7. sql 语句实现可用户名、邮箱、手机号登录系统

    select top 1 nid from Users where (userName collate Chinese_PRC_CS_AS=@userName or mobile collate Ch ...

  8. bupt summer training for 16 #1 ——简单题目

    D.What a Mess 给n个数,求其中能满足 a[i] % a[j] == 0 的数对之和 n = 1W,max_ai = 100W 不是很大,所以就直接筛就可以了 计算可得最高复杂度 < ...

  9. 清北学堂模拟赛d4t1 a

    分析:大模拟,没什么好说的.我在考场上犯了一个超级低级的错误:while (scanf("%s",s + 1)),导致了死循环,血的教训啊,以后要记住了. /* 1.没有发生改变, ...

  10. 开启mysql远程连接

    mysql默认只允许本地连接,也就是说,在安装完mysql后会存在两个root账户,他们的host分别是localhost和127.0.0.1 use mysql; update user set h ...