GDBT 可以解决分类和回归问题

回归问题

def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_decrease=0.,
min_impurity_split=None, init=None, random_state=None,
max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None,
warm_start=False, presort='auto')

示例

import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn.datasets import make_friedman1
from sklearn.ensemble import GradientBoostingRegressor X, y = make_friedman1(n_samples=1200, random_state=0, noise=1.0)
X_train, X_test = X[:200], X[200:]
y_train, y_test = y[:200], y[200:] ### 损失函数
# 如果损失函数为 误差绝对值,L=|y-f(x)|,负梯度为 sign(y-f(x)),即要么1,要么-1,sklearn 中对应为 loss='lad'
# 如果损失函数为 huber,sklearn 中对应为 loss='huber'
# 如果损失函数为 均方误差,sklearn 中对应为 loss='ls'
est = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0, loss='huber').fit(X_train, y_train) pred = est.predict(X_test)
error = mean_squared_error(pred, y_test) print(max(y_test), min(y_test)) # (27.214332670044374, 0.8719243023544349)
print(error)
# loss='ls' 5.009154859960321
# loss='lad' 5.817510629608294
# loss='huber' 4.690823542377095

分类问题

def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_decrease=0.,
min_impurity_split=None, init=None,
random_state=None, max_features=None, verbose=0,
max_leaf_nodes=None, warm_start=False,
presort='auto')

示例

from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, mean_squared_error
from time import time
import numpy as np
import pandas as pd
import mnist if __name__ == "__main__":
# 读取Mnist数据集, 测试GBDT的分类模型
mnistSet = mnist.loadLecunMnistSet()
train_X, train_Y, test_X, test_Y = mnistSet[0], mnistSet[1], mnistSet[2], mnistSet[3] m, n = np.shape(train_X)
idx = range(m)
np.random.shuffle(idx) # 使用PCA降维
# num = 30000
# pca = PCA(n_components=0.9, whiten=True, random_state=0)
# for i in range(int(np.ceil(1.0 * m / num))):
# minEnd = min((i + 1) * num, m)
# sub_idx = idx[i * num:minEnd]
# train_pca_X = pca.fit_transform(train_X[sub_idx])
# print np.shape(train_pca_X) print "\n**********测试GradientBoostingClassifier类**********"
t = time()
# param_grid1 = {"n_estimators": range(1000, 2001, 100)}
# param_grid2 = {'max_depth': range(30, 71, 10), 'min_samples_split': range(4, 9, 2)}
# param_grid3 = {'min_samples_split': range(4, 9, 2), 'min_samples_leaf': range(3, 12, 2)}
# param_grid4 = {'subsample': np.arange(0.6, 1.0, 0.05)}
# model = GridSearchCV(
# estimator=GradientBoostingClassifier(max_features=90, max_depth=40, min_samples_split=8, learning_rate=0.1,
# n_estimators=1800),
# param_grid=param_grid4, cv=3)
# # 拟合训练数据集
# model.fit(train_X, train_Y)
# print "最好的参数是:%s, 此时的得分是:%0.2f" % (model.best_params_, model.best_score_)
model = GradientBoostingClassifier(max_features=90, max_depth=40, min_samples_split=8, min_samples_leaf=3,
n_estimators=1200, learning_rate=0.05, subsample=0.95)
# 拟合训练数据集
model.fit(train_X, train_Y)
# 预测训练集
train_Y_hat = model.predict(train_X[idx])
print "训练集精确度: ", accuracy_score(train_Y[idx], train_Y_hat)
# 预测测试集
test_Y_hat = model.predict(test_X)
print "测试集精确度: ", accuracy_score(test_Y, test_Y_hat)
print "总耗时:", time() - t, "秒"

参考资料:

https://github.com/haidawyl/Mnist  各种模型的用法

