技术背景

在MindSpore深度学习框架中,我们可以向construct函数传输必备参数或者关键字参数,这跟普通的Python函数没有什么区别。但是对于MindSpore中的自定义反向传播bprop函数,因为标准化格式决定了最后的两位函数输入必须是必备参数outdout用于接收函数值和导数值。那么对于一个自定义的反向传播函数而言,我们有可能要传入多个参数。例如这样的一个案例:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp class Net(nn.Cell):
def bprop(self, x, y=1, out, dout):
return msnp.cos(x) + y
def construct(self, x, y=1):
return msnp.sin(x) + y x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

但是因为在Python的函数传参规则下,必备参数必须放在关键字参数之前,也就是out和dout这两个参数要放在前面,否则就会出现这样的报错:

  File "test_rand.py", line 53
def bprop(self, x, y=1, out, dout):
^
SyntaxError: non-default argument follows default argument

按照普通Python函数的传参规则,我们可以把y这个关键字参数的放到最后面去:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp class Net(nn.Cell):
def bprop(self, x, out, dout, y=1):
return msnp.cos(x) + y
def construct(self, x, y=1):
return msnp.sin(x) + y x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

经过这一番调整之后,我们发现没有报错了,可以正常输出结果,但是这个结果似乎不太正常:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [ 1.25169754e-06]))

因为这里x传入了一个近似的\(\pi\),所以在construct函数计算函数值时,得到的结果应该是\(\sin(\pi)+y\),那么这里面\(y\)取\(0\)和\(1\)所得到的结果都是对的。但是关键问题在反向传播函数的计算,原本应该是\(\cos(\pi)+y=y-1\),但是在这里输入的\(y=0\),而导数的计算结果却是\(0\)而不是正确结果\(-1\)。这就说明,在MindSpore的自定义反向传播函数中,并不支持传入关键字参数。

解决方案

刚好前面写了一篇关于PyTorch的文章,这篇文章中提到的两个Issue就针对此类问题。受到这两个Issue的启发,我们在MindSpore中如果需要自定义反向传播函数,可以这么写:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp class Net(nn.Cell):
def bprop(self, x, y, out, dout):
return msnp.cos(x) + y if y is not None else msnp.cos(x)
def construct(self, x, y=1):
return msnp.sin(x) + y x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

简单来说就是,把原本要传给bprop的关键字参数,转换成必备参数的方式进行传入,然后做一个条件判断:当给定了该输入的时候,执行计算一,如果不给定参数值,或者给一个None,执行计算二。上述代码的执行结果如下所示:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [-9.99998748e-01]))

这里输出的结果都是正确的。

当然,这里因为我们其实是强行把关键字参数按照顺序变成了必备参数进行输入,所以在顺序上一定要严格遵守bprop所定义的必备参数的顺序,否则计算结果也会出错:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnp class Net(nn.Cell):
def bprop(self, x, w, y, out, dout):
return w*msnp.cos(x) + y if y is not None else msnp.cos(x)
def construct(self, x, w=1, y=1):
return msnp.sin(x) + y x = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0, w=2))

输出的结果为:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))

那么很显然,这个结果就是因为在执行函数时给定的关键字参数跟必备参数顺序不一致,所以才出错的。

总结概要

继上一篇文章从Torch的两个Issue中找到一些类似的问题之后,可以发现深度学习框架对于自定义反向传播函数中的传参还是比较依赖于必备参数,而不是关键字参数,MindSpore深度学习框架也是如此。但是我们可以使用一些临时的解决方案,对此问题进行一定程度上的规避,只要能够自定义的传参顺序传入关键字参数即可。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/bprop-kwargs.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

参考链接

  1. https://www.cnblogs.com/dechinphy/p/18179248/torch

