MindSpore 计算框架 模型参数 和 优化器 参数的重新载入
本文主要内容源于:
======================================================================
本地加载模型
用于推理验证
针对仅推理场景可以使用load_checkpoint把参数直接加载到网络中,以便进行后续的推理验证。
示例代码如下:
resnet = ResNet50()
load_checkpoint("resnet50-2_32.ckpt", net=resnet)
dateset_eval = create_dataset(os.path.join(mnist_path, "test"), 32, 1) # define the test dataset
loss = CrossEntropyLoss()
model = Model(resnet, loss, metrics={"accuracy"})
acc = model.eval(dataset_eval)
load_checkpoint方法会把参数文件中的网络参数加载到模型中。加载后,网络中的参数就是CheckPoint保存的。eval方法会验证训练后模型的精度。
用于迁移学习
针对任务中断再训练及微调(Fine Tune)场景,可以加载网络参数和优化器参数到模型中。
示例代码如下:
# return a parameter dict for model
param_dict = load_checkpoint("resnet50-2_32.ckpt")
resnet = ResNet50()
opt = Momentum(resnet.trainable_params(), 0.01, 0.9)
# load the parameter into net
load_param_into_net(resnet, param_dict)
# load the parameter into optimizer
load_param_into_net(opt, param_dict)
loss = SoftmaxCrossEntropyWithLogits()
model = Model(resnet, loss, opt)
model.train(epoch, dataset)
load_checkpoint方法会返回一个参数字典。load_param_into_net会把参数字典中相应的参数加载到网络或优化器中。
================================================================
由上面内容可以知道,以下两个函数:
load_checkpoint
load_param_into_net
可以把保存为ckpt文件中的参数重新加载到网络和优化器中。
给出demo, 数据文件下载参考前文:
模型参数 和 优化器参数的保存:
#!/usr/bin python
# encoding:UTF-8 """" 对输入的超参数进行处理 """
import os
import argparse """ 设置运行的背景context """
from mindspore import context """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal """ 训练时对模型参数的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model import os
os.system('rm -f *.ckpt *.meta') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 为mindspore设置运行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
""" def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5() # 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data"
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
test_net(net, model, mnist_path)
生成的参数文件:

其中, ckpt 类型的文件保存的是 网络参数 和 优化器参数, 而 .meta 文件保存的是计算图的编译后的文件,不过 meta 文件具体怎么用这里还是不了解的,具体深入关注可以参考帖子:
https://bbs.huaweicloud.com/forum/forum.php?mod=viewthread&tid=138966&page=1#pid1240965
在网络和优化器初始化后(不载入备份的网络参数 和 优化器参数情况下), 打印优化器的最后一个参数, moments.fc3.bias :
import os
import numpy as np """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import Tensor # 导入模型参数
from mindspore.train.serialization import load_checkpoint, load_param_into_net """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore import Model
from mindspore import context context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5()
# 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
# 构建模型
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) # 加载已经保存的用于测试的模型
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
# 加载参数到网络中
load_param_into_net(net, param_dict)
# 加载参数到优化器中
#load_param_into_net(net_opt, param_dict) _batch_size = 8
# 定义测试数据集,batch_size设置为1,则取出一张图片
mnist_path = "./datasets/MNIST_Data"
ds_test = create_dataset(os.path.join(mnist_path, "test"), batch_size=_batch_size)
print(model.eval(ds_test)) print(type(net.parameters_and_names()))
for i, j in net_opt.parameters_and_names():
print(i)
if i == "moments.fc3.bias":
print(Tensor(j))
运行结果:
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
[WARNING] ME(13133:139644169384064,MainProcess):2021-07-12-03:29:50.183.802 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
{'Accuracy': 0.9594}
<class 'generator'>
learning_rate
conv1.weight
conv2.weight
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias
momentum
moments.conv1.weight
moments.conv2.weight
moments.fc1.weight
moments.fc1.bias
moments.fc2.weight
moments.fc2.bias
moments.fc3.weight
moments.fc3.bias
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
可以看到,优化器的最后一个参数为全0, 那么载入备份的参数后呢:
import os
import numpy as np """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import Tensor # 导入模型参数
from mindspore.train.serialization import load_checkpoint, load_param_into_net """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore import Model
from mindspore import context context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5()
# 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
# 构建模型
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) # 加载已经保存的用于测试的模型
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
# 加载参数到网络中
load_param_into_net(net, param_dict)
# 加载参数到优化器中
load_param_into_net(net_opt, param_dict) _batch_size = 8
# 定义测试数据集,batch_size设置为1,则取出一张图片
mnist_path = "./datasets/MNIST_Data"
ds_test = create_dataset(os.path.join(mnist_path, "test"), batch_size=_batch_size)
print(model.eval(ds_test)) print(type(net.parameters_and_names()))
for i, j in net_opt.parameters_and_names():
print(i)
if i == "moments.fc3.bias":
print(Tensor(j))
运行结果:
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
[WARNING] ME(13292:140444628824192,MainProcess):2021-07-12-03:31:50.471.228 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
{'Accuracy': 0.9594}
<class 'generator'>
learning_rate
conv1.weight
conv2.weight
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias
momentum
moments.conv1.weight
moments.conv2.weight
moments.fc1.weight
moments.fc1.bias
moments.fc2.weight
moments.fc2.bias
moments.fc3.weight
moments.fc3.bias
[-0.00917954 0.00276246 -0.01406308 0.01492264 -0.01100682 -0.0692124
0.02251344 0.00341095 0.03600671 0.02384563]
可以看到,载入备份的优化器参数后,打印结果与之前不同了。
MindSpore 计算框架 模型参数 和 优化器 参数的重新载入的更多相关文章
- PyTorch官方中文文档:torch.optim 优化器参数
内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...
- Oracle 查看相关优化器参数
select x.ksppinm name, y.ksppstvl value, y.ksppstdf isdefault, decode(bitand(y.ksppstvf, 7), 1, 'MOD ...
- keras RAdam优化器使用教程, keras加载模型包含自定义优化器报错 如何解决?
本文首发于个人博客https://kezunlin.me/post/c691f02b/,欢迎阅读最新内容! python keras RAdam tutorial and load custom op ...
- [源码解析] PyTorch分布式优化器(1)----基石篇
[源码解析] PyTorch分布式优化器(1)----基石篇 目录 [源码解析] PyTorch分布式优化器(1)----基石篇 0x00 摘要 0x01 从问题出发 1.1 示例 1.2 问题点 0 ...
- [源码解析] PyTorch分布式优化器(2)----数据并行优化器
[源码解析] PyTorch分布式优化器(2)----数据并行优化器 目录 [源码解析] PyTorch分布式优化器(2)----数据并行优化器 0x00 摘要 0x01 前文回顾 0x02 DP 之 ...
- optimizer_mode优化器模式
查询优化器最主要的工作就是接受输入的SQL以及各种环境参数.配置参数,生成合适的SQL执行计划(Execution Plan). Query Optimizer一共经历了两个历史阶段: RBO: Ru ...
- Python带参数的装饰器
在装饰器函数里传入参数 # -*- coding: utf-8 -*- # 2017/12/2 21:38 # 这不是什么黑魔法,你只需要让包装器传递参数: def a_decorator_passi ...
- MindSpore 高阶优化器
MindSpore 高阶优化器 MindSpore自研优化器THOR(Trace-based Hardware-driven layer-ORiented Natural Gradient Desce ...
- [源码解析] PyTorch分布式优化器(3)---- 模型并行
[源码解析] PyTorch分布式优化器(3)---- 模型并行 目录 [源码解析] PyTorch分布式优化器(3)---- 模型并行 0x00 摘要 0x01 前文回顾 0x02 单机模型 2.1 ...
- QuantLib 金融计算——数学工具之优化器
目录 QuantLib 金融计算--数学工具之优化器 概述 Optimizer Constraint OptimizationMethod EndCriteria 示例 Rosenbrock 问题 校 ...
随机推荐
- SRE Google 运维解密读书笔记一:SRE 方法论概述
SRE Google 运维解密,是 SRE 领域的启蒙之作,讲述了 Google 的 SRE 实践,SRE 就是从 Google 流传出来的.本文是读书笔记,第一篇,概述 SRE 方法论.帮大家把书读 ...
- 增补博客 第五篇 python 电子算盘
[题目描述]设计一个电子算盘.要求绘制电子算盘界面,设计并实现打珠算过程(界面参考如下图示).界面右侧要求以图形绘制的方式绘制自画像,注意不能是图像文件显示的形式. 图 电子算盘参考界面示意 [练习要 ...
- 订单号规则,不能重复。redis去重 redis集合set应用
订单号规则,不能重复.redis去重 redis集合set应用 redis锁定商品解决并发售卖问题 RedisUtil工具类https://www.cnblogs.com/oktokeep/p/179 ...
- golang 所有关键字的列表及释义归类
golang 所有关键字的列表及释义归类,截至1.18版本. [控制结构] if : 条件语句,基于布尔表达式的值决定是否执行特定的代码块. else. else if : 用在 if 语句 ...
- 『手写Mybatis』实现映射器的注册和使用
前言 如何面对复杂系统的设计? 我们可以把 Spring.MyBatis.Dubbo 这样的大型框架或者一些公司内部的较核心的项目,都可以称为复杂的系统. 这样的工程也不在是初学编程手里的玩具项目,没 ...
- MoneyPrinterPlus:AI自动短视频生成工具-阿里云配置详解
MoneyPrinterPlus是一个很好的自动短视频生成工具,虽然是一个非常好的工具,但是有些小伙伴可能不太清楚具体应该如何配置才能让它跑起来. 因为MoneyPrinterPlus依赖一些具体的配 ...
- 使用Microsoft.SemanticKernel基于本地运行的Ollama大语言模型实现Agent调用函数
大语言模型的发展日新月异,记得在去年这个时候,函数调用还是gpt-4的专属.到今年本地运行的大模型无论是推理能力还是文本的输出质量都已经非常接近gpt-4了.而在去年gpt-4尚未发布函数调用时,智能 ...
- 基于 JuiceFS 构建高校 AI 存储方案:高并发、系统稳定、运维简单
中山大学的 iSEE 实验室(Intelligence Science and System) Lab)在进行深度学习任务时,需要处理大量小文件读取.在高并发读写场景下,原先使用的 NFS 性能较低, ...
- C#/.NET这些实用的技巧和知识点你都知道吗?
前言 今天大姚给大家分享一些C#/.NET中的实用的技巧和知识点,它们可以帮助我们提升代码质量和编程效率,希望可以帮助到有需要的同学. .NET使用CsvHelper快速读取和写入CSV文件 本文主要 ...
- Linux Driver : i2c-gpio
# Linux Driver : i2c-gpio https://www.cnblogs.com/haoxing990/p/4718834.html https://blog.csdn.net/ji ...