>>提君博客原创  http://www.cnblogs.com/tijun/  <<

假定线性拟合方程:

提君博客原创

变量 X是 i 个变量或者说属性 

参数 ai 是模型训练的目的就是计算出这些参数的值。

线性回归分析的整个过程可以简单描述为如下三个步骤:

  1. 寻找合适的预测函数,即上文中的 h(x)h(x) ,用来预测输入数据的判断结果。这个过程时非常关键的,需要对数据有一定的了解或分析,知道或者猜测预测函数的“大概”形式,比如是线性函数还是非线性函数,若是非线性的则无法用线性回归来得出高质量的结果。
  2. 构造一个Loss函数(损失函数),该函数表示预测的输出(h)与训练数据标签之间的偏差,可以是二者之间的差(h-y)或者是其他的形式(如平方差开方)。综合考虑所有训练数据的“损失”,将Loss求和或者求平均,记为 J(θ)J(θ) 函数,表示所有训练数据预测值与实际类别的偏差。
  3. 显然, J(θ)J(θ) 函数的值越小表示预测函数越准确(即h函数越准确),所以这一步需要做的是找到 J(θ)J(θ) 函数的最小值。找函数的最小值有不同的方法,Spark中采用的是梯度下降法(stochastic gradient descent, SGD)。

线性回归同样可以采用正则化手段,其主要目的就是防止过拟合。

当采用L1正则化时,则变成了Lasso Regresion;当采用L2正则化时,则变成了Ridge Regression;线性回归未采用正则化手段。通常来说,在训练模型时是建议采用正则化手段的,特别是在训练数据的量特别少的时候,若不采用正则化手段,过拟合现象会非常严重。L2正则化相比L1而言会更容易收敛(迭代次数少),但L1可以解决训练数据量小于维度的问题(也就是n元一次方程只有不到n个表达式,这种情况下是多解或无穷解的)。

提君博客原创

在spark中分三种回归:LinearRegression、Lasso和RidgeRegression(岭回归)

采用L1正则化时为Lasso回归(元素绝对值),采用L2时为RidgeRegression回归(元素平方),没有正则化时就是线性回归。

比如岭回归的损失函数: 

显然,损失函数值越小说明当前这条直线拟合效果越好>>提君博客原创  http://www.cnblogs.com/tijun/  <<
通常用梯度下降法 用来最小化损失值? 

spark中线性回归算法可使用的类包括LinearRegression、LassoWithSGD、RidgeRegressionWithSGD(SGD代表随机梯度下降法),

这几个类都有几个可以用来对算法调优的参数

  • numIterations 要迭代的次数
  • stepSize 梯度下降的步长(默认1.0)
  • intercept 是否给数据加上一个干扰特征或者偏差特征(默认:false)
  • regParam Lasso和ridge的正规参数(默认1.0)

下面是实例>>提君博客原创  http://www.cnblogs.com/tijun/  <<

训练集下载

训练集概况

-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541
...

数据格式:逗号之前为label;之后为8个特征值,以空格分隔。

代码

package com.ltt.spark.ml.example;

