package kaggle

import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionWithSGD, NaiveBayes, SVMWithSGD}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.Statistics /**
* Created by mi on 17-5-23.
*/ object Titanic { def main(args: Array[String]) { // val sparkSession = SparkSession.builder.
// master("local")
// .appName("spark session example")
// .getOrCreate()
// val rawData = sparkSession.read.csv("/home/mi/下载/kaggle/Titanic/nohead-train.csv")
// val d = rawData.map{p => p.asInstanceOf[person]}
// d.show() val conf = new SparkConf().setAppName("WordCount").setMaster("local")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc) //屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF) // 读取数据
val df = sqlContext.load("com.databricks.spark.csv", Map("path" -> "/home/mi/下载/kaggle/Titanic/train.csv", "header" -> "true")) // 分析年龄数据
val ageAnalysis = df.rdd.filter(d => d(5) != null).map { d =>
val age = d(5).toString.toDouble
Vectors.dense(age)
}
val ageMean = Statistics.colStats(ageAnalysis).mean(0)
val ageMax = Statistics.colStats(ageAnalysis).max(0)
val ageMin = Statistics.colStats(ageAnalysis).min(0)
val ageDiff = ageMax - ageMin // 分析船票价格数据
val fareAnalysis = df.rdd.filter(d => d(9) != null).map { d =>
val fare = d(9).toString.toDouble
Vectors.dense(fare)
}
val fareMean = Statistics.colStats(fareAnalysis).mean(0)
val fareMax = Statistics.colStats(fareAnalysis).max(0)
val fareMin = Statistics.colStats(fareAnalysis).min(0)
val fareDiff = fareMax - fareMin // 数据预处理
val trainData = df.rdd.map { d =>
val label = d(1).toString.toInt
val sex = d(4) match {
case "male" => 0.0
case "female" => 1.0
}
val age = d(5) match {
case null => (ageMean - ageMin) / ageDiff
case _ => (d(5).toString().toDouble - ageMin) / ageDiff
}
val fare = d(9) match {
case null => (fareMean - fareMin) / fareDiff
case _ => (d(9).toString().toDouble - fareMin) / fareDiff
} LabeledPoint(label, Vectors.dense(sex, age, fare))
} // 切分数据集和测试集
val Array(trainingData, testData) = trainData.randomSplit(Array(0.8, 0.2)) // 训练数据
val numIterations = 8
val lrModel = new LogisticRegressionWithLBFGS().setNumClasses(2).run(trainingData)
// val svmModel = SVMWithSGD.train(trainingData, numIterations) val nbTotalCorrect = testData.map { point =>
if (lrModel.predict(point.features) == point.label) 1 else 0
}.sum
val nbAccuracy = nbTotalCorrect / testData.count println("SVM模型正确率:" + nbAccuracy) // 预测
// 读取数据
val testdf = sqlContext.load("com.databricks.spark.csv", Map("path" -> "/home/mi/下载/kaggle/Titanic/test.csv", "header" -> "true")) // 分析测试集年龄数据
val ageTestAnalysis = testdf.rdd.filter(d => d(4) != null).map { d =>
val age = d(4).toString.toDouble
Vectors.dense(age)
}
val ageTestMean = Statistics.colStats(ageTestAnalysis).mean(0)
val ageTestMax = Statistics.colStats(ageTestAnalysis).max(0)
val ageTestMin = Statistics.colStats(ageTestAnalysis).min(0)
val ageTestDiff = ageTestMax - ageTestMin // 分析船票价格数据
val fareTestAnalysis = testdf.rdd.filter(d => d(8) != null).map { d =>
val fare = d(8).toString.toDouble
Vectors.dense(fare)
}
val fareTestMean = Statistics.colStats(fareTestAnalysis).mean(0)
val fareTestMax = Statistics.colStats(fareTestAnalysis).max(0)
val fareTestMin = Statistics.colStats(fareTestAnalysis).min(0)
val fareTestDiff = fareTestMax - fareTestMin // 数据预处理
val data = testdf.rdd.map { d =>
val sex = d(3) match {
case "male" => 0.0
case "female" => 1.0
}
val age = d(4) match {
case null => (ageTestMean - ageTestMin) / ageTestDiff
case _ => (d(4).toString().toDouble - ageTestMin) / ageTestDiff
}
val fare = d(8) match {
case null => (fareTestMean - fareTestMin) / fareTestDiff
case _ => (d(8).toString().toDouble - fareTestMin) / fareTestDiff
} Vectors.dense(sex, age, fare)
} val predictions = lrModel.predict(data).map(p => p.toInt)
// 保存预测结果
predictions.coalesce(1).saveAsTextFile("file:///home/mi/下载/kaggle/Titanic/test_predict")
}
}

