提升方法--AdaBoost

前言

AdaBoost是最经典的提升方法,所谓的提升方法就是一系列弱分类器(分类效果只比随机预测好一点)经过组合提升最后的预测效果。而AdaBoost提升方法是在每次训练弱分类器的时候,提升上一个弱分类器误分类的数据的比重来让本次训练的分类器能够弥补上次分类器的不足。AdaBoost的概念和算法还是很好理解的,而且通过书上的例题可以很明显的感觉用一个很简单(计算量很小)的分类器,经过提升后的最终分类器的效果很好,本篇还是着重实现部分,并且将我在实现时候遇到的问题和思考记录下来。代码地址

AdaBoost算法

输入: 训练数据集$T = { (x_1,y_1), (x_2,y_2),...,(x_N,y_N) } $; 某个弱分类算法(比如训练不足够的感知机,我的实现中就选的迭代次数比较少的感知机)。

输出: 最终分类器G(x)。

  1. 初始化训练数据的权值分布(这个权值每个训练数据都有, 但是并不是意味着训练的时候将数据点乘上这个权值,它的作用主要是在计算误差率的时候用到)。

    \(D_1 = (w_{11}, ...,w_{1i},...,w_{1N}), w_{1i}=\frac{1}{N}, i=1,2,...,N\)

  2. 对每一次提升(也就是弱分类器个数)m:

    • 根据训练数据和权重训练分类器\(G_m(x)\)。这里的权值分布是在训练分类器的时候用到的,具体来说就是在优化目标函数(比如最大似然函数或者误分类率)的时候,考虑每个数据点的权值

    • 计算\(G_m(x)\)在训练数据集上的分类误差率:

      \(e_m = \sum^{N}_{i=1}w_{mi}I(G_m(x_i)\neq y_i)\)

    • 计算\(G_m(x)\)的系数

      \(a_m = 0.5log(\frac{1-e_m}{e_m})\)

    • 更新训练数据的权值分布

      \(D_{m+1}=(w_{m+1,1}, ...,w_{m+1,i},...,w_{m+1,N})\)

      \(w_{m+1,i} = \frac{w_{mi}}{Z_m}exp(-a_my_iG_m(x_i))\)

      \(Z_m = \sum^{N}_{i=1}w_{mi}exp(-a_my_iG_m(x_i))\)

  3. 构建基本分类器的线性组合

    \(G(x) = sign(\sum^{M}_{m=1}a_mG_m(x))\)

说明:对于训练数据的权值问题,是我在实际实现的时候发现的,这个要特别注意。还有就是基本分类器的选择问题,一定要选取比随机预测效果好的分类器,比如二分类问题,一定要选择分类误差率小于0.5的分类器,否则后续无法提升。

C++实现

代码结构

重要代码

这里放上求数据权值的代码

int AdaBoost::computeWeights(Perceptron* classifier) {
vector<double> trainGT;
//由于我的感知机算法(见前面的系列代码)采用的loss函数里面包含训练数据的真值
//于是这里我就通过改变真值的比重来反应训练数据的比重
for(int i =0; i<trainDataGT.size();++i)
trainGT.push_back(trainDataGT[i]*featrWeight[i]);
classifier->setTrainD(trainDataF, trainGT);//将本算法的训练数据设为感知机的训练数据
classifier->setDim(indim);
classifier->train(100, 0.9);
//这里第一个参数是感知机的训练次数,第二个参数是学习率。
//100次迭代时学习的是一个强分类器,直接将全部数据分类正确,经过实验,将训练步数设置为90就可以得到弱分类器
//会有分类错误的情况,但是分类误差率小于0.5。
double erroeRate = 0;
for(int i = 0; i<trainDataF.size();++i) {
if (classifier->predict(trainDataF[i])!=int(trainDataGT[i]))
erroeRate += featrWeight[i];
}
if(erroeRate==0){
if(clsfWeight.size()==0)
clsfWeight.push_back(1);
return 0;
} double clsW;
clsW = 0.5*std::log((1-erroeRate)/erroeRate);
clsfWeight.push_back(clsW); double zm=0;
for(int i = 0; i<trainDataF.size();++i) {
zm+=featrWeight[i]*std::exp(-clsW*trainDataGT[i]*classifier->predict(trainDataF[i]));
} for(int i = 0; i<featrWeight.size();++i ){
featrWeight[i] = featrWeight[i]/zm*std::exp(-clsW*trainDataGT[i]*classifier->predict(trainDataF[i]));
}
return 1;
}

再次强调,更改感知机的训练次数和学习率会有不同的结果,但是我的结果得到的最终的分类器却不如一个训练次数多的强分类器好,可能是因为我的训练数据太小。

这里主要是想练习用c++的指针使用其它类,也可以用其它的分类器,单是之前写那些算法并没有提供被调用的接口(当时并没有想要调用),改了改感知机的代码才勉强能用,以后写代码还是需要多思考。

