fm_model是libFM生成的模型

model.ckpt是可以tensorflow serving的模型结构

亲测输出正确。

代码:

 import tensorflow as tf

 # libFM model
def load_fm_model(file_name):
state = ''
fid = 0
max_fid = 0
w0 = 0.0
wj = {}
v = {}
k = 0
with open(file_name) as f:
for line in f:
line = line.rstrip()
if 'global bias W0' in line:
state = 'w0'
fid = 0
continue
elif 'unary interactions Wj' in line:
state = 'wj'
fid = 0
continue
elif 'pairwise interactions Vj,f' in line:
state = 'v'
fid = 0
continue if state == 'w0':
fv = float(line)
w0 = fv
elif state == 'wj':
fv = float(line)
if fv != 0:
wj[fid] = fv
fid += 1
max_fid = max(max_fid, fid)
elif state == 'v':
fv = [float(_v) for _v in line.split(' ')]
k = len(fv)
if any([_v!=0 for _v in fv]):
v[fid] = fv
fid += 1
max_fid = max(max_fid, fid)
return w0, wj, v, k, max_fid _w0, _wj, _v, _k, _max_fid = load_fm_model('libfm_model_file') # max feature_id
n = _max_fid
print 'n', n # vector dimension
k = _k
print 'k', k # write fm algorithm
w0 = tf.constant(_w0)
w1c = tf.constant([_wj.get(fid, 0) for fid in xrange(n)], shape=[n])
w1 = tf.Variable(w1c)
#print 'w1', w1 vec = []
for fid in xrange(n):
vec.append(_v.get(fid, [0]*k))
w2c = tf.constant(vec, shape=[n,k])
w2 = tf.Variable(w2c)
print 'w2', w2 # inputs
x = tf.placeholder(tf.string, [None])
batch = tf.shape(x)[0]
x_s = tf.string_split(x)
inds = tf.stack([tf.cast(x_s.indices[:,0], tf.int64), tf.string_to_number(x_s.values, tf.int64)], axis=1)
x_sparse = tf.sparse.SparseTensor(indices=inds, values=tf.ones([tf.shape(inds)[0]]), dense_shape=[batch,n])
x_ = tf.sparse.to_dense(x_sparse) w2_rep = tf.reshape(tf.tile(w2, [batch,1]), [-1,n,k])
print 'w2_rep', w2_rep x_rep = tf.reshape(tf.tile(tf.reshape(x_, [batch*n, 1]), [1,k]), [-1,n,k])
print 'x_rep', x_rep
x_rep2 = tf.square(x_rep) #print tf.multiply(w2_rep,x_rep)
#print tf.reduce_sum(tf.multiply(w2_rep,x_rep), axis=1)
q = tf.square(tf.reduce_sum(tf.multiply(w2_rep, x_rep), axis=1))
h = tf.reduce_sum(tf.multiply(tf.square(w2_rep), x_rep2), axis=1) y = w0 + tf.reduce_sum(tf.multiply(x_, w1), axis=1) +\
1.0/2 * tf.reduce_sum(q-h, axis=1) saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#a = sess.run(y, feed_dict={x_:x_train,y_:y_train,batch:70})
#print a
save_path = "./model.ckpt"
tf.saved_model.simple_save(sess, save_path, inputs={"x": x}, outputs={"y": y})

参考:

https://blog.csdn.net/u010159842/article/details/78789355 (开头借鉴此文,但其有不少细节错误)

https://www.tensorflow.org/guide/saved_model

http://nowave.it/factorization-machines-with-tensorflow.html

