简介

支持向量机SVM是一种二分类模型。它的基本模型是定义在特征空间上的间隔最大的线性分类器。支持向量机学习方法包含3种模型:线性可分支持向量机、线性支持向量机及非线性支持向量机。当训练数据线性可分时,通过硬间隔最大化,学习一个线性的分类器,即线性可分支持向量机;当训练数据近似线性可分时,通过软间隔最大化,也学习一个线性的分类器,即线性支持向量机;当训练数据线性不可分时,通过使用核技巧及软间隔最大化,学习非线性支持向量机。线性支持向量机支持L1和L2的正则化变型。关于正则化,可以参见http://spark.apache.org/docs/1.6.2/mllib-linear-methods.html#regularizers

基本原理

​ 支持向量机,因其英文名为support vector machine,故一般简称SVM。SVM从线性可分情况下的最优分类面发展而来。最优分类面就是要求分类线不但能将两类正确分开(训练错误率为0),且使分类间隔最大。SVM考虑寻找一个满足分类要求的超平面,并且使训练集中的点距离分类面尽可能的远,也就是寻找一个分类面使它两侧的空白区域(margin)最大。这两类样本中离分类面最近,且平行于最优分类面的超平面上的点,就叫做支持向量(下图中红色的点)。

![此处输入图片的描述][1]

假设超平面可描述为:

![此处输入图片的描述][2]
其分类间隔等于
![此处输入图片的描述][3]

其学习策略是使数据间的间隔最大化,最终可转化为一个凸二次规划问题的求解。

分类器的损失函数(hinge loss铰链损失)如下所示:

![此处输入图片的描述][4]

默认情况下,线性SVM是用L2 正则化来训练的,但也支持L1正则化。在这种情况下,这个问题就变成了一个线性规划。

​ 线性SVM算法输出一个SVM模型。给定一个新的数据点,比如说

![此处输入图片的描述][5]

,这个模型就会根据

![此处输入图片的描述][6]

的值来进行预测。默认情况下,如果

![此处输入图片的描述][7]

,则输出预测结果为正(因为我们想要损失函数最小,如果预测为负,则会导致损失函数大于1),反之则预测为负。

示例代码

下面的例子具体介绍了如何读入一个数据集,然后用SVM对训练数据进行训练,然后用训练得到的模型对测试集进行预测,并计算错误率。以iris数据集(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

1. 导入需要的包:

首先,我们导入需要的包:

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vectors,Vector}
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

2. 读取数据:

​ 首先,读取文本文件;然后,通过map将每行的数据用“,”隔开,在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类。把这里我们用LabeledPoint来存储标签列和特征列。LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。所以,我们把莺尾花的分类进行了一下改变,”Iris-setosa”对应分类0,”Iris-versicolor”对应分类1,其余对应分类2;然后获取莺尾花的4个特征,存储在Vector中。

scala> val data = sc.textFile("/home/hadoop/iris.data")
data: org.apache.spark.rdd.RDD[String] = /home/hadoop/iris.data MapPartitionsRDD[1] at textFile at <console>:24 scala> val parsedData = data.map { line =>
| val parts = line.split(',')
| LabeledPoint(
| if(parts(4)=="Iris-setosa") 0.toDouble
| else if (parts(4) =="Iris-versicolor") 1.toDouble
| else 2.toDouble,
| Vectors.dense(parts(0).toDouble,parts(1).toDouble,
| parts(2).toDouble,parts(3).toDouble)
| )
| } parsedData: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[2] at map at <console>:31

把数据文件放到HDFS上去。

./bin/hdfs dfs -mkdir -p /home/hadoop
./bin/hdfs dfs -put iris.data /home/hadoop

3. 构建模型

因为SVM只支持2分类,所以我们要进行一下数据抽取,这里我们通过filter过滤掉第2类的数据,只选取第0类和第1类的数据。然后,我们把数据集划分成两部分,其中训练集占60%,测试集占40%:

scala> val splits = parsedData.filter { point => point.label != 2 }.randomSplit(Array(0.6, 0.4), seed = 11L)
scala> val training = splits(0).cache()
scala> val test = splits(1)
test: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] =
MapPartitionsRDD[5] at randomSplit at <console>:32

