一.利用回归树实现分类

分类也可以用回归树来做,简单说来就是训练与类别数相同的几组回归树,每一组代表一个类别,然后对所有组的输出进行softmax操作将其转换为概率分布,然后再通过交叉熵或者KL一类的损失函数求每颗树相应的负梯度,指导下一轮的训练,以三分类为例,流程如下:

二.softmax+交叉熵损失,及其梯度求解

分类问题,一般会选择用交叉熵作为损失函数,下面对softmax+交叉熵损失函数的梯度做推导:

softmax函数在最大熵那一节已有使用,再回顾一下:

\[softmax([y_1^{hat},y_2^{hat},...,y_n^{hat}])=\frac{1}{\sum_{i=1}^n e^{y_i^{hat}}}[e^{y_1^{hat}},e^{y_2^{hat}},...,e^{y_n^{hat}}]
\]

交叉熵在logistic回归有介绍:

\[cross\_entropy(y,p)=-\sum_{i=1}^n y_ilog p_i
\]

将\(p_i\)替换为\(\frac{e^{y_i^{hat}}}{\sum_{i=1}^n e^{y_i^{hat}}}\)即是我们的损失函数:

\[L(y^{hat},y)=-\sum_{i=1}^n y_ilog \frac{e^{y_i^{hat}}}{\sum_{j=1}^n e^{x_j^{hat}}}\\
=-\sum_{i=1}^n y_i(y_i^{hat}-log\sum_{j=1}^n e^{y_j^{hat}})\\
=log\sum_{i=1}^n e^{y_i^{hat}}-\sum_{i=1}^ny_iy_i^{hat}(由于是onehot展开,所以\sum_{i=1}^n y_i=1)
\]

计算梯度:

\[\frac{\partial L(y^{hat},y)}{\partial y^{hat}}=softmax([y_1^{hat},y_2^{hat},...,y_n^{hat}])-[y_1,y_2,...,y_n]
\]

所以,第一组回归树的拟合目标为\(y_1-\frac{e^{y_1^{hat}}}{\sum_{i=1}^n e^{y_i^{hat}}}\),第二组回归树学习的拟合目标为\(y_2-\frac{e^{y_2^{hat}}}{\sum_{i=1}^n e^{y_i^{hat}}}\),....,第\(n\)组回归树的拟合目标为\(y_n-\frac{e^{y_n^{hat}}}{\sum_{i=1}^n e^{y_i^{hat}}}\)

三.代码实现

import os
os.chdir('../')
from ml_models.tree import CARTRegressor
from ml_models import utils
import copy
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline class GradientBoostingClassifier(object):
def __init__(self, base_estimator=None, n_estimators=10, learning_rate=1.0):
"""
:param base_estimator: 基学习器,允许异质;异质的情况下使用列表传入比如[estimator1,estimator2,...,estimator10],这时n_estimators会失效;
同质的情况,单个estimator会被copy成n_estimators份
:param n_estimators: 基学习器迭代数量
:param learning_rate: 学习率,降低后续基学习器的权重,避免过拟合
"""
self.base_estimator = base_estimator
self.n_estimators = n_estimators
self.learning_rate = learning_rate
if self.base_estimator is None:
# 默认使用决策树桩
self.base_estimator = CARTRegressor(max_depth=2)
# 同质分类器
if type(base_estimator) != list:
estimator = self.base_estimator
self.base_estimator = [copy.deepcopy(estimator) for _ in range(0, self.n_estimators)]
# 异质分类器
else:
self.n_estimators = len(self.base_estimator) # 扩展class_num组分类器
self.expand_base_estimators = [] def fit(self, x, y):
# 将y转one-hot编码
class_num = np.amax(y) + 1
y_cate = np.zeros(shape=(len(y), class_num))
y_cate[np.arange(len(y)), y] = 1 # 扩展分类器
self.expand_base_estimators = [copy.deepcopy(self.base_estimator) for _ in range(class_num)] # 拟合第一个模型
y_pred_score_ = []
# TODO:并行优化
for class_index in range(0, class_num):
self.expand_base_estimators[class_index][0].fit(x, y_cate[:, class_index])
y_pred_score_.append(self.expand_base_estimators[class_index][0].predict(x))
y_pred_score_ = np.c_[y_pred_score_].T
# 计算负梯度
new_y = y_cate - utils.softmax(y_pred_score_)
# 训练后续模型
for index in range(1, self.n_estimators):
y_pred_score = []
for class_index in range(0, class_num):
self.expand_base_estimators[class_index][index].fit(x, new_y[:, class_index])
y_pred_score.append(self.expand_base_estimators[class_index][index].predict(x))
y_pred_score_ += np.c_[y_pred_score].T * self.learning_rate
new_y = y_cate - utils.softmax(y_pred_score_) def predict_proba(self, x):
# TODO:并行优化
y_pred_score = []
for class_index in range(0, len(self.expand_base_estimators)):
estimator_of_index = self.expand_base_estimators[class_index]
y_pred_score.append(
np.sum(
[estimator_of_index[0].predict(x)] +
[self.learning_rate * estimator_of_index[i].predict(x) for i in
range(1, self.n_estimators - 1)] +
[estimator_of_index[self.n_estimators - 1].predict(x)]
, axis=0)
)
return utils.softmax(np.c_[y_pred_score].T) def predict(self, x):
return np.argmax(self.predict_proba(x), axis=1)
#造伪数据
from sklearn.datasets import make_classification
data, target = make_classification(n_samples=100, n_features=2, n_classes=2, n_informative=1, n_redundant=0,
n_repeated=0, n_clusters_per_class=1, class_sep=.5,random_state=21)
# 同质
classifier = GradientBoostingClassifier(base_estimator=CARTRegressor(),n_estimators=10)
classifier.fit(data, target)
utils.plot_decision_function(data, target, classifier)

