转自:

博客

http://blog.csdn.net/google19890102/article/details/45532745/

github

https://github.com/zhaozhiyong19890102/Python-Machine-Learning-Algorithm/tree/master/Chapter_3%20Factorization%20Machine

一、因子分解机FM的模型

   因子分解机(Factorization Machine, FM)是由Steffen Rendle提出的一种基于矩阵分解的机器学习算法。

1、因子分解机FM的优势

    对于因子分解机FM来说,最大的特点是对于稀疏的数据具有很好的学习能力。现实中稀疏的数据很多,例如作者所举的推荐系统的例子便是一个很直观的具有稀疏特点的例子。

2、因子分解机FM的模型

    对于度为2的因子分解机FM的模型为:
其中,参数表示的是两个大小为的向量和向量的点积:
其中,表示的是系数矩阵的第维向量,且称为超参数。在因子分解机FM模型中,前面两部分是传统的线性模型,最后一部分将两个互异特征分量之间的相互关系考虑进来。
    因子分解机FM也可以推广到高阶的形式,即将更多互异特征分量之间的相互关系考虑进来。

二、因子分解机FM算法

    因子分解机FM算法可以处理如下三类问题:
  1. 回归问题(Regression)
  2. 二分类问题(Binary Classification)
  3. 排序(Ranking)

在这里主要介绍回归问题和二分类问题。

1、回归问题(Regression)

    在回归问题中,直接使用作为最终的预测结果。在回归问题中使用最小均方误差(the least square error)作为优化的标准,即
其中,表示样本的个数。

2、二分类问题(Binary Classification)

    与Logistic回归类似,通过阶跃函数,如Sigmoid函数,将映射成不同的类别。在二分类问题中使用logit loss作为优化的标准,即
其中,表示的是阶跃函数Sigmoid。具体形式为:

三、因子分解机FM算法的求解过程

1、交叉项系数

    在基本线性回归模型的基础上引入交叉项,如下:
  表示共有n个特征:
 
若是这种直接在交叉项的前面加上交叉项系数的方式在稀疏数据的情况下存在一个很大的缺陷,即在对于观察样本中未出现交互的特征分量,不能对相应的参数进行估计。
    对每一个特征分量引入辅助向量,利用对交叉项的系数进行估计,即
这就对应了一种矩阵的分解。对值的限定,对FM的表达能力有一定的影响。

2、模型的求解

这里要求出,主要采用了如公式求出交叉项。具体过程如下:

注:上式中: 

,且,倒数第二行中,将 j 换成 i,原式不变,所以能得到倒数第一行的形式。

3、基于随机梯度的方式求解

对于回归问题:
对于二分类问题:
 
最终交叉项要估计的参数每一个是:Vi,f
有n个特征, 每个特征有k个分量,那交叉项的参数个数就是:n*k。

四、实验(求解二分类问题)

