线性回归和逻辑回归的实现大体一致,将其抽象出一个抽象类Regression,包含整体流程,其中有三个抽象函数,将在线性回归和逻辑回归中重写。

  将样本设为Sample类,其中采用数组作为特征的存储形式。

1. 样本类Sample

 public class Sample {

     double[] features;
int feaNum; // the number of sample's features
double value; // value of sample in regression
int label; // class of sample public Sample(int number) {
feaNum = number;
features = new double[feaNum];
} public void outSample() {
System.out.println("The sample's features are:");
for(int i = 0; i < feaNum; i++) {
System.out.print(features[i] + " ");
}
System.out.println();
System.out.println("The label is: " + label);
System.out.println("The value is: " + value);
}
}

2. 抽象类Regression

public abstract class Regression {

    double[] theta; //parameters
int paraNum; //the number of parameters
double rate; //learning rate
Sample[] sam; // samples
int samNum; // the number of samples
double th; // threshold value /**
* initialize the samples
* @param s : training set
* @param num : the number of training samples
*/
public void Initialize(Sample[] s, int num) {
samNum = num;
sam = new Sample[samNum];
for(int i = 0; i < samNum; i++) {
sam[i] = s[i];
}
} /**
* initialize all parameters
* @param para : theta
* @param learning_rate
* @param threshold
*/
public void setPara(double[] para, double learning_rate, double threshold) {
paraNum = para.length;
theta = para;
rate = learning_rate;
th = threshold;
} /**
* predicte the value of sample s
* @param s : prediction sample
* @return : predicted value
*/
public abstract double PreVal(Sample s); /**
* calculate the cost of all samples
* @return : the cost
*/
public abstract double CostFun(); /**
* update the theta
*/
public abstract void Update(); public void OutputTheta() {
System.out.println("The parameters are:");
for(int i = 0; i < paraNum; i++) {
System.out.print(theta[i] + " ");
}
System.out.println(CostFun());
}
}

3. 线性回归LinearRegression

public class LinearRegression extends Regression{

    public double PreVal(Sample s) {
double val = 0;
for(int i = 0; i < paraNum; i++) {
val += theta[i] * s.features[i];
}
return val;
} public double CostFun() {
double sum = 0;
for(int i = 0; i < samNum; i++) {
double d = PreVal(sam[i]) - sam[i].value;
sum += Math.pow(d, 2);
}
return sum / (2*samNum);
} public void Update() {
double former = 0; // the cost before update
double latter = CostFun(); // the cost after updatedouble[] p = new double[paraNum];
do {
former = latter;
//update theta
for(int i = 0; i < paraNum; i++) {
// for theta[i]
double d = 0;
for(int j = 0; j < samNum; j++) {
d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i];
}
p[i] -= (rate * d) / samNum;
}
theta = p;
latter = CostFun();

         if(former - latter < 0){
          System.out.println("α is larger!!!");
          break;
        }

      }while(former - latter > th);

    }

}

4. 逻辑回归LogisticRegression

public class LogisticRegression extends Regression{

    public double PreVal(Sample s) {
double val = 0;
for(int i = 0; i < paraNum; i++) {
val += theta[i] * s.features[i];
}
return 1/(1 + Math.pow(Math.E, -val));
} public double CostFun() {
double sum = 0;
for(int i = 0; i < samNum; i++) {
double p = PreVal(sam[i]);
double d = Math.log(p) * sam[i].label + (1 - sam[i].label) * Math.log(1 - p);
sum += d;
}
return -1 * (sum / samNum);
} public void Update() {
double former = 0; // the cost before update
double latter = CostFun(); // the cost after update
double d = 0;
double[] p = new double[paraNum];
do {
former = latter;
//update theta
for(int i = 0; i < paraNum; i++) {
// for theta[i]
double d = 0;
for(int j = 0; j < samNum; j++) {
d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i];
}
p[i] -= (rate * d) / samNum;
}
latter = CostFun();

         if(former - latter < 0){
          System.out.println("α is larger!!!");
          break;
         }

      }while(former - latter > th);

         theta = p;
}
}

5. 使用的线性回归样本