#异质
from ml_models.linear_model import LinearRegression
classifier = GradientBoostingClassifier(base_estimator=[LinearRegression(),LinearRegression(),LinearRegression(),CARTRegressor(max_depth=2)])
classifier.fit(data, target)
utils.plot_decision_function(data, target, classifier)

《机器学习Python实现_10_06_集成学习_boosting_gbdt分类实现》的更多相关文章

  1. 简单物联网:外网访问内网路由器下树莓派Flask服务器

    最近做一个小东西,大概过程就是想在教室,宿舍控制实验室的一些设备. 已经在树莓上搭了一个轻量的flask服务器,在实验室的路由器下,任何设备都是可以访问的:但是有一些限制条件,比如我想在宿舍控制我种花 ...

  2. 利用ssh反向代理以及autossh实现从外网连接内网服务器

    前言 最近遇到这样一个问题,我在实验室架设了一台服务器,给师弟或者小伙伴练习Linux用,然后平时在实验室这边直接连接是没有问题的,都是内网嘛.但是回到宿舍问题出来了,使用校园网的童鞋还是能连接上,使 ...

  3. 外网访问内网Docker容器

    外网访问内网Docker容器 本地安装了Docker容器,只能在局域网内访问,怎样从外网也能访问本地Docker容器? 本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装并启动Docker容器 ...

  4. 外网访问内网SpringBoot

    外网访问内网SpringBoot 本地安装了SpringBoot,只能在局域网内访问,怎样从外网也能访问本地SpringBoot? 本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装Java 1 ...

  5. 外网访问内网Elasticsearch WEB

    外网访问内网Elasticsearch WEB 本地安装了Elasticsearch,只能在局域网内访问其WEB,怎样从外网也能访问本地Elasticsearch? 本文将介绍具体的实现步骤. 1. ...

  6. 怎样从外网访问内网Rails

    外网访问内网Rails 本地安装了Rails,只能在局域网内访问,怎样从外网也能访问本地Rails? 本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装并启动Rails 默认安装的Rails端口 ...

  7. 怎样从外网访问内网Memcached数据库

    外网访问内网Memcached数据库 本地安装了Memcached数据库,只能在局域网内访问,怎样从外网也能访问本地Memcached数据库? 本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装 ...

  8. 怎样从外网访问内网CouchDB数据库

    外网访问内网CouchDB数据库 本地安装了CouchDB数据库,只能在局域网内访问,怎样从外网也能访问本地CouchDB数据库? 本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装并启动Cou ...

  9. 怎样从外网访问内网DB2数据库

    外网访问内网DB2数据库 本地安装了DB2数据库,只能在局域网内访问,怎样从外网也能访问本地DB2数据库? 本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装并启动DB2数据库 默认安装的DB2 ...

  10. 怎样从外网访问内网OpenLDAP数据库

    外网访问内网OpenLDAP数据库 本地安装了OpenLDAP数据库,只能在局域网内访问,怎样从外网也能访问本地OpenLDAP数据库? 本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装并启动 ...

随机推荐

  1. JAVA基础(二)—— 常用的类与方法

    JAVA基础(二)-- 常用的类与方法 1 Math类 abs ceil floor 绝对值 大于等于该浮点数的最小整数 小于等于该浮点数的最大整数 max min round 两参数中较大的 两参数 ...

  2. Spring中的依赖查找和依赖注入

    作者:Grey 原文地址: 语雀 博客园 依赖查找 Spring IoC 依赖查找分为以下几种方式 根据 Bean 名称查找 实时查找 延迟查找 根据 Bean 类型查找 单个 Bean 对象 集合 ...

  3. 大话Spark(7)-源码之Master主备切换

    Master作为Spark Standalone模式中的核心,如果Master出现异常,则整个集群的运行情况和资源都无法进行管理,整个集群将处于无法工作的状态. Spark在设计的时候考虑到了这种情况 ...

  4. 一文帮你搞懂 Android 文件描述符

    介绍文件描述符的概念以及工作原理,并通过源码了解 Android 中常见的 FD 泄漏. 一.什么是文件描述符? 文件描述符是在 Linux 文件系统的被使用,由于Android基 于Linux 系统 ...

  5. renren-fast部署发布教程(tomcat)

    renren-fast部署发布教程(tomcat) 说明:renren的开发文档需要付费,官方的生产部署介绍相对比较简单,因此记录自己的部署过程 为了方便,前后端我都部署在同一台linux服务器上,其 ...

  6. RateLimiter源码解析

    RateLimiter是Guava包提供的限流器,采用了令牌桶算法,特定是均匀地向桶中添加令牌,每次消费时也必须持有令牌,否则就需要等待.应用场景之一是限制消息消费的速度,避免消息消费过快而对下游的数 ...

  7. WorkSkill 面试之 字节跳动一面

  8. ASP.Net Core中处理异常的几种方法

    本文将介绍在ASP.Net Core中处理异常的几种方法 1使用开发人员异常页面(The developer exception page) 2配置HTTP错误代码页 Configuring stat ...

  9. 设计vue3的请求实体工厂

    设计一个vue3的请求实体工厂 目录 设计一个vue3的请求实体工厂 描述 实现 构建一个基础请求方法 创建具体请求的方法 下面是对请求的声明文件 下面是请求的定义 generateRequest对请 ...

  10. 计算机体系结构——CH5 标量处理机

    计算机体系结构--CH5 标量处理机 右键点击查看图像,查看清晰图像 X-mind 计算机体系结构--CH5 标量处理机 先行控制技术 指令得重叠执行方式 顺序执行方式 一次重叠执行方式 二次重叠技术 ...