GDBT 可以解决分类和回归问题

回归问题

def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_decrease=0.,
min_impurity_split=None, init=None, random_state=None,
max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None,
warm_start=False, presort='auto')

示例

import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn.datasets import make_friedman1
from sklearn.ensemble import GradientBoostingRegressor X, y = make_friedman1(n_samples=1200, random_state=0, noise=1.0)
X_train, X_test = X[:200], X[200:]
y_train, y_test = y[:200], y[200:] ### 损失函数
# 如果损失函数为 误差绝对值,L=|y-f(x)|,负梯度为 sign(y-f(x)),即要么1,要么-1,sklearn 中对应为 loss='lad'
# 如果损失函数为 huber,sklearn 中对应为 loss='huber'
# 如果损失函数为 均方误差,sklearn 中对应为 loss='ls'
est = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0, loss='huber').fit(X_train, y_train) pred = est.predict(X_test)
error = mean_squared_error(pred, y_test) print(max(y_test), min(y_test)) # (27.214332670044374, 0.8719243023544349)
print(error)
# loss='ls' 5.009154859960321
# loss='lad' 5.817510629608294
# loss='huber' 4.690823542377095

分类问题

def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_decrease=0.,
min_impurity_split=None, init=None,
random_state=None, max_features=None, verbose=0,
max_leaf_nodes=None, warm_start=False,
presort='auto')

示例

from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, mean_squared_error
from time import time
import numpy as np
import pandas as pd
import mnist if __name__ == "__main__":
# 读取Mnist数据集, 测试GBDT的分类模型
mnistSet = mnist.loadLecunMnistSet()
train_X, train_Y, test_X, test_Y = mnistSet[0], mnistSet[1], mnistSet[2], mnistSet[3] m, n = np.shape(train_X)
idx = range(m)
np.random.shuffle(idx) # 使用PCA降维
# num = 30000
# pca = PCA(n_components=0.9, whiten=True, random_state=0)
# for i in range(int(np.ceil(1.0 * m / num))):
# minEnd = min((i + 1) * num, m)
# sub_idx = idx[i * num:minEnd]
# train_pca_X = pca.fit_transform(train_X[sub_idx])
# print np.shape(train_pca_X) print "\n**********测试GradientBoostingClassifier类**********"
t = time()
# param_grid1 = {"n_estimators": range(1000, 2001, 100)}
# param_grid2 = {'max_depth': range(30, 71, 10), 'min_samples_split': range(4, 9, 2)}
# param_grid3 = {'min_samples_split': range(4, 9, 2), 'min_samples_leaf': range(3, 12, 2)}
# param_grid4 = {'subsample': np.arange(0.6, 1.0, 0.05)}
# model = GridSearchCV(
# estimator=GradientBoostingClassifier(max_features=90, max_depth=40, min_samples_split=8, learning_rate=0.1,
# n_estimators=1800),
# param_grid=param_grid4, cv=3)
# # 拟合训练数据集
# model.fit(train_X, train_Y)
# print "最好的参数是:%s, 此时的得分是:%0.2f" % (model.best_params_, model.best_score_)
model = GradientBoostingClassifier(max_features=90, max_depth=40, min_samples_split=8, min_samples_leaf=3,
n_estimators=1200, learning_rate=0.05, subsample=0.95)
# 拟合训练数据集
model.fit(train_X, train_Y)
# 预测训练集
train_Y_hat = model.predict(train_X[idx])
print "训练集精确度: ", accuracy_score(train_Y[idx], train_Y_hat)
# 预测测试集
test_Y_hat = model.predict(test_X)
print "测试集精确度: ", accuracy_score(test_Y, test_Y_hat)
print "总耗时:", time() - t, "秒"

参考资料:

https://github.com/haidawyl/Mnist  各种模型的用法

