提升方法--AdaBoost

前言

AdaBoost是最经典的提升方法,所谓的提升方法就是一系列弱分类器(分类效果只比随机预测好一点)经过组合提升最后的预测效果。而AdaBoost提升方法是在每次训练弱分类器的时候,提升上一个弱分类器误分类的数据的比重来让本次训练的分类器能够弥补上次分类器的不足。AdaBoost的概念和算法还是很好理解的,而且通过书上的例题可以很明显的感觉用一个很简单(计算量很小)的分类器,经过提升后的最终分类器的效果很好,本篇还是着重实现部分,并且将我在实现时候遇到的问题和思考记录下来。代码地址

AdaBoost算法

输入: 训练数据集$T = { (x_1,y_1), (x_2,y_2),...,(x_N,y_N) } $; 某个弱分类算法(比如训练不足够的感知机,我的实现中就选的迭代次数比较少的感知机)。

输出: 最终分类器G(x)。

  1. 初始化训练数据的权值分布(这个权值每个训练数据都有, 但是并不是意味着训练的时候将数据点乘上这个权值,它的作用主要是在计算误差率的时候用到)。

    \(D_1 = (w_{11}, ...,w_{1i},...,w_{1N}), w_{1i}=\frac{1}{N}, i=1,2,...,N\)

  2. 对每一次提升(也就是弱分类器个数)m:

    • 根据训练数据和权重训练分类器\(G_m(x)\)。这里的权值分布是在训练分类器的时候用到的,具体来说就是在优化目标函数(比如最大似然函数或者误分类率)的时候,考虑每个数据点的权值

    • 计算\(G_m(x)\)在训练数据集上的分类误差率:

      \(e_m = \sum^{N}_{i=1}w_{mi}I(G_m(x_i)\neq y_i)\)

    • 计算\(G_m(x)\)的系数

      \(a_m = 0.5log(\frac{1-e_m}{e_m})\)

    • 更新训练数据的权值分布

      \(D_{m+1}=(w_{m+1,1}, ...,w_{m+1,i},...,w_{m+1,N})\)

      \(w_{m+1,i} = \frac{w_{mi}}{Z_m}exp(-a_my_iG_m(x_i))\)

      \(Z_m = \sum^{N}_{i=1}w_{mi}exp(-a_my_iG_m(x_i))\)

  3. 构建基本分类器的线性组合

    \(G(x) = sign(\sum^{M}_{m=1}a_mG_m(x))\)

说明:对于训练数据的权值问题,是我在实际实现的时候发现的,这个要特别注意。还有就是基本分类器的选择问题,一定要选取比随机预测效果好的分类器,比如二分类问题,一定要选择分类误差率小于0.5的分类器,否则后续无法提升。

C++实现

代码结构

重要代码

这里放上求数据权值的代码

int AdaBoost::computeWeights(Perceptron* classifier) {
vector<double> trainGT;
//由于我的感知机算法(见前面的系列代码)采用的loss函数里面包含训练数据的真值
//于是这里我就通过改变真值的比重来反应训练数据的比重
for(int i =0; i<trainDataGT.size();++i)
trainGT.push_back(trainDataGT[i]*featrWeight[i]);
classifier->setTrainD(trainDataF, trainGT);//将本算法的训练数据设为感知机的训练数据
classifier->setDim(indim);
classifier->train(100, 0.9);
//这里第一个参数是感知机的训练次数,第二个参数是学习率。
//100次迭代时学习的是一个强分类器,直接将全部数据分类正确,经过实验,将训练步数设置为90就可以得到弱分类器
//会有分类错误的情况,但是分类误差率小于0.5。
double erroeRate = 0;
for(int i = 0; i<trainDataF.size();++i) {
if (classifier->predict(trainDataF[i])!=int(trainDataGT[i]))
erroeRate += featrWeight[i];
}
if(erroeRate==0){
if(clsfWeight.size()==0)
clsfWeight.push_back(1);
return 0;
} double clsW;
clsW = 0.5*std::log((1-erroeRate)/erroeRate);
clsfWeight.push_back(clsW); double zm=0;
for(int i = 0; i<trainDataF.size();++i) {
zm+=featrWeight[i]*std::exp(-clsW*trainDataGT[i]*classifier->predict(trainDataF[i]));
} for(int i = 0; i<featrWeight.size();++i ){
featrWeight[i] = featrWeight[i]/zm*std::exp(-clsW*trainDataGT[i]*classifier->predict(trainDataF[i]));
}
return 1;
}

再次强调,更改感知机的训练次数和学习率会有不同的结果,但是我的结果得到的最终的分类器却不如一个训练次数多的强分类器好,可能是因为我的训练数据太小。

这里主要是想练习用c++的指针使用其它类,也可以用其它的分类器,单是之前写那些算法并没有提供被调用的接口(当时并没有想要调用),改了改感知机的代码才勉强能用,以后写代码还是需要多思考。