x0 x1 x2 x3 x4 y
1 2104 5 1 45 460
1 1416 3 2 40 232
1 1534 3 2 30 315
1 852 2 1 36 178
1 1254 3 3 45 321
1 987 2 2 35 241
1 1054 3 2 30 287
1 645 2 3 25 87
1 542 2 1 30 94
1 1065 3 1 25 241
1 2465 7 2 50 687
1 2410 6 1 45 654
1 1987 4 2 45 436
1 457 2 3 35 65
1 587 2 2 25 54
1 468 2 1 40 87
1 1354 3 1 35 215
1 1587 4 1 45 345
1 1789 4 2 35 325
1 2500 8 2 40 720

6. 线性回归测试

import java.io.IOException;
import java.io.RandomAccessFile; public class Test { public static void main(String[] args) throws IOException {
//read Sample.txt
Sample[] sam = new Sample[25];
int w = 0; long filePoint = 0;
String s;
RandomAccessFile file = new RandomAccessFile("resource//LinearSample.txt", "r");
long fileLength = file.length(); while(filePoint < fileLength) {
s = file.readLine();
//s --> sample
String[] sub = s.split(" ");
sam[w] = new Sample(sub.length - 1);
for(int i = 0; i < sub.length; i++) {
if(i == sub.length - 1) {
sam[w].value = Double.parseDouble(sub[i]);
}
else {
sam[w].features[i] = Double.parseDouble(sub[i]);
}
}//for
w++;
filePoint = file.getFilePointer();
}//while read file LinearRegression lr = new LinearRegression();
double[] para = {0,0,0,0,0};
double rate = 0.5;
double th = 0.001;
lr.Initialize(sam, w);
lr.setPara(para, rate, th);
lr.Update();
lr.OutputTheta();
} }

7. 使用的逻辑回归样本

x0 x1 x2 class
1 0.23 0.35 0
1 0.32 0.24 0
1 0.6 0.12 0
1 0.36 0.54 0
1 0.02 0.89 0
1 0.36 -0.12 0
1 -0.45 0.62 0
1 0.56 0.42 0
1 0.4 0.56 0
1 0.46 0.51 0
1 1.2 0.32 1
1 0.6 0.9 1
1 0.32 0.98 1
1 0.2 1.3 1
1 0.15 1.36 1
1 0.54 0.98 1
1 1.36 1.05 1
1 0.22 1.65 1
1 1.65 1.54 1
1 0.25 1.68 1

8. 逻辑回归测试

import java.io.IOException;
import java.io.RandomAccessFile; public class Test { public static void main(String[] args) throws IOException {
//read Sample.txt
Sample[] sam = new Sample[25];
int w = 0; long filePoint = 0;
String s;
RandomAccessFile file = new RandomAccessFile("resource//LogisticSample.txt", "r");
long fileLength = file.length(); while(filePoint < fileLength) {
s = file.readLine();
//s --> sample
String[] sub = s.split(" ");
sam[w] = new Sample(sub.length - 1);
for(int i = 0; i < sub.length; i++) {
if(i == sub.length - 1) {
sam[w].label = Integer.parseInt(sub[i]);
}
else {
sam[w].features[i] = Double.parseDouble(sub[i]);
}
}//for
//sam[w].outSample();
w++;
filePoint = file.getFilePointer();
}//while read file LogisticRegression lr = new LogisticRegression();
double[] para = {0,0,0};
double rate = 0.5;
double th = 0.001;
lr.Initialize(sam, w);
lr.setPara(para, rate, th);
lr.Update();
lr.OutputTheta();
} }

