线性回归

  • 需求:从文件读取数据对,计算回归函数及系数
  • 实现1:commons.math的SimpleRegression,定义函数getData从文件读取数据返回SimpleRegression类

 1 import java.io.File;
2 import java.io.FileNotFoundException;
3 import java.util.Scanner;
4 import org.apache.commons.math3.stat.regression.SimpleRegression;
5
6 public class Example1 {
7 public static void main(String[] args) {
8 SimpleRegression sr = getData("data/Data1.dat");
9 double m = sr.getSlope();
10 double b = sr.getIntercept();
11 double r = sr.getR(); // correlation coefficient
12 double r2 = sr.getRSquare();
13 double sse = sr.getSumSquaredErrors();
14 double tss = sr.getTotalSumSquares();
15
16 System.out.printf("y = %.6fx + %.4f%n", m, b);
17 System.out.printf("r = %.6f%n", r);
18 System.out.printf("r2 = %.6f%n", r2);
19 System.out.printf("EV = %.5f%n", tss - sse);
20 System.out.printf("UV = %.4f%n", sse);
21 System.out.printf("TV = %.3f%n", tss);
22 }
23
24 public static SimpleRegression getData(String data) {
25 SimpleRegression sr = new SimpleRegression();
26 try {
27 Scanner fileScanner = new Scanner(new File(data));
28 fileScanner.nextLine(); // read past title line
29 int n = fileScanner.nextInt();
30 fileScanner.nextLine(); // read past line of labels
31 fileScanner.nextLine(); // read past line of labels
32 for (int i = 0; i < n; i++) {
33 String line = fileScanner.nextLine();
34 Scanner lineScanner = new Scanner(line).useDelimiter("\\t");
35 double x = lineScanner.nextDouble();
36 double y = lineScanner.nextDouble();
37 sr.addData(x, y);
38 }
39 } catch (FileNotFoundException e) {
40 System.err.println(e);
41 }
42 return sr;
43 }
44 }
  • 实现2:直接计算统计量

 1 import java.io.File;
2 import java.io.FileNotFoundException;
3 import java.util.Scanner;
4
5 public class Example2 {
6 private static double sX=0, sXX=0, sY=0, sYY=0, sXY=0;
7 private static int n=0;
8
9 public static void main(String[] args) {
10 getData("data/Data1.dat");
11 double m = (n*sXY - sX*sY)/(n*sXX - sX*sX);
12 double b = sY/n - m*sX/n;
13 double r2 = m*m*(n*sXX - sX*sX)/(n*sYY - sY*sY);
14 double r = Math.sqrt(r2);
15 double tv = sYY - sY*sY/n;
16 double mX = sX/n; // mean value of x
17 double ev = (sXX - 2*mX*sX + n*mX*mX)*m*m;
18 double uv = tv - ev;
19
20 System.out.printf("y = %.6fx + %.4f%n", m, b);
21 System.out.printf("r = %.6f%n", r);
22 System.out.printf("r2 = %.6f%n", r2);
23 System.out.printf("EV = %.5f%n", ev);
24 System.out.printf("UV = %.4f%n", uv);
25 System.out.printf("TV = %.3f%n", tv);
26 }
27
28 public static void getData(String data) {
29 try {
30 Scanner fileScanner = new Scanner(new File(data));
31 fileScanner.nextLine(); // read past title line
32 n = fileScanner.nextInt();
33 fileScanner.nextLine(); // read past line of labels
34 fileScanner.nextLine(); // read past line of labels
35 for (int i = 0; i < n; i++) {
36 String line = fileScanner.nextLine();
37 Scanner lineScanner = new Scanner(line).useDelimiter("\\t");
38 double x = lineScanner.nextDouble();
39 double y = lineScanner.nextDouble();
40 sX += x;
41 sXX += x*x;
42 sY += y;
43 sYY += y*y;
44 sXY += x*y;
45 }
46 } catch (FileNotFoundException e) {
47 System.err.println(e);
48 }
49 }
50 }

y = 0.882279x + 18.8739
r = 0.935222
r2 = 0.874641
EV = 1423.35676
UV = 204.0042
TV = 1627.361

  • 实现3:对辅助类进行实例化,并绘图

Example3.java

 1 import java.io.File;
2 import javax.swing.JFrame;
3
4 public class Example3 {
5 public static void main(String[] args) {
6 Data data = new Data(new File("data/Data1.dat"));
7 JFrame frame = new JFrame(data.getTitle());
8 frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
9 RegressionPanel panel = new RegressionPanel(data);
10 frame.add(panel);
11 frame.pack();
12 frame.setSize(500, 422);
13 frame.setResizable(false);
14 frame.setLocationRelativeTo(null); // center frame on screen
15 frame.setVisible(true);
16 }
17 }

