本节主要介绍的是libFM源码分析的第三部分——libFM的模型处理。

3.1、libFM中FM模型的定义

libFM模型的定义过程中主要包括模型中参数的设置及其初始化,利用模型对样本进行预测。在libFM中,首先定义FM模型,在fm_model类中实现对FM模型的定义,fm_model类在“\libfm-1.42.src\src\fm_core\fm_model.h”中。在定义fm_model类之前,使用到了一些数据类:

#include "../util/matrix.h"
#include "../util/fmatrix.h"

#include "fm_data.h"

数据类的具体定义在“机器学习算法实现解析——libFM之libFM的数据处理部分”中定义。fm_model类的代码如下所示:

// fm_model模型类
class fm_model {
    private:
        DVector<double> m_sum, m_sum_sqr;// 分别对应着交叉项的中的两项
    public: //fm模型中的参数
        double w0;// 常数项
        DVectorDouble w;// 一次项的系数
        DMatrixDouble v;// 交叉项的系数矩阵

    public:
        // 属性
        // the following values should be set:
        uint num_attribute;// 特征的个数

        bool k0, k1;// 是否包含常数项和一次项
        int num_factor;// 交叉项因子的个数

        double reg0;// 常数项的正则参数
        double regw, regv;// 一次项和交叉项的正则系数

        double init_stdev;// 初始化参数时的方差
        double init_mean;// 初始化参数时的均值

        // 函数
        fm_model();// 构造函数,主要完成参数的初始化
        void debug();// debug函数
        void init();// 初始化函数,主要用于生成各维度系数的初始值
        // 对样本进行预测
        double predict(sparse_row<FM_FLOAT>& x);
        double predict(sparse_row<FM_FLOAT>& x, DVector<double> &sum, DVector<double> &sum_sqr);
};

FM模型的一般形式如下所示:

y^:=w0+∑i=1nwixi+∑i=1n−1∑j=i+1n⟨vi,vj⟩xixj

其中,w0为常数项系数,wi为一次项系数,vi和vj为交叉项系数。对于交叉项系数vi,其具体的形式为:

vi=(vi,1,vi,2,⋯,vi,k)1×k

在FM模型的定义中,首先需要分别定义三个参数:w0,w和v。其次,需要定义模型中需要使用到的函数,包括初始化init函数和预测predict函数。

3.2、FM的初始化

完成FM模型初始化过程主要包括两个部分:

  • 构造函数fm_model()
  • init()函数

构造函数fm_model()的具体实现如下所示:

// fm_model类的构造函数
fm_model::fm_model() {
    num_factor = 0;// 交叉项中因子的个数
    init_mean = 0;// 初始化的均值
    init_stdev = 0.01;// 初始化的方差
    reg0 = 0.0;// 常数项的正则化参数
    regw = 0.0;// 一次项的正则化参数
    regv = 0.0;// 交叉项的正则化参数
    k0 = true;// 是否包含常数项
    k1 = true;// 是否包含一次项
}

init()函数的具体实现如下所示:

// 初始化fm模型的函数
void fm_model::init() {
    w0 = 0;// 常数项的系数
    w.setSize(num_attribute);// 设置一次项系数的个数
    v.setSize(num_factor, num_attribute);// 设置交叉项的矩阵大小
    w.init(0);// 初始化一次项系数为0
    v.init(init_mean, init_stdev);// 按照均值和方差初始化交叉项系数
    // 交叉项中的两个参数,设置其大小为num_factor
    m_sum.setSize(num_factor);
    m_sum_sqr.setSize(num_factor);
}

在初始化的过程中,除了基本的数据类型外,还涉及到自定义的三种数据类型,分别为:DVectorDouble,DMatrixDouble和DVector,这三种数据类型在“机器学习算法实现解析——libFM之libFM的数据处理部分”中有详细说明。

3.3、利用FM模型对样本进行预测

在libFM中,fm_model类中实现了两种预测函数,分别为:

double predict(sparse_row<FM_FLOAT>& x);
double predict(sparse_row<FM_FLOAT>& x, DVector<double> &sum, DVector<double> &sum_sqr);

两者的区别主要是下面的函数多了两个参数,一个是sum,另一个是sum_sqr,这两个参数分别对应着交叉项计算过程中的两项。

FM模型中的计算方法为:

y^:=w0+∑i=1nwixi+∑i=1n−1∑j=i+1n⟨vi,vj⟩xixj

