sklearn 可视化模型的训练测试收敛情况和特征重要性
show the code:
# Plot training deviance
def plot_training_deviance(clf, n_estimators, X_test, y_test):
# compute test set deviance
test_score = np.zeros((n_estimators,), dtype=np.float64)
for i, y_pred in enumerate(clf.staged_predict(X_test)):
test_score[i] = clf.loss_(y_test, y_pred)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title('Deviance')
train_score = clf.train_score_
logging.info("len(train_score): %s" % len(train_score))
logging.info(train_score)
logging.info("len(test_score): %s" % len(test_score))
logging.info(test_score)
plt.plot(np.arange(n_estimators) + 1, train_score, 'b-',
label='Training Set Deviance')
plt.plot(np.arange(n_estimators) + 1, test_score, 'r*', label='Test Set Deviance')
plt.legend(loc='upper right')
plt.xlabel('Boosting Iterations')
plt.ylabel('Deviance')
plt.show() # Plot feature importance
def plot_feature_importance(clf, feature_names):
feature_importance = clf.feature_importances_
# make importances relative to max importance
feature_importance = 100.0 * (feature_importance / feature_importance.max())
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0]) + .5
plt.subplot(1, 2, 2)
plt.barh(pos, feature_importance[sorted_idx], align='center')
# plt.yticks(pos, feature_names[sorted_idx])
plt.yticks(pos, [feature_names[idx] for idx in sorted_idx])
plt.xlabel('Relative Importance')
plt.title('Variable Importance')
plt.show() class Train(object):
def __init__(self, data_file):
self.data_file = data_file
self.x_fields = ["xxx", "xxx", "xxx"]
self.x_features, self.y_labels = self.load_data() def load_data(self):
x_features, y_labels = [], []
# ......
return x_features, y_labels def train_model(self):
model = GradientBoostingRegressor(random_state=42)
model.fit(self.x_features, self.y_labels)
y_pred = model.predict(self.x_features)
logging.info("mean_squared_error: %.6f" % mean_squared_error(self.y_labels, y_pred))
logging.info("mean_squared_log_error: %.6f" % mean_squared_log_error(self.y_labels, y_pred)) plot_training_deviance(clf=model, n_estimators=model.get_params()["n_estimators"], X_test=self.x_features, y_test=self.y_labels) # 输出feature重要性
logging.info("feature_importances_: %s" % model.feature_importances_)
plot_feature_importance(clf=model, feature_names=self.x_fields)
参考的是sklearn中的样例: Gradient Boosting regression — scikit-learn 0.19.2 documentation
画出的图如下所示:

sklearn 可视化模型的训练测试收敛情况和特征重要性的更多相关文章
- 【集成学习】sklearn中xgboost模块中plot_importance函数(绘图--特征重要性)
直接上代码,简单 # -*- coding: utf-8 -*- """ ################################################ ...
- 使用 TensorBoard 可视化模型、数据和训练
使用 TensorBoard 可视化模型.数据和训练 在 60 Minutes Blitz 中,我们展示了如何加载数据,并把数据送到我们继承 nn.Module 类的模型,在训练数据上训练模型,并在测 ...
- sklearn——train_test_split 随机划分训练集和测试集
sklearn——train_test_split 随机划分训练集和测试集 sklearn.model_selection.train_test_split随机划分训练集和测试集 官网文档:http: ...
- 机器学习使用sklearn进行模型训练、预测和评价
cross_val_score(model_name, x_samples, y_labels, cv=k) 作用:验证某个模型在某个训练集上的稳定性,输出k个预测精度. K折交叉验证(k-fold) ...
- pytorch seq2seq模型训练测试
num_sequence.py """ 数字序列化方法 """ class NumSequence: """ ...
- 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试
AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...
- Python 3 利用 Dlib 19.7 和 sklearn机器学习模型 实现人脸微笑检测
0.引言 利用机器学习的方法训练微笑检测模型,给一张人脸照片,判断是否微笑: 使用的数据集中69张没笑脸,65张有笑脸,训练结果识别精度在95%附近: 效果: 图1 示例效果 工程利用pytho ...
- sklearn保存模型的两种方式
sklearn 中模型保存的两种方法 一. sklearn中提供了高效的模型持久化模块joblib,将模型保存至硬盘. from sklearn.externals import joblib # ...
- sklearn保存模型-【老鱼学sklearn】
训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...
随机推荐
- hashset和treeset的区别
hashset 和 treeset的区别 hashset 和 treeset的区别 1.TreeSet 是二差树实现的,Treeset中的数据是自动排好序的,不允许放入null值. 2.HashSet ...
- SLG手游Java服务器的设计与开发——数据管理
文章版权归腾讯GAD所有,禁止匿名转载:禁止商业使用:禁止个人使用. 一.前言 上文介绍了我们的SLG手游的服务器架构设计以及网络通信部分,本文介绍数据管理部分,在数据存储方面,我选择了Mysql.M ...
- 15信号sigaction
信号处理 信号值小于 SIGRTMIN 的信号 (1~31) 都是不可靠信号 某些unix版本中,调用信号函数处理后会自动恢复默认信号处理,所以在信号处理函数中还需要继续调用signal函数设置信号处 ...
- zw版【转发·台湾nvp系列Delphi例程】HALCON TestSubsetRegio
zw版[转发·台湾nvp系列Delphi例程]HALCON TestSubsetRegio procedure TForm1.Button1Click(Sender: TObject);var rg0 ...
- ts 绘制多边形
let bg = this.createBitmapByName("123_png"); this.addChild(bg) bg.x = this.stage.width / 2 ...
- nginx重启 failed (98: Address already in use)
启动nginx服务,无法正常启动,一查log日志,发现如题错误信息. 问题描述:地址已被使用.可能nginx服务卡死了,导致端口占用,出现此错误. 查看端口 netstat -ntpl 杀掉进程 ...
- 逆向及BOF基础实践
逆向及BOF基础实践 20145316 许心远 一.缓冲区溢出基础知识 缓冲区溢出是一种非常普遍.非常危险的漏洞,在各种操作系统.应用软件中广泛存在.利用缓冲区溢出攻击,可以导致程序运行失败.系统宕机 ...
- JSDoc 注释规范
命令名描述 @param @argument 指定参数名和说明来描述一个函数参数@returns 描述函数的返回值@author 指示代码的作者@deprecated 指示一个函数已经废弃,而且在将来 ...
- JavaScript高级内容笔记:原型链、继承、执行上下文、作用域链、闭包
最近在系统的学习JS深层次内容,并稍微整理了一下,作为备忘和后期复习,这里分享给大家,希望对大家有所帮助.如有错误请留言指正,tks. 了解这些问题,我先一步步来看,先从稍微浅显内容说起,然后引出这些 ...
- STM32时钟树
STM32的时钟系统 相较于51单片机,stm32的时钟系统可以说是非常复杂了,我们现在看下面的一张图: 上图说明了时钟的走向,是从左至右的从时钟源一步步的分配给外设时钟.需要注意的是,上图左侧一共有 ...