统计学习方法c++实现之七 提升方法--AdaBoost的更多相关文章

  1. 机器学习理论提升方法AdaBoost算法第一卷

    AdaBoost算法内容来自<统计学习与方法>李航,<机器学习>周志华,以及<机器学习实战>Peter HarringTon,相互学习,不足之处请大家多多指教! 提 ...

  2. 模型提升方法adaBoost

    他通过改变训练样本的权重,学习多个分类器,并将这些分类器进行线性组合,提高分类的性能. adaboost提高那些被前一轮弱分类器错误分类样本的权重,而降低那些被正确分类样本的权重,这样使得,那些没有得 ...

  3. 提升方法-AdaBoost

    提升方法通过改变训练样本的权重,学习多个分类器(弱分类器/基分类器)并将这些分类器进行线性组合,提高分类的性能. AdaBoost算法的特点是不改变所给的训练数据,而不断改变训练数据权值的分布,使得训 ...

  4. 机器学习——提升方法AdaBoost算法,推导过程

    0提升的基本方法 对于分类的问题,给定一个训练样本集,求比较粗糙的分类规则(弱分类器)要比求精确的分类的分类规则(强分类器)容易的多.提升的方法就是从弱分类器算法出发,反复学习,得到一系列弱分类器(又 ...

  5. Adaboost算法的一个简单实现——基于《统计学习方法(李航)》第八章

    最近阅读了李航的<统计学习方法(第二版)>,对AdaBoost算法进行了学习. 在第八章的8.1.3小节中,举了一个具体的算法计算实例.美中不足的是书上只给出了数值解,这里用代码将它实现一 ...

  6. 08_提升方法_AdaBoost算法

    今天是2020年2月24日星期一.一个又一个意外因素串连起2020这不平凡的一年,多么希望时间能够倒退.曾经觉得电视上科比的画面多么熟悉,现在全成了陌生和追忆. GitHub:https://gith ...

  7. 组合方法(ensemble method) 与adaboost提升方法

    组合方法: 我们分类中用到非常多经典分类算法如:SVM.logistic 等,我们非常自然的想到一个方法.我们是否可以整合多个算法优势到解决某一个特定分类问题中去,答案是肯定的! 通过聚合多个分类器的 ...

  8. Boosting(提升方法)之AdaBoost

    集成学习(ensemble learning)通过构建并结合多个个体学习器来完成学习任务,也被称为基于委员会的学习. 集成学习构建多个个体学习器时分两种情况:一种情况是所有的个体学习器都是同一种类型的 ...

  9. 统计学习方法 AdaBoost

    提升方法的基本思路 在概率近似正确(probably approximately correct,PAC)学习的框架中, 一个概念(一个类),如果存在一个多项式的学习算法能够学习它,并且正确率很高,那 ...

随机推荐

  1. 我的Java之旅——答答租车系统的改进

    之前的答答租车系统虽然可以实现项目的要求,但是没有用Java面向对象,今天用面向对象的三大特性封装.继承和多态来改进原来的代码.题目和之前的代码参考上篇博客,这里不再述说. 改进后的代码: Vehic ...

  2. 8年前,令我窒息的Java socket体验学习

    本来已经放弃编程了,那时我发誓再也不去IT培训班了,如果找不到工作,我就去工地上打工.可心有不甘,老是惦记着,我不想天天面对生产线,做一个丧失思考能力的操作工,可后来呀,还是走上了程序员之路...这么 ...

  3. 【bzoj5016】[Snoi2017]一个简单的询问 莫队算法

    题目描述 给你一个长度为N的序列ai,1≤i≤N和q组询问,每组询问读入l1,r1,l2,r2,需输出 get(l,r,x)表示计算区间[l,r]中,数字x出现了多少次. 输入 第一行,一个数字N,表 ...

  4. luogu【模板】三维偏序(陌上花开)

    嘟嘟嘟 很显然我开始学\(CDQ\)分治了. 我刚开始学的时候看了一篇博客,上面全是一些抽象的概念,看完后真是一头雾水,最后还不得不抄了这题的代码. 但这样可不行呀-- 于是我就不打算再扣那篇博客,而 ...

  5. 分布式缓存技术redis系列(一)——redis简介以及linux上的安装

    redis简介 redis是NoSQL(No Only SQL,非关系型数据库)的一种,NoSQL是以Key-Value的形式存储数据.当前主流的分布式缓存技术有redis,memcached,ssd ...

  6. Oracle(一)执行计划

    目录 一.什么是执行计划 二.如何查看执行计划 三.如何读懂执行计划 1. 执行顺序的原则 2. 执行计划中字段解释 3. 谓词说明 4. JOIN方式 4.1 HASH JOIN(散列连接) 4.2 ...

  7. K2使用Nginx做负载均衡

    K2使用Nginx做负载均衡 K2目前是支持Load Balancing这种方式,来做负载均衡,也可以使用F5来做负载均衡,但这次我使用nginx来实现K2的负载均衡 下载nginx 请下载nginx ...

  8. jQuery----each()方法

    jquery中有隐式迭代,不需要我们再次对某些元素进行操作.但是如果涉及到不同元素有不同操作,需要进行each遍历.本文利用10个li设置不同的透明度的案例,对each方法进行说明. 语法: $(元素 ...

  9. php数组函数array_column:不用循环就能提取多维数组内容

    作为一个有多年PHP开发经验的码农,我也是前段时间才发现PHP处理数组有这么好用的函数, 至此之前,我处理数组的数据基本都是使用循环,记录一下两个函数的用法: array_column() 函数 返回 ...

  10. 【转】CSDN离线网页html文件自动跳转

    问题: 最近使用OneNote2016剪辑csdn的文章时,发现一些公式/文本框不能被正确识别,所以离线保存网页的html文件. 但是每次打开html文件,都会自动跳转的CSDN主页,即使断网,也会自 ...