AdaBoost的java实现
目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。
这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。
关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。

可以看出,误差率越小的,其权重越大


import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList; class Stump{
public int dim;
public double thresh;
public String condition;
public double error;
public ArrayList<Integer> labelList;
double factor; public String toString(){
return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
}
} class Utils{
//加载数据集
public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
ArrayList<Double> data=new ArrayList<Double>();
String[] s=line.split(" "); for(int i=0;i<s.length-1;i++){
data.add(Double.parseDouble(s[i]));
}
dataSet.add(data);
}
return dataSet;
} //加载类别
public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
ArrayList<Integer> labelSet=new ArrayList<Integer>(); FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
String[] s=line.split(" ");
labelSet.add(Integer.parseInt(s[s.length-1]));
}
return labelSet;
}
//测试用的
public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
for(ArrayList<Double> data:dataSet){
System.out.println(data);
}
}
//获取最大值,用于求步长
public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
double max=-9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)>max){
max=data.get(index);
}
}
return max;
}
//获取最小值,用于求步长
public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
double min=9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)<min){
min=data.get(index);
}
}
return min;
} //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别
public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
ArrayList<Integer> labelList=new ArrayList<Integer>();
if(condition.compareTo("lt")==0){
for(ArrayList<Double> data:dataSet){
if(data.get(feature)<=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}else{
for(ArrayList<Double> data:dataSet){
if(data.get(feature)>=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}
return labelList;
}
//求预测类别与真实类别的加权误差
public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
double error=0; int n=real.size(); for(int i=0;i<fake.size();i++){
if(fake.get(i)!=real.get(i)){
error+=weights.get(i); }
} return error;
}
//构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。
public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
int featureNum=dataSet.get(0).size(); int rowNum=dataSet.size();
Stump stump=new Stump();
double minError=999.0;
System.out.println("第"+n+"次迭代");
for(int i=0;i<featureNum;i++){
double min=getMin(dataSet,i);
double max=getMax(dataSet,i);
double step=(max-min)/(rowNum);
for(double j=min-step;j<=max+step;j=j+step){
String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类
for(String condition:conditions){
ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); double error=Utils.getError(labelList,labelSet,weights);
if(error<minError){
minError=error;
stump.dim=i;
stump.thresh=j;
stump.condition=condition;
stump.error=minError;
stump.labelList=labelList;
stump.factor=0.5*(Math.log((1-error)/error));
} }
} } return stump;
} public static ArrayList<Double> getInitWeights(int n){
double weight=1.0/n;
ArrayList<Double> weights=new ArrayList<Double>();
for(int i=0;i<n;i++){
weights.add(weight);
}
return weights;
}
//更新样本权值
public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
double Z=0;
ArrayList<Double> newWeights=new ArrayList<Double>();
int row=labelList.size();
double e=Math.E;
double factor=stump.factor;
for(int i=0;i<row;i++){
Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
} for(int i=0;i<row;i++){
double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
newWeights.add(weight);
}
return newWeights;
}
//对加权误差累加
public static ArrayList<Double> InitAccWeightError(int n){
ArrayList<Double> accError=new ArrayList<Double>();
for(int i=0;i<n;i++){
accError.add(0.0);
}
return accError;
} public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
ArrayList<Integer> t=stump.labelList;
double factor=stump.factor;
ArrayList<Double> newAccError=new ArrayList<Double>();
for(int i=0;i<t.size();i++){
double a=accerror.get(i)+factor*t.get(i);
newAccError.add(a);
}
return newAccError;
} public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
ArrayList<Integer> a=new ArrayList<Integer>();
int wrong=0;
for(int i=0;i<accError.size();i++){
if(accError.get(i)>0){
if(labelList.get(i)==-1){
wrong++;
}
}else if(labelList.get(i)==1){
wrong++;
}
}
double error=wrong*1.0/accError.size();
return error;
} public static void showStumpList(ArrayList<Stump> G){
for(Stump s:G){
System.out.println(s);
System.out.println(" ");
}
}
} public class Adaboost { /**
* @param args
* @throws IOException
*/ public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
int row=labelList.size();
ArrayList<Double> weights=Utils.getInitWeights(row);
ArrayList<Stump> G=new ArrayList<Stump>();
ArrayList<Double> accError=Utils.InitAccWeightError(row);
int n=1;
while(true){
Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树
G.add(stump);
weights=Utils.updateWeights(stump,labelList,weights);//更新权值
accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了
double error=Utils.calErrorRate(accError,labelList);
if(error<0.001){
break;
}
n++;
}
return G;
} public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
Utils.showStumpList(G);
System.out.println("finished");
} }
这里的数据采用的是统计学习方法中的数据
0 1
1 1
2 1
3 -1
4 -1
5 -1
6 1
7 1
8 1
9 -1
这里是单个特征的,也可以是多维数据,例如
1.0 2.1 1
2.0 1.1 1
1.3 1.0 -1
1.0 1.0 -1
2.0 1.0 1
AdaBoost的java实现的更多相关文章
- Spark案例分析
一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...
- 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost
集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...
- smile——Java机器学习引擎
资源 https://haifengl.github.io/ https://github.com/haifengl/smile 介绍 Smile(统计机器智能和学习引擎)是一个基于Java和Scal ...
- 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题
背景起因: 记起以前的另一次也是关于内存的调优分享下 有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...
- Elasticsearch之java的基本操作一
摘要 接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...
- 论:开发者信仰之“天下IT是一家“(Java .NET篇)
比尔盖茨公认的IT界领军人物,打造了辉煌一时的PC时代. 2008年,史蒂夫鲍尔默接替了盖茨的工作,成为微软公司的总裁. 2013年他与微软做了最后的道别. 2013年以后,我才真正看到了微软的变化. ...
- 故障重现, JAVA进程内存不够时突然挂掉模拟
背景,服务器上的一个JAVA服务进程突然挂掉,查看产生了崩溃日志,如下: # Set larger code cache with -XX:ReservedCodeCacheSize= # This ...
- 死磕内存篇 --- JAVA进程和linux内存间的大小关系
运行个JAVA 用sleep去hold住 package org.hjb.test; public class TestOnly { public static void main(String[] ...
- 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用
有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...
随机推荐
- 两个数组各个数相加或相乘变成一个矩阵求第K大
input 1<=T<=20 1<=n<=100000,1<=k<=n*n a1 a2 ... an 0<ai<=10000 b1 b2 ... bn ...
- more分页阅读
相比cat命令,more可以更加灵活的去阅读查看文件. 1.命令格式 more [-dlfpcsu ] [-num ] [+/ pattern] [+ linenum] [file ... ] 2.命 ...
- Notes over compiling..
When compiling VIM on windows, using nmake may be a better choice.. Because so far my attempts to co ...
- stdafx文件介绍
MSDN介绍: These files are used to build a precompiled header file Projname.pch and a precompiled types ...
- Loadrunner之脚本的调试和保存(六)
一.调试脚本 脚本录制完毕后,按F5键或单击菜单上的RUN按钮,可以运行脚本. 在VIRTUAL USER GENERATOR中运行脚本的作用,主要是查看录制的脚本能否正常通过,如果有问题 ...
- 常用的JS页面跳转代码调用大全
一.常规的JS页面跳转代码 1.在原来的窗体中直接跳转用 <script type="text/javascript"> window.location.href=&q ...
- 笔记整理--Linux编程
linux c编程open() read() write()函数的使用方法及实例 | 奶牛博客 - Google Chrome (2013/8/31 17:56:10) 今天把文件IO操作的一些东东整 ...
- 在线的代码托管平台 coding.net ===中国扩展版github
coding.net 是国内新兴的一个项目管理平台,功能主要包括:代码托管.在线运行环境.监控代码质量,兼有一定的社交功能. 在线运行环境支持Java.Ruby.Node.js.PHP.Python. ...
- 表单提交中记得form表单放到table外面
帝国后台按栏目搜索文章时怎么都不生效 控制台查看原来是 栏目的select的值没有提交过去,原来由于form标签在table标签里面,导致js生成的<select>标签提交失败. 解决办 ...
- Linux查看文件夹大小du
du命令参数详解见: http://baike.baidu.com/view/43913.htm 下面我们只对其做简单介绍: 查看linux文件目录的大小和文件夹包含的文件数 统计总数大小 d ...