在Java Web中使用Spark MLlib训练的模型
PMML是一种通用的配置文件,只要遵循标准的配置文件,就可以在Spark中训练机器学习模型,然后再web接口端去使用。目前应用最广的就是基于Jpmml来加载模型在javaweb中应用,这样就可以实现跨平台的机器学习应用了。
训练模型
首先在spark MLlib中使用mllib包下的逻辑回归训练模型:
import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
val training = spark.sparkContext
.parallelize(Seq("0,1 2 3 1", "1,2 4 1 5", "0,7 8 3 6", "1,2 5 6 9").map( line => LabeledPoint.parse(line)))
// Run training algorithm to build the model
val model = new LogisticRegressionWithLBFGS()
.setNumClasses(2)
.run(training)
val test = spark.sparkContext
.parallelize(Seq("0,1 2 3 1").map( line => LabeledPoint.parse(line)))
// Compute raw scores on the test set.
val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
// Get evaluation metrics.
val metrics = new MulticlassMetrics(predictionAndLabels)
val accuracy = metrics.accuracy
println(s"Accuracy = $accuracy")
// Save and load model
// model.save(spark.sparkContext, "target/tmp/scalaLogisticRegressionWithLBFGSModel")
// val sameModel = LogisticRegressionModel.load(spark.sparkContext,"target/tmp/scalaLogisticRegressionWithLBFGSModel")
model.toPMML(spark.sparkContext, "/tmp/xhl/data/test2")
训练得到的模型保存到hdfs。
PMML模型文件
模型下载到本地,重新命名为xml。
可以看到默认四个特征分别叫做feild_0
,field_1
...目标为target
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML version="4.2" xmlns="http://www.dmg.org/PMML-4_2">
<Header description="logistic regression">
<Application name="Apache Spark MLlib" version="2.2.0"/>
<Timestamp>2018-11-15T10:22:25</Timestamp>
</Header>
<DataDictionary numberOfFields="5">
<DataField name="field_0" optype="continuous" dataType="double"/>
<DataField name="field_1" optype="continuous" dataType="double"/>
<DataField name="field_2" optype="continuous" dataType="double"/>
<DataField name="field_3" optype="continuous" dataType="double"/>
<DataField name="target" optype="categorical" dataType="string"/>
</DataDictionary>
<RegressionModel modelName="logistic regression" functionName="classification" normalizationMethod="logit">
<MiningSchema>
<MiningField name="field_0" usageType="active"/>
<MiningField name="field_1" usageType="active"/>
<MiningField name="field_2" usageType="active"/>
<MiningField name="field_3" usageType="active"/>
<MiningField name="target" usageType="target"/>
</MiningSchema>
<RegressionTable intercept="0.0" targetCategory="1">
<NumericPredictor name="field_0" coefficient="-5.552297758753701"/>
<NumericPredictor name="field_1" coefficient="-1.4863480719075117"/>
<NumericPredictor name="field_2" coefficient="-5.7232298850417855"/>
<NumericPredictor name="field_3" coefficient="8.134075057437393"/>
</RegressionTable>
<RegressionTable intercept="-0.0" targetCategory="0"/>
</RegressionModel>
</PMML>
接口使用
在接口的web工程中引入maven jar:
<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator -->
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.3</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator-extension -->
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.3</version>
</dependency>
接口代码中直接读取pmml,使用模型进行预测:
package soundsystem;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public class PMMLDemo2 {
private Evaluator loadPmml(){
PMML pmml = new PMML();
try(InputStream inputStream = new FileInputStream("/Users/xingoo/Desktop/test2.xml")){
pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
} catch (Exception e) {
e.printStackTrace();
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
return modelEvaluatorFactory.newModelEvaluator(pmml);
}
private Object predict(Evaluator evaluator,int a, int b, int c, int d) {
Map<String, Integer> data = new HashMap<String, Integer>();
data.put("field_0", a);
data.put("field_1", b);
data.put("field_2", c);
data.put("field_3", d);
List<InputField> inputFields = evaluator.getInputFields();
//过模型的原始特征,从画像中获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = data.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName();
ProbabilityDistribution target = (ProbabilityDistribution) results.get(targetFieldName);
System.out.println(a + " " + b + " " + c + " " + d + ":" + target);
return target;
}
public static void main(String args[]){
PMMLDemo2 demo = new PMMLDemo2();
Evaluator model = demo.loadPmml();
demo.predict(model,2,5,6,8);
demo.predict(model,7,9,3,6);
demo.predict(model,1,2,3,1);
demo.predict(model,2,4,1,5);
}
}
得到输出内容:
2 5 6 8:ProbabilityDistribution{result=1, probability_entries=[1=0.9999949538769296, 0=5.046123070395758E-6]}
7 9 3 6:ProbabilityDistribution{result=0, probability_entries=[1=1.1216598160542013E-9, 0=0.9999999988783402]}
1 2 3 1:ProbabilityDistribution{result=0, probability_entries=[1=2.363331367481431E-8, 0=0.9999999763666864]}
2 4 1 5:ProbabilityDistribution{result=1, probability_entries=[1=0.9999999831203591, 0=1.6879640907241367E-8]}
其中result为LR最终的结果,概率为二分类的概率。
参考资料
- 官方文档:https://openscoring.io/
- JPMML官方文档:https://github.com/jpmml/jpmml-evaluator
- jpmml-sklearn:https://github.com/jpmml/jpmml-sklearn
- jpmml-sparkml:https://github.com/jpmml/jpmml-sparkml/tree/master
- 用PMML实现机器学习模型的跨平台上线:http://www.cnblogs.com/pinard/p/9220199.html
- PMML模型文件在机器学习的实践经验:https://blog.csdn.net/hopeztm/article/details/78321700
在Java Web中使用Spark MLlib训练的模型的更多相关文章
- Java Web 中 过滤器与拦截器的区别
过滤器,是在java web中,你传入的request,response提前过滤掉一些信息,或者提前设置一些参数,然后再传入servlet或者struts的 action进行业务逻辑,比如过滤掉非法u ...
- JAVA WEB 中的编码分析
JAVA WEB 中的编码分析 */--> pre.src {background-color: #292b2e; color: #b2b2b2;} pre.src {background-co ...
- Java web中常见编码乱码问题(一)
最近在看Java web中中文编码问题,特此记录下. 本文将会介绍常见编码方式和Java web中遇到中文乱码问题的常见解决方法: 一.常见编码方式: 1.ASCII 码 众所周知,这是最简单的编码. ...
- Java web中常见编码乱码问题(二)
根据上篇记录Java web中常见编码乱码问题(一), 接着记录乱码案例: 案例分析: 2.输出流写入内容或者输入流读取内容时乱码(内容中有中文) 原因分析: a. 如果是按字节写入或读取时乱码, ...
- 深入分析Java Web中的编码问题
编码问题一直困扰着我,每次遇到乱码或者编码问题,网上一查,问题解决了,但是实际的原理并没有搞懂,每次遇到,都是什么头疼. 决定彻彻底底的一次性解决编码问题. 1.为什么要编码 计算机的基本单元是字节, ...
- 解决java web中safari浏览器下载后文件中文乱码问题
解决java web中safari浏览器下载后文件中文乱码问题 String fileName = "测试文件.doc"; String userAgent = request.g ...
- Java Web 中使用ffmpeg实现视频转码、视频截图
Java Web 中使用ffmpeg实现视频转码.视频截图 转载自:[ http://www.cnblogs.com/dennisit/archive/2013/02/16/2913287.html ...
- java web中servlet、jsp、html 互相访问的路径问题
java web中servlet.jsp.html 互相访问的路径问题 在java web种经常出现 404找不到网页的错误,究其原因,一般是访问的路径不对. java web中的路径使用按我的分法可 ...
- java web 中 读取windows图标并显示
java web中读取windows对应文件名的 系统图标 ....显示 1.获取系统图标工具类 package utils; import java.awt.Graphics; import j ...
随机推荐
- SQL Server截取字符串
--SQL Server截取字符串 , Len('hello@163.com')) ,charindex('.','hello@163.com'))
- 使用swiper插件,隐藏swiper后再显示,不会触发自动播放的解决办法
问题: 项目中有一个需求,当点击P1时,两个页面进行轮播.当点击P2时,页面不轮播. 设置好以后,点击P2,再点击P1,此时页面不能自动轮播,只能手动触发. 解决: 在轮播器配置里,配置observe ...
- spring @transactional 注解事务
1.在spring配置文件中引入<tx:>命名空间 <beans xmlns="http://www.springframework.org/schema/beans&qu ...
- docker安装portainer
安装好docker之后,可以使用portainer对容器进到管理 docker安装portainer命令 #这一步可以省略,直接运行可以下一条docker pull portainer #因为dock ...
- P1880 [NOI1995]石子合并-(环形区间dp)
https://www.luogu.org/problemnew/show/P1880 解题过程:本次的题目把石子围成一个环,与排成一列的版本有些不一样,可以在后面数组后面再接上n个元素,表示连续n个 ...
- html 提取 公用部分
在写HTML时,总会遇到一些公用部分,如果每个页面都写那就很麻烦,并且代码量大大增加. 网上查询了几种方法: 1.es6 的 embed 标签. <embed src="header. ...
- 业务数据实体(model) 需要克隆的方法
业务数据实体(model) 需要克隆的时候 可以使用 Json.Deserialize<InquireResult>(Json.Serialize<InquireResult> ...
- 微软Office Online服务安装部署(三)
现在开始配置两台服务器,两台服务器的IP: Server: 10.1.3.89 Client: 10.1.3.92 1.在Client中,.打开网络属性,找到ipv4的配置,将dns 改成域控制器的 ...
- 集群环境下定时调度的解决方案之Quartz集群
集群环境可能出现的问题 在上一篇博客我们介绍了如何在自己的项目中从无到有的添加了Quartz定时调度引擎,其实就是一个Quartz 和Spring的整合过程,很容易实现,但是我们现在企业中项目通常都是 ...
- CentOS 使用yum命令安装出现错误提示”could not retrieve mirrorlist http://mirrorlist.centos.org ***”
执行yum命令时出现以上错误; 解决方法: vi /etc/sysconfig/network-scripts/ifcfg-eth0 这一段为你的网卡修改图中框框部分 然后重启 :reboot