Logistic Regression和Linear Regression实现起来几乎是一样的,有疑问的同学可以看一看Andrew N.g在coursera上的machine learning课程。

我下面会给出相应的代码(代码写的其实很糟糕),只是一个简单的实现,效率不高,收敛速度比较慢,对于下面这样一个简单的数据:

0.0 0
0.1 0
0.7 1
1.0 0
1.1 0
1.3 0
1.4 1
1.7 1
2.1 1
2.2 1

都需要很多很多次迭代,而在weka中几乎是瞬间就训练好了(http://www.cnblogs.com/wzm-xu/p/4139831.html)

所以呢。。。。。我也不知道该说什么了。

LogisticRegression.java

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.lang.Math;

public class LogisticRegression {

    private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
    private int row;//训练数据  行数
    private int column;//训练数据 列数

    private double [] theta;//参数theta

    private double alpha;//训练步长
    private int iteration;//迭代次数

    public LogisticRegression(String fileName)
    {
        int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
        int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数

        trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
        this.row=rowoffile;
        this.column=columnoffile+1;

        this.alpha = 0.001;//步长默认为0.001
        this.iteration=100000;//迭代次数默认为 100000

        theta = new double [column-1];
        initialize_theta();

        loadTrainDataFromFile(fileName,rowoffile,columnoffile);
    }
    public LogisticRegression(String fileName,double alpha,int iteration)
    {
        int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
        int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数

        trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
        this.row=rowoffile;
        this.column=columnoffile+1;

        this.alpha = alpha;
        this.iteration=iteration;

        theta = new double [column-1];
        initialize_theta();

        loadTrainDataFromFile(fileName,rowoffile,columnoffile);
    }

    private int getRowNumber(String fileName)
    {
        int count =0;
        File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            while ( reader.readLine() != null)
                count++;
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return count;

    }

    private int getColumnNumber(String fileName)
    {
        int count =0;
        File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = reader.readLine();

            if(tempString!=null)
            {
                String [] temp = tempString.split(" ");
                for(String s : temp)
                    if(!s.equals("") && s!=null)
                        count++;
            }

            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return count;
    }

    private void initialize_theta()//将theta各个参数全部初始化为1.0
    {
        for(int i=0;i<theta.length;i++)
            theta[i]=1.0;
    }

    public void trainTheta()
    {
        int count = 0;
        int iteration = this.iteration;
        while( (count++) < iteration)
        {
            System.out.print(count);
            printTheta();
            //对每个theta i 求 偏导数
            double [] partial_derivative = compute_partial_derivative();//偏导数
                //更新每个theta
            for(int i =0; i< theta.length;i++)
                theta[i]-= alpha * partial_derivative[i];
        }
    }

    private double [] compute_partial_derivative()
    {
        double [] partial_derivative = new double[theta.length];
        for(int j =0;j<theta.length;j++)//遍历,对每个theta求偏导数
        {
            partial_derivative[j]= compute_partial_derivative_for_theta(j);//对 theta j 求 偏导
        }
        return partial_derivative;
    }
    private double compute_partial_derivative_for_theta(int j)
    {
        double sum=0.0;
        for(int i=0;i<row;i++)//遍历 每一行数据
        {
            sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);
        }
        return sum/row;
    }
    private double h_theta_x_i_minus_y_i_times_x_j_i(int i,int j)
    {
        double[] oneRow = getRow(i);//取一行数据,前面是feature,最后一个是y
        double result = 0.0;

        result += h_thera_x_i(oneRow);
        result-=oneRow[oneRow.length-1];
        result*=oneRow[j];

        return result;
    }
    private double h_thera_x_i(double [] oneRow)
    {
        double theta_T_x = 0.0;
        for(int k=0;k< (oneRow.length-1);k++)
            theta_T_x += theta[k]*oneRow[k];

        return 1/(1+(Math.exp(0-theta_T_x)));
    }

    private double [] getRow(int i)//从训练数据中取出第i行,i=0,1,2,。。。,(row-1)
    {
        return trainData[i];
    }

    private void loadTrainDataFromFile(String fileName,int row, int column)
    {
        for(int i=0;i< row;i++)//trainData的第一列全部置为1.0(feature x0)
            trainData[i][0]=1.0;

        File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = null;
            int counter = 0;
            while ( (counter<row) && (tempString = reader.readLine()) != null) {
                String [] tempData = tempString.split(" ");
                int numOfTrainData = 0;
                for(int i=0;i<column;i++)
                {
                    while(tempData[numOfTrainData] == null || tempData[numOfTrainData].equals(""))
                        numOfTrainData++;

                    trainData[counter][i+1]=Double.parseDouble(tempData[numOfTrainData]);
                    numOfTrainData ++;
                }

                counter++;
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
    }

    public void printTrainData()
    {
        System.out.println("Train Data:\n");
        for(int i=0;i<column-1;i++)
            System.out.printf("%10s","x"+i+" ");
        System.out.printf("%10s","y"+" \n");
        for(int i=0;i<row;i++)
        {
            for(int j=0;j<column;j++)
            {
                System.out.printf("%10s",trainData[i][j]+" ");
            }
            System.out.println();
        }
        System.out.println();
    }

    public void printTheta()
    {
        System.out.print("theta: ");
        for(double a:theta)
            System.out.print(a+" ");
        System.out.println();
    }

}

TestLogisticRegression.java

public class TestLogisticRegression {
    public static void main(String[] args) {
        // TODO Auto-generated method stub
         LogisticRegression m = new LogisticRegression("trainData.txt",0.001,2000000);
         m.printTrainData();
         m.trainTheta();
         m.printTheta();
    }
}

trainData.txt:

0.0 0
0.1 0
0.7 1
1.0 0
1.1 0
1.3 0
1.4 1
1.7 1
2.1 1
2.2 1

感兴趣的同学还可以试一试下面这组数据:

0.0 2.9 0.0
1.0 1.9 0.0
2.0 0.9 0.0
3.0 -0.1 0.0
4.0 -1.1 0.0
0.0 2.5 0.0
1.0 1.5 0.0
2.0 0.5 0.0
3.0 -0.5 0.0
4.0 -1.5 0.0
0.0 2.0 0.0
1.0 1.0 0.0
2.0 0.0 0.0
3.0 -1.0 0.0
4.0 -2.0 0.0
0.0 1.0 0.0
1.0 0.0 0.0
2.0 -1.0 0.0
3.0 -2.0 0.0
4.0 -3.0 0.0
0.2 2.9 1.0
1.2 1.9 1.0
2.2 0.9 1.0
3.2 -0.1 1.0
4.2 -1.1 1.0
1.2 2.9 1.0
2.2 1.9 1.0
3.2 0.9 1.0
4.2 -0.1 1.0
5.2 -1.1 1.0
2.2 2.9 1.0
3.2 1.9 1.0
4.2 0.9 1.0
5.2 -0.1 1.0
6.2 -1.1 1.0
3.0 0.2  0.0
1.0 2.3  0.0
1.0 1.8  1.0
2.0 0.8  1.0

最后训练出来的模型是:h(x)=1/(1+exp(-(-13.9827+4.6001*x1+4.6302*x2)))

机器学习算法与实现 之 Logistic Regression---java实现的更多相关文章

  1. Python机器学习算法 — 逻辑回归(Logistic Regression)

    逻辑回归--简介 逻辑回归(Logistic Regression)就是这样的一个过程:面对一个回归或者分类问题,建立代价函数,然后通过优化方法迭代求解出最优的模型参数,然后测试验证我们这个求解的模型 ...

  2. 机器学习总结之逻辑回归Logistic Regression

    机器学习总结之逻辑回归Logistic Regression 逻辑回归logistic regression,虽然名字是回归,但是实际上它是处理分类问题的算法.简单的说回归问题和分类问题如下: 回归问 ...

  3. 【机器学习】逻辑回归(Logistic Regression)

    注:最近开始学习<人工智能>选修课,老师提纲挈领的介绍了一番,听完课只了解了个大概,剩下的细节只能自己继续摸索. 从本质上讲:机器学习就是一个模型对外界的刺激(训练样本)做出反应,趋利避害 ...

  4. 机器学习入门11 - 逻辑回归 (Logistic Regression)

    原文链接:https://developers.google.com/machine-learning/crash-course/logistic-regression/ 逻辑回归会生成一个介于 0 ...

  5. 机器学习技法:05 Kernel Logistic Regression

    Roadmap Soft-Margin SVM as Regularized Model SVM versus Logistic Regression SVM for Soft Binary Clas ...

  6. Coursera台大机器学习技法课程笔记05-Kernel Logistic Regression

    这一节主要讲的是如何将Kernel trick 用到 logistic regression上. 从另一个角度来看soft-margin SVM,将其与 logistic regression进行对比 ...

  7. 机器学习之逻辑回归(Logistic Regression)

    1. Classification 这篇文章我们来讨论分类问题(classification problems),也就是说你想预测的变量 y 是一个离散的值.我们会使用逻辑回归算法来解决分类问题. 之 ...

  8. 机器学习基石笔记:10 Logistic Regression

    线性分类中的是非题------>概率题, 设置概率阈值后,大于等于该值的为O,小于改值的为X.------>逻辑回归. O为1,X为0: 逻辑回归假设: 逻辑函数/S型函数:光滑,单调, ...

  9. 吴恩达机器学习笔记14-逻辑回归(Logistic Regression)

    在分类问题中,你要预测的变量

随机推荐

  1. 关于前台主键输入错误对后台hibernate方法的影响

    由于前台输入时开始不小心打错了主键为value=“${conf_id}”/ 导致后台得到的主键不是数字“1”而是“1/”所以到后台就算是进的updata方法结果运行的却是添加方法 原因可能是传入的对象 ...

  2. python解析XML之ElementTree

    #coding=utf-8 from xml.etree import ElementTree as ET tree=ET.parse('test.xml') root = tree.getroot( ...

  3. mysql问题总结,远程登录

    http://blog.sina.com.cn/s/blog_4550f3ca0101axzd.html 更改mysql数据库的数据库名 http://tech.sina.com.cn/s/s/200 ...

  4. 转载 Deep learning:三(Multivariance Linear Regression练习)

    前言: 本文主要是来练习多变量线性回归问题(其实本文也就3个变量),参考资料见网页:http://openclassroom.stanford.edu/MainFolder/DocumentPage. ...

  5. A*算法实现

    A* 算法非常简单.算法维护两个集合:OPEN 集和 CLOSED 集.OPEN 集包含待检测节点.初始状态,OPEN集仅包含一个元素:开始位置.CLOSED集包含已检测节点.初始状态,CLOSED集 ...

  6. action参数绑定

    thinkPHP支持操作方法的参数绑定功能 action参数通过直接绑定URL中的变量作为操作方法的参数,可以简化方法的定义甚至路由的简析. 原理是把URL的中参数(不包括模块,控制器和操作名)和控制 ...

  7. android 多个shortCut快捷方式实现以及对58同城快捷方式的实现思路的研究

    这几天,项目中有个新需求,需要按照模块添加不同的快捷方式到桌面上,从而方便用户的使用.特意进行了研究并分析了下58上面桌面快捷方式的实现. 首先多个shortcut的实现: <activity ...

  8. zabbix agent自动安装脚本

    #!/bin/bash #desc: used for autoinstall zabbix client #说明:本脚本旨在批量安装zabbix_agent,在一个服务器上放好软件和配置文件,执行本 ...

  9. DW常用

    Dreamweaver代码 基本结构标签: <HTML>,表示该文件为HTML文件 <HEAD>,包含文件的标题,使用的脚本,样式定义等 <TITLE>---< ...

  10. RocketMQ源码 — 三、 Producer消息发送过程

    Producer 消息发送 producer start producer启动过程如下图 public void start(final boolean startFactory) throws MQ ...