在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法优化的PMML文件大多数时候很笨拙,因此本文我们专门讨论下tensorflow机器学习模型的跨平台上线的方法。

1. tensorflow模型的跨平台上线的备选方案

    tensorflow模型的跨平台上线的备选方案一般有三种:即PMML方式,tensorflow serving方式,以及跨语言API方式。

    PMML方式的主要思路在上一篇以及讲过。这里唯一的区别是转化生成PMML文件需要用一个Java库jpmml-tensorflow来完成,生成PMML文件后,跨语言加载模型和其他PMML模型文件基本类似。

    tensorflow serving是tensorflow 官方推荐的模型上线预测方式,它需要一个专门的tensorflow服务器,用来提供预测的API服务。如果你的模型和对应的应用是比较大规模的,那么使用tensorflow serving是比较好的使用方式。但是它也有一个缺点,就是比较笨重,如果你要使用tensorflow serving,那么需要自己搭建serving集群并维护这个集群。所以为了一个小的应用去做这个工作,有时候会觉得麻烦。

    跨语言API方式是本文要讨论的方式,它会用tensorflow自己的Python API生成模型文件,然后用tensorflow的客户端库比如Java或C++库来做模型的在线预测。下面我们会给一个生成生成模型文件并用tensorflow Java API来做在线预测的例子。

2. 训练模型并生成模型文件

    我们这里给一个简单的逻辑回归并生成逻辑回归tensorflow模型文件的例子。

    完整代码参见我的github:https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/tensorflow-java

    首先,我们生成了一个6特征,3分类输出的4000个样本数据。

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.datasets.samples_generator import make_classification
import tensorflow as tf
X1, y1 = make_classification(n_samples=4000, n_features=6, n_redundant=0,
n_clusters_per_class=1, n_classes=3)

    接着我们构建tensorflow的数据流图,这里要注意里面的两个名字,第一个是输入x的名字input,第二个是输出prediction_labels的名字output,这里的这两个名字可以自己取,但是后面会用到,所以要保持一致。

learning_rate = 0.01
training_epochs = 600
batch_size = 100 x = tf.placeholder(tf.float32, [None, 6],name='input') # 6 features
y = tf.placeholder(tf.float32, [None, 3]) # 3 classes W = tf.Variable(tf.zeros([6, 3]))
b = tf.Variable(tf.zeros([3])) # softmax回归
pred = tf.nn.softmax(tf.matmul(x, W) + b, name="softmax")
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) prediction_labels = tf.argmax(pred, axis=1, name="output") init = tf.global_variables_initializer()

    接着就是训练模型了,代码比较简单,毕竟只是一个演示:

sess = tf.Session()
sess.run(init)
y2 = tf.one_hot(y1, 3)
y2 = sess.run(y2) for epoch in range(training_epochs): _, c = sess.run([optimizer, cost], feed_dict={x: X1, y: y2})
if (epoch+1) % 10 == 0:
print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c)) print ("优化完毕!")
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y2, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: X1, y: y2})
print (acc)

    打印输出我这里就不写了,大家可以自己去试一试。接着就是关键的一步,存模型文件了,注意要用convert_variables_to_constants这个API来保存模型,否则模型参数不会随着模型图一起存下来。

graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
tf.train.write_graph(graph, '.', 'rf.pb', as_text=False)

    至此,我们的模型文件rf.pb已经被保存下来了,下面就是要跨平台上线了。 

3. 模型文件在Java平台上线

    这里我们以Java平台的模型上线为例,C++的API上线我没有用过,这里就不写了。我们需要引入tensorflow的java库到我们工程的maven或者gradle文件。这里给出maven的依赖如下,版本可以根据实际情况选择一个较新的版本。

        <dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.7.0</version>
</dependency>

    接着就是代码了,这个代码会比JPMML的要简单,我给出了4个测试样本的预测例子如下,一定要注意的是里面的input和output要和训练模型的时候对应的节点名字一致。

import org.tensorflow.*;
import org.tensorflow.Graph; import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths; /**
* Created by 刘建平pinard on 2018/7/1.
*/
public class TFjavaDemo {
public static void main(String args[]){
byte[] graphDef = loadTensorflowModel("D:/rf.pb");
float inputs[][] = new float[4][6];
for(int i = 0; i< 4; i++){
for(int j =0; j< 6;j++){
if(i<2) {
inputs[i][j] = 2 * i - 5 * j - 6;
}
else{
inputs[i][j] = 2 * i + 5 * j - 6;
}
}
}
Tensor<Float> input = covertArrayToTensor(inputs);
Graph g = new Graph();
g.importGraphDef(graphDef);
Session s = new Session(g);
Tensor result = s.runner().feed("input", input).fetch("output").run().get(0); long[] rshape = result.shape();
int rs = (int) rshape[0];
long realResult[] = new long[rs];
result.copyTo(realResult); for(long a: realResult ) {
System.out.println(a);
}
}
static private byte[] loadTensorflowModel(String path){
try {
return Files.readAllBytes(Paths.get(path));
} catch (IOException e) {
e.printStackTrace();
}
return null;
} static private Tensor<Float> covertArrayToTensor(float inputs[][]){
return Tensors.create(inputs);
}
}

    我的预测输出是1,1,0,0,供大家参考。