Data.java

  1 import java.io.File;
2 import java.io.FileNotFoundException;
3 import java.util.Scanner;
4
5 public class Data {
6 private String title,xName, yName;
7 private int n;
8 private double[] x, y;
9 private double sX, sXX, sY, sYY, sXY, minX, minY, maxX, maxY;
10 private double meanX, meanY, slope, intercept, corrCoef;
11
12 public Data(File inputFile) {
13 try {
14 Scanner input = new Scanner(inputFile);
15 title = input.nextLine();
16 n = input.nextInt();
17 xName = input.next();
18 yName = input.next();
19 input.nextLine();
20 x = new double[n];
21 y = new double[n];
22 minX = minY = Double.POSITIVE_INFINITY;
23 maxX = maxY = Double.NEGATIVE_INFINITY;
24 for (int i = 0; i < n; i++) {
25 double xi = x[i] = input.nextDouble();
26 double yi = y[i] = input.nextDouble();
27 sX += xi;
28 sXX += xi*xi;
29 sY += yi;
30 sYY += yi*yi;
31 sXY += xi*yi;
32 minX = (xi < minX? xi: minX);
33 minY = (yi < minY? yi: minY);
34 maxX = (xi > maxX? xi: maxX);
35 maxY = (yi > maxY? yi: maxY);
36 }
37 meanX = sX/n;
38 meanY = sY/n;
39 slope = (n*sXY - sX*sY)/(n*sXX - sX*sX);
40 intercept = meanY - slope*meanX;
41 corrCoef = slope*Math.sqrt((n*sXX - sX*sX)/(n*sYY - sY*sY));
42 } catch (FileNotFoundException e) {
43 System.err.println(e);
44 }
45 }
46
47 public String getTitle() {
48 return title;
49 }
50
51 public String getXName() {
52 return xName;
53 }
54
55 public String getYName() {
56 return yName;
57 }
58
59 public int getN() {
60 return n;
61 }
62
63 public double[] getX() {
64 return x;
65 }
66
67 public double[] getY() {
68 return y;
69 }
70
71 public double getMeanX() {
72 return meanX;
73 }
74
75 public double getMeanY() {
76 return meanY;
77 }
78
79 public double getSlope() {
80 return slope;
81 }
82
83 public double getIntercept() {
84 return intercept;
85 }
86
87 public double getCorrCoef() {
88 return corrCoef;
89 }
90
91 public double[][] getTable() {
92 double[][] table = new double[n][2];
93 for (int i = 0; i < n; i++) {
94 table[i][0] = x[i];
95 table[i][1] = y[i];
96 }
97 return table;
98 }
99
100 public double getMinX() {
101 return minX;
102 }
103
104 public double getMinY() {
105 return minY;
106 }
107
108 public double getMaxX() {
109 return maxX;
110 }
111
112 public double getMaxY() {
113 return maxY;
114 }
115 }

