DeepLearning4j (DL4J) 是一个开源的深度学习库,专为 Java 和 Scala 设计。它可以用于构建、训练和部署深度学习模型。以下是关于如何使用 DL4J 的基本指南以及一个简单的模型训练示例。

本例中使用了MNIST数据集,MNIST(modified national institute of standard and technology)数据集是由Yann LeCun及其同事于1994年创建一个大型手写数字数据库(包含0~9十个数字)。MNIST数据集的原始数据来源于美国国家标准和技术研究院(national institute of standard and technology)的两个数据集:special database 1和special database 3。它们分别由NIST的员工和美国高中生手写的0-9的数字组成。原始的这两个数据集由128×128像素的黑白图像组成。LeCun等人将其进行归一化和尺寸调整后得到的是28×28的灰度图像。

DeepLearning4j 使用指南

安装与配置

  1. 环境要求

    • Java Development Kit (JDK) 8 或以上版本
    • Maven(推荐)或 Gradle 用于项目管理
  2. 创建 Maven 项目

    在你的 IDE 中创建一个新的 Maven 项目,并在 pom.xml 文件中添加以下依赖:

    <dependencies>
    <!-- DL4J Core -->
    <dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-M1.1</version>
    </dependency>
    <!-- ND4J (Numpy for Java) -->
    <dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-M1.1</version>
    </dependency>
    <!-- DataVec for data preprocessing -->
    <dependency>
    <groupId>org.datavec</groupId>
    <artifactId>datavec-api</artifactId>
    <version>1.0.0-M1.1</version>
    </dependency>
    </dependencies>
  3. 更新 Maven 依赖

    确保你的 IDE 更新了 Maven 依赖,下载所需的库。

简单的模型训练

下面是一个使用 DL4J 训练简单神经网络的示例,目标是对手写数字进行分类(MNIST 数据集)。

代码示例
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions; public class MnistExample {
public static void main(String[] args) throws Exception {
// 加载 MNIST 数据集
DataSetIterator mnistTrain = new MnistDataSetIterator(128, true, 12345); // 配置神经网络
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Adam(0.001))
.list()
.layer(0, new DenseLayer.Builder().nIn(784).nOut(256)
.activation(Activation.RELU)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(256).nOut(10).build())
.build(); // 创建并初始化网络
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100)); // 每100次迭代输出一次分数 // 训练模型
for (int i = 0; i < 10; i++) { // 训练10个epoch
model.fit(mnistTrain);
} System.out.println("训练完成!"); // 加载 MNIST 测试数据集
DataSetIterator mnistTest = new MnistDataSetIterator(128, false, 12345); // 评估模型
double accuracy = model.evaluate(mnistTest).accuracy();
System.out.println("模型准确率: " + accuracy); // 保存模型到文件
File modelFile = new File("mnist_model.zip");
ModelSerializer.writeModel(model, modelFile, true);
}
}
代码说明
  1. 加载数据集:使用 MnistDataSetIterator 加载 MNIST 数据集。
  2. 配置神经网络
    • 使用 NeuralNetConfiguration.Builder 构建神经网络配置。
    • 添加输入层(DenseLayer)和输出层(OutputLayer)。
  3. 创建和初始化模型:使用 MultiLayerNetwork 创建模型并初始化。
  4. 训练模型:通过循环调用 fit() 方法训练模型。

运行示例

确保你的环境已正确设置,然后运行上述代码。模型将在 MNIST 数据集上进行训练,训练完成后会输出“训练完成!”的信息。

模型评估

在训练完模型后,通常需要对其进行评估,以了解模型在未见数据上的表现。你可以使用测试集来评估模型的准确性和其他性能指标。

保存与加载模型

训练完成后,你可能希望保存模型以便以后使用。DL4J 提供了简单的方法来保存和加载模型。

调整与优化模型

根据评估结果,你可能需要调整模型的超参数或架构。可以尝试以下方法:

  • 增加层数或节点数:增加模型的复杂性。
  • 改变学习率:试验不同的学习率以找到最佳值。
  • 使用不同的激活函数:例如,尝试 LeakyReLUELU
  • 正则化:添加 Dropout 层或 L2 正则化以防止过拟合。

部署模型

如果你打算将模型应用于生产环境,可以考虑将其部署为服务。可以使用以下方式之一:

  • REST API:将模型包装为 RESTful 服务,方便客户端调用。
  • 嵌入式应用:将模型嵌入到 Java 应用程序中,直接进行预测。

模型的测试

使用 Java 和 DeepLearning4j 来训练自己的手写数字图像(例如 0 到 9 的标准图像)是一个很好的项目。下面是一个简单的步骤指南,帮助你实现这个目标。

步骤概述

  1. 准备数据:将你的数字图像准备为合适的格式。
  2. 创建和配置模型:使用 DeepLearning4j 创建神经网络模型。
  3. 训练模型:使用你的图像数据训练模型。
  4. 评估和测试模型:验证模型的性能。

1. 准备数据

首先,你需要将你的 0-9 数字图像准备好。假设你有 10 张图像,每张图像都是 28x28 像素的灰度图像,并且它们存储在本地文件系统中。

