DeepLearning4j (DL4J) 是一个开源的深度学习库,专为 Java 和 Scala 设计。它可以用于构建、训练和部署深度学习模型。以下是关于如何使用 DL4J 的基本指南以及一个简单的模型训练示例。

本例中使用了MNIST数据集,MNIST(modified national institute of standard and technology)数据集是由Yann LeCun及其同事于1994年创建一个大型手写数字数据库(包含0~9十个数字)。MNIST数据集的原始数据来源于美国国家标准和技术研究院(national institute of standard and technology)的两个数据集:special database 1和special database 3。它们分别由NIST的员工和美国高中生手写的0-9的数字组成。原始的这两个数据集由128×128像素的黑白图像组成。LeCun等人将其进行归一化和尺寸调整后得到的是28×28的灰度图像。

DeepLearning4j 使用指南

安装与配置

  1. 环境要求

    • Java Development Kit (JDK) 8 或以上版本
    • Maven(推荐)或 Gradle 用于项目管理
  2. 创建 Maven 项目

    在你的 IDE 中创建一个新的 Maven 项目,并在 pom.xml 文件中添加以下依赖:

    <dependencies>
    <!-- DL4J Core -->
    <dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-M1.1</version>
    </dependency>
    <!-- ND4J (Numpy for Java) -->
    <dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-M1.1</version>
    </dependency>
    <!-- DataVec for data preprocessing -->
    <dependency>
    <groupId>org.datavec</groupId>
    <artifactId>datavec-api</artifactId>
    <version>1.0.0-M1.1</version>
    </dependency>
    </dependencies>
  3. 更新 Maven 依赖

    确保你的 IDE 更新了 Maven 依赖,下载所需的库。

简单的模型训练

下面是一个使用 DL4J 训练简单神经网络的示例,目标是对手写数字进行分类(MNIST 数据集)。

代码示例
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions; public class MnistExample {
public static void main(String[] args) throws Exception {
// 加载 MNIST 数据集
DataSetIterator mnistTrain = new MnistDataSetIterator(128, true, 12345); // 配置神经网络
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Adam(0.001))
.list()
.layer(0, new DenseLayer.Builder().nIn(784).nOut(256)
.activation(Activation.RELU)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(256).nOut(10).build())
.build(); // 创建并初始化网络
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100)); // 每100次迭代输出一次分数 // 训练模型
for (int i = 0; i < 10; i++) { // 训练10个epoch
model.fit(mnistTrain);
} System.out.println("训练完成!"); // 加载 MNIST 测试数据集
DataSetIterator mnistTest = new MnistDataSetIterator(128, false, 12345); // 评估模型
double accuracy = model.evaluate(mnistTest).accuracy();
System.out.println("模型准确率: " + accuracy); // 保存模型到文件
File modelFile = new File("mnist_model.zip");
ModelSerializer.writeModel(model, modelFile, true);
}
}
代码说明
  1. 加载数据集:使用 MnistDataSetIterator 加载 MNIST 数据集。
  2. 配置神经网络
    • 使用 NeuralNetConfiguration.Builder 构建神经网络配置。
    • 添加输入层(DenseLayer)和输出层(OutputLayer)。
  3. 创建和初始化模型:使用 MultiLayerNetwork 创建模型并初始化。
  4. 训练模型:通过循环调用 fit() 方法训练模型。

运行示例

确保你的环境已正确设置,然后运行上述代码。模型将在 MNIST 数据集上进行训练,训练完成后会输出“训练完成!”的信息。

模型评估

在训练完模型后,通常需要对其进行评估,以了解模型在未见数据上的表现。你可以使用测试集来评估模型的准确性和其他性能指标。

保存与加载模型

训练完成后,你可能希望保存模型以便以后使用。DL4J 提供了简单的方法来保存和加载模型。

调整与优化模型