接下来,通过训练集构建模型SVMWithSGD。这里的SGD即著名的随机梯度下降算法(Stochastic Gradient Descent)。设置迭代次数为1000,除此之外还有stepSize(迭代步伐大小),regParam(regularization正则化控制参数),miniBatchFraction(每次迭代参与计算的样本比例),initialWeights(weight向量初始值)等参数可以进行设置。

scala> val numIterations = 1000
scala> val model = SVMWithSGD.train(training, numIterations)
2018-04-22 06:08:43 WARN BLAS:61 - Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
2018-04-22 06:08:43 WARN BLAS:61 - Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
model: org.apache.spark.mllib.classification.SVMModel = org.apache.spark.mllib.classification.SVMModel: intercept = 0.0, numFeatures = 4, numClasses = 2, threshold = 0.0

4.模型评估

接下来,我们清除默认阈值,这样会输出原始的预测评分,即带有确信度的结果。

scala> model.clearThreshold()
scala> val scoreAndLabels = test.map { point =>
| val score = model.predict(point.features)
| (score, point.label)
| }
scala> scoreAndLabels.foreach(println)
(-3.0127314882950778,0.0)
(-2.4596261094505403,0.0)
(-2.64505513159329,0.0)
(-3.503342620026854,0.0)
(-2.717199557755541,0.0)
(-2.6779191149350754,0.0)
... ...

那如果设置了阈值,则会把大于阈值的结果当成正预测,小于阈值的结果当成负预测。

scala> model.setThreshold(0.0)
scala> scoreAndLabels.foreach(println)
(0.0,0.0)
(0.0,0.0)
(0.0,0.0)
(0.0,0.0)
(0.0,0.0)
(0.0,0.0)
... ...

​ 最后,我们构建评估矩阵,把模型预测的准确性打印出来:

scala> val metrics = new BinaryClassificationMetrics(scoreAndLabels)
scala> val auROC = metrics.areaUnderROC()
auROC: Double = 1.0
scala> println("Area under ROC = " + auROC)
Area under ROC = 1.0

其中, SVMWithSGD.train() 方法默认的通过把正则化参数设为1来执行来范数。如果我们想配置这个算法,可以通过创建一个新的 SVMWithSGD对象然后调用他的setter方法来进行重新配置。下面这个例子,我们构建了一个正则化参数为0.1的L1正则化SVM方法 ,然后迭代这个训练算法2000次。

import org.apache.spark.mllib.optimization.L1Updater
scala> val svmAlg = new SVMWithSGD()
svmAlg: org.apache.spark.mllib.classification.SVMWithSGD = org.apache.spark.mlli
b.classification.SVMWithSGD@475774a9 scala> svmAlg.optimizer.
| setNumIterations(2000).
| setRegParam(0.1).
| setUpdater(new L1Updater) scala> val modelL1 = svmAlg.run(training)
modelL1: org.apache.spark.mllib.classification.SVMModel = org.apache.spark.mllib
.classification.SVMModel: intercept = 0.0, numFeatures = 4, numClasses = 2, threshold = 0.0