模型测试的步骤

步骤 1: 使用 MNIST 数据集训练模型

  1. 加载数据集:使用 MnistDataSetIterator 加载 MNIST 数据集。
  2. 构建模型:根据你的需求,构建一个适合的神经网络模型。
  3. 训练模型:使用 MNIST 数据集对模型进行训练。
  4. 保存模型:将训练好的模型保存到文件中(例如,保存为 .zip 文件)。

步骤 2: 准备手写数字图片

  1. 手写数字:自己手写一个数字 1,并拍照或扫描成图片。
  2. 预处理图片
    • 将图片转换为灰度图像。
    • 调整图片大小为 28x28 像素(MNIST 数据集中的标准尺寸)。
    • 对图像进行归一化处理(通常将像素值缩放到 [0, 1] 范围内)。

步骤 3: 比较手写数字与 MNIST 数据集

  1. 加载保存的模型:从 zip 文件中加载之前训练好的模型。
  2. 预测手写数字:将预处理后的手写数字图片输入到模型中进行预测。
  3. 输出结果:模型将输出手写数字的预测结果。你可以将这个结果与 MNIST 数据集中相应的标签进行比较。

注意事项

  • 数据预处理:确保手写数字的预处理方式与训练时一致,包括图像大小、颜色通道和归一化。
  • 模型评估:在比较之前,可以先在测试集上评估模型的性能,以确保其准确性。
  • 可视化结果:可以通过可视化工具(如 matplotlib)展示手写数字及其预测结果,以便更好地理解模型的表现。

示例代码

以下是一个简单的示例代码框架,展示了如何实现这些步骤

[MnistUtils.java]

/**
* @author lind
* @date 2025/1/7 14:27
* @since 1.0.0
*/
public class MnistUtils {
/**
* 将图像转换为灰度图像
*
* @param original
* @return
*/
private static BufferedImage convertToGrayscale(BufferedImage original) {
BufferedImage grayImage = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
Graphics g = grayImage.getGraphics();
g.drawImage(original, 0, 0, null);
g.dispose();
return grayImage;
} /**
* 调整图像大小
*
* @param original
* @param width
* @param height
* @return
*/
private static BufferedImage resizeImage(BufferedImage original, int width, int height) {
Image scaledImage = original.getScaledInstance(width, height, Image.SCALE_SMOOTH);
BufferedImage resizedImage = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
Graphics2D g2d = resizedImage.createGraphics();
g2d.drawImage(scaledImage, 0, 0, null);
g2d.dispose();
return resizedImage;
} /**
* 加载图像
*
* @param fileName
* @return
*/
public static INDArray loadGrayImg(String fileName) {
try {
// 1. 加载图片
BufferedImage originalImage = ImageIO.read(new File(fileName));
// 2. 转换为灰度图像
BufferedImage grayImage = convertToGrayscale(originalImage);
// 3. 调整大小为 28x28 像素
BufferedImage resizedImage = resizeImage(grayImage, 28, 28);
// 4. 进行归一化处理
return normalizeImage(resizedImage);
} catch (IOException e) {
e.printStackTrace();
}
return null;
} /**
* 对图像进行归一化处理并生成 INDArray
*
* @param image
* @return
*/
private static INDArray normalizeImage(BufferedImage image) {
int width = image.getWidth();
int height = image.getHeight();
double[] normalizedData = new double[width * height]; // 创建一维数组 for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
// 获取灰度值(0-255)
int grayValue = image.getRGB(x, y) & 0xFF; // 只取灰度部分
// 归一化到 [0, 1] 范围
normalizedData[y * width + x] = grayValue / 255.0; // 填充一维数组
}
} // 将一维数组转换为 INDArray,并添加批次维度
INDArray indArray = Nd4j.create(normalizedData).reshape(1, 784); // reshape to [1, 784]
return indArray;
}
}

[MnistTest.java]

public static void main(String[] args) throws IOException {
// 加载已训练的模型
MultiLayerNetwork model = MultiLayerNetwork.load(new File("E:\\github\\lind-deeplearning4j\\mnist_model.zip"), true);
// 测试图像路径
String testImagePath = "d:\\dlj4\\img\\";
// 假设你有10个测试图像,命名为 0.png 到 9.png,当我从MNIST数据集网站下载9张图片后,这个大模型确实可以给我识别出来
for (int i = 0; i <= 3; i++) {
String fileName = testImagePath + i + ".png";
System.out.println("fileName=" + fileName);
INDArray testImage = loadGrayImg(fileName);
INDArray output = model.output(testImage); // 进行预测 // 获取预测结果
int predictedClass = Nd4j.argMax(output, 1).getInt(0);
System.out.println("测试图像 " + i + " 的预测结果: " + predictedClass);
}
}

模型测试结果,它会根据0-3的图片,将图片上面的数字分析出来,这个事实上是根据我们训练的MINIST数据集得到的结果

