AdaBoost的java实现
目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。
这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。
关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。





- import java.io.BufferedReader;
- import java.io.FileInputStream;
- import java.io.IOException;
- import java.io.InputStreamReader;
- import java.util.ArrayList;
- class Stump{
- public int dim;
- public double thresh;
- public String condition;
- public double error;
- public ArrayList<Integer> labelList;
- double factor;
- public String toString(){
- return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
- }
- }
- class Utils{
- //加载数据集
- public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
- ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
- FileInputStream fis=new FileInputStream(filename);
- InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
- BufferedReader br=new BufferedReader(isr);
- String line="";
- while((line=br.readLine())!=null){
- ArrayList<Double> data=new ArrayList<Double>();
- String[] s=line.split(" ");
- for(int i=0;i<s.length-1;i++){
- data.add(Double.parseDouble(s[i]));
- }
- dataSet.add(data);
- }
- return dataSet;
- }
- //加载类别
- public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
- ArrayList<Integer> labelSet=new ArrayList<Integer>();
- FileInputStream fis=new FileInputStream(filename);
- InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
- BufferedReader br=new BufferedReader(isr);
- String line="";
- while((line=br.readLine())!=null){
- String[] s=line.split(" ");
- labelSet.add(Integer.parseInt(s[s.length-1]));
- }
- return labelSet;
- }
- //测试用的
- public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
- for(ArrayList<Double> data:dataSet){
- System.out.println(data);
- }
- }
- //获取最大值,用于求步长
- public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
- double max=-9999.0;
- for(ArrayList<Double> data:dataSet){
- if(data.get(index)>max){
- max=data.get(index);
- }
- }
- return max;
- }
- //获取最小值,用于求步长
- public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
- double min=9999.0;
- for(ArrayList<Double> data:dataSet){
- if(data.get(index)<min){
- min=data.get(index);
- }
- }
- return min;
- }
- //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别
- public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
- ArrayList<Integer> labelList=new ArrayList<Integer>();
- if(condition.compareTo("lt")==0){
- for(ArrayList<Double> data:dataSet){
- if(data.get(feature)<=thresh){
- labelList.add(1);
- }else{
- labelList.add(-1);
- }
- }
- }else{
- for(ArrayList<Double> data:dataSet){
- if(data.get(feature)>=thresh){
- labelList.add(1);
- }else{
- labelList.add(-1);
- }
- }
- }
- return labelList;
- }
- //求预测类别与真实类别的加权误差
- public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
- double error=0;
- int n=real.size();
- for(int i=0;i<fake.size();i++){
- if(fake.get(i)!=real.get(i)){
- error+=weights.get(i);
- }
- }
- return error;
- }
- //构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。
- public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
- int featureNum=dataSet.get(0).size();
- int rowNum=dataSet.size();
- Stump stump=new Stump();
- double minError=999.0;
- System.out.println("第"+n+"次迭代");
- for(int i=0;i<featureNum;i++){
- double min=getMin(dataSet,i);
- double max=getMax(dataSet,i);
- double step=(max-min)/(rowNum);
- for(double j=min-step;j<=max+step;j=j+step){
- String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类
- for(String condition:conditions){
- ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition);
- double error=Utils.getError(labelList,labelSet,weights);
- if(error<minError){
- minError=error;
- stump.dim=i;
- stump.thresh=j;
- stump.condition=condition;
- stump.error=minError;
- stump.labelList=labelList;
- stump.factor=0.5*(Math.log((1-error)/error));
- }
- }
- }
- }
- return stump;
- }
- public static ArrayList<Double> getInitWeights(int n){
- double weight=1.0/n;
- ArrayList<Double> weights=new ArrayList<Double>();
- for(int i=0;i<n;i++){
- weights.add(weight);
- }
- return weights;
- }
- //更新样本权值
- public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
- double Z=0;
- ArrayList<Double> newWeights=new ArrayList<Double>();
- int row=labelList.size();
- double e=Math.E;
- double factor=stump.factor;
- for(int i=0;i<row;i++){
- Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
- }
- for(int i=0;i<row;i++){
- double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
- newWeights.add(weight);
- }
- return newWeights;
- }
- //对加权误差累加
- public static ArrayList<Double> InitAccWeightError(int n){
- ArrayList<Double> accError=new ArrayList<Double>();
- for(int i=0;i<n;i++){
- accError.add(0.0);
- }
- return accError;
- }
- public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
- ArrayList<Integer> t=stump.labelList;
- double factor=stump.factor;
- ArrayList<Double> newAccError=new ArrayList<Double>();
- for(int i=0;i<t.size();i++){
- double a=accerror.get(i)+factor*t.get(i);
- newAccError.add(a);
- }
- return newAccError;
- }
- public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
- ArrayList<Integer> a=new ArrayList<Integer>();
- int wrong=0;
- for(int i=0;i<accError.size();i++){
- if(accError.get(i)>0){
- if(labelList.get(i)==-1){
- wrong++;
- }
- }else if(labelList.get(i)==1){
- wrong++;
- }
- }
- double error=wrong*1.0/accError.size();
- return error;
- }
- public static void showStumpList(ArrayList<Stump> G){
- for(Stump s:G){
- System.out.println(s);
- System.out.println(" ");
- }
- }
- }
- public class Adaboost {
- /**
- * @param args
- * @throws IOException
- */
- public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
- int row=labelList.size();
- ArrayList<Double> weights=Utils.getInitWeights(row);
- ArrayList<Stump> G=new ArrayList<Stump>();
- ArrayList<Double> accError=Utils.InitAccWeightError(row);
- int n=1;
- while(true){
- Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树
- G.add(stump);
- weights=Utils.updateWeights(stump,labelList,weights);//更新权值
- accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了
- double error=Utils.calErrorRate(accError,labelList);
- if(error<0.001){
- break;
- }
- n++;
- }
- return G;
- }
- public static void main(String[] args) throws IOException {
- // TODO Auto-generated method stub
- String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
- ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
- ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
- ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
- Utils.showStumpList(G);
- System.out.println("finished");
- }
- }
这里的数据采用的是统计学习方法中的数据
- 0 1
- 1 1
- 2 1
- 3 -1
- 4 -1
- 5 -1
- 6 1
- 7 1
- 8 1
- 9 -1
这里是单个特征的,也可以是多维数据,例如
- 1.0 2.1 1
- 2.0 1.1 1
- 1.3 1.0 -1
- 1.0 1.0 -1
- 2.0 1.0 1
AdaBoost的java实现的更多相关文章
- Spark案例分析
一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...
- 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost
集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...
- smile——Java机器学习引擎
资源 https://haifengl.github.io/ https://github.com/haifengl/smile 介绍 Smile(统计机器智能和学习引擎)是一个基于Java和Scal ...
- 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题
背景起因: 记起以前的另一次也是关于内存的调优分享下 有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...
- Elasticsearch之java的基本操作一
摘要 接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...
- 论:开发者信仰之“天下IT是一家“(Java .NET篇)
比尔盖茨公认的IT界领军人物,打造了辉煌一时的PC时代. 2008年,史蒂夫鲍尔默接替了盖茨的工作,成为微软公司的总裁. 2013年他与微软做了最后的道别. 2013年以后,我才真正看到了微软的变化. ...
- 故障重现, JAVA进程内存不够时突然挂掉模拟
背景,服务器上的一个JAVA服务进程突然挂掉,查看产生了崩溃日志,如下: # Set larger code cache with -XX:ReservedCodeCacheSize= # This ...
- 死磕内存篇 --- JAVA进程和linux内存间的大小关系
运行个JAVA 用sleep去hold住 package org.hjb.test; public class TestOnly { public static void main(String[] ...
- 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用
有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...
随机推荐
- Firebase 相关
谷歌在 2016年 I/O 大会上推出了 Firebase 的新版本.Firebase 平台提供了为移动端(iOS和Android)和 Web 端创建后端架构的完整解决方案. 从一开始的移动后端即服务 ...
- USB鼠标线序
鼠标线断了,找了个废弃的手机充电线接上,特记录线序如下: 红————白 白————橙绿————绿黑————蓝
- Win7与Ubuntu双系统时卸载Ubuntu的方法
Win7与Ubuntu双系统时卸载Ubuntu的方法 [日期:2010-03-26] 来源:Ubuntu社区 作者:Ubuntu编辑 [字体:大 中 小] 1. 下载MBRFix工具,放 ...
- AndroidStudio使用注意事项
今天在引入GitHUb上的开源框架时,写好依赖后编译时,报以下错误: Error:Execution failed for task ':app:processDebugResources'.> ...
- SQL Server 索引维护sql语句
使用以下脚本查看数据库索引碎片的大小情况: 复制代码代码如下: DBCC SHOWCONTIG WITH FAST, TABLERESULTS, ALL_INDEXES, NO_INFOMSGS 以 ...
- USER-AGENT是什么
USER-AGENT是什么? USER-AGENT:记录请求所来自的浏览器. User-Agent分析网站 http://www.useragentstring.com/ 通过解析User-Agent ...
- 【转】PHP代码审计
PHP代码审计 目录 1. 概述3 2. 输入验证和输出显示3 2.1 命令注入4 2.2 跨站脚本4 2.3 文件包含5 2.4 代码注入5 2.5 SQL注入6 2.6 XPath注入6 2.7 ...
- UISegmentedControl 踩坑
@interface JLMyContactsViewController () @property (nonatomic, strong) UIImageView *navImageView; ...
- android 1.6 launcher研究之自定义ViewGroup (转 2011.06.03(二)——— android 1.6 launcher研究之自定义ViewGroup )
2011.06.03(2)——— android 1.6 launcher研究之自定义ViewGroup2011.06.03(2)——— android 1.6 launcher研究之自定义ViewG ...
- 笔记整理--HTTP Header 详解
HTTP Header 详解 2013/09/21 | 分类: IT技术 | 0 条评论 | 标签: HTTP 分享到:36 原文出处: zcmhi HTTP(HyperTextTransferPro ...