机器学习算法实现解析——libFM之libFM的训练过程概述
本节主要介绍的是libFM源码分析的第四部分——libFM的训练。
FM模型的训练是FM模型的核心的部分。
4.1、libFM中训练过程的实现
在FM模型的训练过程中,libFM源码中共提供了四种训练的方法,分别为:Stochastic Gradient Descent(SGD),Adaptive SGD(ASGD),Alternating Least Squares(ALS)和Markov Chain Monte Carlo(MCMC),其中ALS是MCMC的特殊形式,实际上其实现的就是SGD,ASGD和MCMC三种训练方法,三者的类之间的关系如下图所示:
FM模型训练的父类为fm_learn,其定义在文件fm_learn.h中,fm_learn_sgd类和fm_learn_mcmc类分别继承自fm_learn类。其中,fm_learn_sgd是基于梯度的实现方法,fm_learn_mcmc是基于蒙特卡洛的实现方法。
fm_learn_sgd_element类和fm_learn_sgd_element_adapt_reg类是fm_learn_sgd类的子类,是两种具体的基于梯度方法的实现,分别为SGD和ASGD。
fm_learn_mcmc_simultaneous类是fm_learn_mcmc类的子类,是具体的基于蒙特卡洛方法的实现。
4.2、训练过程的父类
在所有的训练过程中,fm_learn类为所有模型训练类的父类。
4.2.1、头文件
#include <cmath>
#include "Data.h"
#include "../../fm_core/fm_model.h"
#include "../../util/rlog.h"
#include "../../util/util.h"
4.2.2、第一部分的protected属性和方法
在这部分中定义了交叉项中需要用到两个数据,分别为sum和sum_sqr,这两个数的具体使用可以参见“机器学习算法实现解析——libFM之libFM的模型处理部分”。除此之外,还定义了预测predict_case函数,具体代码如下所示:
protected:
DVector<double> sum, sum_sqr;// FM模型的交叉项中的两项
DMatrix<double> pred_q_term;
// this function can be overwritten (e.g. for MCMC)
// 预测,使用的是fm_model中的predict函数
virtual double predict_case(Data& data) {
return fm->predict(data.data->getRow());
}
其中,预测predict_case函数使用的是fm_model类中的predict函数,对于该函数,可以参见“机器学习算法实现解析——libFM之libFM的模型处理部分”。
4.2.3、第二部分的public属性和方法
在这部分中,主要构造函数fm_learn函数,初始化init函数以及评估evaluate函数,其具体代码如下所示:
public:
DataMetaInfo* meta;
fm_model* fm;// 对应的fm模型
double min_target;// 设置的预测值的最小值
double max_target;// 设置的预测值的最大值
// task用于区分不同的任务:0表示的是回归,1表示的是分类
int task; // 0=regression, 1=classification
// 定义两个常量,分别表示的是回归和分类
const static int TASK_REGRESSION = 0;
const static int TASK_CLASSIFICATION = 1;
Data* validation;// 验证数据集
RLog* log;// 日志指针
// 构造函数,初始化变量,实例化的过程在main函数中
fm_learn() { log = NULL; task = 0; meta = NULL;}
virtual void init() {
// 日志
if (log != NULL) {
if (task == TASK_REGRESSION) {
log->addField("rmse", std::numeric_limits<double>::quiet_NaN());
log->addField("mae", std::numeric_limits<double>::quiet_NaN());
} else if (task == TASK_CLASSIFICATION) {
log->addField("accuracy", std::numeric_limits<double>::quiet_NaN());
} else {
throw "unknown task";
}
log->addField("time_pred", std::numeric_limits<double>::quiet_NaN());
log->addField("time_learn", std::numeric_limits<double>::quiet_NaN());
log->addField("time_learn2", std::numeric_limits<double>::quiet_NaN());
log->addField("time_learn4", std::numeric_limits<double>::quiet_NaN());
}
// 设置交叉项中的两项的大小
sum.setSize(fm->num_factor);
sum_sqr.setSize(fm->num_factor);
pred_q_term.setSize(fm->num_factor, meta->num_relations + 1);
}
// 对数据的评估
virtual double evaluate(Data& data) {
assert(data.data != NULL);// 检查数据不为空
if (task == TASK_REGRESSION) {// 回归
return evaluate_regression(data);// 调用回归的评价方法
} else if (task == TASK_CLASSIFICATION) {// 分类
return evaluate_classification(data);// 调用分类的评价放啊
} else {
throw "unknown task";
}
}
在评估evaluate函数中,根据task的值判断是分类问题还是回归问题,分别调用第四部分中的evaluate_regression和evaluate_classification函数。
4.2.4、第三部分的public属性和方法
在这部分中分别定义了模型的训练函数,模型的预测函数和debug输出函数,代码的具体过程如下所示:
public:
// 模型的训练过程
virtual void learn(Data& train, Data& test) { }
// 纯虚函数
virtual void predict(Data& data, DVector<double>& out) = 0;
// debug函数,用于打印中间的结果
virtual void debug() {
std::cout << "task=" << task << std::endl;
std::cout << "min_target=" << min_target << std::endl;
std::cout << "max_target=" << max_target << std::endl;
}
其中模型的训练learn函数没有定义具体的实现,由上述的继承关系,其具体的训练过程在具体的子类中实现;模型的预测predict函数是一个纯虚函数。对于纯虚函数的概念,可以参见;最后一个函数是一个debug函数,debug函数用于打印中间的结果。
4.2.5、第四部分的protected属性和方法
在这部分中定义了两个评价函数,分别用于处理分类问题和回归问题,代码的具体过程如下所示:
protected:
// 对分类问题的评价
virtual double evaluate_classification(Data& data) {
int num_correct = 0;// 准确类别的个数
double eval_time = getusertime();
for (data.data->begin(); !data.data->end(); data.data->next()) {
double p = predict_case(data);// 对样本进行预测
// 利用预测值的符号与原始标签值的符号是否相同,若相同,则预测是准确的
if (((p >= 0) && (data.target(data.data->getRowIndex()) >= 0)) || ((p < 0) && (data.target(data.data->getRowIndex()) < 0))) {
num_correct++;
}
}
eval_time = (getusertime() - eval_time);
// log the values
// log文件
if (log != NULL) {
log->log("accuracy", (double) num_correct / (double) data.data->getNumRows());
log->log("time_pred", eval_time);
}
return (double) num_correct / (double) data.data->getNumRows();// 返回准确率
}
// 对回归问题的评价
virtual double evaluate_regression(Data& data) {
double rmse_sum_sqr = 0;// 误差的平方和
double mae_sum_abs = 0;// 误差的绝对值之和
double eval_time = getusertime();
for (data.data->begin(); !data.data->end(); data.data->next()) {
// 取出每一条样本
double p = predict_case(data);// 计算该样本的预测值
p = std::min(max_target, p);// 防止预测值超出最大限制
p = std::max(min_target, p);// 防止预测值超出最小限制
double err = p - data.target(data.data->getRowIndex());// 得到预测值与真实值之间的误差
rmse_sum_sqr += err*err;// 计算误差平方和
mae_sum_abs += std::abs((double)err);// 计算误差的绝对值之和
}
eval_time = (getusertime() - eval_time);
// log the values
// log文件
if (log != NULL) {
log->log("rmse", std::sqrt(rmse_sum_sqr/data.data->getNumRows()));
log->log("mae", mae_sum_abs/data.data->getNumRows());
log->log("time_pred", eval_time);
}
return std::sqrt(rmse_sum_sqr/data.data->getNumRows());// 返回均方根误差
}
其中,在分类问题中,使用的评价标准是准确率:
在回归问题中,使用的评价标准是均方根误差:
其中,y^表示的是对样本的预测值,y表示的是样本的原始标签,#(y^⋅y>0)表示的是预测值y^与原始标签y同号的样本的个数(原始标签y∈{−1,1}),m表示的是样本的个数。
在对样本进行预测时用到了predict_case函数,该函数在“第一部分的protected属性和方法“中定义。在回归问题中,为预测值设置了最大的上限(std::max(min_target, p))和最小的下限(std::min(max_target, p))。为了能够记录时间,代码中使用到了getusertime函数,该函数的定义在util.h文件中。
参考文献
- Rendle S. Factorization Machines[C]// IEEE International Conference on Data Mining. IEEE Computer Society, 2010:995-1000.
- Rendle S. Factorization Machines with libFM[M]. ACM, 2012.
机器学习算法实现解析——libFM之libFM的训练过程概述的更多相关文章
- 机器学习算法实现解析——libFM之libFM的训练过程之Adaptive Regularization
本节主要介绍的是libFM源码分析的第五部分之二--libFM的训练过程之Adaptive Regularization的方法. 5.3.Adaptive Regularization的训练方法 5. ...
- 机器学习算法实现解析——libFM之libFM的训练过程之SGD的方法
本节主要介绍的是libFM源码分析的第五部分之一--libFM的训练过程之SGD的方法. 5.1.基于梯度的模型训练方法 在libFM中,提供了两大类的模型训练方法,一类是基于梯度的训练方法,另一类是 ...
- 机器学习算法实现解析——libFM之libFM的模型处理部分
本节主要介绍的是libFM源码分析的第三部分--libFM的模型处理. 3.1.libFM中FM模型的定义 libFM模型的定义过程中主要包括模型中参数的设置及其初始化,利用模型对样本进行预测.在li ...
- 机器学习算法实现解析——word2vec源代码解析
在阅读本文之前,建议首先阅读"简单易学的机器学习算法--word2vec的算法原理"(眼下还没公布).掌握例如以下的几个概念: 什么是统计语言模型 神经概率语言模型的网络结构 CB ...
- 谷歌BERT预训练源码解析(三):训练过程
目录前言源码解析主函数自定义模型遮蔽词预测下一句预测规范化数据集前言本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨.BERT针对两个任务同 ...
- 机器学习算法与Python实践之(四)支持向量机(SVM)实现
机器学习算法与Python实践之(四)支持向量机(SVM)实现 机器学习算法与Python实践之(四)支持向量机(SVM)实现 zouxy09@qq.com http://blog.csdn.net/ ...
- 机器学习算法与Python实践之(五)k均值聚类(k-means)
机器学习算法与Python实践这个系列主要是参考<机器学习实战>这本书.因为自己想学习Python,然后也想对一些机器学习算法加深下了解,所以就想通过Python来实现几个比较常用的机器学 ...
- 机器学习算法与Python实践之(七)逻辑回归(Logistic Regression)
http://blog.csdn.net/zouxy09/article/details/20319673 机器学习算法与Python实践之(七)逻辑回归(Logistic Regression) z ...
- 机器学习算法( 五、Logistic回归算法)
一.概述 这会是激动人心的一章,因为我们将首次接触到最优化算法.仔细想想就会发现,其实我们日常生活中遇到过很多最优化问题,比如如何在最短时间内从A点到达B点?如何投入最少工作量却获得最大的效益?如何设 ...
随机推荐
- linux网络基础设置 以及 软件安装
ifconfig #查看所有已激活的网卡信息 临时配置 #yum install net-tools -y 默认ifconfig是没有安装的,可能需要安装 ifconfig eth0 #查看单独一块网 ...
- centos7 上搭建私有云
OwnCloud环境搭建 一. 环境搭建 1. 环境需求 服务器操作系统:Centos7.0 外网服务器操作系统:Centos7.0 Php版本号:5.4.16 Mysql版本号:5.5.52 Apa ...
- javascript Date对象 之 设置时间
之前对js的date对象总是感觉熟悉,而不愿细细深究其所以然,所以每当自己真正应用起来的时候,总会糊里糊涂的,今日花费2个小时的时间仔细钻研了一下,感觉 豁然开朗,故,以此记录,一来 供以后查阅,二来 ...
- Linux静默安装Oracle
打算在云服务器上装oracle服务,以前DBA美眉都是在图形化界面下安装,这次抓瞎了.赶紧上网查查,静默安装可以解决问题.于是乎赶紧开始部署,过程如下.安装环境:操作系统:CentOS 7内存:11G ...
- Steema TeeChart Pro VCL FMX 2017.20 Full Suorce在Delphi XE10下的安装
一.首先将压缩包TeeChart Pro VCL FMX 2017.20 FS.rar解压到一个目录,比如 E:\Application\Steema TeeChart Pro VCL FMX 201 ...
- Docker基于容器创建镜像
docker commit -m "提交信息" -a "作者信息" imgId imgName
- nginx 中location和root、alias
nginx指定文件路径有两种方式root和alias,这两者的用法区别 root与alias主要区别在于nginx如何解释location后面的uri,这会使两者分别以不同的方式将请求映射到服务器文件 ...
- sklearn学习笔记之开始
简介 自2007年发布以来,scikit-learn已经成为Python重要的机器学习库了.scikit-learn简称sklearn,支持包括分类.回归.降维和聚类四大机器学习算法.还包含了特征 ...
- 使用springmvc时报错org.springframework.beans.NullValueInNestedPathException: Invalid property 'department' of bean class [com.atguigu.springmvc.crud.entities.Employee]:
使用springmvc时报错 org.springframework.beans.NullValueInNestedPathException: Invalid property 'departmen ...
- CSS设置小技巧
水平居中 对于元素的水平居中,有三种情况: 行内元素(文字.图片等):text-align: center; 定宽块状元素(有设置宽度的block元素):margin: 0 auto; 不定宽块状元素 ...