得益于反向传播算法,神经网络计算导数时非常方便,下面代码中演示如何使用LibTorch进行自动微分求导。

进行自动微分运算需要调用函数

torch::autograd::grad(
outputs, // 为某个可微函数的输出 y=f(x) 中的 y
inputs, // 为某个可微函数的输入 y=f(x) 中的 x
grad_outputs,// 雅克比矩阵(此处计算 f'(x),故设置为1,且与x形状相同 )
retain_graph,// 默认值与 create_graph 相同,这里设置为 true即可
create_graph,// 需要设置为 true 以计算高阶导数
allow_unused // 设置为 false 即可
)

在本文示例中,我们计算 \(y=x^2+x\) 在 \(x = 0.1, 0.3, 0.5\) 处的函数值、一阶导数和二阶导数值,根据我们学到的数学知识,很容易计算出下列数据

\(x\) 0.1 0.3 0.5
\(y\) 0.11 0.39 0.75
\(y'\) 1.20 1.60 2.00
\(y''\) 2.00 2.00 2.00

而在LibTorch中调用自动微分计算导数的代码如下所示

#include <iostream>
#include <torch/torch.h> int main(int argc, char* atgv[])
{
std::cout.setf(std::ios::scientific);
std::cout.precision(7); std::vector<float> vec{0.1, 0.3, 0.5};
torch::Tensor x = torch::from_blob(vec.data(), {3}, torch::kFloat).requires_grad_(true);
torch::Tensor y = x * x + x; // y= x^2 + x
auto weight = torch::ones_like(x); std::cout << "x = ";
for (int i = 0; i < 3; ++i)
std::cout << x[i].item<float>() << " ";
std::cout << std::endl; std::cout << "y = "; // 0.11 0.39 0.75
for (int i = 0; i < 3; ++i)
std::cout << y[i].item<float>() << " ";
std::cout << std::endl; // 计算输出一阶导数(y' = 2x + 1)
auto dydx = torch::autograd::grad({y}, {x}, {weight}, true, true, false);
std::cout << "dydx = "; // 1.2 1.6 2.0
for (int i = 0; i < 3; ++i)
std::cout << dydx[0][i].item<float>() << " ";
std::cout << std::endl; // 计算输出二阶导数(y''= 2)
auto d2ydx2 = torch::autograd::grad({dydx[0]}, {x}, {weight});
std::cout << "d2ydx2 = "; // 2.0 2.0 2.0
for (int i = 0; i < 3; ++i)
std::cout << d2ydx2[0][i].item<float>() << " ";
std::cout << std::endl; return 0;
}

计算结果如下图所示,与我们手动计算的结果一致。

LibTorch 自动微分的更多相关文章

  1. 附录D——自动微分(Autodiff)

    本文介绍了五种微分方式,最后两种才是自动微分. 前两种方法求出了原函数对应的导函数,后三种方法只是求出了某一点的导数. 假设原函数是$f(x,y) = x^2y + y +2$,需要求其偏导数$\fr ...

  2. pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分

    参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...

  3. 自动微分(AD)学习笔记

    1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...

  4. <转>如何用C++实现自动微分

    作者:李瞬生转摘链接:https://www.zhihu.com/question/48356514/answer/123290631来源:知乎著作权归作者所有. 实现 AD 有两种方式,函数重载与代 ...

  5. (转)自动微分(Automatic Differentiation)简介——tensorflow核心原理

    现代深度学习系统中(比如MXNet, TensorFlow等)都用到了一种技术——自动微分.在此之前,机器学习社区中很少发挥这个利器,一般都是用Backpropagation进行梯度求解,然后进行SG ...

  6. PyTorch自动微分基本原理

    序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...

  7. 【tensorflow2.0】自动微分机制

    神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情. 而深度学习框架可以帮助我们自动地完成这种求梯度运算. Tensorflow一般使用梯度磁带tf.Gradi ...

  8. PyTorch 自动微分示例

    PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...

  9. PyTorch 自动微分

    PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...

随机推荐

  1. 一次 Keepalived 高可用的事故,让我重学了一遍它!

    原文首发: 你好,我是悟空. 前言 上次我们遇到了一个 MySQL 故障的事故,这次我又遇到了另外一个奇葩的问题: Keepalived 高可用组件的虚拟 IP 持续漂移,导致 MySQL 主从不断切 ...

  2. colab运行.py文件

    !python split_data.py

  3. Linux for CentOS 下的 keepalived 安装与卸载以及相关命令操作之详细教程

    百度百科解释: keepalived 是一个类似于 layer3, 4 & 7 交换机制的软件,也就是我们平时说的第 3 层.第 4 层和第 7 层交换.Keepalived 的作用是检测 w ...

  4. APISpace 尾号限行API接口 免费好用

    尾号限行是一种为了缓解城市交通压力而催生的交通制度,措施实施以后对城市交通拥堵起到缓解作用.每个地区的尾号限行规定都有所不同,具体的以当地的为准.   尾号限行API,提供已知所有执行限行政策的共计6 ...

  5. SpringBoot到底是什么?

    摘要:Spring Boot是由Pivotal团队提供的全新框架,其设计目的是用来简化新Spring应用的初始搭建以及开发过程. 本文分享自华为云社区<SpringBoot到底是什么?如何理解p ...

  6. 基于ABP实现DDD--领域逻辑和应用逻辑

      本文主要介绍了多应用层的问题,包括原因和实现.通过理解介绍了如何区分领域逻辑和应用逻辑,哪些是正确的实践,哪些是不推荐的或者错误的实践. 一.多应用层的问题 1.多应用层介绍   不知道你们是否会 ...

  7. api.versioning 版本控制 自动识别最高版本和多Area但同名Contoller问题解决办法

    Microsoft.AspNetCore.Mvc.Versioning //引入程序集 .net core 下面api的版本控制作用不需要多说,可以查阅https://www.cnblogs.com/ ...

  8. 2537-springsecurity系列--关于session的管理2-session缓存和共享

    版本信息 <parent> <groupId>org.springframework.boot</groupId> <artifactId>spring ...

  9. sql语句实现每日资格设置

    CREATE TABLE 'iDayAuth'( 'openid' VARCHAR(16) NOT NULL , 'iStamp' INT(10) NOT NULL, 'iDayAuth' SMALL ...

  10. Nginx 限制上传文件的大小。responded with a status of 413 (Request Entity Too Large)

    # 限制请求体的大小,若超过所设定的大小,返回413错误. client_max_body_size 50m; # 读取请求头的超时时间,若超过所设定的大小,返回408错误. client_heade ...