lightgbm用于排序
一.
LTR(learning to rank)经常用于搜索排序中,开源工具中比较有名的是微软的ranklib,但是这个好像是单机版的,也有好长时间没有更新了。所以打算想利用lightgbm进行排序,但网上关于lightgbm用于排序的代码很少,关于回归和分类的倒是一堆。这里我将贴上python版的lightgbm用于排序的代码,里面将包括训练、获取叶结点、ndcg评估、预测以及特征重要度等处理代码,有需要的朋友可以参考一下或进行修改。
其实在使用时,本人也对比了ranklib中的lambdamart和lightgbm,令人映像最深刻的是lightgbm的训练速度非常快,快的起飞。可能lambdamart训练需要几个小时,而lightgbm只需要几分钟,但是后面的ndcg测试都差不多,不像论文中所说的lightgbm精度高一点。lightgbm的训练速度快,我想可能最大的原因要可能是:a.节点分裂用到了直方图,而不是预排序方法;b.基于梯度的单边采样,即行采样;c.互斥特征绑定,即列采样;d.其于leaf-wise决策树生长策略;e.类别特征的支持等
二.代码
第一部分代码块是主代码,后面三个代码块是用到的加载数据和ndcg。运行主代码使用命令如训练模型使用:python lgb.py -train等
完成代码和数据格式放在https://github.com/jiangnanboy/learning_to_rank上面,大家可以参考一下!!!!!
import os
import lightgbm as lgb
from sklearn import datasets as ds
import pandas as pd import numpy as np
from datetime import datetime
import sys
from sklearn.preprocessing import OneHotEncoder def split_data_from_keyword(data_read, data_group, data_feats):
'''
利用pandas
转为lightgbm需要的格式进行保存
:param data_read:
:param data_save:
:return:
'''
with open(data_group, 'w', encoding='utf-8') as group_path:
with open(data_feats, 'w', encoding='utf-8') as feats_path:
dataframe = pd.read_csv(data_read,
sep=' ',
header=None,
encoding="utf-8",
engine='python')
current_keyword = ''
current_data = []
group_size = 0
for _, row in dataframe.iterrows():
feats_line = [str(row[0])]
for i in range(2, len(dataframe.columns) - 1):
feats_line.append(str(row[i]))
if current_keyword == '':
current_keyword = row[1]
if row[1] == current_keyword:
current_data.append(feats_line)
group_size += 1
else:
for line in current_data:
feats_path.write(' '.join(line))
feats_path.write('\n')
group_path.write(str(group_size) + '\n') group_size = 1
current_data = []
current_keyword = row[1]
current_data.append(feats_line) for line in current_data:
feats_path.write(' '.join(line))
feats_path.write('\n')
group_path.write(str(group_size) + '\n') def save_data(group_data, output_feature, output_group):
'''
group与features分别进行保存
:param group_data:
:param output_feature:
:param output_group:
:return:
'''
if len(group_data) == 0:
return
output_group.write(str(len(group_data)) + '\n')
for data in group_data:
# 只包含非零特征
# feats = [p for p in data[2:] if float(p.split(":")[1]) != 0.0]
feats = [p for p in data[2:]]
output_feature.write(data[0] + ' ' + ' '.join(feats) + '\n') # data[0] => level ; data[2:] => feats def process_data_format(test_path, test_feats, test_group):
'''
转为lightgbm需要的格式进行保存
'''
with open(test_path, 'r', encoding='utf-8') as fi:
with open(test_feats, 'w', encoding='utf-8') as output_feature:
with open(test_group, 'w', encoding='utf-8') as output_group:
group_data = []
group = ''
for line in fi:
if not line:
break
if '#' in line:
line = line[:line.index('#')]
splits = line.strip().split()
if splits[1] != group: # qid => splits[1]
save_data(group_data, output_feature, output_group)
group_data = []
group = splits[1]
group_data.append(splits)
save_data(group_data, output_feature, output_group) def load_data(feats, group):
'''
加载数据
分别加载feature,label,query
'''
x_train, y_train = ds.load_svmlight_file(feats)
q_train = np.loadtxt(group)
return x_train, y_train, q_train def load_data_from_raw(raw_data):
with open(raw_data, 'r', encoding='utf-8') as testfile:
test_X, test_y, test_qids, comments = letor.read_dataset(testfile)
return test_X, test_y, test_qids, comments def train(x_train, y_train, q_train, model_save_path):
'''
模型的训练和保存
'''
train_data = lgb.Dataset(x_train, label=y_train, group=q_train)
params = {
'task': 'train', # 执行的任务类型
'boosting_type': 'gbrt', # 基学习器
'objective': 'lambdarank', # 排序任务(目标函数)
'metric': 'ndcg', # 度量的指标(评估函数)
'max_position': 10, # @NDCG 位置优化
'metric_freq': 1, # 每隔多少次输出一次度量结果
'train_metric': True, # 训练时就输出度量结果
'ndcg_at': [10],
'max_bin': 255, # 一个整数,表示最大的桶的数量。默认值为 255。lightgbm 会根据它来自动压缩内存。如max_bin=255 时,则lightgbm 将使用uint8 来表示特征的每一个值。
'num_iterations': 500, # 迭代次数
'learning_rate': 0.01, # 学习率
'num_leaves': 31, # 叶子数
# 'max_depth':6,
'tree_learner': 'serial', # 用于并行学习,‘serial’: 单台机器的tree learner
'min_data_in_leaf': 30, # 一个叶子节点上包含的最少样本数量
'verbose': 2 # 显示训练时的信息
}
gbm = lgb.train(params, train_data, valid_sets=[train_data])
gbm.save_model(model_save_path) def predict(x_test, comments, model_input_path):
'''
预测得分并排序
'''
gbm = lgb.Booster(model_file=model_input_path) # 加载model ypred = gbm.predict(x_test) predicted_sorted_indexes = np.argsort(ypred)[::-1] # 返回从大到小的索引 t_results = comments[predicted_sorted_indexes] # 返回对应的comments,从大到小的排序 return t_results def test_data_ndcg(model_path, test_path):
'''
评估测试数据的ndcg
'''
with open(test_path, 'r', encoding='utf-8') as testfile:
test_X, test_y, test_qids, comments = letor.read_dataset(testfile) gbm = lgb.Booster(model_file=model_path)
test_predict = gbm.predict(test_X) average_ndcg, _ = ndcg.validate(test_qids, test_y, test_predict, 60)
# 所有qid的平均ndcg
print("all qid average ndcg: ", average_ndcg)
print("job done!") def plot_print_feature_importance(model_path):
'''
打印特征的重要度
'''
#模型中的特征是Column_数字,这里打印重要度时可以映射到真实的特征名
feats_dict = {
'Column_0': '特征0名称',
'Column_1': '特征1名称',
'Column_2': '特征2名称',
'Column_3': '特征3名称',
'Column_4': '特征4名称',
'Column_5': '特征5名称',
'Column_6': '特征6名称',
'Column_7': '特征7名称',
'Column_8': '特征8名称',
'Column_9': '特征9名称',
'Column_10': '特征10名称',
}
if not os.path.exists(model_path):
print("file no exists! {}".format(model_path))
sys.exit(0) gbm = lgb.Booster(model_file=model_path) # 打印和保存特征重要度
importances = gbm.feature_importance(importance_type='split')
feature_names = gbm.feature_name() sum = 0.
for value in importances:
sum += value for feature_name, importance in zip(feature_names, importances):
if importance != 0:
feat_id = int(feature_name.split('_')[1]) + 1
print('{} : {} : {} : {}'.format(feat_id, feats_dict[feature_name], importance, importance / sum)) def get_leaf_index(data, model_path):
'''
得到叶结点并进行one-hot编码
'''
gbm = lgb.Booster(model_file=model_path)
ypred = gbm.predict(data, pred_leaf=True) one_hot_encoder = OneHotEncoder()
x_one_hot = one_hot_encoder.fit_transform(ypred)
print(x_one_hot.toarray()[0]) if __name__ == '__main__':
model_path = "保存模型的路径" if len(sys.argv) != 2:
print("Usage: python main.py [-process | -train | -predict | -ndcg | -feature | -leaf]")
sys.exit(0) if sys.argv[1] == '-process':
# 训练样本的格式与ranklib中的训练样本是一样的,但是这里需要处理成lightgbm中排序所需的格式
# lightgbm中是将样本特征和group分开保存为txt的,什么意思呢,看下面解释
'''
feats:
1 1:0.2 2:0.4 ...
2 1:0.2 2:0.4 ...
1 1:0.2 2:0.4 ...
3 1:0.2 2:0.4 ...
group:
2
4
这里group中2表示前2个是一个qid,4表示后两个是一个qid
'''
raw_data_path = '训练样本集路径'
data_feats = '特征保存路径'
data_group = 'group保存路径'
process_data_format(raw_data_path, data_feats, data_group) elif sys.argv[1] == '-train':
# train
train_start = datetime.now()
data_feats = '特征保存路径'
data_group = 'group保存路径'
x_train, y_train, q_train = load_data(data_feats, data_group)
train(x_train, y_train, q_train, model_path)
train_end = datetime.now()
consume_time = (train_end - train_start).seconds
print("consume time : {}".format(consume_time)) elif sys.argv[1] == '-predict':
train_start = datetime.now()
raw_data_path = '需要预测的数据路径'#格式如ranklib中的数据格式
test_X, test_y, test_qids, comments = load_data_from_raw(raw_data_path)
t_results = predict(test_X, comments, model_path)
train_end = datetime.now()
consume_time = (train_end - train_start).seconds
print("consume time : {}".format(consume_time)) elif sys.argv[1] == '-ndcg':
# ndcg
test_path = '测试的数据路径'#评估测试数据的平均ndcg
test_data_ndcg(model_path, test_path) elif sys.argv[1] == '-feature':
plot_print_feature_importance(model_path) elif sys.argv[1] == '-leaf':
#利用模型得到样本叶结点的one-hot表示
raw_data = '测试数据路径'#
with open(raw_data, 'r', encoding='utf-8') as testfile:
test_X, test_y, test_qids, comments = letor.read_dataset(testfile)
get_leaf_index(test_X, model_path)
lightgbm用于排序的更多相关文章
- java中的类实现comparable接口 用于排序
import java.util.Arrays; public class SortApp { public static void main(String[] args) { Student[] s ...
- Treemap 有序的hashmap。用于排序
TreeMap:有固定顺序的hashmap.在需要排序的Map时候才用TreeMap. Map.在数组中我们是通过数组下标来对其内容索引的,键值对. HashMap HashMap 用哈希码快速定位一 ...
- C++11新特性应用--介绍几个新增的便利算法(用于排序的几个算法)
继续C++11在头文件algorithm中添加的算法. 至少我认为,在stl的算法中,用到最多的就是sort了,我们不去探索sort的源代码.就是介绍C++11新增的几个关于排序的函数. 对于一个序列 ...
- XGBoost、LightGBM的详细对比介绍
sklearn集成方法 集成方法的目的是结合一些基于某些算法训练得到的基学习器来改进其泛化能力和鲁棒性(相对单个的基学习器而言)主流的两种做法分别是: bagging 基本思想 独立的训练一些基学习器 ...
- LightGBM大战XGBoost,谁将夺得桂冠?
引 言 如果你是一个机器学习社区的活跃成员,你一定知道 提升机器(Boosting Machine)以及它们的能力.提升机器从AdaBoost发展到目前最流行的XGBoost.XGBoost实际上已经 ...
- LightGBM调参笔记
本文链接:https://blog.csdn.net/u012735708/article/details/837497031. 概述在竞赛题中,我们知道XGBoost算法非常热门,是很多的比赛的大杀 ...
- XGBoost、LightGBM、Catboost总结
sklearn集成方法 bagging 常见变体(按照样本采样方式的不同划分) Pasting:直接从样本集里随机抽取的到训练样本子集 Bagging:自助采样(有放回的抽样)得到训练子集 Rando ...
- 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用
有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...
- MS SQL 排序规则总结
排序规则术语 什么是排序规则呢? 排序规则是根据特定语言和区域设置标准指定对字符串数据进行排序和比较的规则.SQL Server 支持在单个数据库中存储具有不同排序规则的对象.MSDN解 ...
随机推荐
- apply 和 call 的用法
apply的用法 语法 func.apply(thisArg, [argsArray]) thisArg 可选的.在func函数运行时使用的this值.请注意,this可能不是该方法看到的实际值:如果 ...
- vscode IIsExpress用法
最近前端调试项目,都要安装IIS,使用IIS Express插件不需要另外在IIS架设站点,方便使用 1.安装IIS Express插件 2.ctrl+shfit+p 启动IIS Express 命令 ...
- 在Linux上安装Zookeeper集群
xl_echo编辑整理,欢迎转载,转载请声明文章来源.欢迎添加echo微信(微信号:t2421499075)交流学习. 百战不败,依不自称常胜,百败不颓,依能奋力前行.——这才是真正的堪称强大!! - ...
- activemq BytesMessage || TextMessage
需求:使用 python 程序向 activemq 的主题推送数据,默认推送的数据类型是 BytesMessage,java 程序那边接收较为麻烦,改为推送 TextMessage 类型的数据 解决方 ...
- 安装jQuery
description jQuery,顾名思义,也就是JavaScript和Query(查询),即辅助JavaScript开发的库.jQuery是一个快速.简洁的JavaScript框架,是继Prot ...
- iOS 开发之模糊效果的五种实现
前言 在iOS开发中我们经常会用到模糊效果使我们的界面更加美观,而iOS本身也提供了几种达到模糊效果的API,如:Core Image,使用Accelerate.Framework中的vImage A ...
- Android笔记(十七) Android中的Service
定义和用途 Service是Android的四大组件之一,一直在后台运行,没有用户界面.Service组件通常用于为其他组件提供后台服务或者监控其他组件的运行状态,例如播放音乐.记录地理位置,监听用户 ...
- 【OF框架】使用OF.WinService项目,添加定时服务,进行创建启动停止删除服务操作
准备 使用框架搭建完成项目,包含OF.WinService项目. 了解Window Service 和定时服务相关知识. 一.添加一个定时服务 第一步:了解项目结构 第二步:创建一个新的Job 第三步 ...
- Luogu P1290 欧几里得的游戏/UVA10368 Euclid's Game
Luogu P1290 欧几里得的游戏/UVA10368 Euclid's Game 对于博弈论的题目没接触过多少,而这道又是比较经典的SG博弈,所以就只能自己来推关系-- 假设我们有两个数$m,n$ ...
- /tmp/supervisor.sock no such file 报错
背景: 在执行 supervisorctl 时,报了这么一个错(如图),查找对应文档后解决,记录下来用来以后遇到使用 解决: 1. 将 supervisord.conf 文件下对应的 /tmp 目录 ...