其中,对于交叉项的计算,在FM算法中提出了快速的计算方法,即:

∑i=1n−1∑j=i+1n⟨vi,vj⟩xixj=12∑i=1n∑j=1n⟨vi,vj⟩xixj−12∑i=1n⟨vi,vi⟩xixi=12⎛⎝∑i=1n∑j=1n∑f=1kvi,fvj,fxixj−∑i=1n∑f=1kvi,fvj,fx2i⎞⎠=12∑f=1k⎛⎝(∑i=1nvi,fxi)⋅⎛⎝∑j=1nvj,fxj⎞⎠−∑i=1nv2i,fx2i⎞⎠=12∑f=1k⎛⎝(∑i=1nvi,fxi)2−∑i=1nv2i,fx2i⎞⎠

利用上面的计算公式,libFM中的两个函数的实现如下所示:

// 对样本进行预测,其中x表示的是一行样本
double fm_model::predict(sparse_row<FM_FLOAT>& x) {
    return predict(x, m_sum, m_sum_sqr);
}

double fm_model::predict(sparse_row<FM_FLOAT>& x, DVector<double> &sum, DVector<double> &sum_sqr) {
    double result = 0;// 最终的结果
    // 第一部分
    if (k0) {// 常数项
        result += w0;
    }

    // 第二部分
    if (k1) {// 一次项
        for (uint i = 0; i < x.size; i++) {// 对样本中的每一个特征
            assert(x.data[i].id < num_attribute);// 验证样本的正确性
            // w * x
            result += w(x.data[i].id) * x.data[i].value;
        }
    }

    // 第三部分
    // 交叉项,对应着公式,有两重循环
    for (int f = 0; f < num_factor; f++) {// 外层循环
        sum(f) = 0;
        sum_sqr(f) = 0;
        for (uint i = 0; i < x.size; i++) {
            double d = v(f,x.data[i].id) * x.data[i].value;
            sum(f) += d;
            sum_sqr(f) += d*d;
        }
        result += 0.5 * (sum(f)*sum(f) - sum_sqr(f));// 得到交叉项的值
    }
    return result;
}

在交叉项的计算过程中,sum(f)和sum_sqr(f)与公式中的对应关系为:

3.4、其他

剩下的代码便是debug函数,debug函数用于打印中间的结果,其具体的代码如下所示:

// debug函数,主要用于输出中间调试的结果
void fm_model::debug() {
    std::cout << "num_attributes=" << num_attribute << std::endl;// 特征的个数
    std::cout << "use w0=" << k0 << std::endl;//是否包含常数项
    std::cout << "use w1=" << k1 << std::endl;//是否包含一次项
    std::cout << "dim v =" << num_factor << std::endl;//交叉项中因子的个数
    std::cout << "reg_w0=" << reg0 << std::endl;//常数项的正则化参数
    std::cout << "reg_w=" << regw << std::endl;//一次项的正则化参数
    std::cout << "reg_v=" << regv << std::endl;//交叉项的正则化参数
    std::cout << "init ~ N(" << init_mean << "," << init_stdev << ")" << std::endl;//初始化的均值和初始化的方差
}

参考文献

  • Rendle S. Factorization Machines[C]// IEEE International Conference on Data Mining. IEEE Computer Society, 2010:995-1000.
  • Rendle S. Factorization Machines with libFM[M]. ACM, 2012.