Spark学习笔记——泰坦尼克生还预测的更多相关文章

  1. Spark学习笔记之SparkRDD

    Spark学习笔记之SparkRDD 一.   基本概念 RDD(resilient distributed datasets)弹性分布式数据集. 来自于两方面 ①   内存集合和外部存储系统 ②   ...

  2. spark学习笔记总结-spark入门资料精化

    Spark学习笔记 Spark简介 spark 可以很容易和yarn结合,直接调用HDFS.Hbase上面的数据,和hadoop结合.配置很容易. spark发展迅猛,框架比hadoop更加灵活实用. ...

  3. Spark学习笔记2(spark所需环境配置

    Spark学习笔记2 配置spark所需环境 1.首先先把本地的maven的压缩包解压到本地文件夹中,安装好本地的maven客户端程序,版本没有什么要求 不需要最新版的maven客户端. 解压完成之后 ...

  4. Spark学习笔记3(IDEA编写scala代码并打包上传集群运行)

    Spark学习笔记3 IDEA编写scala代码并打包上传集群运行 我们在IDEA上的maven项目已经搭建完成了,现在可以写一个简单的spark代码并且打成jar包 上传至集群,来检验一下我们的sp ...

  5. Spark学习笔记-GraphX-1

    Spark学习笔记-GraphX-1 标签: SparkGraphGraphX图计算 2014-09-29 13:04 2339人阅读 评论(0) 收藏 举报  分类: Spark(8)  版权声明: ...

  6. Spark学习笔记3——RDD(下)

    目录 Spark学习笔记3--RDD(下) 向Spark传递函数 通过匿名内部类 通过具名类传递 通过带参数的 Java 函数类传递 通过 lambda 表达式传递(仅限于 Java 8 及以上) 常 ...

  7. Spark学习笔记0——简单了解和技术架构

    目录 Spark学习笔记0--简单了解和技术架构 什么是Spark 技术架构和软件栈 Spark Core Spark SQL Spark Streaming MLlib GraphX 集群管理器 受 ...

  8. Spark学习笔记2——RDD(上)

    目录 Spark学习笔记2--RDD(上) RDD是什么? 例子 创建 RDD 并行化方式 读取外部数据集方式 RDD 操作 转化操作 行动操作 惰性求值 Spark学习笔记2--RDD(上) 笔记摘 ...

  9. Spark学习笔记1——第一个Spark程序:单词数统计

    Spark学习笔记1--第一个Spark程序:单词数统计 笔记摘抄自 [美] Holden Karau 等著的<Spark快速大数据分析> 添加依赖 通过 Maven 添加 Spark-c ...

随机推荐

  1. [模板][P3377]左偏树

    Description: 一开始有N个小根堆,每个堆包含且仅包含一个数.接下来需要支持两种操作: 操作1: 1 x y 将第x个数和第y个数所在的小根堆合并(若第x或第y个数已经被删除或第x和第y个数 ...

  2. [BJOI2014]大融合

    Description 给你一个n个点的森林,要求支持m个操作: 1.连接两个点 x,y 2.询问若断掉 x,y这条边,两点所在联通块乘积的大小 Hint: \(n,m<=10^5\) Solu ...

  3. 2016年3月4日Android实习笔记

    1.让水平LinearLayout中的两个子元素分别居左和居右 在LinearLayout中有两个子元素,LinearLayout的orientation是horizontal.需要让第一个元素居左, ...

  4. python 发送邮件脚本

    一.该脚本适合在 linux 中做邮件发送测试用,只需要填写好 发送账号和密码以及发送人即可,然后使用  python ./filename.py (当前目录下)即可.如果发送出错,会将错误详情抛出来 ...

  5. 连接mysql 出现:java.sql.SQLException: Unable to load authentication plugin 'caching_sha2_password'.

    数据测试的时候出现: 网上查资料说的是mysql5.x 版本和 8.x版本的区别: 5.7版本是:default_authentication_plugin=mysql_native_password ...

  6. Android典型界面设计(3)——访网易新闻实现双导航tab切换

    一.问题描述 双导航tab切换(底部区块+区域内头部导航),实现方案底部区域使用FragmentTabHost+Fragment, 区域内头部导航使用ViewPager+Fragment,可在之前博客 ...

  7. What's the difference between ConcurrentHashMap and Collections.synchronizedMap(Map)?

    来自:http://stackoverflow.com/questions/510632/whats-the-difference-between-concurrenthashmap-and-coll ...

  8. 【Zookeeper】源码分析之服务器(四)之FollowerZooKeeperServer

    一.前言 前面分析了LeaderZooKeeperServer,接着分析FollowerZooKeeperServer. 二.FollowerZooKeeperServer源码分析 2.1 类的继承关 ...

  9. zabbix 中文乱码的处理

    一.乱码原因 查看cpu负载,中文乱码如下 这个问题是由于zabbix的web端没有中文字库,我们最需要把中文字库加上即可 二.解决zabbix乱码方法 2.1 上传字体文件到zabbix中 找到本地 ...

  10. [3] MQTT,mosquitto,Eclipse Paho---怎样使用 Eclipse Paho MQTT工具来发送订阅MQTT消息?

    在上两节,笔者主要介绍了 MQTT,mosquitto,Eclipse Paho的基本概念已经怎样安装mosquitto. 在这个章节我们就来看看怎样用 Eclipse Paho MQTT工具来发送接 ...