RegressionPanal.java

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import javax.swing.JPanel; public class RegressionPanel extends JPanel {
private static final int WIDTH=500, HEIGHT=400, BUFFER=28, MARGIN=40;
private final Data data;
private double xMin, xMax, yMin, yMax, xRange, yRange, gWidth, gHeight;
private double slope, intercept; public RegressionPanel(Data data) {
this.data = data;
this.setSize(WIDTH, HEIGHT);
this.xMin = data.getMinX();
this.xMax = data.getMaxX();
this.yMin = data.getMinY();
this.yMax = data.getMaxY();
this.slope = data.getSlope();
this.intercept = data.getIntercept();
this.xRange = xMax - xMin;
this.yRange = yMax - yMin;
this.gWidth = WIDTH - 2*MARGIN - BUFFER;
this.gHeight = HEIGHT - 2*MARGIN - BUFFER;
setBackground(Color.WHITE);
} @Override
public void paintComponent(Graphics g) {
super.paintComponent(g);
Graphics2D g2 = (Graphics2D)g;
g2.setStroke(new BasicStroke(1));
drawGrid(g2);
drawPoints(g2, data.getX(), data.getY());
drawLine(g2);
} private void drawGrid(Graphics2D g2) {
g2.setStroke(new BasicStroke(1));
double xGd = Math.pow(10, Math.floor(Math.log10(xRange)));
int xd = dToI(xGd);
int x0 = dToI(xGd*Math.floor(xMin/xGd));
int xn = dToI(xGd*Math.ceil(xMax/xGd));
for (int xi = x0; xi <= xn; xi += xd) {
g2.setColor(Color.LIGHT_GRAY);
int p = f(xi);
g2.drawLine(p, 0, p, HEIGHT-18); // vertical lines
g2.setColor(Color.BLACK);
g2.drawString(""+xi, p-8, HEIGHT-4);
}
double yGd = Math.pow(10, Math.floor(Math.log10(yRange)));
int yd = dToI(yGd);
int y0 = dToI(xGd*Math.floor(xMin/yGd));
int yn = dToI(xGd*Math.ceil(yMax/yGd));
for (int yi = y0; yi <= yn; yi += yd) {
g2.setColor(Color.LIGHT_GRAY);
int q = g(yi);
g2.drawLine(BUFFER, q, WIDTH, q); // horizontal lines
g2.setColor(Color.LIGHT_GRAY);
g2.setColor(Color.BLACK);
g2.drawString((yi<100?" ":"")+yi, 2, q+5);
}
} private void drawPoints(Graphics2D g2, double[] x, double[] y) {
g2.setColor(Color.BLACK);
for (int i = 0; i < x.length; i++) {
int u = f(x[i]);
int v = g(y[i]);
g2.fillOval(u-3, v-3, 6, 6); // coordinates are at NW corners
}
} private void drawLine(Graphics2D g2) {
g2.setColor(Color.BLUE);
g2.setStroke(new BasicStroke(2));
int p0 = BUFFER;
int q0 = g(yLine(fInv(p0)));
int p1 = WIDTH;
int q1 = g(yLine(fInv(p1)));
g2.drawLine(p0, q0, p1, q1);
} private double yLine(double x) {
return slope*x + intercept;
} private int dToI(double x) {
return (int)Math.round(x);
} private int f(double x) {
return dToI((x - xMin)*gWidth/xRange) + BUFFER + MARGIN;
} private int g(double y) {
return dToI(gHeight - (y - yMin)*gHeight/yRange) + MARGIN;
} private double fInv(int p) {
return (p - BUFFER - MARGIN)*xRange/gWidth + xMin;
} private double gInv(int q) {
return yMin + (gHeight + MARGIN - q)*yRange/gHeight;
}
}

多项式回归

  • 需求:已知刹车速度和距离的数据,求解
  • 实现:最小二乘法,解方程组,LU分解

 1 import org.apache.commons.math3.linear.*;
2
3 public class Example4 {
4 static double[] x = {20, 30, 40, 50, 60, 70};
5 static double[] y = {52, 87, 136, 203, 290, 394};
6 static int n = y.length; // 6
7
8 public static void main(String[] args) {
9 double[][] a = new double[3][3];
10 double[] w = new double[3];
11 deriveNormalEquations(a, w);
12 printNormalEquations(a, w);
13 double[] b = solveNormalEquations(a, w);
14 printResults(b);
15 }
16
17 public static void deriveNormalEquations(double[][] a, double[] w) {
18 for (int i = 0; i < n; i++) {
19 double xi = x[i];
20 double yi = y[i];
21 a[0][0] = n;
22 a[0][1] = a[1][0] += xi;
23 a[0][2] = a[1][1] = a[2][0] += xi*xi;
24 a[1][2] = a[2][1] += xi*xi*xi;
25 a[2][2] += xi*xi*xi*xi;
26 w[0] += yi;
27 w[1] += xi*yi;
28 w[2] += xi*xi*yi;
29 }
30 }
31
32 public static void printNormalEquations(double[][] a, double[] w) {
33 for (int i = 0; i < 3; i++) {
34 System.out.printf("%8.0fb0 + %6.0fb1 + %8.0fb2 = %7.0f%n",
35 a[i][0], a[i][1], a[i][2], w[i]);
36 }
37 }
38
39 /* Solves the matrix equation a*b = w for b[], representing a[]
40 as RealMatrix m and b[] as RealVector v:
41 */
42 private static double[] solveNormalEquations(double[][] a, double[] w) {
43 RealMatrix m = new Array2DRowRealMatrix(a, false);
44 LUDecomposition lud = new LUDecomposition(m);
45 DecompositionSolver solver = lud.getSolver();
46 RealVector v = new ArrayRealVector(w, false);
47 return solver.solve(v).toArray();
48 }
49
50 private static void printResults(double[] b) {
51 System.out.printf("f(t) = %.2f + %.3ft + %.5ft^2%n", b[0], b[1], b[2]);
52 System.out.printf("f(55) = %.1f%n", f(55, b));
53 }
54
55 private static double f(double t, double[] b) {
56 return b[0] + b[1]*t + b[2]*t*t;
57 }
58 }

