在MindSpore中使用model.train训练网络时我们难以处理间断性的任务,为此我们可以考虑使用MindSpore中的Callback机制。

Callback 函数可以在 model.train 的每一步(step)训练结束后进行自定义的操作。

Callback 函数:
from mindspore.train.callback import Callback


在官方文档中一般使用 Callback 函数来记录每一步的loss 或 在一定训练步数后进行算法评估:
官网地址:
https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/quick_start/quick_start.html

具体使用的代码:
参考:https://www.cnblogs.com/devilmaycry812839668/p/14971668.html

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import os import mindspore.nn as nn
from mindspore.nn import Accuracy
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore import dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.common.initializer import Normal
from mindspore import Tensor, Model
from mindspore.train.callback import Callback
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test Args:
data_path (str): Data path
batch_size (int): The number of data records in each group
repeat_size (int): The number of replicated data records
num_parallel_workers (int): The number of parallel workers
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path) # define some parameters needed for data enhancement and rough justification
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # according to the parameters, generate the corresponding data enhancement method
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # using map to apply operations to a dataset
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # process the generated dataset
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""Lenet network structure."""
# define the operator required
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() # use the preceding operators to construct networks
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x # custom callback function
class StepLossAccInfo(Callback):
def __init__(self, model, eval_dataset, steps_loss, steps_eval):
self.model = model
self.eval_dataset = eval_dataset
self.steps_loss = steps_loss
self.steps_eval = steps_eval
self.steps = 0 def step_end(self, run_context):
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
#cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
self.steps = self.steps+1
cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
self.steps_loss["step"].append(str(cur_step))
if cur_step % 125 == 0:
acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
self.steps_eval["step"].append(cur_step)
self.steps_eval["acc"].append(acc["Accuracy"]) def train_model(_model, _epoch_size, _repeat_size, _mnist_path, _model_path):
ds_train = create_dataset(os.path.join(_mnist_path, "train"), 32, _repeat_size)
eval_dataset = create_dataset(os.path.join(_mnist_path, "test"), 32) # save the network model and parameters for subsequence fine-tuning
config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16)
# group layers into an object with training and evaluation features
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=_model_path, config=config_ck) steps_loss = {"step": [], "loss_value": []}
steps_eval = {"step": [], "acc": []}
# collect the steps,loss and accuracy information
step_loss_acc_info = StepLossAccInfo(_model, eval_dataset, steps_loss, steps_eval) model.train(_epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=False) return steps_loss, steps_eval epoch_size = 1
repeat_size = 1
mnist_path = "./datasets/MNIST_Data"
model_path = "./models/ckpt/mindspore_quick_start/" # clean up old run files before in Linux
os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path)) lr = 0.01
momentum = 0.9 # create the network
network = LeNet5() # define the optimizer
net_opt = nn.Momentum(network.trainable_params(), lr, momentum) # define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # define the model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) steps_loss, steps_eval = train_model(model, epoch_size, repeat_size, mnist_path, model_path) print(steps_loss, steps_eval)

运行结果:



核心代码:
from mindspore.train.callback import Callback

# custom callback function
class StepLossAccInfo(Callback):
def __init__(self, model, eval_dataset, steps_loss, steps_eval):
self.model = model
self.eval_dataset = eval_dataset
self.steps_loss = steps_loss
self.steps_eval = steps_eval
self.steps = 0 def step_end(self, run_context):
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
#cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
self.steps = self.steps+1
cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
self.steps_loss["step"].append(str(cur_step))
if cur_step % 125 == 0:
acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
self.steps_eval["step"].append(cur_step)
self.steps_eval["acc"].append(acc["Accuracy"])
可以看到,继承 Callback 类后我们可以自己定义新的功能类,只要我们实现 step_end 方法即可。
默认传入给 step_end 方法的参数 run_context 可以通过以下方法获得当前刚结束的step数和当前的epoch数:

cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num


其中,cb_params.cur_epoch_num 为当前的epoch数,
cb_params.cur_step_num 为在当前epoch中的当前步数,
需要注意的是,cb_params.cur_step_num 步数不是总共的计算步数,而是在当前epoch中的计算步数。

当前step训练中的损失值也是可以获得的,具体如下:
cb_params.net_outputs  代表当前step的损失值


=========================================================

上述代码,引入绘图功能的代码:

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import os import mindspore.nn as nn
from mindspore.nn import Accuracy
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore import dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.common.initializer import Normal
from mindspore import Tensor, Model
from mindspore.train.callback import Callback
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test Args:
data_path (str): Data path
batch_size (int): The number of data records in each group
repeat_size (int): The number of replicated data records
num_parallel_workers (int): The number of parallel workers
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path) # define some parameters needed for data enhancement and rough justification
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # according to the parameters, generate the corresponding data enhancement method
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # using map to apply operations to a dataset
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # process the generated dataset
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""Lenet network structure."""
# define the operator required
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() # use the preceding operators to construct networks
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x # custom callback function
class StepLossAccInfo(Callback):
def __init__(self, model, eval_dataset, steps_loss, steps_eval):
self.model = model
self.eval_dataset = eval_dataset
self.steps_loss = steps_loss
self.steps_eval = steps_eval
self.steps = 0 def step_end(self, run_context):
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
#cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
self.steps = self.steps+1
cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
self.steps_loss["step"].append(str(cur_step))
if cur_step % 125 == 0:
acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
self.steps_eval["step"].append(cur_step)
self.steps_eval["acc"].append(acc["Accuracy"]) def train_model(_model, _epoch_size, _repeat_size, _mnist_path, _model_path):
ds_train = create_dataset(os.path.join(_mnist_path, "train"), 32, _repeat_size)
eval_dataset = create_dataset(os.path.join(_mnist_path, "test"), 32) # save the network model and parameters for subsequence fine-tuning
config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16)
# group layers into an object with training and evaluation features
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=_model_path, config=config_ck) steps_loss = {"step": [], "loss_value": []}
steps_eval = {"step": [], "acc": []}
# collect the steps,loss and accuracy information
step_loss_acc_info = StepLossAccInfo(_model, eval_dataset, steps_loss, steps_eval) model.train(_epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=True) return steps_loss, steps_eval epoch_size = 1
repeat_size = 1
mnist_path = "./datasets/MNIST_Data"
model_path = "./models/ckpt/mindspore_quick_start/" # clean up old run files before in Linux
os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path)) lr = 0.01
momentum = 0.9 # create the network
network = LeNet5() # define the optimizer
net_opt = nn.Momentum(network.trainable_params(), lr, momentum) # define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # define the model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) steps_loss, steps_eval = train_model(model, epoch_size, repeat_size, mnist_path, model_path) steps = steps_loss["step"]
loss_value = steps_loss["loss_value"]
steps = list(map(int, steps))
loss_value = list(map(float, loss_value))
plt.plot(steps, loss_value, color="red")
plt.xlabel("Steps")
plt.ylabel("Loss_value")
plt.title("Change chart of model loss value")
plt.show() def eval_show(steps_eval):
plt.xlabel("step number")
plt.ylabel("Model accuracy")
plt.title("Model accuracy variation chart")
plt.plot(steps_eval["step"], steps_eval["acc"], "red")
plt.show() eval_show(steps_eval)




