MindSpore多元自动微分
技术背景
当前主流的深度学习框架,除了能够便捷高效的搭建机器学习的模型之外,其自动并行和自动微分等功能还为其他领域的科学计算带来了模式的变革。本文我们将探索如何用MindSpore去实现一个多维的自动微分,并且得到该多元函数的雅可比矩阵。
函数形式与雅可比矩阵形式
首先我们给定一个比较简单的z关于自变量x的函数形式(其中y和I是一些参数):
\]
比如我们考虑一个3*3的z,我们最终需要计算的是这样一个雅可比矩阵:
\left[
\begin{array}{l}
\frac{\partial z_0}{\partial x_0} & \frac{\partial z_0}{\partial x_1} & \frac{\partial z_0}{\partial x_2}\\
\frac{\partial z_1}{\partial x_0} & \frac{\partial z_1}{\partial x_1} & \frac{\partial z_1}{\partial x_2}\\
\frac{\partial z_2}{\partial x_0} & \frac{\partial z_2}{\partial x_1} & \frac{\partial z_2}{\partial x_2}
\end{array}
\right]
\]
假如我们给定一些简单的初始值:
y=[1,3,2]
\]
那么理论上我们应该得到的结果是:
\begin{array}{l}
1 & 0 & 0\\
0 & 0 & 3\\
0 & 2 & 0\\
\end{array}
\right]
\]
接下来我们看看如何在MindSpore的自动微分框架下实现这一功能。
初步尝试Grad自动微分
我们先按照上一章节中的公式的基本内容,直接写一个Net函数用于表示z,然后再用grad函数对其进行微分,代码内容如下所示:
from mindspore import nn, Tensor, ops
from mindspore.ops.functional import grad
import numpy as np
from mindspore import numpy as msnp
class Net(nn.Cell):
def __init__(self, y, index):
super(Net, self).__init__()
self.y = y
self.index = index
self.norm = nn.Norm(-1)
def construct(self, x):
return self.y[self.index]*x
x = Tensor(np.array([1,2,3]).astype(np.float32))
y = Tensor(np.array([[1],[2],[3]]).astype(np.float32))
index = Tensor(np.array([0,2,1]).astype(np.int32))
shape = (y.shape[0], x.shape[0])
output = grad(Net(y,index))(x)
print(output)
# [6. 6. 6.]
在这个案例中,我们得到的结果,首先维度就不对,我们理想中的雅可比矩阵应该是3*3大小的,可见MindSpore中自动微分的逻辑是把其中的一个维度进行了加和,类似于这样的形式:
\frac{\partial z_0}{\partial x_0}+\frac{\partial z_1}{\partial x_0}+\frac{\partial z_2}{\partial x_0}, \frac{\partial z_0}{\partial x_1}+\frac{\partial z_1}{\partial x_1}+\frac{\partial z_2}{\partial x_1}, \frac{\partial z_0}{\partial x_2}+\frac{\partial z_1}{\partial x_2}+\frac{\partial z_2}{\partial x_2}
\right]
\]
所以为了得到我们的结果,需要对输入的x进行扩维。
尝试扩维输入的自动微分
在MindSpore中提供了BroadcastTo这样的接口,可以自动的在扩展维度填充待扩展张量的元素,我们需要把x的最外层维度扩展到与参数y一致,在这个案例中就是3*3的维度,具体代码实现如下所示:
from mindspore import nn, Tensor, ops
from mindspore.ops.functional import grad
import numpy as np
from mindspore import numpy as msnp
class Net(nn.Cell):
def __init__(self, y, index):
super(Net, self).__init__()
self.y = y
self.index = index
self.norm = nn.Norm(-1)
def construct(self, x):
return self.y[self.index]*x
x = Tensor(np.array([1,2,3]).astype(np.float32))
y = Tensor(np.array([[1],[2],[3]]).astype(np.float32))
index = Tensor(np.array([0,2,1]).astype(np.int32))
shape = (y.shape[0], x.shape[0])
output = grad(Net(y,index))(ops.BroadcastTo(shape)(x))
print(output)
'''
[[1. 1. 1.]
[3. 3. 3.]
[2. 2. 2.]]
'''
从这个输出结果中我们发现,虽然维度上是被扩展成功了,但是那些本该为0的位置却出现了非0元素,这说明在自动微分计算的过程中,我们输入的参数y也被自动的Broadcast了,而实际上正确的计算过程中是不能使用Broadcast的。
为参数添加Mask
上一个章节中说道,如果利用Tensor本身的自动Broadcast会导致输入参数被扩维,会得到一个错误的微分结果。因此这里我们手动对输入参数进行正确的扩维,这个过程是添加一个Mask矩阵,用于标记每一个参数所对应的位置。这里我们假设输入一个这样的Mask矩阵:
\begin{array}{l}
1 & 0 & 0\\
0 & 0 & 1\\
0 & 1 & 0
\end{array}
\right]
\]
这样理论上最终微分结果的非0元素应该跟这个矩阵是一致的,相关代码如下所示:
from mindspore import nn, Tensor, ops
from mindspore.ops.functional import grad
import numpy as np
from mindspore import numpy as msnp
class Net(nn.Cell):
def __init__(self, y, index, size):
super(Net, self).__init__()
self.y = y
self.index = index
self.norm = nn.Norm(-1)
self.mask = msnp.zeros((y.shape[0],size))
self.mask[msnp.arange(self.index.shape[0]),self.index] = 1
def construct(self, x):
return self.mask*self.y[self.index]*x
x = Tensor(np.array([1,2,3]).astype(np.float32))
y = Tensor(np.array([[1],[2],[3]]).astype(np.float32))
index = Tensor(np.array([0,2,1]).astype(np.int32))
shape = (y.shape[0], x.shape[0])
output = grad(Net(y,index,x.shape[0]))(ops.BroadcastTo(shape)(x))
print(output)
'''
[[1. 0. 0.]
[0. 0. 3.]
[0. 2. 0.]]
'''
这里我们看到得到的结果就是正确的了。当然,需要说明的是,虽然这个案例只是非常简单的内容,但是这里给出的如何去计算多维函数的自动微分的方法,同样也适用于一些更加复杂的网络和函数。
总结概要
在本文中通过一个实际函数案例的多次尝试,给出了得到预期结果的一种解决方案。虽然MindSpore框架本身提供了Jvp和Vjp等功能,但是实际上和Grad没有太大的区别,只是用Tuple的形式增加了输入的一个维度。如果可以使用纯Tensor的输入,用这种Mask加上Grad或者GradOperation的方案会更加简单一些。同时我也尝试过使用HyperMap(类似于Jax中的vmap)来解决这个问题,只需要写好一条对z求导的函数形式,就可以自动对这个求导过程进行扩维,两者的结果是一致的。但是MindSpore的HyperMap在Graph模式下兼容效果不是很好,建议非必要不尝试。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/jvp.html
作者ID:DechinPhy
更多原著文章请参考:https://www.cnblogs.com/dechinphy/
打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958
MindSpore多元自动微分的更多相关文章
- MindSpore:自动微分
MindSpore:自动微分 作为一款「全场景 AI 框架」,MindSpore 是人工智能解决方案的重要组成部分,与 TensorFlow.PyTorch.PaddlePaddle 等流行深度学习框 ...
- 附录D——自动微分(Autodiff)
本文介绍了五种微分方式,最后两种才是自动微分. 前两种方法求出了原函数对应的导函数,后三种方法只是求出了某一点的导数. 假设原函数是$f(x,y) = x^2y + y +2$,需要求其偏导数$\fr ...
- pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分
参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...
- 自动微分(AD)学习笔记
1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...
- <转>如何用C++实现自动微分
作者:李瞬生转摘链接:https://www.zhihu.com/question/48356514/answer/123290631来源:知乎著作权归作者所有. 实现 AD 有两种方式,函数重载与代 ...
- (转)自动微分(Automatic Differentiation)简介——tensorflow核心原理
现代深度学习系统中(比如MXNet, TensorFlow等)都用到了一种技术——自动微分.在此之前,机器学习社区中很少发挥这个利器,一般都是用Backpropagation进行梯度求解,然后进行SG ...
- PyTorch自动微分基本原理
序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...
- 【tensorflow2.0】自动微分机制
神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情. 而深度学习框架可以帮助我们自动地完成这种求梯度运算. Tensorflow一般使用梯度磁带tf.Gradi ...
- PyTorch 自动微分示例
PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...
随机推荐
- [ unittest ] 文档粗读
参考: https://blog.csdn.net/ljl6158999/article/details/80994979 1.概念提出 unittest最初灵感来自于Junit,它有着和其他单元测试 ...
- c# - 常量定义与赋值
1.前言 c#与Java很相似,但是不一样,又与js(JavaScript)相似,但是也不一样,所以我认为c#是Java和 js的孩子. 2.常量定义 字符串: const string = &quo ...
- [网络编程] 自己构建一个cgi.FieldStorage()的对象
问题描述: 通常cgi.FieldStorage()返回一个类似于Python字典的对象. 在cgi框架中必须通过浏览器发送表单过来才能接受消息 那么我该怎么进行本地调试呢? 或者说在没有搭建好一整套 ...
- spring cloud 与spring boot的版本对应总结
1.前言 一开始不理解为什么使用 spring boot 高版本 ,却没有对应的spring cloud版本 ,还以为最高版本的 spring cloud 会向上兼容 . 这个坑 ,没有人告诉我, ...
- 网络协议学习笔记(八)DNS协议和HttpDNS协议
概述 上一篇主要讲解了流媒体协议和p2p协议,现在我给大家讲解一下关于DNS和HttpDNS的相关知识. DNS协议:网络世界的地址簿 在网络世界,也是这样的.你肯定记得住网站的名称,但是很难记住网站 ...
- 浅解XXE与Portswigger Web Sec
XXE与Portswigger Web Sec 相关链接: 博客园 安全脉搏 FreeBuf 简介XML XML,可扩展标记语言,标准通用标记语言的子集.XML的简单易于在任何应用程序 ...
- [STM32F4xx 学习] SPI与nRF24L01+的应用
前面已经总结过STM32Fxx的特点和传输过程,下面以nRF24L01+ 2.4GHz无线收发器为例,来说明如何使用SPI. 一.nRF24L01+ 2.4GHz无线收发器的介绍 1. 主要特性 全球 ...
- 【VictoriaMetrics】vm单机版和vm-storage的查询功能的对比
1.vm-storage源码调用表 文件 行号 函数 说明 app/vmstorage/main.go 53 main 入口94行调用srv.RunVMSelect() app/vmstorage/t ...
- 【代码分享】用redis+lua实现多个集合取交集并过滤,类似于: select key from set2 where key in (select key from set1) and value>=xxx
redis中的zset结构可以看成一个个包含数值的集合,或者认为是一个关系数据库中用列存储方式存储的一列. 需求 假设我有这样一个数据筛选需求,用SQL表示为: select key from set ...
- 在linux下编译android下的opencv,使用cmake的方法
#前一篇帖子实验了build_sdk.py来编译opencv,失败了.#本篇尝试使用cmake来编译#感谢这篇帖子提供的指导:https://www.cnblogs.com/jojodru/p/100 ...