根据评估结果,你可能需要调整模型的超参数或架构。可以尝试以下方法:

  • 增加层数或节点数:增加模型的复杂性。
  • 改变学习率:试验不同的学习率以找到最佳值。
  • 使用不同的激活函数:例如,尝试 LeakyReLUELU
  • 正则化:添加 Dropout 层或 L2 正则化以防止过拟合。

部署模型

如果你打算将模型应用于生产环境,可以考虑将其部署为服务。可以使用以下方式之一:

  • REST API:将模型包装为 RESTful 服务,方便客户端调用。
  • 嵌入式应用:将模型嵌入到 Java 应用程序中,直接进行预测。

模型的测试

使用 Java 和 DeepLearning4j 来训练自己的手写数字图像(例如 0 到 9 的标准图像)是一个很好的项目。下面是一个简单的步骤指南,帮助你实现这个目标。

步骤概述

  1. 准备数据:将你的数字图像准备为合适的格式。
  2. 创建和配置模型:使用 DeepLearning4j 创建神经网络模型。
  3. 训练模型:使用你的图像数据训练模型。
  4. 评估和测试模型:验证模型的性能。

1. 准备数据

首先,你需要将你的 0-9 数字图像准备好。假设你有 10 张图像,每张图像都是 28x28 像素的灰度图像,并且它们存储在本地文件系统中。

模型测试的步骤

步骤 1: 使用 MNIST 数据集训练模型

  1. 加载数据集:使用 MnistDataSetIterator 加载 MNIST 数据集。
  2. 构建模型:根据你的需求,构建一个适合的神经网络模型。
  3. 训练模型:使用 MNIST 数据集对模型进行训练。
  4. 保存模型:将训练好的模型保存到文件中(例如,保存为 .zip 文件)。

步骤 2: 准备手写数字图片

  1. 手写数字:自己手写一个数字 1,并拍照或扫描成图片。
  2. 预处理图片
    • 将图片转换为灰度图像。
    • 调整图片大小为 28x28 像素(MNIST 数据集中的标准尺寸)。
    • 对图像进行归一化处理(通常将像素值缩放到 [0, 1] 范围内)。

步骤 3: 比较手写数字与 MNIST 数据集

  1. 加载保存的模型:从 zip 文件中加载之前训练好的模型。
  2. 预测手写数字:将预处理后的手写数字图片输入到模型中进行预测。
  3. 输出结果:模型将输出手写数字的预测结果。你可以将这个结果与 MNIST 数据集中相应的标签进行比较。

注意事项

  • 数据预处理:确保手写数字的预处理方式与训练时一致,包括图像大小、颜色通道和归一化。
  • 模型评估:在比较之前,可以先在测试集上评估模型的性能,以确保其准确性。
  • 可视化结果:可以通过可视化工具(如 matplotlib)展示手写数字及其预测结果,以便更好地理解模型的表现。

示例代码

以下是一个简单的示例代码框架,展示了如何实现这些步骤

[MnistUtils.java]

