如题所述,官网地址:

https://www.mindspore.cn/tutorial/zh-CN/r1.2/quick_start.html

数据集下载:

mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test
wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte
wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte
wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte
wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte
tree ./datasets/MNIST_Data

个人整合后的代码:

#!/usr/bin python
# encoding:UTF-8 """" 对输入的超参数进行处理 """
import os
import argparse """ 设置运行的背景context """
from mindspore import context """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal """ 训练时对模型参数的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 为mindspore设置运行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
""" def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5() # 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data"
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
test_net(net, model, mnist_path)

训练结果:

epoch: 1 step: 125, loss is 2.2982173
epoch: 1 step: 250, loss is 2.296105
epoch: 1 step: 375, loss is 2.3065567
epoch: 1 step: 500, loss is 2.3062077
epoch: 1 step: 625, loss is 2.3096561
epoch: 1 step: 750, loss is 2.2847052
epoch: 1 step: 875, loss is 2.284628
epoch: 1 step: 1000, loss is 1.8122461
epoch: 1 step: 1125, loss is 0.4140602
epoch: 1 step: 1250, loss is 0.25238502
epoch: 1 step: 1375, loss is 0.17819008
epoch: 1 step: 1500, loss is 0.3202765
epoch: 1 step: 1625, loss is 0.12312577
epoch: 1 step: 1750, loss is 0.11027573
epoch: 1 step: 1875, loss is 0.2680659
{'Accuracy': 0.9598357371794872}

为网络导入模型参数,并进行预测:

本步骤与上面的训练步骤相关,需要前面设置好的数据集,并且需要前面已经训练好的网络参数。

import os
import numpy as np """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import Tensor # 导入模型参数
from mindspore.train.serialization import load_checkpoint, load_param_into_net """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore import Model def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5()
# 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
# 构建模型
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) # 加载已经保存的用于测试的模型
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
# 加载参数到网络中
load_param_into_net(net, param_dict) _batch_size = 8
# 定义测试数据集,batch_size设置为1,则取出一张图片
mnist_path = "./datasets/MNIST_Data"
ds_test = create_dataset(os.path.join(mnist_path, "test"), batch_size=_batch_size).create_dict_iterator()
data = next(ds_test) # images为测试图片,labels为测试图片的实际分类
images = data["image"].asnumpy()
labels = data["label"].asnumpy() # 使用函数model.predict预测image对应分类
output = model.predict(Tensor(data['image']))
predicted = np.argmax(output.asnumpy(), axis=1) # 输出预测分类与实际分类
for i in range(_batch_size):
print(f'Predicted: "{predicted[i]}", Actual: "{labels[i]}"')

运行结果:

MindSpore 初探, 使用LeNet训练minist数据集的更多相关文章

  1. 多层感知机训练minist数据集

    MLP .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1p ...

  2. Window10 上MindSpore(CPU)用LeNet网络训练MNIST

    本文是在windows10上安装了CPU版本的Mindspore,并在mindspore的master分支基础上使用LeNet网络训练MNIST数据集,实践已训练成功,此文为记录过程中的出现问题: ( ...

  3. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  4. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  5. LeNet训练MNIST

    jupyter notebook: https://github.com/Penn000/NN/blob/master/notebook/LeNet/LeNet.ipynb LeNet训练MNIST ...

  6. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  7. 用CNN及MLP等方法识别minist数据集

    用CNN及MLP等方法识别minist数据集 2017年02月13日 21:13:09 hnsywangxin 阅读数:1124更多 个人分类: 深度学习.keras.tensorflow.cnn   ...

  8. Fast RCNN 训练自己数据集 (1编译配置)

    FastRCNN 训练自己数据集 (1编译配置) 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ https:/ ...

  9. 使用py-faster-rcnn训练VOC2007数据集时遇到问题

    使用py-faster-rcnn训练VOC2007数据集时遇到如下问题: 1. KeyError: 'chair' File "/home/sai/py-faster-rcnn/tools/ ...

  10. BP算法在minist数据集上的简单实现

    BP算法在minist上的简单实现 数据:http://yann.lecun.com/exdb/mnist/ 参考:blog,blog2,blog3,tensorflow 推导:http://www. ...

随机推荐

  1. CSV文件存储

    CSV 文件存储 CSV,全称为 Comma-Separated Values,中文可以叫作逗号分隔值或字符分隔值,其文件以纯文本形式存储表格数据.该文件是一个字符序列,可以由任意数目的记录组成,记录 ...

  2. C# .NET core Avalonia 11.0版本,发布linux和MAC的简单记录

    .net core 7.0+centos 7.0 cetnos目前运行在hyper V虚拟机里 虚拟机部署的注意事项 1 需要配置网络环境, 确保在同一局域网下 如果sftp无法连接 ctrl+shi ...

  3. 3562-Qt工程编译说明、GPU核心使用说明

  4. power bi 如何删除敏感度标签

    经验证,此方法不够彻底,我的office excel打开后还是要添加敏感度标签,即使我把敏感度标签删掉也不行. 当我把创建敏感度标签的管理员账户删掉之后,虽然打开excel还是会显示敏感度标签,但是已 ...

  5. .Net Core WebAPI Swagger Failed to load API definition

    1.错误现象 1.1.写完一个测试API,Ctrl+F5运行,提示错误: Failed to load API definition.(如下图) 1.2.点击 http://localhost:516 ...

  6. SpringCloud学习篇

    什么是SpringCloud Spring cloud 流应用程序启动器是基于 Spring Boot 的 Spring 集成应用程序,提供与外部系统的集成.Spring cloud Task,一个生 ...

  7. 记一次 .NET某酒业业务系统 崩溃分析

    一:背景 1. 讲故事 前些天有位朋友找到我,说他的程序每次关闭时就会自动崩溃,一直找不到原因让我帮忙看一下怎么回事,这位朋友应该是第二次找我了,分析了下 dump 还是挺经典的,拿出来给大家分享一下 ...

  8. MyBatis-Plus 整理

    # 前言 代码生成器插件选择去这里:https://www.cnblogs.com/zixq/p/16726534.html 相关插件在那里面已经提到了 # 上手 MyBatis-Plus 是一个 M ...

  9. 大一新生的作业(洛谷P1150,1035,1075)

    本帖背景:此帖讲解大一新生团队作业 截止日期10-31 17:09 P1150(Peter的烟) 算法简介 本题主要考察的是模拟算法 模拟算法一般考察一些比较基础的题目,它将生活中的实例融合到了编程题 ...

  10. Java-JSTL标签简化和替换jsp页面上的java代码

    概念:JavaServer Pages Tag Library JSP标准标签库 作用:用于简化和替换jsp页面上的java代码 使用标签: 导入jstl相关jar包 引入标签库:taglib指令:& ...