AdaBoost的java实现
目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。
这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。
关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。





import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList; class Stump{
public int dim;
public double thresh;
public String condition;
public double error;
public ArrayList<Integer> labelList;
double factor; public String toString(){
return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
}
} class Utils{
//加载数据集
public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
ArrayList<Double> data=new ArrayList<Double>();
String[] s=line.split(" "); for(int i=0;i<s.length-1;i++){
data.add(Double.parseDouble(s[i]));
}
dataSet.add(data);
}
return dataSet;
} //加载类别
public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
ArrayList<Integer> labelSet=new ArrayList<Integer>(); FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
String[] s=line.split(" ");
labelSet.add(Integer.parseInt(s[s.length-1]));
}
return labelSet;
}
//测试用的
public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
for(ArrayList<Double> data:dataSet){
System.out.println(data);
}
}
//获取最大值,用于求步长
public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
double max=-9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)>max){
max=data.get(index);
}
}
return max;
}
//获取最小值,用于求步长
public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
double min=9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)<min){
min=data.get(index);
}
}
return min;
} //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别
public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
ArrayList<Integer> labelList=new ArrayList<Integer>();
if(condition.compareTo("lt")==0){
for(ArrayList<Double> data:dataSet){
if(data.get(feature)<=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}else{
for(ArrayList<Double> data:dataSet){
if(data.get(feature)>=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}
return labelList;
}
//求预测类别与真实类别的加权误差
public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
double error=0; int n=real.size(); for(int i=0;i<fake.size();i++){
if(fake.get(i)!=real.get(i)){
error+=weights.get(i); }
} return error;
}
//构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。
public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
int featureNum=dataSet.get(0).size(); int rowNum=dataSet.size();
Stump stump=new Stump();
double minError=999.0;
System.out.println("第"+n+"次迭代");
for(int i=0;i<featureNum;i++){
double min=getMin(dataSet,i);
double max=getMax(dataSet,i);
double step=(max-min)/(rowNum);
for(double j=min-step;j<=max+step;j=j+step){
String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类
for(String condition:conditions){
ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); double error=Utils.getError(labelList,labelSet,weights);
if(error<minError){
minError=error;
stump.dim=i;
stump.thresh=j;
stump.condition=condition;
stump.error=minError;
stump.labelList=labelList;
stump.factor=0.5*(Math.log((1-error)/error));
} }
} } return stump;
} public static ArrayList<Double> getInitWeights(int n){
double weight=1.0/n;
ArrayList<Double> weights=new ArrayList<Double>();
for(int i=0;i<n;i++){
weights.add(weight);
}
return weights;
}
//更新样本权值
public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
double Z=0;
ArrayList<Double> newWeights=new ArrayList<Double>();
int row=labelList.size();
double e=Math.E;
double factor=stump.factor;
for(int i=0;i<row;i++){
Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
} for(int i=0;i<row;i++){
double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
newWeights.add(weight);
}
return newWeights;
}
//对加权误差累加
public static ArrayList<Double> InitAccWeightError(int n){
ArrayList<Double> accError=new ArrayList<Double>();
for(int i=0;i<n;i++){
accError.add(0.0);
}
return accError;
} public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
ArrayList<Integer> t=stump.labelList;
double factor=stump.factor;
ArrayList<Double> newAccError=new ArrayList<Double>();
for(int i=0;i<t.size();i++){
double a=accerror.get(i)+factor*t.get(i);
newAccError.add(a);
}
return newAccError;
} public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
ArrayList<Integer> a=new ArrayList<Integer>();
int wrong=0;
for(int i=0;i<accError.size();i++){
if(accError.get(i)>0){
if(labelList.get(i)==-1){
wrong++;
}
}else if(labelList.get(i)==1){
wrong++;
}
}
double error=wrong*1.0/accError.size();
return error;
} public static void showStumpList(ArrayList<Stump> G){
for(Stump s:G){
System.out.println(s);
System.out.println(" ");
}
}
} public class Adaboost { /**
* @param args
* @throws IOException
*/ public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
int row=labelList.size();
ArrayList<Double> weights=Utils.getInitWeights(row);
ArrayList<Stump> G=new ArrayList<Stump>();
ArrayList<Double> accError=Utils.InitAccWeightError(row);
int n=1;
while(true){
Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树
G.add(stump);
weights=Utils.updateWeights(stump,labelList,weights);//更新权值
accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了
double error=Utils.calErrorRate(accError,labelList);
if(error<0.001){
break;
}
n++;
}
return G;
} public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
Utils.showStumpList(G);
System.out.println("finished");
} }
这里的数据采用的是统计学习方法中的数据
0 1
1 1
2 1
3 -1
4 -1
5 -1
6 1
7 1
8 1
9 -1
这里是单个特征的,也可以是多维数据,例如
1.0 2.1 1
2.0 1.1 1
1.3 1.0 -1
1.0 1.0 -1
2.0 1.0 1
AdaBoost的java实现的更多相关文章
- Spark案例分析
一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...
- 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost
集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...
- smile——Java机器学习引擎
资源 https://haifengl.github.io/ https://github.com/haifengl/smile 介绍 Smile(统计机器智能和学习引擎)是一个基于Java和Scal ...
- 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题
背景起因: 记起以前的另一次也是关于内存的调优分享下 有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...
- Elasticsearch之java的基本操作一
摘要 接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...
- 论:开发者信仰之“天下IT是一家“(Java .NET篇)
比尔盖茨公认的IT界领军人物,打造了辉煌一时的PC时代. 2008年,史蒂夫鲍尔默接替了盖茨的工作,成为微软公司的总裁. 2013年他与微软做了最后的道别. 2013年以后,我才真正看到了微软的变化. ...
- 故障重现, JAVA进程内存不够时突然挂掉模拟
背景,服务器上的一个JAVA服务进程突然挂掉,查看产生了崩溃日志,如下: # Set larger code cache with -XX:ReservedCodeCacheSize= # This ...
- 死磕内存篇 --- JAVA进程和linux内存间的大小关系
运行个JAVA 用sleep去hold住 package org.hjb.test; public class TestOnly { public static void main(String[] ...
- 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用
有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...
随机推荐
- 【dp】 比较经典的dp poj 1160
转自http://blog.sina.com.cn/s/blog_5dd8fece0100rq7d.html [题目大意]:用数轴描述一条高速公路,有V个村庄,每一个村庄坐落在数轴的某个点上,需要选择 ...
- Spring Boot 系列教程13-注解定时任务
注解 @Scheduled(cron = "0/5 * * * * ?") 相当于原来的xml版本的如下配置 <task:scheduled ref="schedu ...
- db2数据导出导入
C:\Users\yexuxia>set db2instance=TCASHMAN C:\Users\yexuxia>db2(c) Copyright IBM Corporation 19 ...
- tomcat内存优化问题
Java内存组成 1) 堆 运行时数据区域,所有类实例和数组的内存均从此处分配.Java 虚拟机启动时创建.对象的堆内存由称为垃圾回收器 的自动内存管理系统回收. 堆由两部分组成: 其中eden+fr ...
- 学习笔记——状态模式State
状态模式,主要是用于存在大量case判断的操作执行,同时这些case依赖于对象的状态,那么就可以将大量的case判断封装为独立的类. Context: -state,当前状态对象. ChangeSta ...
- Node.js学习 - RESTFul API
REST Representational State Transfer (表述性状态转移), 是一组架构约束条件和原则.满足这些约束条件和原则的应用程序或设计就是RESTful. RESTful W ...
- 转:使用WebDriver过程中遇到的那些问题
在做web项目的自动化端到端测试时主要使用的是Selenium WebDriver来驱动浏览器.Selenium WebDriver的优点是支持的语言多,支持的浏览器多.主流的浏览器Chrome.Fi ...
- php-fpm配置优化
PHP配置文件php-fpm的优化 2013/06/28 php, php-fpm 应用加速与性能调优 评论 6,029 本文所涉及的配置文件名为PHP-fpm.conf,里面比较重要的配置项有如 ...
- java中把list列表转为arrayList以及arraylist数组截取的简单方法
java中把list列表转为arrayList以及arraylist数组截取的简单方法 package xiaobai; import java.util.ArrayList; import java ...
- CentOS 6.5 开机启动指定服务
gedit /etc/rc.d/rc.local #关闭防火墙 service iptables stop #开启samba服务 service smb start #开启ntopng 端口5000 ...