import org.apache.spark.api.java.*;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.GeneralizedLinearModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.mllib.regression.LassoWithSGD;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.mllib.regression.RidgeRegressionModel;
import org.apache.spark.mllib.regression.RidgeRegressionWithSGD; import java.util.Arrays; import org.apache.spark.SparkConf;
import scala.Tuple2; /**
*
* Title: LinearRegresionExample.java
* Description: 本地代码执行,机器学习之线性回归
* <br/>
* @author liutiti
* @created 2017年11月21日 下午4:03:45
*/
@SuppressWarnings("resource")
public class LinearRegresionExample { /**
*
* @discription 程序测试入口
* @author liutiti
* @created 2017年11月21日 上午4:03:45
* @param args
*/
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setAppName("LinearRegresion").setMaster("local[*]");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
//原始的数据-0.4307829,-1.63735562648104 -2.00621178480549 ...
JavaRDD<String> data = sc.textFile("E:\\spark-ml-data\\lpsa.txt"); //转换数据格式:把每一行原始的数据(num1,num2 num3 ...)转换成LabeledPoint(label, features)
JavaRDD<LabeledPoint> parsedData = data.filter(line -> { //过滤掉不符合的数据行
if(line.length() > 3)
return true;
return false;
}).map(line -> { //读取转换成LabeledPoint
String[] parts = line.split(","); //逗号分隔
double[] ds = Arrays.stream(parts[1].split(" ")) //空格分隔
.mapToDouble(Double::parseDouble)
.toArray();
return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(ds));
});
//rdd持久化内存中,后边反复使用,不必再从磁盘加载
parsedData.cache(); //设置迭代次数
int numIterations = 100;
//三种模型进行训练
LinearRegressionModel linearModel = LinearRegressionWithSGD.train(parsedData.rdd(), numIterations);
RidgeRegressionModel ridgeModel = RidgeRegressionWithSGD.train(parsedData.rdd(), numIterations);
LassoModel lassoModel = LassoWithSGD.train(parsedData.rdd(), numIterations);
//打印信息
print(parsedData, linearModel);
print(parsedData, ridgeModel);
print(parsedData, lassoModel); //预测一条新数据方法,8个特征值
double[] d = new double[]{1.0, 1.0, 2.0, 1.0, 3.0, -1.0, 1.0, -2.0};
Vector v = Vectors.dense(d);
System.out.println("Prediction of linear: "+linearModel.predict(v));
System.out.println("Prediction of ridge: "+ridgeModel.predict(v));
System.out.println("Prediction of lasso: "+lassoModel.predict(v)); // //保存模型
// model.save(sc.sc(),"myModelPath" );
// //加载模型
// LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); //关闭
sc.close();
} /**
*
* @discription 统一输出方法
* @author liutiti
* @created 2017年11月22日 上午10:00:27
* @param parsedData
* @param model
*/
public static void print(JavaRDD<LabeledPoint> parsedData, GeneralizedLinearModel model) {
JavaPairRDD<Double, Double> valuesAndPreds = parsedData.mapToPair(point -> {
double prediction = model.predict(point.features()); //用模型预测训练数据
return new Tuple2<>(point.label(), prediction);
});
//打印训练集中的真实值与相对应的预测值
valuesAndPreds.foreach((Tuple2<Double, Double> t) -> {
System.out.println("训练集真实值:"+t._1()+" ,预测值: "+t._2());
});
//计算预测值与实际值差值的平方值的均值
Double MSE = valuesAndPreds.mapToDouble((Tuple2<Double, Double> t) -> Math.pow(t._1() - t._2(), 2)).mean();
System.out.println(model.getClass().getName() + " training Mean Squared Error = " + MSE);
}
}

提君博客原创

>>提君博客原创  http://www.cnblogs.com/tijun/  <<

spark官方java api 文档

