deeplearning4j训练MNIST数据集以及验证
训练模型官方示例
MNIST数据下载地址: http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz
GitHub示例地址: https://github.com/deeplearning4j/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNISTReLu.java
/*******************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.examples.quickstart.modeling.convolution;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.examples.utils.DataUtilities;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
* Implementation of LeNet-5 for handwritten digits image classification on MNIST dataset (99% accuracy)
* <a href="http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf">[LeCun et al., 1998. Gradient based learning applied to document recognition]</a>
* Some minor changes are made to the architecture like using ReLU and identity activation instead of
* sigmoid/tanh, max pooling instead of avg pooling and softmax output layer.
* <p>
* This example will download 15 Mb of data on the first run.
*
* @author hanlon
* @author agibsonccc
* @author fvaleri
* @author dariuszzbyrad
*/
public class LeNetMNISTReLu {
private static final Logger LOGGER = LoggerFactory.getLogger(LeNetMNISTReLu.class);
// private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";
private static final String BASE_PATH = "D:\\Documents\\Downloads\\mnist_png";
private static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
public static void main(String[] args) throws Exception {
// 图片高度
int height = 28; // height of the picture in px
// 图片宽度
int width = 28; // width of the picture in px
// 通道 1 表示 黑白
int channels = 1; // single channel for grayscale images
// 可能出现的结果数量 0-9 10个数字
int outputNum = 10; // 10 digits classification
// 批处理数量
int batchSize = 54; // number of samples that will be propagated through the network in each iteration
// 迭代次数
int nEpochs = 1; // number of training epochs
// 随机数生成器
int seed = 1234; // number used to initialize a pseudorandom number generator.
Random randNumGen = new Random(seed);
LOGGER.info("Data load...");
if (!new File(BASE_PATH + "/mnist_png").exists()) {
LOGGER.debug("Data downloaded from {}", DATA_URL);
String localFilePath = BASE_PATH + "/mnist_png.tar.gz";
if (DataUtilities.downloadFile(DATA_URL, localFilePath)) {
DataUtilities.extractTarGz(localFilePath, BASE_PATH);
}
}
LOGGER.info("Data vectorization...");
// vectorization of train data
File trainData = new File(BASE_PATH + "/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // use parent directory name as the image label
ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
trainRR.initialize(trainSplit);
// MNIST中的数据
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
// pixel values from 0-255 to 0-1 (min-max scaling)
DataNormalization imageScaler = new ImagePreProcessingScaler();
imageScaler.fit(trainIter);
trainIter.setPreProcessor(imageScaler);
// vectorization of test data
File testData = new File(BASE_PATH + "/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
testIter.setPreProcessor(imageScaler); // same normalization for better results
LOGGER.info("Network configuration and training...");
// reduce the learning rate as the number of training epochs increases
// iteration #, learning rate
Map<Integer, Double> learningRateSchedule = new HashMap<>();
learningRateSchedule.put(0, 0.06);
learningRateSchedule.put(200, 0.05);
learningRateSchedule.put(600, 0.028);
learningRateSchedule.put(800, 0.0060);
learningRateSchedule.put(1000, 0.001);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.l2(0.0005) // ridge regression value
.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
.weightInit(WeightInit.XAVIER)
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new ConvolutionLayer.Builder(5, 5)
.stride(1, 1) // nIn need not specified in later layers
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(10));
LOGGER.info("Total num of params: {}", net.numParams());
// evaluation while training (the score should go down)
for (int i = 0; i < nEpochs; i++) {
net.fit(trainIter);
LOGGER.info("Completed epoch {}", i);
Evaluation eval = net.evaluate(testIter);
LOGGER.info(eval.stats());
trainIter.reset();
testIter.reset();
}
File ministModelPath = new File(BASE_PATH + "/minist-model.zip");
ModelSerializer.writeModel(net, ministModelPath, true);
LOGGER.info("The MINIST model has been saved in {}", ministModelPath.getPath());
}
}
验证模型
package org.deeplearning4j.examples.quickstart.modeling.convolution;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import java.io.File;
import java.io.IOException;
/**
* @description:
* @author: Mr.Fang
* @create: 2023-07-14 15:06
**/
public class VerifyMNSIT {
public static void main(String[] args) throws IOException {
// 加载训练好的模型
File modelFile = new File("D:\\Documents\\Downloads\\mnist_png\\minist-model.zip");
MultiLayerNetwork model = MultiLayerNetwork.load(modelFile, true);
// 加载待验证的图像
File imageFile = new File("D:\\Documents\\Downloads\\mnist_png\\mnist_png\\testing\\8\\1717.png");
NativeImageLoader loader = new NativeImageLoader(28, 28, 1);
INDArray image = loader.asMatrix(imageFile);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.transform(image);
// 对图像进行预测
INDArray output = model.output(image);
int predictedLabel = output.argMax().getInt();
// 在这行代码中,`output.argMax()`用于找到`output`中具有最大值的索引。`output`是一个包含模型的输出概率的NDArray对象。对于MNIST模型,输出是一个长度为10的向量,表示数字0到9的概率分布。
//
//`.argMax()`方法返回具有最大值的索引。例如,如果`output`的值为[0.1, 0.3, 0.2, 0.05, 0.25, 0.05, 0.05, 0.1, 0.05, 0.05],则`.argMax()`将返回索引1,因为在位置1处的值0.3是最大的。
//
//最后,`.getInt()`方法将获取`.argMax()`的结果并将其转换为一个整数,表示预测的标签。在这个例子中,`predictedLabel`将包含模型预测的数字标签。
//
//简而言之,这行代码的作用是找到输出中概率最高的数字标签,以进行预测。
System.out.println("Predicted label: " + predictedLabel);
}
}
输出结果
o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
o.n.n.NativeOpsHolder - Number of threads used for linear algebra: 6
o.n.l.c.n.CpuNDArrayFactory - Binary level Generic x86 optimization level AVX/AVX2
o.n.n.Nd4jBlas - Number of threads used for OpenMP BLAS: 6
o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 10]
o.n.l.a.o.e.DefaultOpExecutioner - Cores: [12]; Memory: [4.0GB];
o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [OPENBLAS]
o.n.l.c.n.CpuBackend - Backend build information:
GCC: "12.1.0"
STD version: 201103L
DEFAULT_ENGINE: samediff::ENGINE_CPU
HAVE_FLATBUFFERS
HAVE_OPENBLAS
o.d.n.m.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
Predicted label: 8
deeplearning4j训练MNIST数据集以及验证的更多相关文章
- 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集
上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...
- 使用caffe训练mnist数据集 - caffe教程实战(一)
个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...
- 搭建简单模型训练MNIST数据集
# -*- coding = utf-8 -*- # @Time : 2021/3/16 # @Author : pistachio # @File : test1.py # @Software : ...
- MXNet学习-第一个例子:训练MNIST数据集
一个门外汉写的MXNET跑MNIST的例子,三层全连接层最后验证率是97%左右,毕竟是第一个例子,主要就是用来理解MXNet怎么使用. #导入需要的模块 import numpy as np #num ...
- 【Mxnet】----1、使用mxnet训练mnist数据集
使用自己准备的mnist数据集,将0-9的bmp图像分别放到0-9文件夹下,然后用mxnet训练. 1.制作rec数据集 (1).制作list
- TensorFlow 训练MNIST数据集(2)—— 多层神经网络
在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...
- TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络
1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...
- pytorch实现MLP并在MNIST数据集上验证
写在前面 由于MLP的实现框架已经非常完善,网上搜到的代码大都大同小异,而且MLP的实现是deeplearning学习过程中较为基础的一个实验.因此完全可以找一份源码以参考,重点在于照着源码手敲一遍, ...
- TensorFlow初探之简单神经网络训练mnist数据集(TensorFlow2.0代码)
from __future__ import print_function from tensorflow.examples.tutorials.mnist import input_data #加载 ...
- TensorFlow训练MNIST数据集(3) —— 卷积神经网络
前面两篇随笔实现的单层神经网络 和多层神经网络, 在MNIST测试集上的正确率分别约为90%和96%.在换用多层神经网络后,正确率已有很大的提升.这次将采用卷积神经网络继续进行测试. 1.模型基本结构 ...
随机推荐
- 元启发式算法库 MEALPY 初体验-遗传算法为例
简介 官网: MealPY官网 开源许可: (GPL) V3 MEALPY简介 官网简介翻译 MEALPY (MEta-heuristic ALgorithms in PYthon) 是一个提供最新自 ...
- WPF 像CSS一样使用 Font Awesome 图标字体
WPF 像CSS一样使用 Font Awesome 图标字体 编写目的 WPF中使用这种图标字体不免会出现可读性差的问题,现阶段网络上有的大部分实现方式都是建立枚举,我感觉这样后续维护起来有些麻烦,需 ...
- SilentEye qsnctf wp
题目附件(注:文件名为Luminous.jpg) 根据题目提示,使用SilentEye工具 将图片使用SilentEye打开 使用左下角的Decode解密功能 猜测密码为文件名,输入并开始解密 将被加 ...
- Excel 字符串拆分
用 Excel 处理数据时,有时需要对字符串进行拆分.对于比较简单的拆分,使用 Excel 函数可以顺利完成,但碰到一些特殊需求,或者拆分的规则比较复杂时,则很难用 Excel 实现了.这里列出一些拆 ...
- 重新整理 .net core 实践篇—————微服务的桥梁EventBus[三十一]
前言 简单介绍一下EventBus. 正文 EventBus 也就是集成事件,用于服务与服务之间的通信. 比如说我们的订单处理事件,当订单处理完毕后,我们如果通过api马上去调用后续接口. 比如说订单 ...
- 面向切面编程AOP[二](java @EnableAspectJAutoProxy 代码原理)
前言 @EnableAspectJAutoProxy 是启动aop功能的意思,那么里面是什么呢? 正文 @Target({ElementType.TYPE}) @Retention(Retention ...
- marquee实现滚动
marquee的基本语法:<marquee> ... </marquee> 参数:1.滚动方向 (direction):left(左).right(右).up(上).down( ...
- 爱奇艺在 Dubbo 生态下的微服务架构实践
简介: 本文整理自作者于 2020 年云原生微服务大会上的分享<爱奇艺在 Dubbo 生态下的微服务架构实践>,重点介绍了爱奇艺在 Dubbo.Sentinel 等开发框架方面的使用经验以 ...
- 5月25日,阿里云开源 PolarDB-X 将迎来重磅升级发布
简介:2022年5月25日,阿里云开源 PolarDB-X 将升级发布新版本!PolarDB-X 从 2009 年开始服务于阿里巴巴电商核心系统, 2015 年开始对外提供商业化服务,并于 2021 ...
- 深度解析|基于 eBPF 的 Kubernetes 一站式可观测性系统
简介:阿里云 Kubernetes 可观测性是一套针对 Kubernetes 集群开发的一站式可观测性产品.基于 Kubernetes 集群下的指标.应用链路.日志和事件,阿里云 Kubernete ...