线性、逻辑回归的java实现的更多相关文章

  1. 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型

    MNIST 被喻为深度学习中的Hello World示例,由Yann LeCun等大神组织收集的一个手写数字的数据集,有60000个训练集和10000个验证集,是个非常适合初学者入门的训练集.这个网站 ...

  2. 逻辑回归代码demo

    程序所用文件:https://files.cnblogs.com/files/henuliulei/%E5%9B%9E%E5%BD%92%E5%88%86%E7%B1%BB%E6%95%B0%E6%8 ...

  3. PRML读书会第四章 Linear Models for Classification(贝叶斯marginalization、Fisher线性判别、感知机、概率生成和判别模型、逻辑回归)

    主讲人 planktonli planktonli(1027753147) 19:52:28 现在我们就开始讲第四章,第四章的内容是关于 线性分类模型,主要内容有四点:1) Fisher准则的分类,以 ...

  4. 逻辑回归&线性支持向量机

    代码: # -*- coding: utf-8 -*- """ Created on Tue Jul 17 10:13:20 2018 @author: zhen &qu ...

  5. 关于逻辑回归是否线性?sigmoid

    from :https://www.zhihu.com/question/29385169/answer/44177582 逻辑回归的模型引入了sigmoid函数映射,是非线性模型,但本质上又是一个线 ...

  6. 逻辑回归的相关问题及java实现

    本讲主要说下逻辑回归的相关问题和详细的实现方法 1. 什么是逻辑回归 逻辑回归是线性回归的一种,那么什么是回归,什么是线性回归 回归指的是公式已知,对公式中的未知參数进行预计,注意公式必须是已知的,否 ...

  7. 机器学习---三种线性算法的比较(线性回归,感知机,逻辑回归)(Machine Learning Linear Regression Perceptron Logistic Regression Comparison)

    最小二乘线性回归,感知机,逻辑回归的比较:   最小二乘线性回归 Least Squares Linear Regression 感知机 Perceptron 二分类逻辑回归 Binary Logis ...

  8. 通俗地说逻辑回归【Logistic regression】算法(一)

    在说逻辑回归前,还是得提一提他的兄弟,线性回归.在某些地方,逻辑回归算法和线性回归算法是类似的.但它和线性回归最大的不同在于,逻辑回归是作用是分类的. 还记得之前说的吗,线性回归其实就是求出一条拟合空 ...

  9. 逻辑回归 Logistic Regression

    逻辑回归(Logistic Regression)是广义线性回归的一种.逻辑回归是用来做分类任务的常用算法.分类任务的目标是找一个函数,把观测值匹配到相关的类和标签上.比如一个人有没有病,又因为噪声的 ...

随机推荐

  1. C# DataTable.Compute()用法

    DataTable.Compute()用法 2010-04-07 11:28 一.DataTable.Compute()方法說明如下 作用:          计算用来传递筛选条件的当前行上的给定表达 ...

  2. Asp.net MVC 控制器ActionResult的例子

    ActionResult 父类型 ViewResult View() 多重载应用 PartialViewResult PartialView() 部分试图 New EmptyResult()  空 如 ...

  3. (转)ResNet, AlexNet, VGG, Inception: Understanding various architectures of Convolutional Networks

    ResNet, AlexNet, VGG, Inception: Understanding various architectures of Convolutional Networks by KO ...

  4. TIM定时器的应用

    TIM定时器的应用   ①输入捕获的应用: 上一节,我已阐述TIM的输入捕获具体作用有两个(如下图):     对输入信号的测量:                                    ...

  5. Docker7之Docker overview

    Docker is an open platform for developing, shipping, and running applications. Docker enables you to ...

  6. Markdown 指南

    Markdown 是一种轻量级的「标记语言」,使用用特殊的 Markdown 文档处理器将 Markdown 语法翻译成预设的文档格式.标题大小等,一般用于展示时输出的是 HTML.这个教程可以让使用 ...

  7. CentOS7时间和日期的同步 (chrony和)

    CentOS 6版本,使用 hwclock CentOS 7版本,使用timedatectl 1.基本概念 1.1 GMT,UTC,CST,DST时间 世界标准时间 整个地球分为二十四时区,每个时区都 ...

  8. JavaScript运行机制详解

    JavaScript运行机制详解   var test = function(){ alert("test"); } var test2 = function(){ alert(& ...

  9. git项目,VSCode显示不同颜色块的含义

    一. 概念 代码里的左侧颜色标识: 红色,未加入版本控制; (刚clone到本地) 绿色,已经加入版本控制暂未提交; (新增部分) 蓝色,加入版本控制,已提交,有改动: (修改部分) 白色,加入版本控 ...

  10. Codeforces 769D k-Интересные пары чисел

    题目链接:http://codeforces.com/contest/769/problem/D 搜索题 考虑这些数的值域较小,直接${O(2^{k})}$次方枚举每个数字二进制位上是否改变,剪枝一下 ...