spark-MLlib之线性回归的更多相关文章

  1. Spark MLlib之线性回归源代码分析

    1.理论基础 线性回归(Linear Regression)问题属于监督学习(Supervised Learning)范畴,又称分类(Classification)或归纳学习(Inductive Le ...

  2. spark mllib 之线性回归

    public static void main(String[] args) { SparkConf sparkConf = new SparkConf() .setAppName("Reg ...

  3. Spark MLlib回归算法------线性回归、逻辑回归、SVM和ALS

    Spark MLlib回归算法------线性回归.逻辑回归.SVM和ALS 1.线性回归: (1)模型的建立: 回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多 ...

  4. Spark Mllib里如何生成KMeans的训练样本数据、生成线性回归的训练样本数据、生成逻辑回归的训练样本数据和其他数据生成

    不多说,直接上干货! 具体,见 Spark Mllib机器学习(算法.源码及实战详解)的第2章 Spark数据操作

  5. 《Spark MLlib机器学习实践》内容简介、目录

      http://product.dangdang.com/23829918.html Spark作为新兴的.应用范围最为广泛的大数据处理开源框架引起了广泛的关注,它吸引了大量程序设计和开发人员进行相 ...

  6. Spark入门实战系列--8.Spark MLlib(上)--机器学习及SparkMLlib简介

    [注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 .机器学习概念 1.1 机器学习的定义 在维基百科上对机器学习提出以下几种定义: l“机器学 ...

  7. Spark入门实战系列--8.Spark MLlib(下)--机器学习库SparkMLlib实战

    [注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 .MLlib实例 1.1 聚类实例 1.1.1 算法说明 聚类(Cluster analys ...

  8. Spark MLlib知识点学习整理

    MLlib的设计原理:把数据以RDD的形式表示,然后在分布式数据集上调用各种算法.MLlib就是RDD上一系列可供调用的函数的集合. 操作步骤: 1.用字符串RDD来表示信息. 2.运行MLlib中的 ...

  9. 推荐系统那点事 —— 基于Spark MLlib的特征选择

    在机器学习中,一般都会按照下面几个步骤:特征提取.数据预处理.特征选择.模型训练.检验优化.那么特征的选择就很关键了,一般模型最后效果的好坏往往都是跟特征的选择有关系的,因为模型本身的参数并没有太多优 ...

  10. Spark Mllib框架1

    1. 概述 1.1 功能 MLlib是Spark的机器学习(machine learing)库,其目标是使得机器学习的使用更加方便和简单,其具有如下功能: ML算法:常用的学习算法,包括分类.回归.聚 ...

随机推荐

  1. [PHP] yield沟通函数循环内外

    1.yield是函数内外,循环内外沟通用的 , 当你的函数需要返回一个大数组 , 循环的时候需要遍历这个大数组时 , 并且需要多次遍历这个函数的返回值 , 这个是有用的 2.当我也是只需要在一次循环中 ...

  2. Java学习笔记之——IO

    一. IO IO读写 流分类: 按照方向:输入流(读),输出流(写) 按照数据单位:字节流(传输时以字节为单位),字符流(传输时以字符为单位) 按照功能:节点流,过滤流 四个抽象类: InputStr ...

  3. 【学习笔记】tensorflow实现一个简单的线性回归

    目录 准备知识 Tensorflow运算API 梯度下降API 简单的线性回归的实现 建立事件文件 变量作用域 增加变量显示 模型的保存与加载 自定义命令行参数 准备知识 Tensorflow运算AP ...

  4. 史上最全python面试题详解(四)(附带详细答案(关注、持续更新))

    python高级进阶-网络编程和并发(?道题详解) 1.简述 OSI 七层协议. OSI是Open System Interconnection的缩写,意为开放式系统互联. OSI七层协议模型主要是: ...

  5. 深入理解 JavaScript 执行上下文和执行栈

    前言 如果你是一名 JavaScript 开发者,或者想要成为一名 JavaScript 开发者,那么你必须知道 JavaScript 程序内部的执行机制.执行上下文和执行栈是 JavaScript ...

  6. Web前端 页面功能——点击按钮返回顶部的实现方法

    1. 最简单的静态返回顶部,点击直接跳转页面顶部,常见于固定放置在页面底部返回顶部功能 方法一:用命名锚点击返回到顶部预设的id为top的元素 html代码 <a href="#top ...

  7. Sublime Text 快捷键列表

    Sublime Text 快捷键列表 快捷键按类型分列如下: 补充:1.快速的创建一个html页 :ctrl+n创建一个新的文件-->右下角选择文件类型-->输入英文"!&quo ...

  8. 十分钟(小时)学习pandas

    十分钟学习pandas 一.导语 这篇文章从pandas官网翻译:链接,而且也有很多网友翻译过,而我为什么没去看他们的,而是去官网自己艰难翻译呢? 毕竟这是一个学习的过程,别人写的不如自己写的记忆深刻 ...

  9. 一对多Excel自定义函数:SVLOOKUP

    语法规则 该函数的语法规则如下: SVLOOKUP(lookup_value,table_array,col_index_num,nth_appearance,unique_value) 参数 简单说 ...

  10. 如何开启红米手机4X的ROOT超级权限

    红米手机4X通过什么方法拥有了root权限?大家都清楚,Android机器有root权限,如果手机拥有了root相关权限,可以实现更强的功能,举个栗子大家公司的营销部门同事,使用大多数营销软件都需要在 ...