统计学习方法c++实现之七 提升方法--AdaBoost的更多相关文章

  1. 机器学习理论提升方法AdaBoost算法第一卷

    AdaBoost算法内容来自<统计学习与方法>李航,<机器学习>周志华,以及<机器学习实战>Peter HarringTon,相互学习,不足之处请大家多多指教! 提 ...

  2. 模型提升方法adaBoost

    他通过改变训练样本的权重,学习多个分类器,并将这些分类器进行线性组合,提高分类的性能. adaboost提高那些被前一轮弱分类器错误分类样本的权重,而降低那些被正确分类样本的权重,这样使得,那些没有得 ...

  3. 提升方法-AdaBoost

    提升方法通过改变训练样本的权重,学习多个分类器(弱分类器/基分类器)并将这些分类器进行线性组合,提高分类的性能. AdaBoost算法的特点是不改变所给的训练数据,而不断改变训练数据权值的分布,使得训 ...

  4. 机器学习——提升方法AdaBoost算法,推导过程

    0提升的基本方法 对于分类的问题,给定一个训练样本集,求比较粗糙的分类规则(弱分类器)要比求精确的分类的分类规则(强分类器)容易的多.提升的方法就是从弱分类器算法出发,反复学习,得到一系列弱分类器(又 ...

  5. Adaboost算法的一个简单实现——基于《统计学习方法(李航)》第八章

    最近阅读了李航的<统计学习方法(第二版)>,对AdaBoost算法进行了学习. 在第八章的8.1.3小节中,举了一个具体的算法计算实例.美中不足的是书上只给出了数值解,这里用代码将它实现一 ...

  6. 08_提升方法_AdaBoost算法

    今天是2020年2月24日星期一.一个又一个意外因素串连起2020这不平凡的一年,多么希望时间能够倒退.曾经觉得电视上科比的画面多么熟悉,现在全成了陌生和追忆. GitHub:https://gith ...

  7. 组合方法(ensemble method) 与adaboost提升方法

    组合方法: 我们分类中用到非常多经典分类算法如:SVM.logistic 等,我们非常自然的想到一个方法.我们是否可以整合多个算法优势到解决某一个特定分类问题中去,答案是肯定的! 通过聚合多个分类器的 ...

  8. Boosting(提升方法)之AdaBoost

    集成学习(ensemble learning)通过构建并结合多个个体学习器来完成学习任务,也被称为基于委员会的学习. 集成学习构建多个个体学习器时分两种情况:一种情况是所有的个体学习器都是同一种类型的 ...

  9. 统计学习方法 AdaBoost

    提升方法的基本思路 在概率近似正确(probably approximately correct,PAC)学习的框架中, 一个概念(一个类),如果存在一个多项式的学习算法能够学习它,并且正确率很高,那 ...

随机推荐

  1. Qt如何设置应用ico图标

    第一步,创建ico文件.将ico图标文件复制到工程文件夹目录中(注意必须是图标文件,任何格式的改后缀都不行) ,重命名为"myico.ico“.然后在该目录中右击,新建文本文档,并输入一行代 ...

  2. Oracle 数据库创建(图形界面操作)

    Oracle 创建数据库图文分解: 1. 选择所有程序->Oracle-OraDb11g_home1->Configuration and Migration Tools -> Da ...

  3. HttpClient的包含注意事项

    HttpClient 功能介绍 以下列出的是 HttpClient 提供的主要的功能,要知道更多详细的功能可以参见 HttpClient 的主页. 实现了所有 HTTP 的方法(GET,POST,PU ...

  4. Hibernate三种状态;query查询;ResultTransformer转换为pojo对象;能够将query语句写在xml中;Criteria查询;ProjectionList总和/f分组等函数

    版权声明:本文为博主原创文章,未经博主同意不得转载. https://blog.csdn.net/u010026901/article/details/24256091 Session操作过程中的po ...

  5. 批量删除Redis中的数据

    测试环境上是docker安装的redis,生产上使用的是阿里云Redis服务,需要批量清理生产上的数据. 阿里云提供了BS结构的工具管理Redis,但是不能全选批量删除,只能脚本删除,方法是在测试环境 ...

  6. 2668: [cqoi2012]交换棋子

    Description 有一个n行m列的黑白棋盘,你每次可以交换两个相邻格子(相邻是指有公共边或公共顶点)中的棋子,最终达到目标状态.要求第i行第j列的格子只能参与mi,j次交换. Input 第一行 ...

  7. gcd?人生赢家!

    题目背景 原创:b2019dy gcd是一个热爱游戏的人 题目描述 gcd最近在玩一个有趣的游戏 我们把这个游戏抽象成一张图,图上有n个点,我们需要寻找总计m件宝物,它们分布在图上,对于每件宝物而言, ...

  8. 7、Android---网络技术

    玩手机不能上网是单机的时代 而且现在的流量也出了无限使用 几乎网络离不开人们的日常生活 7.1.WebView的用法 遇到一些特殊的请求 在程序中展示一些网页 加载和显示网页都是浏览器的任务 在不打开 ...

  9. 二进制包 vs. 源代码包

    在ROS中, 我们可能经常会遇到缺少相关的ROS依赖的问题. 有些时候你编译或者运行一些ROS程序, 系统会提示找不到XXX功能包. 如果是缺少ROS的依赖, 通常可以用以下命令来安装: $ sudo ...

  10. K2 BPM介绍(1)

    K2 BPM介绍(1) 官网访问地址: 中文官网 英文官网 它是一个强大的BPM产品 K2 BPM详解 产品特性 与任何内容集成 Integrate with Anything 功能丰富的窗体 Fea ...