将libFM模型变换成tensorflow可serving的形式的更多相关文章

  1. object detection模型转换成TensorFlow Lite,在Android应用

    环境 tensorflow = 1.12.0 bazel = 0.18.1 ubuntu = 16.04 python = 3.6.2 安装 bazel (0.18.1) 如果tensorflow是1 ...

  2. javascript将毫秒转换成hh:mm:ss的形式

    function formatMilliseconds(value) { var second = parseInt(value) / 1000; // second var minute = 0; ...

  3. sql数值显示成千分位分隔符的形式

    ), )--带小数点 ), ),'.00','')--不带小数点

  4. 将百度坐标转换的javascript api官方示例改写成传统的回调函数形式

    改写前: 百度地图中坐标转换的JavaScript API示例官方示例如下: var points = [new BMap.Point(116.3786889372559,39.90762965106 ...

  5. http://xx.xxx.xxx.xx:8080/把路径设置成http服务访问的形式

    1.官网下载python安装包(eg:python-3.6.3-embed-win32),并解压文件 2.配置环境变量 3.cmd里查看python版本并设置服务路径 4. 访问查看

  6. 【C/C++】任意大于1的整数分解成素数因子乘积的形式

    // #include<stdio.h> #include<math.h> #include<malloc.h> int isprime(long n); void ...

  7. 【python 数据结构】相同某个字段值的所有数据(整理成数组包字典的形式)

    class MonitoredKeywordMore(APIView): def post(self, request): try: # 设置原生命令并且请求数据 parents_asin = str ...

  8. 21个项目玩转深度学习:基于TensorFlow的实践详解02—CIFAR10图像识别

    cifar10数据集 CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集.一共包含 10 个类别的 ...

  9. 从锅炉工到AI专家(9)

    无监督学习 前面已经说过了无监督学习的概念.无监督学习在实际的工作中应用还是比较多见的. 从典型的应用上说,监督学习比较多用在"分类"上,利用给定的数据,做出一个决策,这个决策在有 ...

随机推荐

  1. c++基础之引用reference

    1.何为引用 简单来说就是,比如你换了个新名字,用新名字叫你,你也会答应 2.引用vs指针 -引用没有null,好比你说你换了个新名字,但是新名字是啥总得有点东西 -一旦引用被初始化后就不可以指到另外 ...

  2. HackTwelve 为背景添加圆角边框

    1.概要:     ShapeDrawable是一个为UI控件添加特效的好工具.这个技巧适用于那些可以添加背景的控件 2.添加圆角边框其实就是添加的背景那里不是直接添加图片,而是添加一个XML文件即可 ...

  3. ComicEnhancerPro 系列教程二十:用“文件比较”看有损、无损

    作者:马健邮箱:stronghorse_mj@hotmail.com 主页:http://www.comicer.com/stronghorse/ 发布:2017.07.23 教程二十:用“文件比较” ...

  4. .net core MVC Filters 过滤器介绍

    一.过滤器的优级依次介绍如下(逐次递减): Authorization Filter ->  Resource Filter -> Acton Filter -> Exception ...

  5. Socket 简易静态服务器 WPF MVVM模式(二)

    command类 标准来说,command会有三种模式,委托命令 准备命令 附加命令 1.DelegateCommand 2.RelayCommand 3.AttachbehaviorCommand ...

  6. Django rest framework框架——APIview源码分析

    一.什么是rest REST其实是一种组织Web服务的架构,而并不是我们想象的那样是实现Web服务的一种新的技术,更没有要求一定要使用HTTP.其目标是为了创建具有良好扩展性的分布式系统. 可用一句话 ...

  7. Azure ASM虚拟机部署反恶意软件-安全扩展

    Azure虚拟机,默认情况下没有安装杀毒软件.如果您有此需求可以通过Azure 扩展进行安装,有关Azure反恶意软件的官方说明请参考:https://docs.azure.cn/zh-cn/secu ...

  8. Unity---动画系统学习(5)---使用MatchTarget来匹配动画

    1. 介绍 做好了走.跑.转弯后,我们就需要来点更加高级的动画了. 我们使用自带动画学习笔记2中的FQVault动画,来控制人物FQ. 在动画学习笔记4的基础上添加Vault动画. 添加一个参数Vau ...

  9. Python3之configparser模块

    1. 简介 configparser用于配置文件解析,可以解析特定格式的配置文件,多数此类配置文件名格式为XXX.ini,例如mysql的配置文件.在python3.X中 模块名为configpars ...

  10. 品味Zookeeper之选举及数据一致性_3

    品味Zookeeper之选举及数据一致性 本文思维导图 前言 为了高可用和数据安全起见,zk集群一般都是由几个节点构成(由n/2+1,投票机制决定,肯定是奇数个节点).多节点证明它们之间肯定会有数据的 ...