Java 调用 PaddleDetection 模型
介绍
训练好的模型要给业务调用,deepjavalibrary/djl:Java 中与引擎无关的深度学习框架 (github.com) 可以完成这件事,它支持使用 Java 调用 PyTorch、TensorFlow、MXNet、ONNX、PaddlePaddle 等引擎的模型(也支持部分引擎的模型构建和训练),本文只介绍调用 PaddlePaddle 引擎的模型调用。
调用模型流程:
- 导出模型(我更喜欢 ONNX 格式,它在 CPU 上推理也挺快的,可以快速测试,但有的算子不支持导出),确认模型输入输出
- 编写 Java 加载模型以及处理输入输出的代码
PaddleDetection 模型导出
导出模型
Anaconda 配置一个 PaddleDetection 的环境,cpu 版本即可(paddlepaddle==2.2.2),下载 PaddleDetection 工程,修改工程中 configs/runtime.yml
的属性 use_gpu
为 false
。
下面以 configs/pphuman/pedestrian_yolov3/pedestrian_yolov3_darknet.yml
为例介绍整个流程,导出模型:
$ python tools/export_model.py -c configs/pphuman/pedestrian_yolov3/pedestrian_yolov3_darknet.yml -o weights=https://paddledet.bj.bcebos.com/models/pedestrian_yolov3_darknet.pdparams --output_dir pedestrian_yolov3_darknet
再转换为 ONNX:
$ paddle2onnx --model_dir pedestrian_yolov3_darknet/pedestrian_yolov3_darknet --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 11 --save_file pedestrianYolov3.onnx --enable_onnx_checker True
确认输入输出
在 PaddleDetection 模型导出教程 中查看模型输入输出参数,再通过 Netorn 打开前面导出的 ONNX 模型详细确认
Java 读取模型及推理
依赖
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<!--混合引擎,因为有的引擎 NDArray 不支持-->
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.onnxruntime</groupId>
<artifactId>onnxruntime-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!--使用 openpnp 的 opencv 加快图片读取-->
<dependency>
<groupId>ai.djl.opencv</groupId>
<artifactId>opencv</artifactId>
</dependency>
</dependencies>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>0.20.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
处理输入输出
确定输入参数为图片原形状 im_shape、图片(需要归一化)image、比例 scale_factor,输出为预测框和预测数量,参数详细说明见前面提到的 PaddleDetection 模型导出教程中的说明。
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
// 非批量输入输出应实现 NoBatchifyTranslator 接口,而不是 Translator
public class PedestrianTranslator implements NoBatchifyTranslator<Image, DetectedObjects> {
private final Pipeline pipeline;
private final float threshold;
private final List<String> classes;
private final float imageWidth = 608f;
private final float imageHeight = 608f;
public PedestrianTranslator(float threshold) {
// 定义图片预处理过程
pipeline = new Pipeline();
pipeline.add(new Resize((int) imageWidth, (int) imageHeight)) // resize 为模型图片输入格式,变成 608 * 608 * 3,HWC
.add(new ToTensor()) // HWC -> CHW
.add(new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f})) // 归一化
.add(array -> array.expandDims(0)); // CHW -> NCHW
// 预测阈值
this.threshold = threshold;
// 类别
classes = Collections.singletonList("pedestrian");
}
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
// 内存管理器,负责 NDArray 的内存回收
NDManager manager = ctx.getNDManager();
// 通过构造函数定义好的管道把图片转换到模型需要的图片格式。NDList 是一个集合,与 List<NDArray> 类似
NDList ndList = pipeline.transform(new NDList(input.toNDArray(manager, Image.Flag.COLOR)));
// 添加原图尺寸参数
ndList.add(0, manager.create(new float[]{input.getHeight(), input.getWidth()}).expandDims(0));
// 添加原图片尺寸与输入图片尺寸的比值
ndList.add(manager.create(new float[]{input.getHeight() / 608f, input.getWidth() / 608f}).expandDims(0));
return ndList;
}
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
// 获取第一个参数预测结果,第二个预测数量没什么用
NDArray result = list.get(0);
/*
result demo:
ND: (3, 6) cpu() float32
[[ 0. , 0.9759, 10.0805, 276.1631, 298.1623, 586.246 ],
[ 0. , 0.955 , 486.306 , 221.0572, 585.966 , 480.4897],
[ 0. , 0.8031, 295.0543, 206.104 , 395.3066, 485.3789],
]
*/
// 获取类别
int[] classIndices = result.get(":, 0").toType(DataType.INT32, true).flatten().toIntArray();
// 获取置信度
double[] probs = result.get(":, 1").toType(DataType.FLOAT64, true).toDoubleArray();
// 获取预测的目标数量
int detected = Math.toIntExact(probs.length);
// 获取矩形框左上角 x 坐标比例(第 2 列)
NDArray xMin = result.get(":, 2:3").clip(0, imageWidth).div(imageWidth);
// 获取矩形框左上角 y 坐标比例(第 3 列)
NDArray yMin = result.get(":, 3:4").clip(0, imageHeight).div(imageHeight);
// 获取矩形框右上角 x 坐标比例(第 4 列)
NDArray xMax = result.get(":, 4:5").clip(0, imageWidth).div(imageWidth);
// 获取矩形框右上角 y 坐标比例(第 5 列)
NDArray yMax = result.get(":, 5:6").clip(0, imageHeight).div(imageHeight);
// 转为可以直接绘制的数据,分别是矩形框左上角的 x 和 y 坐标、矩形框的宽和高,均为比例
float[] boxX = xMin.toFloatArray();
float[] boxY = yMin.toFloatArray();
float[] boxWidth = xMax.sub(xMin).toFloatArray();
float[] boxHeight = yMax.sub(yMin).toFloatArray();
// 封装成 DetectedObjects 对象输出
List<String> retClasses = new ArrayList<>(detected);
List<Double> retProbs = new ArrayList<>(detected);
List<BoundingBox> retBB = new ArrayList<>(detected);
for (int i = 0; i < detected; i++) {
// 类别不存在或者置信度低于预测阈值则跳过
if (classIndices[i] < 0 || probs[i] < threshold) {
continue;
}
retClasses.add(classes.get(0));
retProbs.add(probs[i]);
retBB.add(new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]));
}
return new DetectedObjects(retClasses, retProbs, retBB);
}
}
这里涉及的 NDArray 操作比较多,使用官方实现的 Transform 和 Pipeline 可以简化代码,不过手动调 NDImageUtils 更清晰。简单说几个 API:
- expandDims:增加维度,比如 Pipeline 的一个 Transform Lambda 将 CHW 前面加一个维度变成 NCHW
- get:查看 NDIndex API(方法注释上均有代码样例说明)、百度 numpy 索引切片或 NDArray 教程,搞懂
:
和,
- clip:限制数值,数值越界就取该方法传入的值
加载模型
import ai.djl.MalformedModelException;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.nio.file.Paths;
public class Models {
public static ZooModel<Image, DetectedObjects> getModel() throws ModelNotFoundException, MalformedModelException, IOException {
return Criteria.builder()
.optEngine("OnnxRuntime") // 选择引擎
.setTypes(Image.class, DetectedObjects.class) // 设置输入输出
.optModelPath(Paths.get("D:\\Repository\\Github\\PaddleDetection\\pedestrian_yolov3_darknet.onnx")) // 设置模型地址。Jar 包、Zip 包根据 API 自行配置
.optProgress(new ProgressBar()) // 进度条
.optTranslator(new PedestrianTranslator(.5f)) // 默认的转换器,不是线程安全的
.build().loadModel();
}
}
推理
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
public class Inference {
public static void main(String[] args) throws IOException, MalformedModelException, TranslateException, ModelNotFoundException {
String imageFilePath = "C:\\Users\\DELL\\Desktop\\2.png";
// 加载模型
try (ZooModel<Image, DetectedObjects> model = Models.getModel()) {
// 新建一个推理,使用 GPU
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor(Device.gpu())) {
Image image = ImageFactory.getInstance().fromFile(Paths.get(imageFilePath));
// 推理
DetectedObjects result = predictor.predict(image);
// 绘制矩形框
image.drawBoundingBoxes(result);
image.save(Files.newOutputStream(Paths.get("output.png")), "png");
}
}
}
}
CPU GPU 配置
没有配置 cuda 的话自动下载 CPU 所需的文件,有 cuda 的话会自动寻找匹配 cuda 版本的文件,目前官网上的 cuda 版本是 10.2 和 11.2。
也可以通过配置 jar 来指定 CPU 还是 GPU,以 ONNX 为例(详见DJL Hybrid engines ONNX):
<dependency>
<groupId>ai.djl.onnxruntime</groupId>
<artifactId>onnxruntime-engine</artifactId>
<version>0.20.0</version>
<scope>runtime</scope>
<exclusions>
<exclusion>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime_gpu</artifactId>
<version>1.13.1</version>
<scope>runtime</scope>
</dependency>
注意
- 最需要知道的是导出的模型的输入和输出,否则不知道怎么写 Translator
- DJL 运行所需的文件挺大的,它会在第一次运行时下载,网卡流量在动就等会吧(在
/${HOME}/.djl.ai/
下) - 通常第一次推理比较慢,建议预热一次
- 多线程建议每个线程一个 Predictor
Jupyter Notebook
附上可以直接运行的 notebook:d2l/paddledetection.ipynb at master · hligaty/d2l (github.com)。Maven 下载依赖比较慢,建议手动下载依赖放到 /${HOME}/.ivy2/cache/
下。
参考与推荐
PaddleDetection 模型导出为 ONNX 格式教程
frankfliu/IJava:用于执行Java代码的Jupyter内核
Java 调用 PaddleDetection 模型的更多相关文章
- 机器学习——Java调用sklearn生成好的Logistic模型进行鸢尾花的预测
机器学习是python语言的长处,而Java在web开发方面更具有优势,如何通过java来调用python中训练好的模型进行在线的预测呢?在java语言中去调用python构建好的模型主要有三种方法: ...
- Java进阶(五)Java I/O模型从BIO到NIO和Reactor模式
原创文章,同步发自作者个人博客,http://www.jasongj.com/java/nio_reactor/ Java I/O模型 同步 vs. 异步 同步I/O 每个请求必须逐个地被处理,一个请 ...
- java线程内存模型,线程、工作内存、主内存
转自:http://rainyear.iteye.com/blog/1734311 java线程内存模型 线程.工作内存.主内存三者之间的交互关系图: key edeas 所有线程共享主内存 每个线程 ...
- java 调用webservice的各种方法总结
java 调用webservice的各种方法总结 几种流行的开源WebService框架Axis1,Axis2,Xfire,CXF,JWS比较 方法一:创建基于JAX-WS的webservice(包括 ...
- Java虚拟机内存模型及垃圾回收监控调优
Java虚拟机内存模型及垃圾回收监控调优 如果你想理解Java垃圾回收如果工作,那么理解JVM的内存模型就显的非常重要.今天我们就来看看JVM内存的各不同部分及如果监控和实现垃圾回收调优. JVM内存 ...
- Indri查询命令及Java调用并保存结果
查询参数 index Indri索引库路径.在参数文件中像/path/to/repository这样指定,在命令行中像-index=/path/to/repository这样指定.该参数可以设置多次来 ...
- Java I/O 模型的演进
什么是同步?什么是异步?阻塞和非阻塞又有什么区别?本文先从 Unix 的 I/O 模型讲起,介绍了5种常见的 I/O 模型.而后再引出 Java 的 I/O 模型的演进过程,并用实例说明如何选择合适的 ...
- Java虚拟机--内存模型与线程
Java虚拟机--内存模型与线程 高速缓存:处理器要与内存交互,如读取.存储运算结果,而计算机的存储设备和处理器的运算速度差异巨大,所以加入一层读写速度和处理器接近的高速缓存来作为内存和处理器之间的缓 ...
- Java网络编程和NIO详解3:IO模型与Java网络编程模型
Java网络编程和NIO详解3:IO模型与Java网络编程模型 基本概念说明 用户空间与内核空间 现在操作系统都是采用虚拟存储器,那么对32位操作系统而言,它的寻址空间(虚拟存储空间)为4G(2的32 ...
- OSGI 面向Java的动态模型系统
OSGI (面向Java的动态模型系统) OSGi(Open Service Gateway Initiative)技术是Java动态化模块化系统的一系列规范.OSGi一方面指维护OSGi规范的OSG ...
随机推荐
- 关于解决windows安装gcc g++环境 mingw失败
前言 这几天学习c++,为了详细了解编译过程我没有安装vs全家桶,当然使用命令行是最好的方法. 但是为了解决这个网络问题折腾了我很久,经过我研究发现,其实就是到固定网站下载几个压缩格式的文件,然后解压 ...
- MyEclipse连接MySQL
在官网http://www.mysql.com/downloads/下载数据库连接驱动 本文中使用驱动版本为mysql-connector-java-5.1.40 一.创建一个java测试项目MySQ ...
- python-py文件打包成exe可执行文件
方法一::打包完成后可以直接被他人使用,他人不用安装python环境的 可以使用pyinstaller模块实现将python项目打包成exe执行文件 """ 先安装模块 ...
- java反射基础知识整理
目录 1.反射机制的作用 2.获取一个类的实例 3.使用Class.forName()方法加载类的静态代码块 4.获取配置文件的路径 5.java反编译 5.1.获取类中的成员变量 5.2.通过类名反 ...
- 求和【第十三届蓝桥杯省赛C++A/C组 , 第十三届蓝桥杯省赛JAVAA组】
求和 给定 \(n\) 个整数 \(a1,a2,⋅⋅⋅,an\),求它们两两相乘再相加的和,即 \(S=a1⋅a2+a1⋅a3+⋅⋅⋅+a1⋅an+a2⋅a3+⋅⋅⋅+an−2⋅an−1+an−2⋅a ...
- 安装pytorch-gpu的经验与教训
首先说明 本文并不是安装教程,网上有很多,这里只是自己遇到的一些问题 我是以前安装的tensorflow-gpu的,但是发现现在的学术论文大部分都是用pytorch复现的,因此才去安装的pytorch ...
- Joplin修改笔记存储位置
默认存储路径 笔记的默认保存位置可以通过 工具 > 选项 > 通用选项 ,在最上方可以看到路径 使用Windows快捷方式启动 在Joplin的快捷方式上右击,选择属性,然后选择快捷方式选 ...
- [阿里云]I+的一些探索
I+是阿里云的关系网络分析,万物皆可联 使用中遇到的一些问题,特记录如下: 1.添加数据源 这个数据源是用于数据落地的存储,所以一定要选择<是> 2.配置对象信息 这一步就像是创建一个表来 ...
- SOFAJRaft依赖框架Disruptor浅析
Disruptor是英国外汇交易公司LMAX开发的一个高性能队列,研发的初衷是解决内存队列的延迟问题.与Kafka.RabbitMQ用于服务间的消息队列不同,disruptor一般用于线程间消息的传递 ...
- do while 出口條件循環