5.2.从数据中提取合适的特征

[root@demo1 ch05]# sed 1d train.tsv > train_noheader.tsv
[root@demo1 ch05]# ll
total 42920
-rw-r--r-- 1 root root 21972457 Jan 31 15:03 train_noheader.tsv
-rw-r--r-- 1 root root 21972916 Jan 31 15:00 train.tsv
[root@demo1 ch05]# hdfs dfs -mkdir /user/root/studio/MachineLearningWithSpark/ch05
[root@demo1 ch05]# hdfs dfs -put train_noheader.tsv /user/root/studio/MachineLearningWithSpark/ch05

[root@demo1 ch05]# spark-shell --master yarn

scala> val rawData = sc.textFile("/user/root/studio/MachineLearningWithSpark/ch05/train_noheader.tsv")
rawData: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[1] at textFile at <console>:27

scala> val records = rawData.map(line => line.split("\t"))
records: org.apache.spark.rdd.RDD[Array[String]] = MapPartitionsRDD[2] at map at <console>:29

scala> records.first()
res1: Array[String] = Array("http://www.bloomberg.com/news/2010-12-23/ibm-predicts-holographic-calls-air-breathing-batteries-by-2015.html", "4042", "{""title"":""IBM Sees Holographic Calls Air Breathing Batteries ibm sees holographic calls, air-breathing batteries"",""body"":""A sign stands outside the International Business Machines Corp IBM Almaden Research Center campus in San Jose California Photographer Tony Avelar Bloomberg Buildings stand at the International Business Machines Corp IBM Almaden Research Center campus in the Santa Teresa Hills of San Jose California Photographer Tony Avelar Bloomberg By 2015 your mobile phone will project a 3 D image of anyone who calls and your laptop will be powered by kinetic energy At least that s what International Business Machines Corp sees ...
scala> import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.regression.LabeledPoint

scala> import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vectors

scala> val data = records.map{ r =>
| val trimmed = r.map(_.replaceAll("\"",""))
| val label = trimmed(r.size - 1).toInt
| val features = trimmed.slice(4, r.size - 1).map(d => if (d == "?") 0.0 else d.toDouble)
| LabeledPoint(label, Vectors.dense(features))
| }
data: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[3] at map at <console>:33

5.3.训练分类模型

scala> import org.apache.spark.mllib.classification.LogisticRegressionWithSGD
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD

scala> import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.classification.SVMWithSGD

scala> import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.classification.NaiveBayes

scala> import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.DecisionTree

scala> import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.Algo

scala> import org.apache.spark.mllib.tree.impurity.Entropy
import org.apache.spark.mllib.tree.impurity.Entropy

scala> val numIterations = 10
numIterations: Int = 10

scala> val maxTreeDepth = 5
maxTreeDepth: Int = 5

scala> val lrModel = LogisticRegressionWithSGD.train(data, numIterations)
lrModel: org.apache.spark.mllib.classification.LogisticRegressionModel = org.apache.spark.mllib.classification.LogisticRegressionModel: intercept = 0.0, numFeatures = 22, numClasses = 2, threshold = 0.5

scala> val svmModel = SVMWithSGD.train(data, numIterations)
svmModel: org.apache.spark.mllib.classification.SVMModel = org.apache.spark.mllib.classification.SVMModel: intercept = 0.0, numFeatures = 22, numClasses = 2, threshold = 0.0

scala> val nbModel = NaiveBayes.train(nbData)
nbModel: org.apache.spark.mllib.classification.NaiveBayesModel = org.apache.spark.mllib.classification.NaiveBayesModel@42cf75c1

scala> val dtModel = DecisionTree.train(data, Algo.Classification, Entropy, maxTreeDepth)
dtModel: org.apache.spark.mllib.tree.model.DecisionTreeModel = DecisionTreeModel classifier of depth 5 with 61 nodes

5.4使用分类模型

scala> val dataPoint = data.first
dataPoint: org.apache.spark.mllib.regression.LabeledPoint = (0.0,[0.789131,2.055555556,0.676470588,0.205882353,0.047058824,0.023529412,0.443783175,0.0,0.0,0.09077381,0.0,0.245831182,0.003883495,1.0,1.0,24.0,0.0,5424.0,170.0,8.0,0.152941176,0.079129575])

scala> val prediction = lrModel.predict(dataPoint.features)
prediction: Double = 1.0

scala> val trueLabel = dataPoint.label
trueLabel: Double = 0.0

scala> val predictions = lrModel.predict(data.map(lp => lp.features))
predictions: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[99] at mapPartitions at GeneralizedLinearAlgorithm.scala:69

scala> predictions.take(5)
res3: Array[Double] = Array(1.0, 1.0, 1.0, 1.0, 1.0)

5.5.评估分类模型的性能

scala> val lrTotalCorrect = data.map { point =>
| if (lrModel.predict(point.features) == point.label) 1 else 0
| }.sum
lrTotalCorrect: Double = 3806.0

scala> val lrAccuracy = lrTotalCorrect / data.count
lrAccuracy: Double = 0.5146720757268425

scala> val svmTotalCorrect = data.map { point =>
| if (svmModel.predict(point.features) == point.label) 1 else 0
| }.sum
svmTotalCorrect: Double = 3806.0

scala> val svmAccuracy = svmTotalCorrect / data.count
svmAccuracy: Double = 0.5146720757268425

scala> val nbTotalCorrect = nbData.map { point =>
| if (nbModel.predict(point.features) == point.label) 1 else 0
| }.sum
nbTotalCorrect: Double = 4292.0

scala> val nbAccuracy = nbTotalCorrect / data.count
nbAccuracy: Double = 0.5803921568627451

scala> val dtTotalCorrect = data.map { point =>
| val score = dtModel.predict(point.features)
| val predicted = if (score > 0.5) 1 else 0
| if (predicted == point.label) 1 else 0
| }.sum
dtTotalCorrect: Double = 4794.0

scala> val dtAccuracy = dtTotalCorrect / data.count
dtAccuracy: Double = 0.6482758620689655

scala> import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

scala> val metrics = Seq(lrModel, svmModel).map { model =>
| val scoreAndLabels = data.map { point => (model.predict(point.features), point.label) }
| val metrics = new BinaryClassificationMetrics(scoreAndLabels)
| (model.getClass.getSimpleName, metrics.areaUnderPR, metrics.areaUnderROC)
| }
metrics: Seq[(String, Double, Double)] = List((LogisticRegressionModel,0.7567586293858841,0.5014181143280931), (SVMModel,0.7567586293858841,0.5014181143280931))

scala> val nbMetrics = Seq(nbModel).map { model =>
| val scoreAndLabels = nbData.map { point =>
| val score = model.predict(point.features)
| (if (score > 0.5) 1.0 else 0.0, point.label)
| }
| val metrics = new BinaryClassificationMetrics(scoreAndLabels)
| (model.getClass.getSimpleName, metrics.areaUnderPR, metrics.areaUnderROC)
| }
nbMetrics: Seq[(String, Double, Double)] = List((NaiveBayesModel,0.6808510815151734,0.5835585110136261))

scala> val dtMetrics = Seq(dtModel).map { model =>
| val scoreAndLabels = data.map { point =>
| val score = model.predict(point.features)
| (if (score > 0.5) 1.0 else 0.0, point.label)
| }
| val metrics = new BinaryClassificationMetrics(scoreAndLabels)
| (model.getClass.getSimpleName, metrics.areaUnderPR, metrics.areaUnderROC)
| }
dtMetrics: Seq[(String, Double, Double)] = List((DecisionTreeModel,0.7430805993331199,0.6488371887050935))

scala> val allMetrics = metrics ++ nbMetrics ++ dtMetrics
allMetrics: Seq[(String, Double, Double)] = List((LogisticRegressionModel,0.7567586293858841,0.5014181143280931), (SVMModel,0.7567586293858841,0.5014181143280931), (NaiveBayesModel,0.6808510815151734,0.5835585110136261), (DecisionTreeModel,0.7430805993331199,0.6488371887050935))

scala> allMetrics.foreach { case (m, pr, roc) =>
| println(f"$m, Area under PR: ${pr * 100.0}%2.4f%%, Area under ROC: ${roc * 100.0}%2.4f%%")
| }
LogisticRegressionModel, Area under PR: 75.6759%, Area under ROC: 50.1418%
SVMModel, Area under PR: 75.6759%, Area under ROC: 50.1418%
NaiveBayesModel, Area under PR: 68.0851%, Area under ROC: 58.3559%
DecisionTreeModel, Area under PR: 74.3081%, Area under ROC: 64.8837%

Spark机器学习读书笔记-CH05的更多相关文章

  1. Spark机器学习读书笔记-CH04

    [root@demo1 ch04]# spark-shell --master yarn --jars /root/studio/jblas-1.2.3.jar scala> val rawDa ...

  2. Spark机器学习读书笔记-CH03

    3.1.获取数据: wget http://files.grouplens.org/datasets/movielens/ml-100k.zip 3.2.探索与可视化数据: In [3]: user_ ...

  3. 视觉机器学习读书笔记--------BP学习

    反向传播算法(Back-Propagtion Algorithm)即BP学习属于监督式学习算法,是非常重要的一种人工神经网络学习方法,常被用来训练前馈型多层感知器神经网络. 一.BP学习原理 1.前馈 ...

  4. 视觉机器学习读书笔记--------SVM方法

    SVM是一种二类分类模型,有监督的统计学习方法,能够最小化经验误差和最大化几何边缘,被称为最大间隔分类器,可用于分类和回归分析.支持向量机的学习策略就是间隔最大化,可形式化为一个求解凸二次规划的问题, ...

  5. 机器学习读书笔记(一)k-近邻算法

    一.机器学习是什么 机器学习的英文名称叫Machine Learning,简称ML,该领域主要研究的是如何使计算机能够模拟人类的学习行为从而获得新的知识和技能,并且重新组织已学习到的知识和和技能,使之 ...

  6. 机器学习读书笔记(七)支持向量机之线性SVM

    一.SVM SVM的英文全称是Support Vector Machines,我们叫它支持向量机.支持向量机是我们用于分类的一种算法. 1 示例: 先用一个例子,来了解一下SVM 桌子上放了两种颜色的 ...

  7. 机器学习读书笔记(五)AdaBoost

    一.Boosting算法 .Boosting算法是一种把若干个分类器整合为一个分类器的方法,在boosting算法产生之前,还出现过两种比较重要的将多个分类器整合为一个分类器的方法,即boostrap ...

  8. 机器学习读书笔记(二)使用k-近邻算法改进约会网站的配对效果

    一.背景 海伦女士一直使用在线约会网站寻找适合自己的约会对象.尽管约会网站会推荐不同的任选,但她并不是喜欢每一个人.经过一番总结,她发现自己交往过的人可以进行如下分类 不喜欢的人 魅力一般的人 极具魅 ...

  9. 【Todo】【读书笔记】机器学习-周志华

    书籍位置: /Users/baidu/Documents/Data/Interview/机器学习-数据挖掘/<机器学习_周志华.pdf> 一共442页.能不能这个周末先囫囵吞枣看完呢.哈哈 ...

随机推荐

  1. 22.mongodb副本集集群

    软件版本64位:     $ wget https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-rhel62-3.2.0.tgz     mongo ...

  2. FAQ

    1.Baudrare and the speed of Byte. 2. Linux FS and Flash store. 3. SW's Coupling. 4. Protocol and Pro ...

  3. -/bin/sh: ./led: not found的解决办法

    环境介绍: 开发板:qq2440 交叉编译器:arm-linux-gcc 3.4.1 内核版本:2.6.13 一.针对该类问题从两个方面入手: 1.从权限出发,权限不够会出现此问题 2.从库文件出发, ...

  4. CSS与JavaScript的一些问题汇总

    通过最近的学习,总结了一些问题,可能总结得不够完善,但是好记性不如烂笔头,先记在这儿,后面看到更完整的回答,再进行修改. 1.事件流,如何阻止冒泡事件流:在点击一个按钮时,实则,按的父容器与按钮的父容 ...

  5. SpringMVC生成任意文件,访问链接即下载

    原理上讲就是返回的 ResponseEntity<byte[]> 形式的值就可以了 @RequestMapping("/api/watermark_download") ...

  6. iOS 开发之崩溃日志分析

    1. (js 与webview 交互崩溃)-[CFRunLoopTimer release]: message sent to deallocated instance 0x62398f80 I've ...

  7. Could not resolve placeholder 解决方案

    spring 配置加载properties文件的时候,报 Could not resolve placeholder 错误. 经过仔细查找,排除文件路径,文件类容错误的原因,经过查找相关资料,出现&q ...

  8. JQuery_表单选择器

    表单作为HTML 中一种特殊的元素,操作方法较为多样性和特殊性 开发者不但可以使用之前的常规选择器或过滤器,也可以使用jQuery 为表单专门提供的选择器和过滤器来准确的定位表单元素. 一.常规选择器 ...

  9. java中Array/List/Map/Object与Json互相转换详解

    http://blog.csdn.net/xiaomu709421487/article/details/51456705 JSON(JavaScript Object Notation): 是一种轻 ...

  10. java中的数据结构(集合|容器)

    对java中的数据结构做一个小小的个人总结,虽然还没有到研究透彻jdk源码的地步.首先.java中为何需要集合的出现?什么需求导致.我想对于面向对象来说,对象适用于描述任何事物,所以为了方便对于对象的 ...