统计学习方法——实现AdaBoost
Adaboost
适用问题:二分类问题
- 模型:加法模型
\]
- 策略:损失函数为指数函数
\]
- 算法:前向分步算法
\]
特点:AdaBoost算法的特点是通过迭代每次学习一个基本分类器。每次迭代中,提高那些被前一轮分类器错误分类数据的权值,而降低那些被正确分类的数据的权值。最后,AdaBoost将基本分类器的线性组合作为强分类器,其中给分类误差率小的基本分类器以大的权值,给分类误差率大的基本分类器以小的权值。
算法步骤:
1)给每个训练样本(\(x_{1},x_{2},….,x_{N}\))分配权重,初始权重\(w_{1}\)均为1/N。
2)针对带有权值的样本进行训练,得到模型\(G_m\)(初始模型为G1)。
3)计算模型\(G_m\)的误分率\(e_m=\sum_{i=1}^Nw_iI(y_i\not= G_m(x_i))\) (误分率应小于0.5,否则将预测结果翻转即可得到误分率小于0.5的分类器)
4)计算模型\(G_m\)的系数\(\alpha_m=0.5\log[(1-e_m)/e_m]\)
5)根据误分率e和当前权重向量\(w_m\)更新权重向量\(w_{m+1}\)。
6)计算组合模型\(f(x)=\sum_{m=1}^M\alpha_mG_m(x_i)\)的误分率。
7)当组合模型的误分率或迭代次数低于一定阈值,停止迭代;否则,回到步骤2)
提升树
提升树是以分类树或回归树为基本分类器的提升方法。提升树被认为是统计学习中最有效的方法之一。
提升方法:将弱可学习算法提升为强可学习算法。提升方法通过反复修改训练数据的权值分布,构建一系列基本分类器(弱分类器),并将这些基本分类器线性组合,构成一个强分类器。AdaBoost算法是提升方法的一个代表。
AdaBoost源码实现
假设弱分类器由 \(x < v\) 或 \(x > v\) 产生,阈值\(v\)使该分类器在训练集上分类误差率最低。
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
def create_data():
iris = load_iris() # 鸢尾花数据集
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
data = np.array(df.iloc[:100, [0, 1, -1]]) # 取前一百个数据,只保留前两个特征
for d in data:
if d[-1] == 0:
d[-1] = -1
return data[:, :2], data[:, -1].astype(np.int)
class AdaBoost:
def __init__(self, num_classifier, increment=0.5):
"""
num_classifier: 弱分类器的数量
increment: 在特征上寻找最优切分点时,搜索时每次的增加值(数据稀疏时建议根据样本点来选择)
"""
self.num_classifier = num_classifier
self.increment = increment
def fit(self, X, Y):
self._init_args(X, Y)
# 逐个训练分类器
for m in range(self.num_classifier):
min_error, v_optimal, preds = float('INF'), None, None
direct_split = None
feature_idx = None # 选定的特征的列索引
# 遍历选择特征和切分点使得分类误差最小
for j in range(self.num_feature):
feature_values = self.X[:, j] # 第j个特征对应的所有取值
_ret = self._get_optimal_split(feature_values)
v_split, _direct_split, error, pred_labels = _ret
if error < min_error:
min_error = error
v_optimal = v_split
preds = pred_labels
direct_split = _direct_split
feature_idx = j
# 计算分类型权重alpha
alpha = self._cal_alpha(min_error)
self.alphas.append(alpha)
# 记录当前分类器G(x)
self.classifiers.append((feature_idx, v_optimal, direct_split))
# 更新样本集合权值分布
self._update_weights(alpha, preds)
def predict(self, x):
res = 0.0
for i in range(len(self.classifiers)):
idx, v, direct = self.classifiers[i]
# 输入弱分类器进行分类
if direct == '>':
output = 1 if x[idx] > v else -1
else: # direct == '<'
output = -1 if x[idx] > v else 1
res += self.alphas[i] * output
return 1 if res > 0 else -1 # sign(res)
def score(self, X_test, Y_test):
cnt = 0
for i, x in enumerate(X_test):
if self.predict(x) == Y_test[i]:
cnt += 1
return cnt / len(X_test)
def _init_args(self, X, Y):
self.X = X
self.Y = Y
self.N, self.num_feature = X.shape # N:样本数,num_feature:特征数量
# 初始时每个样本的权重均相同
self.weights = [1/self.N] * self.N
# 弱分类器集合
self.classifiers = []
# 每个分类器G(x)的权重
self.alphas = []
def _update_weights(self, alpha, pred_labels):
# 计算规范化因子Z
Z = self._cal_norm_factor(alpha, pred_labels)
for i in range(self.N):
self.weights[i] = (self.weights[i] *
np.exp(-1*alpha*self.Y[i]*pred_labels[i]) / Z)
def _cal_alpha(self, error):
return 0.5 * np.log((1-error)/error)
def _cal_norm_factor(self, alpha, pred_labels):
return sum([self.weights[i] * np.exp(-1*alpha*self.Y[i]*pred_labels[i])
for i in range(self.N)])
def _get_optimal_split(self, feature_values):
error = float('INF') # 分类误差
pred_labels = [] # 分类结果
v_split_optimal = None # 当前特征的最优切割点
direct_split = None # 最优切割点的判别方向
max_v = max(feature_values)
min_v = min(feature_values)
num_step = (max_v - min_v + self.increment)/self.increment
for i in range(int(num_step)):
# 选取分割点
v_split = min_v + i * self.increment
judge_direct = '>'
preds = [1 if feature_values[k] > v_split else -1
for k in range(len(feature_values))]
# 错误样本加权误差
weight_error = sum([self.weights[k] for k in range(self.N)
if preds[k] != self.Y[k]])
# 计算分类标签翻转后的误差
preds_inv = [-p for p in preds]
weight_error_inv = sum([self.weights[k] for k in range(self.N)
if preds_inv[k] != self.Y[k]])
# 取较小误差的判别方向作为分类器的判别方向
if weight_error_inv < weight_error:
preds = preds_inv
weight_error = weight_error_inv
judge_direct = '<'
if weight_error < error:
error = weight_error
pred_labels = preds
v_split_optimal = v_split
direct_split = judge_direct
return v_split_optimal, direct_split, error, pred_labels
测试模型准确率:
X, Y = create_data()
res = []
for i in range(10):
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)
clf = AdaBoost(num_classifier=50)
clf.fit(X_train, Y_train)
res.append(clf.score(X_test, Y_test))
print('My AdaBoost: {}次的平均准确率: {:.3f}'.format(len(res), sum(res)/len(res)))
My AdaBoost: 10次的平均准确率: 0.970
sklearn库的AdaBoost实例
from sklearn.ensemble import AdaBoostClassifier
res = []
for i in range(10):
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)
clf_sklearn = AdaBoostClassifier(n_estimators=50, learning_rate=0.5)
clf_sklearn.fit(X_train, Y_train)
res.append(clf_sklearn.score(X_test, Y_test))
print('sklearn AdaBoostClassifier: {}次的平均准确率: {:.3f}'.format(
len(res), sum(res)/len(res)))
sklearn AdaBoostClassifier: 10次的平均准确率: 0.945
统计学习方法——实现AdaBoost的更多相关文章
- Adaboost算法的一个简单实现——基于《统计学习方法(李航)》第八章
最近阅读了李航的<统计学习方法(第二版)>,对AdaBoost算法进行了学习. 在第八章的8.1.3小节中,举了一个具体的算法计算实例.美中不足的是书上只给出了数值解,这里用代码将它实现一 ...
- 【NLP】基于统计学习方法角度谈谈CRF(四)
基于统计学习方法角度谈谈CRF 作者:白宁超 2016年8月2日13:59:46 [摘要]:条件随机场用于序列标注,数据分割等自然语言处理中,表现出很好的效果.在中文分词.中文人名识别和歧义消解等任务 ...
- 统计学习方法 --- 感知机模型原理及c++实现
参考博客 Liam Q博客 和李航的<统计学习方法> 感知机学习旨在求出将训练数据集进行线性划分的分类超平面,为此,导入了基于误分类的损失函数,然后利用梯度下降法对损失函数进行极小化,从而 ...
- 统计学习方法笔记--EM算法--三硬币例子补充
本文,意在说明<统计学习方法>第九章EM算法的三硬币例子,公式(9.5-9.6如何而来) 下面是(公式9.5-9.8)的说明, 本人水平有限,怀着分享学习的态度发表此文,欢迎大家批评,交流 ...
- 统计学习方法:KNN
作者:桂. 时间:2017-04-19 21:20:09 链接:http://www.cnblogs.com/xingshansi/p/6736385.html 声明:欢迎被转载,不过记得注明出处哦 ...
- 统计学习方法:罗杰斯特回归及Tensorflow入门
作者:桂. 时间:2017-04-21 21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...
- 统计学习方法:核函数(Kernel function)
作者:桂. 时间:2017-04-26 12:17:42 链接:http://www.cnblogs.com/xingshansi/p/6767980.html 前言 之前分析的感知机.主成分分析( ...
- 统计学习方法学习(四)--KNN及kd树的java实现
K近邻法 1基本概念 K近邻法,是一种基本分类和回归规则.根据已有的训练数据集(含有标签),对于新的实例,根据其最近的k个近邻的类别,通过多数表决的方式进行预测. 2模型相关 2.1 距离的度量方式 ...
- 李航《统计学习方法》CH01
CH01 统计学方法概论 前言 章节目录 统计学习 监督学习 基本概念 问题的形式化 统计学习三要素 模型 策略 算法 模型评估与模型选择 训练误差与测试误差 过拟合与模型选择 正则化与交叉验证 正则 ...
随机推荐
- 关于Java中的对象、类、抽象类、接口、继承之间的联系
关于Java中的对象.类.抽象类.接口.继承之间的联系: 导读: 寒假学习JavaSE基础,其中的概念属实比较多,关联性也比较大,再次将相关的知识点复习一些,并理顺其中的关系. 正文: 举个例子:如果 ...
- HBase 数据存储结构
在HBase中, 从逻辑上来讲数据大概就长这样: 单从图中的逻辑模型来看, HBase 和 MySQL 的区别就是: 将不同的列归属与同一个列族下 支持多版本数据 这看着感觉也没有那么太大的区别呀, ...
- 从头捋了一遍 Java 代理机制,收获颇丰
尽人事,听天命.博主东南大学硕士在读,热爱健身和篮球,乐于分享技术相关的所见所得,关注公众号 @ 飞天小牛肉,第一时间获取文章更新,成长的路上我们一起进步 本文已收录于 「CS-Wiki」Gitee ...
- HDOJ-2181(深搜记录路径)
哈密顿绕行世界问题 HDOJ-2181 1.本题是典型的搜索记录路径的问题 2.主要使用的方法是dfs深搜,在输入的时候对vector进行排序,这样才能按照字典序输出. 3.为了记录路径,我使用的是两 ...
- OpenCV计算机视觉学习(13)——图像特征点检测(Harris角点检测,sift算法)
如果需要处理的原图及代码,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/ComputerVisionPractice 前言 ...
- CVE-2019-2618 任意文件上传
漏洞描述:CVE-2019-2618漏洞主要是利用了WebLogic组件中的DeploymentService接口,该接口支持向服务器上传任意文件.攻击者突破了OAM(Oracle Access Ma ...
- HTML基础速览
HTML概述 HTML ,CSS , JavaScript, JQuery, Vue 的关系 HTML可以写一个简单的前端,但是很丑,所以需要CSS对HTML进行美化 HTML是静态的.JavaScr ...
- FreeBSD 12.2 发布
FreeBSD 团队宣布 FreeBSD 12.2 正式发布,这是 FreeBSD 12 的第三个稳定版本. 本次更新的一些亮点: 引入了对无线网络堆栈的更新和各种驱动程序,以提供更好的 802.11 ...
- 在用free()函数释放指针内存时为何要将其指针置空
在通过free()函数释放指针内存之后讲其指针置空,这样可以避免后面的程序对与该指针非法性的判断所造成的程序崩溃问题.释放空间,指针的值并没有改变,无法直接通过指针自身来进行判断空间是否已经被释放,将 ...
- MyBatis-Plus【踩坑记录01】
不要使用Mybatis原生的SqlSessionFactory,而应使用MybatisSqlSessionFactory. 原因 依赖关系如下 因此会在使用Mybaits-Plus时默认的时Mybat ...