sklearn-GDBT的更多相关文章

  1. sklearn的常用函数以及参数

    sklearn可实现的函数或者功能可分为如下几个方面 1.分类算法2.回归算法3.聚类算法4.降维算法5.模型优化6.文本预处理 其中分类算法和回归算法又叫监督学习,聚类算法和降维算法又叫非监督学习 ...

  2. 机器学习之sklearn——EM

    GMM计算更新∑k时,转置符号T应该放在倒数第二项(这样计算出来结果才是一个协方差矩阵) from sklearn.mixture import GMM    GMM中score_samples函数第 ...

  3. 机器学习之sklearn——聚类

    生成数据集方法:sklearn.datasets.make_blobs(n_samples,n_featurs,centers)可以生成数据集,n_samples表示个数,n_features表示特征 ...

  4. 机器学习之sklearn——SVM

    sklearn包对于SVM可输出支持向量,以及其系数和数目: print '支持向量的数目: ', clf.n_support_ print '支持向量的系数: ', clf.dual_coef_ p ...

  5. 使用sklearn做单机特征工程

    目录 1 特征工程是什么?2 数据预处理 2.1 无量纲化 2.1.1 标准化 2.1.2 区间缩放法 2.1.3 标准化与归一化的区别 2.2 对定量特征二值化 2.3 对定性特征哑编码 2.4 缺 ...

  6. 使用sklearn进行集成学习——实践

    系列 <使用sklearn进行集成学习——理论> <使用sklearn进行集成学习——实践> 目录 1 Random Forest和Gradient Tree Boosting ...

  7. 【原】关于使用sklearn进行数据预处理 —— 归一化/标准化/正则化

    一.标准化(Z-Score),或者去除均值和方差缩放 公式为:(X-mean)/std  计算时对每个属性/每列分别进行. 将数据按期属性(按列进行)减去其均值,并处以其方差.得到的结果是,对于每个属 ...

  8. sklearn 增量学习 数据量大

    问题 实际处理和解决机器学习问题过程中,我们会遇到一些"大数据"问题,比如有上百万条数据,上千上万维特征,此时数据存储已经达到10G这种级别.这种情况下,如果还是直接使用传统的方式 ...

  9. 使用sklearn优雅地进行数据挖掘【转】

    目录 1 使用sklearn进行数据挖掘 1.1 数据挖掘的步骤 1.2 数据初貌 1.3 关键技术2 并行处理 2.1 整体并行处理 2.2 部分并行处理3 流水线处理4 自动化调参5 持久化6 回 ...

  10. Sklearn库例子——决策树分类

    Sklearn上关于决策树算法使用的介绍:http://scikit-learn.org/stable/modules/tree.html 1.关于决策树:决策树是一个非参数的监督式学习方法,主要用于 ...

随机推荐

  1. DB 分库分表(4):多数据源的事务处理

    系统经sharding改造之后,原来单一的数据库会演变成多个数据库,如何确保多数据源同时操作的原子性和一致性是不得不考虑的一个问题.总体上看,目前对于一个分布式系统的事务处理有三种方式:分布式事务.基 ...

  2. 我的Linux vim配置文件

    map <F9> :call SaveInputData()<CR> func! SaveInputData() exec "tabnew" exec 'n ...

  3. 踩坑:VScode 集成 eslint 插件

    本文以 Vue 官方脚手架 Vue-cli 为例: 1. 创建 Vue 项目 注意:Vue-cli 默认给出了 eslint 配置,一路回车即可.最后在安装模块的时候,选择直接安装!我用淘宝镜像安装时 ...

  4. 尚学堂requireJs课程---3、私有和公有属性和方法

    尚学堂requireJs课程---3.私有和公有属性和方法 一.总结 一句话总结: 在 [模块] 的基础上,在return对象里面的方法和属性就是公有的(因为外部可以访问),不在的就是私有的 < ...

  5. Python定时框架 Apscheduler 详解【转】

    内容来自网络: https://www.cnblogs.com/luxiaojun/p/6567132.html 在平常的工作中几乎有一半的功能模块都需要定时任务来推动,例如项目中有一个定时统计程序, ...

  6. Android jni/ndk编程一:jni初级认识与实战体验

    Android平台很多地方都可以看到jni的身影,比如之前接触到一个投屏的项目,主要的代码是c/c++写的,然后通过Jni供Java层调用;另外,就拿Android系统中的Service来说,很多的S ...

  7. 关于排查python内存泄露的简单总结

    这次的内存泄露问题是发生在多线程场景下的. 各种工具都试过了,gc,objgraph, pdb,pympler等,仍然没有找到问题所在. pdb感觉用起来很方便,可以调试代码,对原来的代码无侵入性. ...

  8. RN性能优化(重新探索react吧)

    最近做RN遇到了一些性能瓶颈,逼着自己不得不做一些优化 已经做过,或者尝试过得优化方案: 1.点击效果防止重复点击. 2.左右两边分别用两个异步栈进行更新,这样能让右边的缓慢不影响左边的更新. 3.I ...

  9. 【SQL】MySQL---using的用法

    学习记录: mysql中using的用法为: using()用于两张表的join查询,要求using()指定的列在两个表中均存在,并使用之用于join的条件

  10. 利用Calendar类判断是平年还是闰年

    package com.bgs.Math; import java.util.Calendar; import java.util.Scanner; /*###14.21_常见对象(如何获取任意年份是 ...