MindSpore反向传播配置关键字参数的更多相关文章

  1. 深度学习原理与框架-神经网络结构与原理 1.得分函数 2.SVM损失函数 3.正则化惩罚项 4.softmax交叉熵损失函数 5. 最优化问题(前向传播) 6.batch_size(批量更新权重参数) 7.反向传播

    神经网络由各个部分组成 1.得分函数:在进行输出时,对于每一个类别都会输入一个得分值,使用这些得分值可以用来构造出每一个类别的概率值,也可以使用softmax构造类别的概率值,从而构造出loss值, ...

  2. 一个batch的数据如何做反向传播

    一个batch的数据如何做反向传播 对于一个batch内部的数据,更新权重我们是这样做的: 假如我们有三个数据,第一个数据我们更新一次参数,不过这个更新只是在我们脑子里,实际的参数没有变化,然后使用原 ...

  3. 神经网络(NN)+反向传播算法(Backpropagation/BP)+交叉熵+softmax原理分析

    神经网络如何利用反向传播算法进行参数更新,加入交叉熵和softmax又会如何变化? 其中的数学原理分析:请点击这里.

  4. Django---路由系统,URLconf的配置,正则表达式的说明(位置参数),分组命名(捕获关键字参数),传递额外的参数给视图,命名url和url的反向解析,url名称空间

    Django---路由系统,URLconf的配置,正则表达式的说明(位置参数),分组命名(捕获关键字参数),传递额外的参数给视图,命名url和url的反向解析,url名称空间 一丶URLconf配置 ...

  5. 深度学习原理与框架-卷积神经网络基本原理 1.卷积层的前向传播 2.卷积参数共享 3. 卷积后的维度计算 4. max池化操作 5.卷积流程图 6.卷积层的反向传播 7.池化层的反向传播

    卷积神经网络的应用:卷积神经网络使用卷积提取图像的特征来进行图像的分类和识别       分类                        相似图像搜索                        ...

  6. 神经网络(9)--如何求参数: backpropagation algorithm(反向传播算法)

    Backpropagation algorithm(反向传播算法) Θij(l) is a real number. Forward propagation 上图是给出一个training examp ...

  7. <反向传播(backprop)>梯度下降法gradient descent的发展历史与各版本

    梯度下降法作为一种反向传播算法最早在上世纪由geoffrey hinton等人提出并被广泛接受.最早GD由很多研究团队各自发表,可他们大多无人问津,而hinton做的研究完整表述了GD方法,同时hin ...

  8. [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

    [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播 目录 [源码解析] PyTorch 分布式(13) ----- Distribu ...

  9. 一文弄懂神经网络中的反向传播法——BackPropagation

    最近在看深度学习的东西,一开始看的吴恩达的UFLDL教程,有中文版就直接看了,后来发现有些地方总是不是很明确,又去看英文版,然后又找了些资料看,才发现,中文版的译者在翻译的时候会对省略的公式推导过程进 ...

  10. Backpropagation反向传播算法(BP算法)

    1.Summary: Apply the chain rule to compute the gradient of the loss function with respect to the inp ...

随机推荐

  1. list集合中的实现类LinkedList

    LinkedList: 底层是一个双向链表,方便数据的频繁出入.便于快速插入,删除元素,不太方便进行查询 toArray(): 以正确的顺序(从第一个到最后一个素)返回一个包含此列表中所有元素的数组 ...

  2. #dp#洛谷 3244 [HNOI2015]落忆枫音

    题目 分析 每个有入度的点可以选择任意一个父节点组成一棵树,那么原来的答案就是 \(\prod_{i=2}^ndeg[i]\) 现在多了一条边,如果边的终点是1或者它是一个自环那么可以不用管这条边. ...

  3. 2020.02.05【NOIP提高组】模拟A 组

    [toc] CF293B Distinct Paths=JZOJ 4012 CF261E Maxim and Calculator=JZOJ 4010 JZOJ 2292 PPMM 题目 满足队列出入 ...

  4. gRPC入门学习之旅(五)

    gRPC入门学习之旅(一) gRPC入门学习之旅(二) gRPC入门学习之旅(三) gRPC入门学习之旅(四) 通过之前的文章,我们已经创建了gRPC的服务端应用程序,那么应该如何来使用这个服务端应用 ...

  5. JDK13的新特性:AppCDS详解

    目录 简介 基本步骤 JDK class文件归档 创建JDK class-data archive 使用JDK class-data archive启动应用程序 运行时间对比 应用程序class文件归 ...

  6. C++ 异常和错误处理机制:如何使您的程序更加稳定和可靠

    在C++编程中,异常处理和错误处理机制是非常重要的.它们可以帮助程序员有效地处理运行时错误和异常情况.本文将介绍C++中的异常处理和错误处理机制. 什么是异常处理? 异常处理是指在程序执行过程中发生异 ...

  7. HDD与你相约深圳,一起探讨创新开发与运营增长

    12月14日,HUAWEI Developer Day(以下简称HDD)将在深圳与广大开发者见面.本次HDD共设有主论坛.两个分论坛及两个闭门会议,期待各位开发者前来参加. 精彩预告 01·主论坛 在 ...

  8. 爱奇艺携手HMS Core,为用户打造更流畅的沉浸式观影体验

    本文分享于HMS Core开发者论坛<[开发者说]爱奇艺携手HMS Core,为用户打造更流畅.更沉浸的观影体验>采访稿整理. 爱奇艺是国内领先的视频播放平台,通过接入HMS Core H ...

  9. c++ 暂停2秒,等待2秒

    std::chrono::milliseconds stopTime(2000); std::this_thread::sleep_for(stopTime);

  10. mysql交集查询按照时间范围查询myBatis

    查询  开始时间 --结束时间 <if test="searchParam.startTime != null and searchParam.endTime != null" ...