线性回归和逻辑回归的实现大体一致,将其抽象出一个抽象类Regression,包含整体流程,其中有三个抽象函数,将在线性回归和逻辑回归中重写。

  将样本设为Sample类,其中采用数组作为特征的存储形式。

1. 样本类Sample

 public class Sample {

     double[] features;
int feaNum; // the number of sample's features
double value; // value of sample in regression
int label; // class of sample public Sample(int number) {
feaNum = number;
features = new double[feaNum];
} public void outSample() {
System.out.println("The sample's features are:");
for(int i = 0; i < feaNum; i++) {
System.out.print(features[i] + " ");
}
System.out.println();
System.out.println("The label is: " + label);
System.out.println("The value is: " + value);
}
}

2. 抽象类Regression

public abstract class Regression {

    double[] theta; //parameters
int paraNum; //the number of parameters
double rate; //learning rate
Sample[] sam; // samples
int samNum; // the number of samples
double th; // threshold value /**
* initialize the samples
* @param s : training set
* @param num : the number of training samples
*/
public void Initialize(Sample[] s, int num) {
samNum = num;
sam = new Sample[samNum];
for(int i = 0; i < samNum; i++) {
sam[i] = s[i];
}
} /**
* initialize all parameters
* @param para : theta
* @param learning_rate
* @param threshold
*/
public void setPara(double[] para, double learning_rate, double threshold) {
paraNum = para.length;
theta = para;
rate = learning_rate;
th = threshold;
} /**
* predicte the value of sample s
* @param s : prediction sample
* @return : predicted value
*/
public abstract double PreVal(Sample s); /**
* calculate the cost of all samples
* @return : the cost
*/
public abstract double CostFun(); /**
* update the theta
*/
public abstract void Update(); public void OutputTheta() {
System.out.println("The parameters are:");
for(int i = 0; i < paraNum; i++) {
System.out.print(theta[i] + " ");
}
System.out.println(CostFun());
}
}

3. 线性回归LinearRegression

public class LinearRegression extends Regression{

    public double PreVal(Sample s) {
double val = 0;
for(int i = 0; i < paraNum; i++) {
val += theta[i] * s.features[i];
}
return val;
} public double CostFun() {
double sum = 0;
for(int i = 0; i < samNum; i++) {
double d = PreVal(sam[i]) - sam[i].value;
sum += Math.pow(d, 2);
}
return sum / (2*samNum);
} public void Update() {
double former = 0; // the cost before update
double latter = CostFun(); // the cost after updatedouble[] p = new double[paraNum];
do {
former = latter;
//update theta
for(int i = 0; i < paraNum; i++) {
// for theta[i]
double d = 0;
for(int j = 0; j < samNum; j++) {
d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i];
}
p[i] -= (rate * d) / samNum;
}
theta = p;
latter = CostFun();

         if(former - latter < 0){
          System.out.println("α is larger!!!");
          break;
        }

      }while(former - latter > th);

    }

}

4. 逻辑回归LogisticRegression

public class LogisticRegression extends Regression{

    public double PreVal(Sample s) {
double val = 0;
for(int i = 0; i < paraNum; i++) {
val += theta[i] * s.features[i];
}
return 1/(1 + Math.pow(Math.E, -val));
} public double CostFun() {
double sum = 0;
for(int i = 0; i < samNum; i++) {
double p = PreVal(sam[i]);
double d = Math.log(p) * sam[i].label + (1 - sam[i].label) * Math.log(1 - p);
sum += d;
}
return -1 * (sum / samNum);
} public void Update() {
double former = 0; // the cost before update
double latter = CostFun(); // the cost after update
double d = 0;
double[] p = new double[paraNum];
do {
former = latter;
//update theta
for(int i = 0; i < paraNum; i++) {
// for theta[i]
double d = 0;
for(int j = 0; j < samNum; j++) {
d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i];
}
p[i] -= (rate * d) / samNum;
}
latter = CostFun();

         if(former - latter < 0){
          System.out.println("α is larger!!!");
          break;
         }

      }while(former - latter > th);

         theta = p;
}
}

5. 使用的线性回归样本