1、实验的代码:

  1. #coding:UTF-8
  2. from __future__ import division
  3. from math import exp
  4. from numpy import *
  5. from random import normalvariate#正态分布
  6. from datetime import datetime
  7. trainData = 'E://data//diabetes_train.txt'
  8. testData = 'E://data//diabetes_test.txt'
  9. featureNum = 8
  10. def loadDataSet(data):
  11. dataMat = []
  12. labelMat = []
  13. fr = open(data)#打开文件
  14. for line in fr.readlines():
  15. currLine = line.strip().split()
  16. #lineArr = [1.0]
  17. lineArr = []
  18. for i in xrange(featureNum):
  19. lineArr.append(float(currLine[i + 1]))
  20. dataMat.append(lineArr)
  21. labelMat.append(float(currLine[0]) * 2 - 1)
  22. return dataMat, labelMat
  23. def sigmoid(inx):
  24. return 1.0 / (1 + exp(-inx))
  25. def stocGradAscent(dataMatrix, classLabels, k, iter):
  26. #dataMatrix用的是mat, classLabels是列表
  27. m, n = shape(dataMatrix)
  28. alpha = 0.01
  29. #初始化参数
  30. w = zeros((n, 1))#其中n是特征的个数
  31. w_0 = 0.    #截距项
  32. v = normalvariate(0, 0.2) * ones((n, k))   #交叉项
  33. for it in xrange(iter):
  34. print it
  35. for x in xrange(m):#随机优化,对每一个样本而言的
  36. inter_1 = dataMatrix[x] * v
  37. inter_2 = multiply(dataMatrix[x], dataMatrix[x]) * multiply(v, v)#multiply对应元素相乘
  38. #完成交叉项
  39. interaction = sum(multiply(inter_1, inter_1) - inter_2) / 2.
  40. p = w_0 + dataMatrix[x] * w + interaction#计算预测的输出
  41. loss = sigmoid(classLabels[x] * p[0, 0]) - 1
  42. print loss
  43. w_0 = w_0 - alpha * loss * classLabels[x]
  44. for i in xrange(n):
  45. if dataMatrix[x, i] != 0:
  46. w[i, 0] = w[i, 0] - alpha * loss * classLabels[x] * dataMatrix[x, i]
  47. for j in xrange(k):
  48. v[i, j] = v[i, j] - alpha * loss * classLabels[x] * (dataMatrix[x, i] * inter_1[0, j] - v[i, j] * dataMatrix[x, i] * dataMatrix[x, i])
  49. return w_0, w, v
  50. def getAccuracy(dataMatrix, classLabels, w_0, w, v):
  51. m, n = shape(dataMatrix)
  52. allItem = 0
  53. error = 0
  54. result = []
  55. for x in xrange(m):
  56. allItem += 1
  57. inter_1 = dataMatrix[x] * v
  58. inter_2 = multiply(dataMatrix[x], dataMatrix[x]) * multiply(v, v)#multiply对应元素相乘
  59. #完成交叉项
  60. interaction = sum(multiply(inter_1, inter_1) - inter_2) / 2.
  61. p = w_0 + dataMatrix[x] * w + interaction#计算预测的输出
  62. pre = sigmoid(p[0, 0])
  63. result.append(pre)
  64. if pre < 0.5 and classLabels[x] == 1.0:
  65. error += 1
  66. elif pre >= 0.5 and classLabels[x] == -1.0:
  67. error += 1
  68. else:
  69. continue
  70. print result
  71. return float(error) / allItem
  72. if __name__ == '__main__':
  73. dataTrain, labelTrain = loadDataSet(trainData)
  74. dataTest, labelTest = loadDataSet(testData)
  75. date_startTrain = datetime.now()
  76. print "开始训练"
  77. w_0, w, v = stocGradAscent(mat(dataTrain), labelTrain, 20, 200)
  78. print "训练准确性为:%f" % (1 - getAccuracy(mat(dataTrain), labelTrain, w_0, w, v))
  79. date_endTrain = datetime.now()
  80. print "训练时间为:%s" % (date_endTrain - date_startTrain)
  81. print "开始测试"
  82. print "测试准确性为:%f" % (1 - getAccuracy(mat(dataTest), labelTest, w_0, w, v))

2、实验结果:

五、几点疑问

    在传统的非稀疏数据集上,有时效果并不是很好。在实验中,我有一点处理,即在求解Sigmoid函数的过程中,在有的数据集上使用了带阈值的求法:
  1. def sigmoid(inx):
  2. #return 1.0 / (1 + exp(-inx))
  3. return 1. / (1. + exp(-max(min(inx, 15.), -15.)))

六 图片

