THOR:MindSpore 自研高阶优化器源码分析和实践应用
摘要:这篇文章跟大家分享下THOR的实践应用。THOR算法的部分内容当前已经在MindSpore中开源
本文分享自华为云社区《MindSpore 自研高阶优化器源码分析和实践应用》,原文作者:HWCloudAI 。
这篇文章跟大家分享下THOR的实践应用。THOR算法的部分内容当前已经在MindSpore中开源,源码位置:
https://gitee.com/mindspore/mindspore/blob/master/mindspore/nn/optim/thor.py
MindSpore中使用THOR训练网络非常简单,下面用四行代码先来带大家看一下怎么使用。
from mindspore.nn.optim import THOR #引用二阶优化器 #创建网络
net = Net() #调用优化器
opt = THOR(net, lr, Tensor(damping), config.momentum, config.weight_decay, config.loss_scale,
config.batch_size, split_indices=split_indices) #增加计算图提升性能
model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False,
frequency=config.frequency) #训练网络
model.train(config.epoch_size, dataset, callbacks=cb, sink_size=dataset.get_dataset_size(), dataset_sink_mode=True)
- 导入二阶优化器THOR所需要的包
- 第一行代码常规创建网络
- 第二行代码定义我们使用的优化器THOR
- 第三行代码是为了增加计算图从而使THOR达到更优性能
- 第四行代码训练网络
我们再具体展开介绍下。首先导入MindSpore所需的二阶优化器的包,位于 mindspore.nn.optim
然后创建你所需的网络;接着定义THOR优化器,传入网络信息和THOR所需的超参信息(如学习率,正则化项系数等);
再调用 convert_to_thor_model函数,该函数是通过增加计算图使THOR达到更优性能,什么意思呢,本身网络运行的时候是一张计算图,THOR中会使用过时的二阶信息,通过额外增加一张计算图,两张计算图分别执行更新二阶矩阵和不更新二阶矩阵的操作从而达到更优性能(PS. MindSpore支持动静态图,在这里为了更好的性能使用的是静态图模式,对这块内容比较感兴趣的同学,可以点这个链接:https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/white_paper/MindSpore_white_paper.pdf);
最后,调用model.train就可以开始训练啦。简单介绍了下怎么使用,接下来我们来看下它的源码。
源码分析
init 函数用于THOR的初始化,需要传入THOR所需的超参和网络结构,THOR支持GPU和Ascend,分别为class THOR_GPU(Optimizer)和 class THOR_Ascend(Optimizer),这两个类之间的主要差别是算子不同。下面我们以 class THOR_Ascend(Optimizer)为例,来分析一下。
class THOR_Ascend(Optimizer):
def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
decay_filter=lambda x: x.name not in [], split_indices=None):
params = filter(lambda x: x.requires_grad, net.get_parameters())
super(THOR_Ascend, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.net = net
self.matrix_A_cov = ParameterTuple(filter(lambda x: 'matrix_A' in x.name, net.get_parameters()))
self.matrix_G_cov = ParameterTuple(filter(lambda x: 'matrix_G' in x.name, net.get_parameters()))
...
MindSpore中所有优化器都继承了 class Optimizer,该基类中定义了一些基本函数(如获取学习率,梯度缩放等)。THOR初始化时将传进去的超参定义为类属性方便调用,并且定义了后续计算会使用到的算子。
也就是说初始化函数的作用就是定义THOR计算所需要用到的算子和变量(Parameter,Tensor等)。
重点介绍下 self.matrix_A_cov , self.matrix_G_cov 。这两个变量是计算二阶梯度所需要的信息,分别为每层输入 的协方差矩阵 和每层输出的一阶导数 的协方差矩阵 ,其中 已经在运行时的前向过程和反向过程中保存下来。
我们再来看下创建THOR时的入参:
- net:本次训练建立的模型;
- learning_rate:学习率超参;
- damping:二阶矩阵中加的正则化项的超参;
- momentum:动量超参;
- weight_decay:权值衰减,用于防止过拟合,默认值为0.0,即不使用权值衰减;loss_scale:用于缩放训练过程中的loss,防止梯度越界,默认值为1.0,即不使用缩放;batch_size:当前训练一个step所使用的数据量,默认为32;
- decay_filter:选择对哪些层做weight decay,当weight_decay>0时起作用;split_indices:这个参数的作用是用于加速allreduce过程。
- _get_Ainv_Ginv_Amax_Gmax_list函数用于计算协方差矩阵A/G的逆,并返回求完逆后的矩阵。具体过程是遍历模型所有层,按层处理,对每一层的协方差矩阵加上正则化项,然后对矩阵进行cholesky分解从而来求逆。当前开源代码THOR中支持全连接层和卷积层的处理。
def _get_Ainv_Ginv_Amax_Gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
for i in range(len(self.params)):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
if layer_type in [Conv, FC, Embedding]:
g = gradients[i]
matrix_A = self.matrix_A_cov[thor_layer_count]
matrix_G = self.matrix_G_cov[thor_layer_count]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
A_shape = self.shape(matrix_A)
A_eye = self.eye(A_shape[0], A_shape[0], mstype.float32)
G_shape = self.shape(matrix_G)
G_eye = self.eye(G_shape[0], G_shape[0], mstype.float32)
if layer_type == Conv:
...
elif layer_type == FC:
matrix_A = matrix_A + damping * A_eye
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
- _get_second_gradients函数用于计算最终参数更新方向,在论文中参数更新方向公式为

,所以代码实际实现的方式为

,代码如下
def _get_second_gradients(self, new_grads, damping_step, gradients):
"""get second gradients for thor"""
params_len = len(self.params)
for i in range(params_len):
...
else:
...
elif layer_type == FC:
temp_a = self.matrix_A_cov[thor_layer_count]
temp_g = self.matrix_G_cov[thor_layer_count]
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
construct函数是在网络训练过程中会实际执行的内容,该函数中包含了上述两个函数_get_Ainv_Ginv_Amax_Gmax_list和_get_second_gradients的调用,该函数完成了二阶矩阵的计算和梯度更新方向的调整。
def construct(self, gradients):
params = self.params
moments = self.moments
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
matrix_A_max_allreduce = ()
matrix_G_max_allreduce = ()
matrix_A_allreduce, matrix_G_allreduce, matrix_A_max_allreduce, matrix_G_max_allreduce = \
self._get_Ainv_Ginv_Amax_Gmax_list(gradients, damping_step, matrix_A_allreduce, matrix_G_allreduce,
matrix_A_max_allreduce, matrix_G_max_allreduce) #计算A/G的逆
...
new_grads = ()
for i in range(len(self.params)):
...
if self.conv_layer_count > 0:#有卷积层时的处理
...
else: #都是全连接层时的处理
if layer_type == Embedding:
...
elif layer_type == FC:
temp_a = matrix_A_allreduce[thor_layer_count]
temp_g = matrix_G_allreduce[thor_layer_count]
fake_A = self.assign(self.matrix_A_cov[thor_layer_count], temp_a)
fake_G = self.assign(self.matrix_G_cov[thor_layer_count], temp_g)
g = F.depend(g, fake_A)#确保执行顺序
g = F.depend(g, fake_G)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)#将一阶方向变为二阶方向
g = self.cast(g, mstype.float32)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
gradients = new_grads #计算后得到的更新方向
else: #该分支表示使用过时二阶信息更新参数
new_grads = ()
gradients = self._get_second_gradients(new_grads, damping_step, gradients) #调用_get_second_gradients函数计算方向
...
THOR的实践应用
在这一节中跟大家分享下THOR的实践应用,举了两个例子分别为ResNet50和BERT,这两个例子的代码也已开源,链接如下:ResNet50:https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/resnet/train.pyBERT:https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/nlp/bert/run_pretrain.py
ResNet50[1]
优化器的调用方式与文中开头提到的一致,在这个例子中把具体训练过程给展开了。
首先创建了网络训练需要的训练集和网络定义为ResNet50;随后设置THOR所需要用到的超参策略,其他超参值设定可去该目录下的src/config.py中修改;接着创建THOR优化器,并传入设置的超参值;然后转换模型保存二阶所需信息;最后就可以训练网络了。
from mindspore.nn.optim import Momentum, THOR #引用二阶优化器
from src.resnet import resnet50 as resnet
from mindspore.train.model import Model
...
if __name__ == '__main__':
...
#创建网络训练过程中的训练集
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target, distribute=args_opt.run_distribute)
step_size = dataset.get_dataset_size() #创建resnet50模型
net = resnet(class_num=config.class_num)
...
# init lr
if cfg.optimizer == "Thor":
#设置超参值
from src.lr_generator import get_thor_lr
lr = get_thor_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
# define loss, model
if target == "Ascend":
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) #高层抽象,集成网络模型的训练和测试
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
from src.lr_generator import get_thor_damping #设置超参damping
damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size) #用于通信时的并行加速
split_indices = [26, 53] #创建THOR优化器
opt = THOR(net, lr, Tensor(damping), config.momentum, config.weight_decay, config.loss_scale,
config.batch_size, split_indices=split_indices) #增加计算图提升性能
model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False,
frequency=config.frequency)
...
#训练网络
model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb,
sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode)
最后输入