x0 x1 x2 x3 x4 y
1 2104 5 1 45 460
1 1416 3 2 40 232
1 1534 3 2 30 315
1 852 2 1 36 178
1 1254 3 3 45 321
1 987 2 2 35 241
1 1054 3 2 30 287
1 645 2 3 25 87
1 542 2 1 30 94
1 1065 3 1 25 241
1 2465 7 2 50 687
1 2410 6 1 45 654
1 1987 4 2 45 436
1 457 2 3 35 65
1 587 2 2 25 54
1 468 2 1 40 87
1 1354 3 1 35 215
1 1587 4 1 45 345
1 1789 4 2 35 325
1 2500 8 2 40 720

6. 线性回归测试

import java.io.IOException;
import java.io.RandomAccessFile; public class Test { public static void main(String[] args) throws IOException {
//read Sample.txt
Sample[] sam = new Sample[25];
int w = 0; long filePoint = 0;
String s;
RandomAccessFile file = new RandomAccessFile("resource//LinearSample.txt", "r");
long fileLength = file.length(); while(filePoint < fileLength) {
s = file.readLine();
//s --> sample
String[] sub = s.split(" ");
sam[w] = new Sample(sub.length - 1);
for(int i = 0; i < sub.length; i++) {
if(i == sub.length - 1) {
sam[w].value = Double.parseDouble(sub[i]);
}
else {
sam[w].features[i] = Double.parseDouble(sub[i]);
}
}//for
w++;
filePoint = file.getFilePointer();
}//while read file LinearRegression lr = new LinearRegression();
double[] para = {0,0,0,0,0};
double rate = 0.5;
double th = 0.001;
lr.Initialize(sam, w);
lr.setPara(para, rate, th);
lr.Update();
lr.OutputTheta();
} }

7. 使用的逻辑回归样本

x0 x1 x2 class
1 0.23 0.35 0
1 0.32 0.24 0
1 0.6 0.12 0
1 0.36 0.54 0
1 0.02 0.89 0
1 0.36 -0.12 0
1 -0.45 0.62 0
1 0.56 0.42 0
1 0.4 0.56 0
1 0.46 0.51 0
1 1.2 0.32 1
1 0.6 0.9 1
1 0.32 0.98 1
1 0.2 1.3 1
1 0.15 1.36 1
1 0.54 0.98 1
1 1.36 1.05 1
1 0.22 1.65 1
1 1.65 1.54 1
1 0.25 1.68 1

8. 逻辑回归测试

import java.io.IOException;
import java.io.RandomAccessFile; public class Test { public static void main(String[] args) throws IOException {
//read Sample.txt
Sample[] sam = new Sample[25];
int w = 0; long filePoint = 0;
String s;
RandomAccessFile file = new RandomAccessFile("resource//LogisticSample.txt", "r");
long fileLength = file.length(); while(filePoint < fileLength) {
s = file.readLine();
//s --> sample
String[] sub = s.split(" ");
sam[w] = new Sample(sub.length - 1);
for(int i = 0; i < sub.length; i++) {
if(i == sub.length - 1) {
sam[w].label = Integer.parseInt(sub[i]);
}
else {
sam[w].features[i] = Double.parseDouble(sub[i]);
}
}//for
//sam[w].outSample();
w++;
filePoint = file.getFilePointer();
}//while read file LogisticRegression lr = new LogisticRegression();
double[] para = {0,0,0};
double rate = 0.5;
double th = 0.001;
lr.Initialize(sam, w);
lr.setPara(para, rate, th);
lr.Update();
lr.OutputTheta();
} }

