目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。

这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。

关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。

bagging和boosting的区别

bagging:是指在原始数据上通过放回抽样,抽出和原始数据大小相等的新数据集(这个性质说明新数据集存在重复的值,而原始数据部分数据值不会出现在新数据集中),并重复该过程选择N个新数据集,这样通过N个分类器对这个N个数据集进行分类,最后选择分类器投票结果中最多类别作为最后的分类结果。
boosting:相比bagging,boosting像是一种串行,bagging是一种并行的,bagging可以对于N个数据集通过N个分类器同时进行分类,并且每个分类器的权重是一样的,但是boosting则相反,boosting是利用一个数据集依次由每个分类器进行分类,而确定每个分类器的权重是加大正确率高的分类器的权重,减少正确率低的分类器的权重。同时为了提高准确率,每次会降低被正确分类的样本的权重,提高没有正确分类的样本的权重。这样做其实比较符合人的决策过程,就是要多训练自己容易做错的题型,并且要多听取正确性高的老师的意见。
 
那么AdaBoost的主要的两个过程就是提高错误分类的样本权重和提高正确率高的分类器的权重。
算法的步骤:
输入:训练集T,弱学习分类器(这里是一个节点的决策树)
输出:最终的分类器G
1 先初始化样本权重值,D1={W11...W1n}W1i=1/n
2 根据样本权重D1以及决策树求分类误差率,并求的最小的误差率em,以及该决策树
  em=
3 计算该分类器的权重
  可以看出,误差率越小的,其权重越大
4 更新各个样本的权重,Dm+1,(用公式编辑器好麻烦。。。 )
  
其中Zm是规范化银子:
  
5 构建基本分类器
  F(X)=
6 计算该分类器下的误差率,如果小于某个阈值就停止,否则从第二步开始迭代
 