/**
* @author lind
* @date 2025/1/7 14:27
* @since 1.0.0
*/
public class MnistUtils {
/**
* 将图像转换为灰度图像
*
* @param original
* @return
*/
private static BufferedImage convertToGrayscale(BufferedImage original) {
BufferedImage grayImage = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
Graphics g = grayImage.getGraphics();
g.drawImage(original, 0, 0, null);
g.dispose();
return grayImage;
} /**
* 调整图像大小
*
* @param original
* @param width
* @param height
* @return
*/
private static BufferedImage resizeImage(BufferedImage original, int width, int height) {
Image scaledImage = original.getScaledInstance(width, height, Image.SCALE_SMOOTH);
BufferedImage resizedImage = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
Graphics2D g2d = resizedImage.createGraphics();
g2d.drawImage(scaledImage, 0, 0, null);
g2d.dispose();
return resizedImage;
} /**
* 加载图像
*
* @param fileName
* @return
*/
public static INDArray loadGrayImg(String fileName) {
try {
// 1. 加载图片
BufferedImage originalImage = ImageIO.read(new File(fileName));
// 2. 转换为灰度图像
BufferedImage grayImage = convertToGrayscale(originalImage);
// 3. 调整大小为 28x28 像素
BufferedImage resizedImage = resizeImage(grayImage, 28, 28);
// 4. 进行归一化处理
return normalizeImage(resizedImage);
} catch (IOException e) {
e.printStackTrace();
}
return null;
} /**
* 对图像进行归一化处理并生成 INDArray
*
* @param image
* @return
*/
private static INDArray normalizeImage(BufferedImage image) {
int width = image.getWidth();
int height = image.getHeight();
double[] normalizedData = new double[width * height]; // 创建一维数组 for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
// 获取灰度值(0-255)
int grayValue = image.getRGB(x, y) & 0xFF; // 只取灰度部分
// 归一化到 [0, 1] 范围
normalizedData[y * width + x] = grayValue / 255.0; // 填充一维数组
}
} // 将一维数组转换为 INDArray,并添加批次维度
INDArray indArray = Nd4j.create(normalizedData).reshape(1, 784); // reshape to [1, 784]
return indArray;
}
}

[MnistTest.java]

public static void main(String[] args) throws IOException {
// 加载已训练的模型
MultiLayerNetwork model = MultiLayerNetwork.load(new File("E:\\github\\lind-deeplearning4j\\mnist_model.zip"), true);
// 测试图像路径
String testImagePath = "d:\\dlj4\\img\\";
// 假设你有10个测试图像,命名为 0.png 到 9.png,当我从MNIST数据集网站下载9张图片后,这个大模型确实可以给我识别出来
for (int i = 0; i <= 3; i++) {
String fileName = testImagePath + i + ".png";
System.out.println("fileName=" + fileName);
INDArray testImage = loadGrayImg(fileName);
INDArray output = model.output(testImage); // 进行预测 // 获取预测结果
int predictedClass = Nd4j.argMax(output, 1).getInt(0);
System.out.println("测试图像 " + i + " 的预测结果: " + predictedClass);
}
}

模型测试结果,它会根据0-3的图片,将图片上面的数字分析出来,这个事实上是根据我们训练的MINIST数据集得到的结果

deeplearning4j~实现简单模型训练和测试的更多相关文章

  1. Windows下mnist数据集caffemodel分类模型训练及测试

    1. MNIST数据集介绍 MNIST是一个手写数字数据库,样本收集的是美国中学生手写样本,比较符合实际情况,大体上样本是这样的: MNIST数据库有以下特性: 包含了60000个训练样本集和1000 ...

  2. fcn模型训练及测试

    1.模型下载 1)下载新版caffe: https://github.com/BVLC/caffe 2)下载fcn代码: https://github.com/shelhamer/fcn.berkel ...

  3. 搭建简单模型训练MNIST数据集

    # -*- coding = utf-8 -*- # @Time : 2021/3/16 # @Author : pistachio # @File : test1.py # @Software : ...

  4. Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)

    基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html  摘要 在前面的博文中,我详细介绍了Caf ...

  5. ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试

    http://www.cnblogs.com/denny402/p/5852983.html ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试   刚开始学习tf时,我们从 ...

  6. windows+caffe(四)——创建模型并编写配置文件+训练和测试

    1.模型就用程序自带的caffenet模型,位置在 models/bvlc_reference_caffenet/文件夹下, 将需要的两个配置文件,复制到myfile文件夹内 2. 修改solver. ...

  7. 超简单!pytorch入门教程(五):训练和测试CNN

    我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...

  8. 使用caffemodel模型(由mnist训练)测试单张手写数字样本

    caffe中训练和测试mnist数据集都是批处理,可以反馈识别率,但是看不到单张样本的识别效果,这里使用windows自带的画图工具手写制作0~9的测试数字,然后使用caffemodel模型识别. 1 ...

  9. 【新人赛】阿里云恶意程序检测 -- 实践记录10.13 - Google Colab连接 / 数据简单查看 / 模型训练

    1. 比赛介绍 比赛地址:阿里云恶意程序检测新人赛 这个比赛和已结束的第三届阿里云安全算法挑战赛赛题类似,是一个开放的长期赛. 2. 前期准备 因为训练数据量比较大,本地CPU跑不起来,所以决定用Go ...

  10. Caffe学习系列(12):训练和测试自己的图片

    学习caffe的目的,不是简单的做几个练习,最终还是要用到自己的实际项目或科研中.因此,本文介绍一下,从自己的原始图片到lmdb数据,再到训练和测试模型的整个流程. 一.准备数据 有条件的同学,可以去 ...