机器学习算法实现解析——libFM之libFM的模型处理部分的更多相关文章

  1. 机器学习算法实现解析——libFM之libFM的训练过程之Adaptive Regularization

    本节主要介绍的是libFM源码分析的第五部分之二--libFM的训练过程之Adaptive Regularization的方法. 5.3.Adaptive Regularization的训练方法 5. ...

  2. 机器学习算法实现解析——libFM之libFM的训练过程之SGD的方法

    本节主要介绍的是libFM源码分析的第五部分之一--libFM的训练过程之SGD的方法. 5.1.基于梯度的模型训练方法 在libFM中,提供了两大类的模型训练方法,一类是基于梯度的训练方法,另一类是 ...

  3. 机器学习算法实现解析——libFM之libFM的训练过程概述

    本节主要介绍的是libFM源码分析的第四部分--libFM的训练. FM模型的训练是FM模型的核心的部分. 4.1.libFM中训练过程的实现 在FM模型的训练过程中,libFM源码中共提供了四种训练 ...

  4. 机器学习算法实现解析——word2vec源代码解析

    在阅读本文之前,建议首先阅读"简单易学的机器学习算法--word2vec的算法原理"(眼下还没公布).掌握例如以下的几个概念: 什么是统计语言模型 神经概率语言模型的网络结构 CB ...

  5. scikit-learn中的机器学习算法封装——kNN

    接前面 https://www.cnblogs.com/Liuyt-61/p/11738399.html 回过头来看这张图,什么是机器学习?就是将训练数据集喂给机器学习算法,在上面kNN算法中就是将特 ...

  6. 机器学习算法与Python实践之(四)支持向量机(SVM)实现

    机器学习算法与Python实践之(四)支持向量机(SVM)实现 机器学习算法与Python实践之(四)支持向量机(SVM)实现 zouxy09@qq.com http://blog.csdn.net/ ...

  7. 机器学习算法与Python实践之(五)k均值聚类(k-means)

    机器学习算法与Python实践这个系列主要是参考<机器学习实战>这本书.因为自己想学习Python,然后也想对一些机器学习算法加深下了解,所以就想通过Python来实现几个比较常用的机器学 ...

  8. 机器学习算法与Python实践之(七)逻辑回归(Logistic Regression)

    http://blog.csdn.net/zouxy09/article/details/20319673 机器学习算法与Python实践之(七)逻辑回归(Logistic Regression) z ...

  9. 机器学习算法( 五、Logistic回归算法)

    一.概述 这会是激动人心的一章,因为我们将首次接触到最优化算法.仔细想想就会发现,其实我们日常生活中遇到过很多最优化问题,比如如何在最短时间内从A点到达B点?如何投入最少工作量却获得最大的效益?如何设 ...

随机推荐

  1. CNN学习笔记:目标函数

    CNN学习笔记:目标函数 分类任务中的目标函数 目标函数,亦称损失函数或代价函数,是整个网络模型的指挥棒,通过样本的预测结果与真实标记产生的误差来反向传播指导网络参数学习和表示学习. 假设某分类任务共 ...

  2. Canvas:橡皮筋线条绘制

    Canvas:橡皮筋线条绘制 效果演示 实现要点 事件监听 [说明]: 在Canvas中检测鼠标事件是非常简单的,可以在canvas中添加一个事件监听器,当事件发生时,浏览器就会调用这个监听器. 我们 ...

  3. maven依赖排除、顺序原则、版本统一管理

    <dependency> <groupId>org.springframework</groupId> <artifactId>spring-core& ...

  4. 菩提树下的杨过.Net 的《hadoop 2.6全分布安装》补充版

    对菩提树下的杨过.Net的这篇博客<hadoop 2.6全分布安装>,我真是佩服的五体投地,我第一次见过教程能写的这么言简意赅,但是又能比较准确表述每一步做法的,这篇博客主要就是在他的基础 ...

  5. linux:查看磁盘硬件信息hdparm,smartctl

    smartctl 命令 这个一个用于控制和监控支持smart技术的硬盘的命令.通常配合 -a 选项我们可以查看到比较详尽的硬盘信息(比如序列号.硬盘容量.已运行时间.硬盘健康状况等).用法如下: sm ...

  6. Entity FrameWork Code First 之 MVC4 数据库初始化策略用法

    通过启用迁移和更新数据库可以很容易的生成一张表.但是对数据库修改之后,通过数据迁移就没那么好实现了. 这里用到数据库生成策略,进行对数据库操作: 一.3种主要数据库生成策略 1 CreateDatab ...

  7. AppLocker Pro FAQ

    How to use AppLocker Pro: 1. Start AppLocker Pro, create a password.2. In the main console, click &q ...

  8. spark 参数调优

    调整partition数量,每次reduece和distict的时候都应该调整,数量太大和太小都不好,通常来讲保证一个partition的大小在1-2G左右为宜 调整excutors 调整core 调 ...

  9. ThinkPHP3.2添加scws中文分词

    前言 前一段时间,公司网站做站内搜索,只简单针对输入的文字进行搜索,作全匹配检索,搜索出来的内容很少.如何达到模糊搜索,匹配到更多的内容成了需要解决的问题.于是,今天想到可以做分词检索,如何对输入的一 ...

  10. 配置可对外链接的Redis

    链接服务器的Redis telnet 192.168.1.200 6379 Trying 192.168.1.200... telnet: Unable to connect to remote ho ...