fm 讲解加代码的更多相关文章

  1. 简单的自动化使用--使用selenium实现学习通网站的刷慕课程序。注释空格加代码大概200行不到

    简单的自动化使用--使用selenium实现学习通网站的刷慕课程序.注释空格加代码大概200行不到 相见恨晚啊 github地址 环境Python3.6 + pycharm + chrom浏览器 + ...

  2. [洛谷P3376题解]网络流(最大流)的实现算法讲解与代码

    [洛谷P3376题解]网络流(最大流)的实现算法讲解与代码 更坏的阅读体验 定义 对于给定的一个网络,有向图中每个的边权表示可以通过的最大流量.假设出发点S水流无限大,求水流到终点T后的最大流量. 起 ...

  3. [CodeIgniter4]讲解-加载静态页

    讲解 本教程旨在向您介绍CodeIgniter框架和MVC体系结构的基本原理.它将向您展示如何以逐步的方式构造基本的CodeIgniter应用程序. 在本教程中,您将创建一个基本的新闻应用程序.您将从 ...

  4. Java核心技术及面试指南的视频讲解和代码下载位置

    都是百度云盘,均无密码 代码下载位置: https://pan.baidu.com/s/1I44ob0vygMxvmj2BoNioAQ 视频讲解位置: https://pan.baidu.com/s/ ...

  5. 扩展欧几里得(ex_gcd),中国剩余定理(CRT)讲解 有代码

    扩展欧几里得算法 求逆元就不说了. ax+by=c 这个怎么求,很好推. 设d=gcd(a,b) 满足d|c方程有解,否则无解. 扩展欧几里得求出来的解是 x是 ax+by=gcd(a,b)的解. 对 ...

  6. 傻瓜式的go modules的讲解和代码,及gomod能不能引入另一个gomod和gomod的use of internal package xxxx not allowed

    一 国内关于gomod的文章,哪怕是使用了百度 -csdn,依然全是理论,虽然golang的使用者大多是大神但是也有像我这样的的弱鸡是不是? 所以,我就写个傻瓜式教程了. github地址:https ...

  7. Rainbond 对接 Istio 原理讲解和代码实现分析

    一.背景 现有的 ServiceMesh 框架有很多,如 Istio.linkerd等.对于用户而言,在测试环境下,需要达到的效果是快.开箱即用.但在生产环境下,可能又有熔断.延时注入等需求.那么单一 ...

  8. C++工厂方法模式讲解和代码示例

    在C++中使用模式 使用示例: 工厂方法模式在 C++ 代码中得到了广泛使用. 当你需要在代码中提供高层次的灵活性时, 该模式会非常实用. 识别方法: 工厂方法可通过构建方法来识别, 它会创建具体类的 ...

  9. Vue学习之--------组件嵌套以及VueComponent的讲解(代码实现)(2022/7/23)

    欢迎加入刚建立的社区:http://t.csdn.cn/Q52km 加入社区的好处: 1.专栏更加明确.便于学习 2.覆盖的知识点更多.便于发散学习 3.大家共同学习进步 3.不定时的发现金红包(不多 ...

随机推荐

  1. HttpPostedFile类

    在研究HttpRequest的时候,搞文件上传的时候,经常碰到返回HttpPostedFile对象的情况,这个对象才是真正包含文件内容的东西. 经常要获取的最重要的内容是FileName属性与Sava ...

  2. innotop监控mysql

    InnoTop 是一个系统活动报告,类似于Linux性能工具,它与Linux的top命令相仿,并参考mytop工具而设计. 它专门用后监控InnoDB性能和MySQL服务器.主要用于监控事务,死锁,外 ...

  3. django orm 常用查询筛选

    大于.大于等于 __gt 大于 __gte 大于等于 User.objects.filter(age__gt=10) // 查询年龄大于10岁的用户 User.objects.filter(age__ ...

  4. Lucene.Net 入门级实例 浅显易懂。。。

    Lucene.Net 阅读目录 开始 Lucene简介 效果图 Demo文件说明 简单使用 重点类的说明 存在问题 调整后 Lucene.Net博文与资源下载 做过站内搜索的朋友应该对Lucene.N ...

  5. BASIC-9_蓝桥杯_特殊回文数

    示例代码: #include <stdio.h> int main(void){ int n = 0 ; scanf("%d",&n); int i = 0 ; ...

  6. 【VS】使用vs2017自带的诊断工具(Diagnostic Tools)诊断程序的内存问题

    前言 一般来说.NET程序员是不用担心内存分配释放问题的,因为有垃圾收集器(GC)会自动帮你处理.但是GC只能收集那些不再使用的内存(根据对象是否被其它活动的对象所引用)来确定.所以如果代码编写不当的 ...

  7. my sql 只展示 前10条数据的写法

    select * from 表 where 条件 limit 10 这里想看多少条 limit 后面的数字就是多少

  8. excel 怎么添加超链接

    1.只能对单元格添加超链接 2.如果要对单元格里面个别字做成超链接,可以使用图形工具,设置一个图形在里面,对这个图形做超链接 参考:https://jingyan.baidu.com/article/ ...

  9. 2018-2019 20165226 网络对抗 Exp1+ 逆向进阶

    2018-2019 20165226 网络对抗 Exp1+ 逆向进阶 目录 一.实验内容介绍 二.64位shellcode的编写及注入 三.ret2lib及rop的实践 四.问题与思考 一.实验内容介 ...

  10. Vue 组件以及生命周期函数

    组件相当于母版的功能 新建.vue文件,手动完善 <template><div>根节点</div></template> <script>& ...