deeplearning4j~实现简单模型训练和测试的更多相关文章

  1. Windows下mnist数据集caffemodel分类模型训练及测试

    1. MNIST数据集介绍 MNIST是一个手写数字数据库,样本收集的是美国中学生手写样本,比较符合实际情况,大体上样本是这样的: MNIST数据库有以下特性: 包含了60000个训练样本集和1000 ...

  2. fcn模型训练及测试

    1.模型下载 1)下载新版caffe: https://github.com/BVLC/caffe 2)下载fcn代码: https://github.com/shelhamer/fcn.berkel ...

  3. 搭建简单模型训练MNIST数据集

    # -*- coding = utf-8 -*- # @Time : 2021/3/16 # @Author : pistachio # @File : test1.py # @Software : ...

  4. Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)

    基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html  摘要 在前面的博文中,我详细介绍了Caf ...

  5. ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试

    http://www.cnblogs.com/denny402/p/5852983.html ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试   刚开始学习tf时,我们从 ...

  6. windows+caffe(四)——创建模型并编写配置文件+训练和测试

    1.模型就用程序自带的caffenet模型,位置在 models/bvlc_reference_caffenet/文件夹下, 将需要的两个配置文件,复制到myfile文件夹内 2. 修改solver. ...

  7. 超简单!pytorch入门教程(五):训练和测试CNN

    我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...

  8. 使用caffemodel模型(由mnist训练)测试单张手写数字样本

    caffe中训练和测试mnist数据集都是批处理,可以反馈识别率,但是看不到单张样本的识别效果,这里使用windows自带的画图工具手写制作0~9的测试数字,然后使用caffemodel模型识别. 1 ...

  9. 【新人赛】阿里云恶意程序检测 -- 实践记录10.13 - Google Colab连接 / 数据简单查看 / 模型训练

    1. 比赛介绍 比赛地址:阿里云恶意程序检测新人赛 这个比赛和已结束的第三届阿里云安全算法挑战赛赛题类似,是一个开放的长期赛. 2. 前期准备 因为训练数据量比较大,本地CPU跑不起来,所以决定用Go ...

  10. Caffe学习系列(12):训练和测试自己的图片

    学习caffe的目的,不是简单的做几个练习,最终还是要用到自己的实际项目或科研中.因此,本文介绍一下,从自己的原始图片到lmdb数据,再到训练和测试模型的整个流程. 一.准备数据 有条件的同学,可以去 ...

随机推荐

  1. dotnet core微服务框架Jimu介绍

    jimu是一个基于.Net6.0 简单易用的微服务框架,参考了很多开源库以及想法,使用了大量的开源库(如 DotNetty, consul.net, Flurl.Http, Json.net, Log ...

  2. OpenGL常用函数整理

    常用函数 颜色设置 glClear(GL_COLOR_BUFFER_BIT); //清空颜色,GL_COLOR_BUFFER_BIT是颜色缓冲区 glClearColor(R,G,B,A); //设置 ...

  3. Paimon lookup store 实现

    Lookup Store 主要用于 Paimon 中的 Lookup Compaction 以及 Lookup join 的场景. 会将远程的列存文件在本地转化为 KV 查找的格式. Hash htt ...

  4. 基于ctfshow的信息收集思路与CTF实战

    本文靶场来源于CTFshow,并不完全按照靶机的顺序排列,而是以测试操作为导向,按博主个人理解排列. 1. 前端源码 在CTF中,先看源代码是个好习惯,出题者经常会在源代码中以注释的形式提供一些提示 ...

  5. Redis学习笔记整理

    一.Redis概述 1.redis简介 Redis(REmote DIctionary Server 远程字典服务器)是一款开源的,用ANSI C编写.支持网络.基于内存.亦可持久化的日志型.Key- ...

  6. .NET Core 委托底层原理浅谈

    简介 .NET通过委托来提供回调函数机制,与C/C++不同的是,委托确保回调是类型安全,且允许多播委托.并支持调用静态/实例方法. 简单来说,C++的函数指针有如下功能限制,委托作为C#中的上位替代, ...

  7. 《Java开发手册》-部分编码规范分享

    0. 前言 本文来自<阿里巴巴Java开发手册>,以下内容均根据自己偏好摘抄.总结.分享. 1. 编程规约 包名单数,类名复数.例如:com.tao.util.JsonUtils.java ...

  8. 网站刚上线,就被 DDoS 攻击炸了!

    今天是一个值得纪念的日子,你打开一罐可乐,看着自己刚刚上线的小网站,洋洋得意. 这是你第一次做的网站,上线之后,网站访问量突飞猛进:没过多久,你就拿到了千万的风投,迎娶了女神,走上了人生巅峰... 害 ...

  9. H5扫码

    1.前言 H5可以获取视频流,并通过video元素进行播放 可以canvas对视频进行定时截图,然后使用插件对图片进行二维码解析 也可以直接对视频进行二维码解析(推荐) 解析二维码的插件为qr-sca ...

  10. pyc文件花指令

    pyc花指令 常见的python花指令形式有两种:单重叠指令和多重叠指令. 以下以python3.8为例,指令长度为2字节. 单重叠指令: 例如pyc经过反编译后得到的东西为 0 JUMP_ABSOL ...