终于不用打公式了。。。。
附上代码:
  1. import java.io.BufferedReader;
  2. import java.io.FileInputStream;
  3. import java.io.IOException;
  4. import java.io.InputStreamReader;
  5. import java.util.ArrayList;
  6.  
  7. class Stump{
  8. public int dim;
  9. public double thresh;
  10. public String condition;
  11. public double error;
  12. public ArrayList<Integer> labelList;
  13. double factor;
  14.  
  15. public String toString(){
  16. return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
  17. }
  18. }
  19.  
  20. class Utils{
  21. //加载数据集
  22. public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
  23. ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
  24. FileInputStream fis=new FileInputStream(filename);
  25. InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
  26. BufferedReader br=new BufferedReader(isr);
  27. String line="";
  28.  
  29. while((line=br.readLine())!=null){
  30. ArrayList<Double> data=new ArrayList<Double>();
  31. String[] s=line.split(" ");
  32.  
  33. for(int i=0;i<s.length-1;i++){
  34. data.add(Double.parseDouble(s[i]));
  35. }
  36. dataSet.add(data);
  37. }
  38. return dataSet;
  39. }
  40.  
  41. //加载类别
  42. public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
  43. ArrayList<Integer> labelSet=new ArrayList<Integer>();
  44.  
  45. FileInputStream fis=new FileInputStream(filename);
  46. InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
  47. BufferedReader br=new BufferedReader(isr);
  48. String line="";
  49.  
  50. while((line=br.readLine())!=null){
  51. String[] s=line.split(" ");
  52. labelSet.add(Integer.parseInt(s[s.length-1]));
  53. }
  54. return labelSet;
  55. }
  56. //测试用的
  57. public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
  58. for(ArrayList<Double> data:dataSet){
  59. System.out.println(data);
  60. }
  61. }
  62. //获取最大值,用于求步长
  63. public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
  64. double max=-9999.0;
  65. for(ArrayList<Double> data:dataSet){
  66. if(data.get(index)>max){
  67. max=data.get(index);
  68. }
  69. }
  70. return max;
  71. }
  72. //获取最小值,用于求步长
  73. public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
  74. double min=9999.0;
  75. for(ArrayList<Double> data:dataSet){
  76. if(data.get(index)<min){
  77. min=data.get(index);
  78. }
  79. }
  80. return min;
  81. }
  82.  
  83. //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别
  84. public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
  85. ArrayList<Integer> labelList=new ArrayList<Integer>();
  86. if(condition.compareTo("lt")==0){
  87. for(ArrayList<Double> data:dataSet){
  88. if(data.get(feature)<=thresh){
  89. labelList.add(1);
  90. }else{
  91. labelList.add(-1);
  92. }
  93. }
  94. }else{
  95. for(ArrayList<Double> data:dataSet){
  96. if(data.get(feature)>=thresh){
  97. labelList.add(1);
  98. }else{
  99. labelList.add(-1);
  100. }
  101. }
  102. }
  103. return labelList;
  104. }
  105. //求预测类别与真实类别的加权误差
  106. public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
  107. double error=0;
  108.  
  109. int n=real.size();
  110.  
  111. for(int i=0;i<fake.size();i++){
  112. if(fake.get(i)!=real.get(i)){
  113. error+=weights.get(i);
  114.  
  115. }
  116. }
  117.  
  118. return error;
  119. }
  120. //构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。
  121. public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
  122. int featureNum=dataSet.get(0).size();
  123.  
  124. int rowNum=dataSet.size();
  125. Stump stump=new Stump();
  126. double minError=999.0;
  127. System.out.println("第"+n+"次迭代");
  128. for(int i=0;i<featureNum;i++){
  129. double min=getMin(dataSet,i);
  130. double max=getMax(dataSet,i);
  131. double step=(max-min)/(rowNum);
  132. for(double j=min-step;j<=max+step;j=j+step){
  133. String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类
  134. for(String condition:conditions){
  135. ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition);
  136.  
  137. double error=Utils.getError(labelList,labelSet,weights);
  138. if(error<minError){
  139. minError=error;
  140. stump.dim=i;
  141. stump.thresh=j;
  142. stump.condition=condition;
  143. stump.error=minError;
  144. stump.labelList=labelList;
  145. stump.factor=0.5*(Math.log((1-error)/error));
  146. }
  147.  
  148. }
  149. }
  150.  
  151. }
  152.  
  153. return stump;
  154. }
  155.  
  156. public static ArrayList<Double> getInitWeights(int n){
  157. double weight=1.0/n;
  158. ArrayList<Double> weights=new ArrayList<Double>();
  159. for(int i=0;i<n;i++){
  160. weights.add(weight);
  161. }
  162. return weights;
  163. }
  164. //更新样本权值
  165. public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
  166. double Z=0;
  167. ArrayList<Double> newWeights=new ArrayList<Double>();
  168. int row=labelList.size();
  169. double e=Math.E;
  170. double factor=stump.factor;
  171. for(int i=0;i<row;i++){
  172. Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
  173. }
  174.  
  175. for(int i=0;i<row;i++){
  176. double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
  177. newWeights.add(weight);
  178. }
  179. return newWeights;
  180. }
  181. //对加权误差累加
  182. public static ArrayList<Double> InitAccWeightError(int n){
  183. ArrayList<Double> accError=new ArrayList<Double>();
  184. for(int i=0;i<n;i++){
  185. accError.add(0.0);
  186. }
  187. return accError;
  188. }
  189.  
  190. public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
  191. ArrayList<Integer> t=stump.labelList;
  192. double factor=stump.factor;
  193. ArrayList<Double> newAccError=new ArrayList<Double>();
  194. for(int i=0;i<t.size();i++){
  195. double a=accerror.get(i)+factor*t.get(i);
  196. newAccError.add(a);
  197. }
  198. return newAccError;
  199. }
  200.  
  201. public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
  202. ArrayList<Integer> a=new ArrayList<Integer>();
  203. int wrong=0;
  204. for(int i=0;i<accError.size();i++){
  205. if(accError.get(i)>0){
  206. if(labelList.get(i)==-1){
  207. wrong++;
  208. }
  209. }else if(labelList.get(i)==1){
  210. wrong++;
  211. }
  212. }
  213. double error=wrong*1.0/accError.size();
  214. return error;
  215. }
  216.  
  217. public static void showStumpList(ArrayList<Stump> G){
  218. for(Stump s:G){
  219. System.out.println(s);
  220. System.out.println(" ");
  221. }
  222. }
  223. }
  224.  
  225. public class Adaboost {
  226.  
  227. /**
  228. * @param args
  229. * @throws IOException
  230. */
  231.  
  232. public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
  233. int row=labelList.size();
  234. ArrayList<Double> weights=Utils.getInitWeights(row);
  235. ArrayList<Stump> G=new ArrayList<Stump>();
  236. ArrayList<Double> accError=Utils.InitAccWeightError(row);
  237. int n=1;
  238. while(true){
  239. Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树
  240. G.add(stump);
  241. weights=Utils.updateWeights(stump,labelList,weights);//更新权值
  242. accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了
  243. double error=Utils.calErrorRate(accError,labelList);
  244. if(error<0.001){
  245. break;
  246. }
  247. n++;
  248. }
  249. return G;
  250. }
  251.  
  252. public static void main(String[] args) throws IOException {
  253. // TODO Auto-generated method stub
  254. String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
  255. ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
  256. ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
  257. ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
  258. Utils.showStumpList(G);
  259. System.out.println("finished");
  260. }
  261.  
  262. }