sklearn-GDBT的更多相关文章

  1. sklearn的常用函数以及参数

    sklearn可实现的函数或者功能可分为如下几个方面 1.分类算法2.回归算法3.聚类算法4.降维算法5.模型优化6.文本预处理 其中分类算法和回归算法又叫监督学习,聚类算法和降维算法又叫非监督学习 ...

  2. 机器学习之sklearn——EM

    GMM计算更新∑k时,转置符号T应该放在倒数第二项(这样计算出来结果才是一个协方差矩阵) from sklearn.mixture import GMM    GMM中score_samples函数第 ...

  3. 机器学习之sklearn——聚类

    生成数据集方法:sklearn.datasets.make_blobs(n_samples,n_featurs,centers)可以生成数据集,n_samples表示个数,n_features表示特征 ...

  4. 机器学习之sklearn——SVM

    sklearn包对于SVM可输出支持向量,以及其系数和数目: print '支持向量的数目: ', clf.n_support_ print '支持向量的系数: ', clf.dual_coef_ p ...

  5. 使用sklearn做单机特征工程

    目录 1 特征工程是什么?2 数据预处理 2.1 无量纲化 2.1.1 标准化 2.1.2 区间缩放法 2.1.3 标准化与归一化的区别 2.2 对定量特征二值化 2.3 对定性特征哑编码 2.4 缺 ...

  6. 使用sklearn进行集成学习——实践

    系列 <使用sklearn进行集成学习——理论> <使用sklearn进行集成学习——实践> 目录 1 Random Forest和Gradient Tree Boosting ...

  7. 【原】关于使用sklearn进行数据预处理 —— 归一化/标准化/正则化

    一.标准化(Z-Score),或者去除均值和方差缩放 公式为:(X-mean)/std  计算时对每个属性/每列分别进行. 将数据按期属性(按列进行)减去其均值,并处以其方差.得到的结果是,对于每个属 ...

  8. sklearn 增量学习 数据量大

    问题 实际处理和解决机器学习问题过程中,我们会遇到一些"大数据"问题,比如有上百万条数据,上千上万维特征,此时数据存储已经达到10G这种级别.这种情况下,如果还是直接使用传统的方式 ...

  9. 使用sklearn优雅地进行数据挖掘【转】

    目录 1 使用sklearn进行数据挖掘 1.1 数据挖掘的步骤 1.2 数据初貌 1.3 关键技术2 并行处理 2.1 整体并行处理 2.2 部分并行处理3 流水线处理4 自动化调参5 持久化6 回 ...

  10. Sklearn库例子——决策树分类

    Sklearn上关于决策树算法使用的介绍:http://scikit-learn.org/stable/modules/tree.html 1.关于决策树:决策树是一个非参数的监督式学习方法,主要用于 ...

随机推荐

  1. 免费馅饼~-~ (hdu 1176

    当我准备要写这个随笔的时候是需要勇气的. 掉馅饼嘛,肯定是坑. (hdu1176 话说,gameboy人品太好,放学回家路上有馅饼可捡.还就在0~10这11个位置里,当馅饼开始掉的时候,gameboy ...

  2. Android_activity实现一个简单的新建联系表

    项目展示: 第一个Activity用于显示联系人信息 第二个Activity输入联系人信息 要求: 运行“新建联系人”程序,结果如下图所示: 点击“新建联系人”按钮,打开输入信息界面并输入姓名.公司. ...

  3. django 网站上传资源的显示与配置

    1.  上传资源的配置 1. 首先在项目里创建一个名称叫media的文件夹专门保存用户上传 2. settings.py文件配置上传资源的路径 # 上传资源路径,如果图片,上传文件等 MEDIA_UR ...

  4. zookeeper系列(八)zookeeper客户端的底层详解

    作者:leesf    掌控之中,才会成功:掌控之外,注定失败.出处:http://www.cnblogs.com/leesf456/p/6098255.html 尊重原创,共同学习进步:  一.前言 ...

  5. Linux 变量 $$ $! $? $- $# $* $@ $0 $n

    [参考文章]:linux中shell变量$#,$@,$0,$1,$2的含义解释 1. 变量说明 1.1 $$ Shell本身的PID(ProcessID) 1.2 $! Shell最后运行的后台Pro ...

  6. 自定义一个数组对象工具demo

    <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...

  7. 爬虫 lxml 模块

    Xpath 在 XML 文档中查找信息的语言, 同样适用于 HTML 辅助工具 Xpath Helper Chrome插件  快捷键 Ctrl + shift + x XML Quire xpath ...

  8. [idea]创建一个控制台程序

    新建项目时,选择JBoss即可.

  9. OpenStack 虚拟机启动流程 UML 分析(内含 UML 源码)

    目录 文章目录 目录 前言 API 请求 Nova API 阶段 Nova Conductor 阶段 Nova Scheduler 阶段 Nova Compute 阶段(计算节点资源分配部分) Nov ...

  10. JAVA记事本的图形用户界面应用程序含加密

    JAVA记事本的图形用户界面应用程序 加密 题目简介: 整体分析: 实验代码: import java.awt.EventQueue; import java.awt.event.ActionEven ...