线性、逻辑回归的java实现的更多相关文章

  1. 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型

    MNIST 被喻为深度学习中的Hello World示例,由Yann LeCun等大神组织收集的一个手写数字的数据集,有60000个训练集和10000个验证集,是个非常适合初学者入门的训练集.这个网站 ...

  2. 逻辑回归代码demo

    程序所用文件:https://files.cnblogs.com/files/henuliulei/%E5%9B%9E%E5%BD%92%E5%88%86%E7%B1%BB%E6%95%B0%E6%8 ...

  3. PRML读书会第四章 Linear Models for Classification(贝叶斯marginalization、Fisher线性判别、感知机、概率生成和判别模型、逻辑回归)

    主讲人 planktonli planktonli(1027753147) 19:52:28 现在我们就开始讲第四章,第四章的内容是关于 线性分类模型,主要内容有四点:1) Fisher准则的分类,以 ...

  4. 逻辑回归&线性支持向量机

    代码: # -*- coding: utf-8 -*- """ Created on Tue Jul 17 10:13:20 2018 @author: zhen &qu ...

  5. 关于逻辑回归是否线性?sigmoid

    from :https://www.zhihu.com/question/29385169/answer/44177582 逻辑回归的模型引入了sigmoid函数映射,是非线性模型,但本质上又是一个线 ...

  6. 逻辑回归的相关问题及java实现

    本讲主要说下逻辑回归的相关问题和详细的实现方法 1. 什么是逻辑回归 逻辑回归是线性回归的一种,那么什么是回归,什么是线性回归 回归指的是公式已知,对公式中的未知參数进行预计,注意公式必须是已知的,否 ...

  7. 机器学习---三种线性算法的比较(线性回归,感知机,逻辑回归)(Machine Learning Linear Regression Perceptron Logistic Regression Comparison)

    最小二乘线性回归,感知机,逻辑回归的比较:   最小二乘线性回归 Least Squares Linear Regression 感知机 Perceptron 二分类逻辑回归 Binary Logis ...

  8. 通俗地说逻辑回归【Logistic regression】算法(一)

    在说逻辑回归前,还是得提一提他的兄弟,线性回归.在某些地方,逻辑回归算法和线性回归算法是类似的.但它和线性回归最大的不同在于,逻辑回归是作用是分类的. 还记得之前说的吗,线性回归其实就是求出一条拟合空 ...

  9. 逻辑回归 Logistic Regression

    逻辑回归(Logistic Regression)是广义线性回归的一种.逻辑回归是用来做分类任务的常用算法.分类任务的目标是找一个函数,把观测值匹配到相关的类和标签上.比如一个人有没有病,又因为噪声的 ...

随机推荐

  1. (转) The Incredible PyTorch

    转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...

  2. Bytomd 助记词恢复密钥体验指南

    比原项目仓库: Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom 背景知识 ...

  3. 【C#】可空类型 NullAble<T>

    在实际编写代码时候 ,  会遇到很多场景, 需要将值置成空, 比如发货日期, 有可能是没有. 在没有可空类型之前, 程序都是用 魔值, 即为一个minValue或者常量, 来代表这个值为空, 也有用一 ...

  4. 在mybatis中resultMap与resultType的区别

    MyBatis中在查询进行select映射的时候,返回类型可以用resultType,也可以用resultMapresultType是直接表示返回类型的,而resultMap则是对外部ResultMa ...

  5. 17秋 SDN课程 第三次上机作业

    SDN 第三次上机作业 1.创建拓扑 2.利用OVS命令下发流表,实现vlan功能 3.利用OVS命令查看流表 s1: s2: 4.验证性测试 5.Wireshark 抓包验证

  6. SAP S/4 HANA

    通常说到SAP,指的都是SAP Business Suite/R3(ECC) 这款产品. 那么SAP S/4 HANA与SAP R3究竟有什么不同呢? 简单地说,S/4 HANA是下一代的R/3和SA ...

  7. java笔试总结

    1. Java的IO操作中有面向字节(Byte)和面向字符(Character)两种方式.面向字节的操作为以8位为单位对二进制的数据进行操作,对数据不进行转换,这些类都是InputStream和Out ...

  8. Mysql 函数使用记录(二)——ELT()、FIELD()、IFNULL()

    昨天在对一业务修改的过程中想到用DECODE()来实现效果,转眼发现目前使用的是Mysql库,经过查阅,最终用ELT().FIELD().IFNULL()函数来实现需求.现对其做一个记录. 语法: E ...

  9. 蚂蚁金服“定损宝”现身AI顶级会议NeurIPS

    小蚂蚁说: 长期以来,车险定损(通过现场拍摄定损照片确定车辆损失,以作为保险公司理赔的依据)是车险理赔中最为重要的操作环节.以往传统保险公司的车险处理流程,一般为报案.现场查勘.提交理赔材料.审核.最 ...

  10. Educational Codeforces Round 25 E. Minimal Labels 拓扑排序+逆向建图

    E. Minimal Labels time limit per test 1 second memory limit per test 256 megabytes input standard in ...