6b0 + 270b1 + 13900b2 = 1162
270b0 + 13900b1 + 783000b2 = 64220
13900b0 + 783000b1 + 46750000b2 = 3798800
f(t) = 40.73 + -1.170t + 0.08875t^2
f(55) = 244.8

多元线性回归

  • 需求:变量y依赖于多个变量
  • 实现:直接求解或通过Apache Commons

Example5.java

 1 import org.apache.commons.math3.linear.*;
2
3 public class Example5 {
4 static double[] x = {10, 9, 12, 10, 9, 10, 8, 11};
5 static double[] y = {59, 57, 61, 52, 48, 55, 51, 62};
6 static double[] z = {71, 68, 76, 56, 57, 77, 55, 67};
7 static int n = z.length; // 8
8
9 public static void main(String[] args) {
10 double[][] a = new double[3][3];
11 double[] w = new double[3];
12 deriveNormalEquations(a, w);
13 printNormalEquations(a, w);
14 double[] b = solveNormalEquations(a, w);
15 printResults(b);
16 }
17
18 public static void deriveNormalEquations(double[][] a, double[] w) {
19 for (int i = 0; i < n; i++) {
20 double xi = x[i];
21 double yi = y[i];
22 double zi = z[i];
23 a[0][0] = n;
24 a[0][1] = a[1][0] += xi;
25 a[0][2] = a[2][0] += yi;
26 a[1][1] += xi*xi;
27 a[1][2] = a[2][1] += xi*yi;
28 a[2][2] += yi*yi;
29 w[0] += zi;
30 w[1] += xi*zi;
31 w[2] += yi*zi;
32 }
33 }
34
35 public static void printNormalEquations(double[][] a, double[] w) {
36 for (int i = 0; i < 3; i++) {
37 System.out.printf("%6.0fx0 + %4.0fx1 + %5.0fx2 = %5.0f%n",
38 a[i][0], a[i][1], a[i][2], w[i]);
39 }
40 }
41
42 private static double[] solveNormalEquations(double[][] a, double[] w) {
43 RealMatrix m = new Array2DRowRealMatrix(a, false);
44 LUDecomposition lud = new LUDecomposition(m);
45 DecompositionSolver solver = lud.getSolver();
46 RealVector v = new ArrayRealVector(w, false);
47 return solver.solve(v).toArray();
48 }
49
50 private static void printResults(double[] b) {
51 System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]);
52 System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b));
53 System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b));
54 System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b));
55 }
56
57 private static double f(double s, double t, double[] b) {
58 return b[0] + b[1]*s + b[2]*t;
59 }
60 }

Example6.java

 1 import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
2
3 public class Example6 {
4 static double[][] x = { {10, 59}, {9, 57}, {12, 61}, {10, 52}, {9, 48},
5 {10, 55}, {8, 51}, {11, 62} };
6 static double[] y = {71, 68, 76, 56, 57, 77, 55, 67};
7
8 public static void main(String[] args) {
9 OLSMultipleLinearRegression mlr = new OLSMultipleLinearRegression();
10 mlr.newSampleData(y, x);
11 double[] b = mlr.estimateRegressionParameters();
12 printResults(b);
13 }
14
15 private static void printResults(double[] b) {
16 System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]);
17 System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b));
18 System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b));
19 System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b));
20 }
21
22 private static double f(double s, double t, double[] b) {
23 return b[0] + b[1]*s + b[2]*t;
24 }
25 }

8x0 + 79x1 + 445x2 = 527
79x0 + 791x1 + 4427x2 = 5254
445x0 + 4427x1 + 24929x2 = 29543
f(s, t) = -5.75 + 1.55s + 1.01t
f(10, 59) = 69.5
f(9, 57) = 65.9
f(11, 64) = 76.1

