机器学习算法与实现 之 Logistic Regression---java实现
Logistic Regression和Linear Regression实现起来几乎是一样的,有疑问的同学可以看一看Andrew N.g在coursera上的machine learning课程。
我下面会给出相应的代码(代码写的其实很糟糕),只是一个简单的实现,效率不高,收敛速度比较慢,对于下面这样一个简单的数据:
0.0 0
0.1 0
0.7 1
1.0 0
1.1 0
1.3 0
1.4 1
1.7 1
2.1 1
2.2 1
都需要很多很多次迭代,而在weka中几乎是瞬间就训练好了(http://www.cnblogs.com/wzm-xu/p/4139831.html)
所以呢。。。。。我也不知道该说什么了。
LogisticRegression.java
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.lang.Math;
public class LogisticRegression {
private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
private int row;//训练数据 行数
private int column;//训练数据 列数
private double [] theta;//参数theta
private double alpha;//训练步长
private int iteration;//迭代次数
public LogisticRegression(String fileName)
{
int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数
int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数
trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
this.row=rowoffile;
this.column=columnoffile+1;
this.alpha = 0.001;//步长默认为0.001
this.iteration=100000;//迭代次数默认为 100000
theta = new double [column-1];
initialize_theta();
loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
public LogisticRegression(String fileName,double alpha,int iteration)
{
int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数
int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数
trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
this.row=rowoffile;
this.column=columnoffile+1;
this.alpha = alpha;
this.iteration=iteration;
theta = new double [column-1];
initialize_theta();
loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
private int getRowNumber(String fileName)
{
int count =0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
while ( reader.readLine() != null)
count++;
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return count;
}
private int getColumnNumber(String fileName)
{
int count =0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = reader.readLine();
if(tempString!=null)
{
String [] temp = tempString.split(" ");
for(String s : temp)
if(!s.equals("") && s!=null)
count++;
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return count;
}
private void initialize_theta()//将theta各个参数全部初始化为1.0
{
for(int i=0;i<theta.length;i++)
theta[i]=1.0;
}
public void trainTheta()
{
int count = 0;
int iteration = this.iteration;
while( (count++) < iteration)
{
System.out.print(count);
printTheta();
//对每个theta i 求 偏导数
double [] partial_derivative = compute_partial_derivative();//偏导数
//更新每个theta
for(int i =0; i< theta.length;i++)
theta[i]-= alpha * partial_derivative[i];
}
}
private double [] compute_partial_derivative()
{
double [] partial_derivative = new double[theta.length];
for(int j =0;j<theta.length;j++)//遍历,对每个theta求偏导数
{
partial_derivative[j]= compute_partial_derivative_for_theta(j);//对 theta j 求 偏导
}
return partial_derivative;
}
private double compute_partial_derivative_for_theta(int j)
{
double sum=0.0;
for(int i=0;i<row;i++)//遍历 每一行数据
{
sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);
}
return sum/row;
}
private double h_theta_x_i_minus_y_i_times_x_j_i(int i,int j)
{
double[] oneRow = getRow(i);//取一行数据,前面是feature,最后一个是y
double result = 0.0;
result += h_thera_x_i(oneRow);
result-=oneRow[oneRow.length-1];
result*=oneRow[j];
return result;
}
private double h_thera_x_i(double [] oneRow)
{
double theta_T_x = 0.0;
for(int k=0;k< (oneRow.length-1);k++)
theta_T_x += theta[k]*oneRow[k];
return 1/(1+(Math.exp(0-theta_T_x)));
}
private double [] getRow(int i)//从训练数据中取出第i行,i=0,1,2,。。。,(row-1)
{
return trainData[i];
}
private void loadTrainDataFromFile(String fileName,int row, int column)
{
for(int i=0;i< row;i++)//trainData的第一列全部置为1.0(feature x0)
trainData[i][0]=1.0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
int counter = 0;
while ( (counter<row) && (tempString = reader.readLine()) != null) {
String [] tempData = tempString.split(" ");
int numOfTrainData = 0;
for(int i=0;i<column;i++)
{
while(tempData[numOfTrainData] == null || tempData[numOfTrainData].equals(""))
numOfTrainData++;
trainData[counter][i+1]=Double.parseDouble(tempData[numOfTrainData]);
numOfTrainData ++;
}
counter++;
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
}
public void printTrainData()
{
System.out.println("Train Data:\n");
for(int i=0;i<column-1;i++)
System.out.printf("%10s","x"+i+" ");
System.out.printf("%10s","y"+" \n");
for(int i=0;i<row;i++)
{
for(int j=0;j<column;j++)
{
System.out.printf("%10s",trainData[i][j]+" ");
}
System.out.println();
}
System.out.println();
}
public void printTheta()
{
System.out.print("theta: ");
for(double a:theta)
System.out.print(a+" ");
System.out.println();
}
}
TestLogisticRegression.java
public class TestLogisticRegression {
public static void main(String[] args) {
// TODO Auto-generated method stub
LogisticRegression m = new LogisticRegression("trainData.txt",0.001,2000000);
m.printTrainData();
m.trainTheta();
m.printTheta();
}
}
trainData.txt:
0.0 0 0.1 0 0.7 1 1.0 0 1.1 0 1.3 0 1.4 1 1.7 1 2.1 1 2.2 1
感兴趣的同学还可以试一试下面这组数据:
0.0 2.9 0.0 1.0 1.9 0.0 2.0 0.9 0.0 3.0 -0.1 0.0 4.0 -1.1 0.0 0.0 2.5 0.0 1.0 1.5 0.0 2.0 0.5 0.0 3.0 -0.5 0.0 4.0 -1.5 0.0 0.0 2.0 0.0 1.0 1.0 0.0 2.0 0.0 0.0 3.0 -1.0 0.0 4.0 -2.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 2.0 -1.0 0.0 3.0 -2.0 0.0 4.0 -3.0 0.0 0.2 2.9 1.0 1.2 1.9 1.0 2.2 0.9 1.0 3.2 -0.1 1.0 4.2 -1.1 1.0 1.2 2.9 1.0 2.2 1.9 1.0 3.2 0.9 1.0 4.2 -0.1 1.0 5.2 -1.1 1.0 2.2 2.9 1.0 3.2 1.9 1.0 4.2 0.9 1.0 5.2 -0.1 1.0 6.2 -1.1 1.0 3.0 0.2 0.0 1.0 2.3 0.0 1.0 1.8 1.0 2.0 0.8 1.0
最后训练出来的模型是:h(x)=1/(1+exp(-(-13.9827+4.6001*x1+4.6302*x2)))
机器学习算法与实现 之 Logistic Regression---java实现的更多相关文章
- Python机器学习算法 — 逻辑回归(Logistic Regression)
逻辑回归--简介 逻辑回归(Logistic Regression)就是这样的一个过程:面对一个回归或者分类问题,建立代价函数,然后通过优化方法迭代求解出最优的模型参数,然后测试验证我们这个求解的模型 ...
- 机器学习总结之逻辑回归Logistic Regression
机器学习总结之逻辑回归Logistic Regression 逻辑回归logistic regression,虽然名字是回归,但是实际上它是处理分类问题的算法.简单的说回归问题和分类问题如下: 回归问 ...
- 【机器学习】逻辑回归(Logistic Regression)
注:最近开始学习<人工智能>选修课,老师提纲挈领的介绍了一番,听完课只了解了个大概,剩下的细节只能自己继续摸索. 从本质上讲:机器学习就是一个模型对外界的刺激(训练样本)做出反应,趋利避害 ...
- 机器学习入门11 - 逻辑回归 (Logistic Regression)
原文链接:https://developers.google.com/machine-learning/crash-course/logistic-regression/ 逻辑回归会生成一个介于 0 ...
- 机器学习技法:05 Kernel Logistic Regression
Roadmap Soft-Margin SVM as Regularized Model SVM versus Logistic Regression SVM for Soft Binary Clas ...
- Coursera台大机器学习技法课程笔记05-Kernel Logistic Regression
这一节主要讲的是如何将Kernel trick 用到 logistic regression上. 从另一个角度来看soft-margin SVM,将其与 logistic regression进行对比 ...
- 机器学习之逻辑回归(Logistic Regression)
1. Classification 这篇文章我们来讨论分类问题(classification problems),也就是说你想预测的变量 y 是一个离散的值.我们会使用逻辑回归算法来解决分类问题. 之 ...
- 机器学习基石笔记:10 Logistic Regression
线性分类中的是非题------>概率题, 设置概率阈值后,大于等于该值的为O,小于改值的为X.------>逻辑回归. O为1,X为0: 逻辑回归假设: 逻辑函数/S型函数:光滑,单调, ...
- 吴恩达机器学习笔记14-逻辑回归(Logistic Regression)
在分类问题中,你要预测的变量
随机推荐
- 事务(JDBC、Hibernate、Spring)
如果不用spring管理事务,我们自己写代码来操作事务.那么这个代码怎么写要看底层怎么访问数据库了. 当采用原生JDBC访问数据库时,操作事务需要使用java.sql.Connection的API.开 ...
- Android中获取网络数据时的分页加载
//此实在Fragment中实现的,黄色部分为自动加载,红色部分是需要注意的和手动加载, 蓝色部分是睡眠时间,自我感觉不用写 ,还有就是手动加载时,不知道为什么进去后显示的就是最后一行,求大神 ...
- lucene 中关于Store.YES 关于Store.NO的解释
总算搞明白 lucene 中关于Store.YES 关于Store.NO的解释了 一直对Lucene Store.YES不太理解,网上多数的说法是存储字段,NO为不存储. 这样的解释有点郁闷:字面意 ...
- 使用Java的BlockingQueue实现生产者-消费者
http://tonl.iteye.com/blog/1936391 使用Java的BlockingQueue实现生产者-消费者 博客分类: Java JavaBlockingQueue阻塞队列 B ...
- .Net 中的反射(查看基本类型信息)
反射概述 和Type类 1.反射的作用 简单来说,反射提供这样几个能力:1.查看和遍历类型(及其成员)的基本信息和程序集元数据(metadata):2.迟绑定(Late-Binding)方法和属性.3 ...
- 深入理解React、Redux
深入理解React.ReduReact+Redux非常精炼,良好运用将发挥出极强劲的生产力.但最大的挑战来自于函数式编程(FP)范式.在工程化过程中,架构(顶层)设计将是一个巨大的挑战.要不然做出来的 ...
- Linux查看文件夹大小du
du命令参数详解见: http://baike.baidu.com/view/43913.htm 下面我们只对其做简单介绍: 查看linux文件目录的大小和文件夹包含的文件数 统计总数大小 d ...
- Spring + iBATIS完整示例
最近研究了一下Spring + iBATIS.发现看别人的例子是一回事,自己写一个完整的应用又是另外一回事.自己受够了网上贴的一知半解的代码. iBATIS是一个持久化框架,封面了sql过程,虽然sq ...
- 转:web_submit_data和web_submit_form的差别
在LoadRunner中有两个常用函数:Web_submit_form和Web_submit_data,在群里有人问这两个函数有什么区别.为什么会有两个不同却功能相似的函数.区别在哪里. 首先,从工具 ...
- Ubuntu 12.04 中文输入法
Ubuntu 12.04 中文输入法 [日期:2012-07-28] 来源:Linux社区 作者:lqhbupt [字体:大 中 小] Ubuntu上的输入法主要有小小输入平台(支持拼音/二笔/ ...