MindSpore中使用model.train,在每一步训练结束后自动进行调用自定义函数 —— from mindspore.train.callback import Callback的更多相关文章

  1. 如何在sqlite3连接中创建并调用自定义函数

    #!/user/bin/env python # @Time :2018/6/8 14:44 # @Author :PGIDYSQ #@File :CreateFunTest.py '''如何在sql ...

  2. PHP中call user func()和call_user_func_array()调用自定义函数小结

    call_user_func() 和 call_user_func_array(),通过传入字符串函数,可以调用自定义函数,并且支持引用,都允许用户调用自定义函数并传入一定的参数: 1.mixed c ...

  3. 微信小程序wxml文件中调用自定义函数

    想在微信小程序的wxml文件里自如的像vue那样调用自定义的方法,发现并不成功,得利用WXS脚本语言. WXS脚本语言是 WeiXin Script 脚本语言的简称,是JavaScript.JSON. ...

  4. C#中WebService 的 Timer定时器过段时间后自动停止运行

    我用.net做的一个Timer定时器,定时获取短信并给予回复,但大概过了十几个小时以后,Timer定时器会自动停止,再发送短信就不能收到回复,需要在服务器中重新运行定时器才可以,请教各位! 我是在.n ...

  5. ionic1项目中 ion-slide轮播用ng-repeat遍历数据后自动循环出问题

    <ion-slide-box>属性中循环播放:dose-continue=‘true’,但是在项目遇到这样一个问题,从后台获取数据后将数据ng-repeat到<ion-slide&g ...

  6. Problem D: 调用自定义函数search(int list[], int n),在数组中查找某个数

    AC代码#include <stdio.h> int find(int *a, int l, int x) { ; int i; ; i < l; i ++) if(a[i] == ...

  7. Mysql5.7创建存储过程中调用自定义函数报错Not allowed to return a result set from a function

    因为很多存储过程都会共用一段sql语句,所以我把共用的sql封装成一个自定义函数 AddCapital(); 然后通过存储过程调用,创建存储过程会报错1415,Not allowed to retur ...

  8. Entity Framework 6 Recipes 2nd Edition(10-5)译 -> 在存储模型中使用自定义函数

    10-5. 在存储模型中使用自定义函数 问题 想在模型中使用自定义函数,而不是存储过程. 解决方案 假设我们数据库里有成员(members)和他们已经发送的信息(messages) 关系数据表,如Fi ...

  9. 在SQL中使用自定义函数

      由于数据库的一个表字段中多包含html标签,现在需要修改数据库的字段把html标签都替换掉.当然我可以通过写一个程序去修改,那毕竟有点麻烦.直接在查询分析器中执行,但是MS SQL Server并 ...

  10. 【转载】 Sqlserver中查看自定义函数被哪些对象引用

    Sqlserver数据库中支持自定义函数,包含表值函数和标量值函数,表值函数一般返回多个数据行即数据集,而标量值函数一般返回一个值,在数据库的存储过程中可调用自定义函数,也可在该自定义函数中调用另一个 ...