[Java] 数据分析 -- 回归分析的更多相关文章

  1. [Java] 数据分析 -- 大数据

    单词计数 需求:输入小说文本,输出每个单词出现的次数 实现:分map.combine.reduce三个阶段实现 1 /* Data Analysis with Java 2 * John R. Hub ...

  2. [Java] 数据分析 -- NoSQL数据库

    MongoDB概念:与关系型数据库对应 database(数据库):数据库 collection(集合):表 document(文档):行 field(域):列/字段 注意事项 文档是一组键值(key ...

  3. [Java]数据分析--聚类

    距离度量 需求:计算两点间的欧几里得距离.曼哈顿距离.切比雪夫距离.堪培拉距离 实现:利用commons.math3库相应函数 1 import org.apache.commons.math3.ml ...

  4. [Java] 数据分析--分类

    ID3算法 思路:分类算法的输入为训练集,输出为对数据进行分类的函数.ID3算法为分类函数生成分类树 需求:对水果训练集的一个维度(是否甜)进行预测 实现:决策树,熵函数,ID3,weka库 J48类 ...

  5. [Java] 数据分析--统计

    二项分布 需求:5个四面体筛子,筛子三面绿色,一面红色,模拟1000000次,统计每次试验红色落地筛子个数的分布 实现:用循环实现5个筛子和1000000次试验,定义函数numRedDown模拟5个筛 ...

  6. [Java]数据分析--数据可视化

    时间序列 需求:将一组字符顺序添加到时间序列中 实现:定义时间序列类TimeSeries,包含静态类Entry表示序列类中的各项,以及add,get,iterator,entry方法 TimeSeri ...

  7. [Java] 数据分析--数据预处理

    数据结构 键-值对:HashMap 1 import java.io.File; 2 import java.io.FileNotFoundException; 3 import java.util. ...

  8. Spark案例分析

    一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...

  9. 一元线性回归分析及java实现

    http://blog.csdn.net/hwwn2009/article/details/38414911 一元线性回归分析及java实现 2014-08-07 11:02 1072人阅读 评论(0 ...

随机推荐

  1. 面试高频题:说一说对Spring和SpringMvc父子容器的理解?

    引言 以前写了几篇关于SpringBoot的文章<面试高频题:springBoot自动装配的原理你能说出来吗>.<保姆级教程,手把手教你实现一个SpringBoot的starter& ...

  2. 【分布式】SpringCloud(3)--Eureka服务注册与发现

    1.Eureka概述 1.1.什么是Eureka Eureka是Netflix的一个子模块.基于REST的服务,用于定位服务,以实现云端中间层服务发现和故障转移. 只需要使用服务的标识符,就可以访问到 ...

  3. PAT (Basic Level) Practice (中文) 1050 螺旋矩阵 (25 分) 凌宸1642

    PAT (Basic Level) Practice (中文) 1050 螺旋矩阵 (25 分) 目录 PAT (Basic Level) Practice (中文) 1050 螺旋矩阵 (25 分) ...

  4. 定制开发——GitHub 热点速览 v.21.15

    作者:HelloGitHub-小鱼干 自定义 或者说 定制 是本周 GitHub 热点的最佳写照.比如,lipgloss 这个项目,可以让你自己定义终端样式,五彩斑斓的黑终端来一个.接着,是 Appl ...

  5. CSS3新增了哪些新特性

    一.是什么 css,即层叠样式表(Cascading Style Sheets)的简称,是一种标记语言,由浏览器解释执行用来使页面变得更为美观 css3是css的最新标准,是向后兼容的,CSS1/2的 ...

  6. Leedcode算法专题训练(动态规划)

    递归和动态规划都是将原问题拆成多个子问题然后求解,他们之间最本质的区别是,动态规划保存了子问题的解,避免重复计算. 斐波那契数列 1. 爬楼梯 70. Climbing Stairs (Easy) L ...

  7. centos7.4 卸载python2.7.5安装python3.6.3版本

    CentOS 中默认安装了 2.7的Python,为了使用新版 python,可以对旧版本进行升级.但是由于很多基本的命令.软件包都依赖旧版本,比如:yum等.所以,在更新 Python 时,建议不要 ...

  8. JavaFX获取屏幕尺寸

    1 awt Dimension screenSize = Toolkit.getDefaultToolkit().getScreenSize(); double width = screenSize. ...

  9. Ball

    玉 図のように二股に分かれている容器があります.1 から 10 までの番号が付けられた10 個の玉を容器の開口部 A から落とし.左の筒 B か右の筒 C に玉を入れます.板 D は支点 E を中心に ...

  10. 从苏宁电器到卡巴斯基第15篇:我在苏宁电器当营业员 VII

    我们苹果的倒班制度 当年我在苏宁的时候,实行的是单休制度,而且只能选择在周一到周五其中的某一天,因为周六周日顾客比较多,是不允许休息的.尽管是单休,但并不表示我们在上班的时候每天都要完完整整地上八小时 ...