这里的数据采用的是统计学习方法中的数据

  1. 0 1
  2. 1 1
  3. 2 1
  4. 3 -1
  5. 4 -1
  6. 5 -1
  7. 6 1
  8. 7 1
  9. 8 1
  10. 9 -1

这里是单个特征的,也可以是多维数据,例如

  1. 1.0 2.1 1
  2. 2.0 1.1 1
  3. 1.3 1.0 -1
  4. 1.0 1.0 -1
  5. 2.0 1.0 1

AdaBoost的java实现的更多相关文章

  1. Spark案例分析

    一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...

  2. 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost

    集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...

  3. smile——Java机器学习引擎

    资源 https://haifengl.github.io/ https://github.com/haifengl/smile 介绍 Smile(统计机器智能和学习引擎)是一个基于Java和Scal ...

  4. 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题

    背景起因: 记起以前的另一次也是关于内存的调优分享下   有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...

  5. Elasticsearch之java的基本操作一

    摘要   接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...

  6. 论:开发者信仰之“天下IT是一家“(Java .NET篇)

    比尔盖茨公认的IT界领军人物,打造了辉煌一时的PC时代. 2008年,史蒂夫鲍尔默接替了盖茨的工作,成为微软公司的总裁. 2013年他与微软做了最后的道别. 2013年以后,我才真正看到了微软的变化. ...

  7. 故障重现, JAVA进程内存不够时突然挂掉模拟

    背景,服务器上的一个JAVA服务进程突然挂掉,查看产生了崩溃日志,如下: # Set larger code cache with -XX:ReservedCodeCacheSize= # This ...

  8. 死磕内存篇 --- JAVA进程和linux内存间的大小关系

    运行个JAVA 用sleep去hold住 package org.hjb.test; public class TestOnly { public static void main(String[] ...

  9. 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用

    有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...

随机推荐

  1. Firebase 相关

    谷歌在 2016年 I/O 大会上推出了 Firebase 的新版本.Firebase 平台提供了为移动端(iOS和Android)和 Web 端创建后端架构的完整解决方案. 从一开始的移动后端即服务 ...

  2. USB鼠标线序

    鼠标线断了,找了个废弃的手机充电线接上,特记录线序如下: 红————白          白————橙绿————绿黑————蓝

  3. Win7与Ubuntu双系统时卸载Ubuntu的方法

    Win7与Ubuntu双系统时卸载Ubuntu的方法 [日期:2010-03-26] 来源:Ubuntu社区  作者:Ubuntu编辑 [字体:大 中 小]       1. 下载MBRFix工具,放 ...

  4. AndroidStudio使用注意事项

    今天在引入GitHUb上的开源框架时,写好依赖后编译时,报以下错误: Error:Execution failed for task ':app:processDebugResources'.> ...

  5. SQL Server 索引维护sql语句

    使用以下脚本查看数据库索引碎片的大小情况: 复制代码代码如下: DBCC SHOWCONTIG WITH FAST, TABLERESULTS, ALL_INDEXES, NO_INFOMSGS  以 ...

  6. USER-AGENT是什么

    USER-AGENT是什么? USER-AGENT:记录请求所来自的浏览器. User-Agent分析网站 http://www.useragentstring.com/ 通过解析User-Agent ...

  7. 【转】PHP代码审计

    PHP代码审计 目录 1. 概述3 2. 输入验证和输出显示3 2.1 命令注入4 2.2 跨站脚本4 2.3 文件包含5 2.4 代码注入5 2.5 SQL注入6 2.6 XPath注入6 2.7 ...

  8. UISegmentedControl 踩坑

    @interface JLMyContactsViewController () @property (nonatomic, strong)   UIImageView *navImageView; ...

  9. android 1.6 launcher研究之自定义ViewGroup (转 2011.06.03(二)——— android 1.6 launcher研究之自定义ViewGroup )

    2011.06.03(2)——— android 1.6 launcher研究之自定义ViewGroup2011.06.03(2)——— android 1.6 launcher研究之自定义ViewG ...

  10. 笔记整理--HTTP Header 详解

    HTTP Header 详解 2013/09/21 | 分类: IT技术 | 0 条评论 | 标签: HTTP 分享到:36 原文出处: zcmhi HTTP(HyperTextTransferPro ...