随机推荐

  1. 修改带有强签名的DLL并重新生成

    一.如果含有强签名,需要先使用去除强签名工具,这里使用的是"StrongNameRemove",点击修正就可以了,如果没有强签名可直接下一步: 二.将没有强签名的DLL进行反编译, ...

  2. Spring Data JPA 学习笔记1 - JPA与Spring Data

    标记[跳过]的未来完善 1 理解JPA 1.1 什么是持久化? 当一个软件关闭的时候,软件内储存的状态数据还能在下次开启时被恢复,这就是持久化.对象持久化是指每个独立的对象的生命周期都能不依赖应用程序 ...

  3. Vue学习:14.工程化开发&脚手架

    0基础如何进入IT行业? 简介:对于没有任何相关背景知识的人来说,如何才能成功进入IT行业?是否有一些特定的方法或技巧可以帮助他们实现这一目标? 方向一:学习路径 1.明确兴趣和目标:首先确定你对IT ...

  4. 项目管理--PMBOK 读书笔记(9)【项目资源管理】

    1.团队成员的角色与职责: 1)层级结构(OBS):与 WBS 交叉确认部门的全部项目指责,项目组织结构图: 2)矩阵结构(RAM):工作包(活动)与项目团队的关系,主要用于明确角色与期望(职责) 3 ...

  5. C++操作符重载(operator)

    c++操作符 例如-=+*/等,甚至包括,<<等都是操作符.c++特色之一就是给予完全重构和重载操作符(Java不可以,c#操作部分). 例子入手 假设一个结构体,定义如下 struct ...

  6. hbase第一课:hbase-2.2.7分布式搭建

    hbase-2.2.7分布式搭建文档 1.上传解压配置环境变量 # 1.解压 tar -xvf hbase-2.2.7-bin.tar.gz.gz # 2.配置环境变量 vim /etc/profil ...

  7. 【论文阅读】RAL2020: UFOMap An Efficient Probabilistic 3D Mapping Framework That Embraces the Unknown

    Last edited time: March 31, 2023 1:30 PM Reference and prenotes Paper link:https://ieeexplore.ieee.o ...

  8. Java的TimeStamp

    Java的TimeStamp 很简单,我们可以这样声明 Timestamp ts=new Timestamp(new Date().getTime());这样我们就可以得到时间比较具体的一个类型转换! ...

  9. B+树要点梳理

    B+树重要操作 中间节点 中间节点的key,与其对应的指针的原则是,小于key的元素在其指针指向的节点中 中间节点的key可以看成是右斜着排放的,即小于等于key的节点由key对应的指针指定,最有一个 ...

  10. [oeasy]python0002_终端_CLI_GUI_编程环境_游戏_真实_元宇宙 🥊

    回忆   上次 了解了 python 语言的特点 历史悠久 功能强大 深受好评 已成趋势   3大主流操作系统 mac windows linux             添加图片注释,不超过 140 ...