随机推荐

  1. 游戏引擎数学库 Plane

    0 前言 平面的表达方式有很多,常用的就两种.向量形式的点法式,标量形式的平面方程.两者可以互相转化. \[(\mathbf{p}-\mathbf{p_0})\cdot\mathbf{n}=0 \] ...

  2. nginx配置tomcat的反向代理记录二,根据访问的路径跳转到不同端口的tomcat服务器

    实现效果:使用 nginx 反向代理,根据访问的路径跳转到不同端口的服务中. 设置nginx 监听端口为 9001,访问 http://192.168.17.129:9001/vod/ 直接跳转到 1 ...

  3. CF#795 D - Max GEQ Sum O(n)做法

    D - Max GEQ Sum input 3 4 -1 1 -1 2 5 -1 2 -3 2 -1 3 2 3 -1 output YES YES NO 题目大意:给你一个数组a,让你任意一个字段均 ...

  4. 使用PYNQ生成PWM波控制舵机/步进电机/机械臂

    使用PYNQ生成PWM波控制舵机/步进电机/机械臂 在开始这个工程之前,你需要PYNQ-Z2的板卡文件,约束文件,原理图作为参考,你可以在我上传的资源里下载. 当然,这个工程也适用于PYNQ-Z1,只 ...

  5. 基于surging的木舟平台如何构建起微服务

    一.概述 木舟平台分为微服务平台和物联网平台, 上面几篇都是介绍如何通过网络组件接入设备,那么此篇文章就细致介绍下在木舟平台下如何构建微服务. 木舟 (Kayak) 是什么? 木舟(Kayak)是基于 ...

  6. Air780E如何发送SMS?一文详解!

    ​ 今天一起来学习使用合宙低功耗4G模组Air780E发送SMS短消息: 一.SMS简介 SMS(短消息服务,ShortMessageService)功能主要用于在蜂窝网络中传输短消息. 在4G网络中 ...

  7. 【一步步开发AI运动小程序】十五、AI运动识别中,如何判断人体站位的远近?

    [云智AI运动识别小程序插件],可以为您的小程序,赋于人体检测识别.运动检测识别.姿态识别检测AI能力.本地原生识别引擎,无需依赖任何后台或第三方服务,有着识别速度快.体验佳.扩展性强.集成快.成本低 ...

  8. python模块导入规则(相对导入和绝对导入)

    python模块可以相对导入和绝对导入,但这两者是不能随意替换使用的.本文主要讨论工作目录下模块之间的导入规则.其中相对导入前面有一个'.',表示从该脚本所在目录开始索引,而绝对导入前面没有'.',表 ...

  9. 反编译工具之Jadx

    jadx 是一款功能强大的反编译工具,使用起来简单方便(拖拽式操作),不光提供了命令行程序,还提供了 GUI 程序.一般情况下,我们直接使用 GUI 程序就可以了. jadx 支持 Windows.L ...

  10. 十亿手机号去重-BitSet

    思路:使用Java自带BitSet函数,将手机号分为两段(15555555555 -> 155+55555555) public class demo { public static void ...