4. 一点小结

    对于tensorflow来说,模型上线一般选择tensorflow serving或者client API库来上线,前者适合于较大的模型和应用场景,后者则适合中小型的模型和应用场景。因此算法工程师使用在产品之前需要做好选择和评估。

(欢迎转载,转载请注明出处。欢迎沟通交流: liujianping-ok@163.com)

tensorflow机器学习模型的跨平台上线的更多相关文章

  1. 用PMML实现机器学习模型的跨平台上线

    在机器学习用于产品的时候,我们经常会遇到跨平台的问题.比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环 ...

  2. 用PMML实现python机器学习模型的跨平台上线

    python信用评分卡(附代码,博主录制) https://study.163.com/course/introduction.htm?courseId=1005214003&utm_camp ...

  3. Kubernetes入门(四)——如何在Kubernetes中部署一个可对外服务的Tensorflow机器学习模型

    机器学习模型常用Docker部署,而如何对Docker部署的模型进行管理呢?工业界的解决方案是使用Kubernetes来管理.编排容器.Kubernetes的理论知识不是本文讨论的重点,这里不再赘述, ...

  4. 使用pmml实现跨平台部署机器学习模型

    一.概述   对于由Python训练的机器学习模型,通常有pickle和pmml两种部署方式,pickle方式用于在python环境中的部署,pmml方式用于跨平台(如Java环境)的部署,本文叙述的 ...

  5. 使用pmml跨平台部署机器学习模型Demo——房价预测

      基于房价数据,在python中训练得到一个线性回归的模型,在JavaWeb中加载模型完成房价预测的功能. 一. 训练.保存模型 工具:PyCharm-2017.Python-39.sklearn2 ...

  6. 利用tensorflow编写自己的机器学习模型主要步骤

    利用tensorlow编写自己的机器学习模型主要分为两个阶段: 第一阶段:建立模型或者建立网络结构 1.定义输入和输出所需要的占位符 2.定义权重 3.定义具体的模型接口 4.定义损失函数 5.定义优 ...

  7. Tensorflow Serving 模型部署和服务

    http://blog.csdn.net/wangjian1204/article/details/68928656 本文转载自:https://zhuanlan.zhihu.com/p/233614 ...

  8. 使用ML.NET + ASP.NET Core + Docker + Azure Container Instances部署.NET机器学习模型

    本文将使用ML.NET创建机器学习分类模型,通过ASP.NET Core Web API公开它,将其打包到Docker容器中,并通过Azure Container Instances将其部署到云中. ...

  9. TensorFlow机器学习实战指南之第二章

    一.计算图中的操作 在这个例子中,我们将结合前面所学的知识,传入一个列表到计算图中的操作,并打印返回值: 声明张量和占位符.这里,创建一个numpy数组,传入计算图操作: import tensorf ...

随机推荐

  1. Python 3.5 filter

    filter(F, L) F: 函数.L:范围 filter的功能是:用函数F把L范围内的参数做过滤 通常和list一起使用,把过滤后的参数做成列表 list(filter(lambda n:not ...

  2. typescript 安装

    1,全局安装 cnpm install typescript -g (tsc -v) 2,初始化 tsc --init 3,自动编译(hbuilder) 工具-插件安装-浏览eclipse插件市场-搜 ...

  3. js数组和对象相等判断、拷贝详解(结合几个现象讲解引用数据类型的趣事)

    序言 最近遇到几个js引用数据类型造成的bug,今天结合bug详细分析一下,避免以后再犯,也希望能帮大家提个醒,强化js基本功. 目录 1.浅拷贝.深拷贝,解决变量赋值相互影响问题 2.判断2个数组. ...

  4. 拼接SQL执行语句时,对单引号的处理

    例: declare @SQL nvarchar(1000); declare @str nvarchar(100); set @str='Joe''s NB'; // 打印出来的应该是这样:Joe' ...

  5. 在C++中,setw(int n)

    setw(int n)用来控制输出间隔例如:cout<<'s'<<setw(8)<<'a'<<endl;则在屏幕显示s        a //s与a之间 ...

  6. php 将一个或多个二维数组组合成一个二维数组并根据某个字段排序排序

    最近再写项目的时候,碰到一个问题:如何将一个或多个二维数组组合成一个二维数组并根据某个字段排序排序:实在是想不到哪个php库中有哪个函数能实现,只能自己写一个了,将代码写出来后,发现自己的代码繁琐,并 ...

  7. C语言面试题分类->排序算法

    1.选择排序. 每次将最小的数,与剩余数做比较.找到更小的,做交换. 时间复杂度:O(n²) 空间复杂度:O(1) 优缺点:耗时但内存空间使用小. void selectSort(int *p,int ...

  8. Java中的队列同步器AQS

    一.AQS概念 1.队列同步器是用来构建锁或者其他同步组件的基础框架,使用一个int型变量代表同步状态,通过内置的队列来完成线程的排队工作. 2.下面是JDK8文档中对于AQS的部分介绍 public ...

  9. Java Web每天学之Servlet的原理解析

    Java Web每天学之Servlet的工作原理解析,上海尚学堂Java技术文章Java Web系列之二上一篇文章Java Web每天学之Servlet的工作原理解析是之一,欢迎点击阅读. Servl ...

  10. SUSE12Sp3-.NET Core 2.2.1 runtime安装

    1.安装libicu依赖 1.在线安装 sudo mkdir /usr/local/dotnet #创建目录 cd /usr/local/dotnet sudo wget https://downlo ...