即可运行脚本啦。
BERT[2]
BERT中步骤与ResNet50差不多。首先创建了网络训练需要的训练集和网络定义为BERT;随后设置THOR所需要用到的超参策略,其他超参值设定可去该目录下的src/config.py中修改;优化器创建时传入BERT设定的超参值,本例中创建时传入了:

表示做weight decay操作时排除LN层和FC中的bias参数;然后转换模型保存二阶所需信息;最后就可以训练网络了。
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, THOR #引用二阶优化器
from src import BertNetworkWithLoss
...
def _get_optimizer(args_opt, network):
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
if cfg.optimizer == 'Lamb':
...
elif cfg.optimizer == "Thor":
from src.utils import get_bert_thor_lr, get_bert_thor_damping #设置lr和damping的超参值
lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps)
damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power,
cfg.Thor.damping_total_steps)
split_indices = None #设置并行加速方式
if bert_net_cfg.num_hidden_layers == 12:
if bert_net_cfg.use_relative_positions:
split_indices = [29, 58, 87, 116, 145, 174, 203, 217]
else:
split_indices = [28, 55, 82, 109, 136, 163, 190, 205]
elif bert_net_cfg.num_hidden_layers == 24:
if bert_net_cfg.use_relative_positions:
split_indices = [30, 90, 150, 210, 270, 330, 390, 421]
else:
split_indices = [38, 93, 148, 203, 258, 313, 368, 397] #创建优化器
optimizer = THOR(network, lr, damping, cfg.Thor.momentum,
cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
split_indices=split_indices)
...
return optimizer
def run_pretrain():
...
#创建数据集
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
#网络和损失函数创建
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) ...
#加载初始checkpoint
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(net_with_loss, param_dict) #动态loss缩放
if args_opt.enable_lossscale == "true":
... #固定loss缩放值
else:
#反向过程梯度计算过程创建
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) #创建网络
model = Model(net_with_grads) #增加计算图提升性能
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer,
frequency=cfg.Thor.frequency)
#网络训练
model.train(new_repeat_count, ds, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
if __name__ == '__main__':
set_seed(0)
最后输入

即可运行脚本啦.至此高阶优化器系列的内容就结束啦,该系列总共有三篇文章分别从优化器的背景,MindSpore自研优化器的介绍和MindSpore 高阶优化器THOR 的源码分析&实践应用这三个内容来跟大家分享,如有不足之处欢迎大家批评指正。同时也欢迎大家到MindSpore开源社区中一起玩耍。
参考文献:
[1]He K, Zhang X, Ren S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
[2]Devlin J, Chang M W, Lee K, et al. Bert: Pre-training of deep bidirectional transformers for language understanding[J]. arXiv preprint arXiv:1810.04805, 2018.
THOR:MindSpore 自研高阶优化器源码分析和实践应用的更多相关文章
- MindSpore 高阶优化器
MindSpore 高阶优化器 MindSpore自研优化器THOR(Trace-based Hardware-driven layer-ORiented Natural Gradient Desce ...
- Linux 内核调度器源码分析 - 初始化
导语 上篇系列文 混部之殇-论云原生资源隔离技术之CPU隔离(一) 介绍了云原生混部场景中CPU资源隔离核心技术:内核调度器,本系列文章<Linux内核调度器源码分析>将从源码的角度剖析内 ...
- linux调度器源码分析 - 运行(四)
本文为原创,转载请注明:http://www.cnblogs.com/tolimit/ 引言 之前的文章已经将调度器的数据结构.初始化.加入进程都进行了分析,这篇文章将主要说明调度器是如何在程序稳定运 ...
- linux调度器源码分析 - 初始化(二)
本文为原创,转载请注明:http://www.cnblogs.com/tolimit/ 引言 上期文章linux调度器源码分析 - 概述(一)已经把调度器相关的数据结构介绍了一遍,本篇着重通过代码说明 ...
- 一步步实现windows版ijkplayer系列文章之三——Ijkplayer播放器源码分析之音视频输出——音频篇
一步步实现windows版ijkplayer系列文章之一--Windows10平台编译ffmpeg 4.0.2,生成ffplay 一步步实现windows版ijkplayer系列文章之二--Ijkpl ...
- 一步步实现windows版ijkplayer系列文章之二——Ijkplayer播放器源码分析之音视频输出——视频篇
一步步实现windows版ijkplayer系列文章之一--Windows10平台编译ffmpeg 4.0.2,生成ffplay 一步步实现windows版ijkplayer系列文章之二--Ijkpl ...
- struts2拦截器源码分析
前面博客我们介绍了开发struts2应用程序的基本流程(开发一个struts2的实例),通过前面我们知道了struts2实现请求转发和配置文件加载都是拦截器进行的操作,这也就是为什么我们要在web.x ...
- OkHttp3 拦截器源码分析
OkHttp 拦截器流程源码分析 在这篇博客 OkHttp3 拦截器(Interceptor) ,我们已经介绍了拦截器的作用,拦截器是 OkHttp 提供的对 Http 请求和响应进行统一处理的强大机 ...
- linux调度器源码分析 - 概述(一)
本文为原创,转载请注明:http://www.cnblogs.com/tolimit/ 引言 调度器作为操作系统的核心部件,具有非常重要的意义,其随着linux内核的更新也不断进行着更新.本系列文章通 ...
- mapTask并行度优化及源码分析
mapTask并行度的决定机制 一个job的map阶段并行度由客户端在提交job时决定,而客户端对map阶段并行度的规划的基本逻辑为:将待处理数据执行逻辑切片(即按照一个特定切片大小,将待处理数据划分 ...
随机推荐
- Kubernetes跨StorageClass迁移,切换Rainbond默认SC
基于主机安装或基于Kubernetes安装的 Rainbond 集群(均使用默认参数安装),默认使用的共享文件存储是 NFS ,以 Pod 方式运行在 Kubernetes 中,但这种方式也有一些无法 ...
- P9073 [WC/CTS2023] 楼梯 题解
题目链接 简要题意 有一块楼梯,这里指的楼梯是倒着的,正过来看上一层宽度一定小于等于这一层宽度,并且由格子组成,你需要对其进行增删和恢复某一历史版本的操作,并回答这块楼梯是否有固定格数的子楼梯. 题目 ...
- 每天5分钟复习OpenStack(八)存储虚拟化
KVM存储虚拟化是通过存储池(Storage Pool)和卷(Volume)来管理的.Storage Pool 是宿主机上可以看到的一片存储空间,可以是多种类型,Volume 是在 Storage P ...
- linux其他命令(查找,软链接,打包和压缩,软件安装)笔记
1,查找文件 * 是通配符,代表任意字符,0到多个. find 路径 -name "*.txt" : 查找在路径下所有以 .txt 结尾的文件. 2,软链接 (1)将桌面目 ...
- 《流畅的Python》 读书笔记 第8章_对象引用、可变性和垃圾回收
第8章_对象引用.可变性和垃圾回收 本章的主题是对象与对象名称之间的区别.名称不是对象,而是单独的东西 name = 'wuxianfeng' # name是对象名称 'wuxianfeng'是个st ...
- 从旺店通·企业奇门到用友U8通过接口集成数据
从旺店通·企业奇门到用友U8通过接口集成数据 接入系统:旺店通·企业奇门 慧策(原旺店通)是一家技术驱动型智能零售服务商,基于云计算PaaS.SaaS模式,以一体化智能零售解决方案,帮助零售企业数字化 ...
- 金蝶对接电商ERP库存数据,实现监听库存变化
金蝶云星空实时库存专题 通过向金蝶库存单据注册Python脚本,用于实时监听库存单据审核/反审核,并且将数据发送到轻易云系统集成平台 .通过集成平台将数据分发到对应的目标系统. 向金蝶的库存单据注册脚 ...
- Python 中的单下划线和双下划线
哈喽大家好,我是咸鱼 当我们在学习 Python 的时候,可能会经常遇到单下划线 _ 和双下划线 __ 这两种命名方式 单下划线 _ 和双下划线 __ 不仅仅是只是一种简单的命名习惯,它们在 Pyth ...
- 文心一言 VS 讯飞星火 VS chatgpt (146)-- 算法导论12.2 1题
一.用go语言,假设一棵二叉搜索树中的结点在1到 1000 之间,现在想要查找数值为 363 的结点.下面序列中哪个不是查找过的序列? a.2,252,401,398,330,344,397,363. ...
- windows端口被占用怎么办?
简单只需要按照一下命令查找到对应的端口kill掉就好了 1.查看本机所有的端口信息 netstat -ano 2.查看本机指定端口信息 netstat -ano | findstr "端口号 ...