大数据-10-Spark入门之支持向量机SVM分类器的更多相关文章

  1. 【互动问答分享】第10期决胜云计算大数据时代Spark亚太研究院公益大讲堂

    “决胜云计算大数据时代” Spark亚太研究院100期公益大讲堂 [第10期互动问答分享] Q1:Spark on Yarn的运行方式是什么? Spark on Yarn的运行方式有两种:Client ...

  2. 大数据:Hadoop入门

    大数据:Hadoop入门 一:什么是大数据 什么是大数据: (1.)大数据是指在一定时间内无法用常规软件对其内容进行抓取,管理和处理的数据集合,简而言之就是数据量非常大,大到无法用常规工具进行处理,如 ...

  3. 【互动问答分享】第15期决胜云计算大数据时代Spark亚太研究院公益大讲堂

    "决胜云计算大数据时代" Spark亚太研究院100期公益大讲堂 [第15期互动问答分享] Q1:AppClient和worker.master之间的关系是什么? AppClien ...

  4. Kaggle大数据竞赛平台入门

    Kaggle大数据竞赛平台入门 大数据竞赛平台,国内主要是天池大数据竞赛和DataCastle,国外主要就是Kaggle.Kaggle是一个数据挖掘的竞赛平台,网站为:https://www.kagg ...

  5. 【互动问答分享】第13期决胜云计算大数据时代Spark亚太研究院公益大讲堂

    “决胜云计算大数据时代” Spark亚太研究院100期公益大讲堂 [第13期互动问答分享] Q1:tachyon+spark框架现在有很多大公司在使用吧? Yahoo!已经在长期大规模使用: 国内也有 ...

  6. 【互动问答分享】第8期决胜云计算大数据时代Spark亚太研究院公益大讲堂

    “决胜云计算大数据时代” Spark亚太研究院100期公益大讲堂 [第8期互动问答分享] Q1:spark线上用什么版本好? 建议从最低使用的Spark 1.0.0版本,Spark在1.0.0开始核心 ...

  7. 【互动问答分享】第7期决胜云计算大数据时代Spark亚太研究院公益大讲堂

    “决胜云计算大数据时代” Spark亚太研究院100期公益大讲堂 [第7期互动问答分享] Q1:Spark中的RDD到底是什么? RDD是Spark的核心抽象,可以把RDD看做“分布式函数编程语言”. ...

  8. 【互动问答分享】第6期决胜云计算大数据时代Spark亚太研究院公益大讲堂

    “决胜云计算大数据时代” Spark亚太研究院100期公益大讲堂 [第6期互动问答分享] Q1:spark streaming 可以不同数据流 join吗? Spark Streaming不同的数据流 ...

  9. 大数据技术Hadoop入门理论系列之一----hadoop生态圈介绍

    Technorati 标记: hadoop,生态圈,ecosystem,yarn,spark,入门 1. hadoop 生态概况 Hadoop是一个由Apache基金会所开发的分布式系统基础架构. 用 ...

随机推荐

  1. 如何使a标签打开新页面并阻止刷新当前页面

    错误: HTML中,使用href属性时,当前页面和新页面均跳转到URL指定的页面,即当前页面也刷新: <li id='goToBack'><a href='**.action' ta ...

  2. docker 系列之 docker安装

    Docker支持以下的CentOS版本 CentOS 7 (64-bit) CentOS 6.5 (64-bit) 或更高的版本 前提条件 目前,CentOS 仅发行版本中的内核支持 Docker. ...

  3. 【转】SQLServer汉字转全拼音函数

    USE Test go IF OBJECT_ID('Fn_GetQuanPin','Fn') IS NOT NULL DROP FUNCTION fn_GetQuanPin go create fun ...

  4. C#实体对象序列化成Json并让字段的首字母小写的两种解决方法

    引言:最近在工作中遇到与某些API对接的post的数据需要将对象的字段首字母小写.解决办法有两种:第一种:使用对象的字段属性设置JsonProperty来实现(不推荐,因为需要手动的修改每个字段的属性 ...

  5. python字符串内建函数

  6. 微信小程序生成二维码

    微信小程序客户端生成二维码的方法, 请参考这里: https://github.com/demi520/wxapp-qrcode  代码片段如下: const QR = require(". ...

  7. python 首次安装 报错

    最近python很火,想在空余时间学习一波,但是安装完Python后运行发现居然报错了,错误代码是0xc000007b,于是通过往上查找发现是因为首次安装Python缺乏VC++库的原因 错误提示如下 ...

  8. php ajax bootstrap多文件上传图片预览,ajax上传文件

    <form enctype="multipart/form-data" id="upForm"> <label class="btn ...

  9. Spring Boot + Spring Cloud 实现权限管理系统 (集成 Shiro 框架)

    Apache Shiro 优势特点 它是一个功能强大.灵活的,优秀开源的安全框架. 它可以处理身份验证.授权.企业会话管理和加密. 它易于使用和理解,相比Spring Security入门门槛低. 主 ...

  10. JavaScript创建对象(三)——原型模式

    在JavaScript创建对象(二)——构造函数模式中提到,构造函数模式存在相同功能的函数定义多次的问题.本篇文章就来讨论一下该问题的解决方案——原型模式. 首先我们来看下什么是原型.我们在创建一个函 ...