得益于反向传播算法,神经网络计算导数时非常方便,下面代码中演示如何使用LibTorch进行自动微分求导。

进行自动微分运算需要调用函数

torch::autograd::grad(
outputs, // 为某个可微函数的输出 y=f(x) 中的 y
inputs, // 为某个可微函数的输入 y=f(x) 中的 x
grad_outputs,// 雅克比矩阵(此处计算 f'(x),故设置为1,且与x形状相同 )
retain_graph,// 默认值与 create_graph 相同,这里设置为 true即可
create_graph,// 需要设置为 true 以计算高阶导数
allow_unused // 设置为 false 即可
)

在本文示例中,我们计算 \(y=x^2+x\) 在 \(x = 0.1, 0.3, 0.5\) 处的函数值、一阶导数和二阶导数值,根据我们学到的数学知识,很容易计算出下列数据

\(x\) 0.1 0.3 0.5
\(y\) 0.11 0.39 0.75
\(y'\) 1.20 1.60 2.00
\(y''\) 2.00 2.00 2.00

而在LibTorch中调用自动微分计算导数的代码如下所示

#include <iostream>
#include <torch/torch.h> int main(int argc, char* atgv[])
{
std::cout.setf(std::ios::scientific);
std::cout.precision(7); std::vector<float> vec{0.1, 0.3, 0.5};
torch::Tensor x = torch::from_blob(vec.data(), {3}, torch::kFloat).requires_grad_(true);
torch::Tensor y = x * x + x; // y= x^2 + x
auto weight = torch::ones_like(x); std::cout << "x = ";
for (int i = 0; i < 3; ++i)
std::cout << x[i].item<float>() << " ";
std::cout << std::endl; std::cout << "y = "; // 0.11 0.39 0.75
for (int i = 0; i < 3; ++i)
std::cout << y[i].item<float>() << " ";
std::cout << std::endl; // 计算输出一阶导数(y' = 2x + 1)
auto dydx = torch::autograd::grad({y}, {x}, {weight}, true, true, false);
std::cout << "dydx = "; // 1.2 1.6 2.0
for (int i = 0; i < 3; ++i)
std::cout << dydx[0][i].item<float>() << " ";
std::cout << std::endl; // 计算输出二阶导数(y''= 2)
auto d2ydx2 = torch::autograd::grad({dydx[0]}, {x}, {weight});
std::cout << "d2ydx2 = "; // 2.0 2.0 2.0
for (int i = 0; i < 3; ++i)
std::cout << d2ydx2[0][i].item<float>() << " ";
std::cout << std::endl; return 0;
}

计算结果如下图所示,与我们手动计算的结果一致。

LibTorch 自动微分的更多相关文章

  1. 附录D——自动微分(Autodiff)

    本文介绍了五种微分方式,最后两种才是自动微分. 前两种方法求出了原函数对应的导函数,后三种方法只是求出了某一点的导数. 假设原函数是$f(x,y) = x^2y + y +2$,需要求其偏导数$\fr ...

  2. pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分

    参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...

  3. 自动微分(AD)学习笔记

    1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...

  4. <转>如何用C++实现自动微分

    作者:李瞬生转摘链接:https://www.zhihu.com/question/48356514/answer/123290631来源:知乎著作权归作者所有. 实现 AD 有两种方式,函数重载与代 ...

  5. (转)自动微分(Automatic Differentiation)简介——tensorflow核心原理

    现代深度学习系统中(比如MXNet, TensorFlow等)都用到了一种技术——自动微分.在此之前,机器学习社区中很少发挥这个利器,一般都是用Backpropagation进行梯度求解,然后进行SG ...

  6. PyTorch自动微分基本原理

    序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...

  7. 【tensorflow2.0】自动微分机制

    神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情. 而深度学习框架可以帮助我们自动地完成这种求梯度运算. Tensorflow一般使用梯度磁带tf.Gradi ...

  8. PyTorch 自动微分示例

    PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...

  9. PyTorch 自动微分

    PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...

随机推荐

  1. Json多层级动态结构数据解析

    一.工具 (1)GSON Google Gson是一个简单的基于Java的库,用于将Java对象序列化为JSON,反之亦然. 它是由Google开发的一个开源库. 以下几点说明为什么应该使用这个库 - ...

  2. 分布式事务(Seata) 四大模式详解

    前言 在上一节中我们讲解了,关于分布式事务和seata的基本介绍和使用,感兴趣的小伙伴可以回顾一下<别再说你不知道分布式事务了!> 最后小农也说了,下期会带给大家关于Seata中关于sea ...

  3. SSRS筛选器的IN运算(即包含于)用法

    筛选器的IN运算,在Microsoft的官网上没像样儿的例子,不好设置,很容易错 Microsoft上的文档:https://docs.microsoft.com/zh-cn/sql/reportin ...

  4. gnet: 一个轻量级且高性能的 Go 网络框架 使用笔记

    一个偶然的机会接触到了golang,被它的高并发传说所吸引,就开始学这门语言,越学感觉越有意思^_^ 注册了博客园这么多年,第一次写东西,年纪大了,脑子不好使了,就得写下来,记下来,为了自己以后查阅, ...

  5. Jira7.3.8环境搭建

    安装JDK sudo apt-get install openjdk-8-jdk 安装&配置MySQL sudo apt-get install mysql-server 创建jira用户 # ...

  6. Linux 更改家目录下的目录为英文

    export LANG=en_US xdg-user-dirs-gtk-update

  7. 如何学习Vim

    如果你是Linux用户,学习Vim会有很大的好处. 如果你是windows用户,个人建议还是使用vscode. 准备大约40min的学习时间,打开终端,输入下面命令开启自带教程 vimtutor 按操 ...

  8. 漏洞扫描工具nessus、rapid7 insightvm、openvas安装&简单使用

    Rapid7-insightvm 申请试用 申请地址 邮件地址不能用常用邮件,要使用自己域名的邮件,可以使用这个临时邮箱 手机号随便输入,10位以上 提交后会跳转下载页面 安装 安装:./Rapid7 ...

  9. web 前端 基础HTML知识点

    web系统架构体系 B/S(Browser/Server):浏览器实现 优点: 规范.使用方便.本身实现成本低 容易升级.便于维护 缺点: 没有网络,无法使用 保存数据量有限,和服务器交互频率高.耗费 ...

  10. FPS游戏逆向-方框透视(三角函数)

    本套课程主要学习FPS类游戏安全 由于FPS类游戏本身的特性问题,可能产生一些通用的游戏安全问题 在通过逆向与正向对FPS类游戏分析之后,找到其可能出现的不安全点 才能更好的保护游戏不被外部力量侵犯 ...