注意:本文相关基础知识不介绍。

给出代码:

from jax import jacfwd, jacrev
import jax.numpy as jnp
def hessian_1(f):
return jacfwd(jacrev(f)) def hessian_2(f):
return jacfwd(jacfwd(f)) def hessian_3(f):
return jacrev(jacfwd(f)) def hessian_4(f):
return jacrev(jacrev(f)) def f(x):
return (x ** 2).sum() print(hessian_1(f)(jnp.ones((100,))))
print(hessian_2(f)(jnp.ones((100,))))
print(hessian_3(f)(jnp.ones((100,))))
print(hessian_4(f)(jnp.ones((100,)))) import time a=time.time()
hessian_1(f)(jnp.ones((100,)))
b=time.time()
print(b-a) hessian_2(f)(jnp.ones((100,)))
c=time.time()
print(c-b) hessian_3(f)(jnp.ones((100,)))
d=time.time()
print(d-b) hessian_4(f)(jnp.ones((100,)))
e=time.time()
print(e-d)

运算结果:

结论(不一定正确):

两次求导均使用后向模式的要比两次求导均使用前向模式的要速度快,并且两次求导使用相同模式的要比两次求导分别使用不同模式的速度要快;

第一次求导使用后向模式,第二次求导使用前向模式,要比第一次求导使用前向模式,第二次求导使用反向模式的速度要快。

修改代码:

from jax import jacfwd, jacrev
import jax.numpy as jnp
from jax import jit def hessian_1(f):
return jacfwd(jacrev(f)) def hessian_2(f):
return jacfwd(jacfwd(f)) def hessian_3(f):
return jacrev(jacfwd(f)) def hessian_4(f):
return jacrev(jacrev(f)) @jit
def f(x):
return (x ** 2).sum() x = jnp.ones((100,)) print(hessian_1(f)(x))
print(hessian_2(f)(x))
print(hessian_3(f)(x))
print(hessian_4(f)(x)) import time a=time.time()
hessian_1(f)(x)
b=time.time()
print(b-a) hessian_2(f)(x)
c=time.time()
print(c-b) hessian_3(f)(x)
d=time.time()
print(d-b) hessian_4(f)(x)
e=time.time()
print(e-d)

运算结果:

得出另一种结论(之所以上下两次结论不同,个人估计是这个函数太过于简单造成的):

(不一定正确)

两次求导均使用后向模式的要比两次求导均使用前向模式的要速度慢;

第一次求导使用后向模式,第二次求导使用前向模式,要比第一次求导使用前向模式,第二次求导使用反向模式的速度要快。

jax框架为例:求hession矩阵时前后向模式的自动求导的性能差别的更多相关文章

  1. [Pytorch框架] 1.4 Autograd:自动求导

    文章目录 Autograd: 自动求导机制 张量(Tensor) 梯度 Autograd: 自动求导机制 PyTorch 中所有神经网络的核心是 autograd 包. 我们先简单介绍一下这个包,然后 ...

  2. EOJ3536 求蛇形矩阵每一行的和---找规律

    题目链接: https://acm.ecnu.edu.cn/problem/3536/ 题目大意: 求蛇形矩阵的每一行的和,数据范围n<=200000. 思路: 由于n数据较大,所以感觉应该是需 ...

  3. Robot Framework测试框架用例脚本设计方法

    Robot Framework介绍 Robot Framework是一个通用的关键字驱动自动化测试框架.测试用例以HTML,纯文本或TSV(制表符分隔的一系列值)文件存储.通过测试库中实现的关键字驱动 ...

  4. Hession矩阵(整理)

    二阶偏导数矩阵也就所谓的赫氏矩阵(Hessian matrix). 一元函数就是二阶导,多元函数就是二阶偏导组成的矩阵. 求向量函数最小值时用的,矩阵正定是最小值存在的充分条件. 经济学中常常遇到求最 ...

  5. 轻松应对并发问题,简易的火车票售票系统,Newbe.Claptrap 框架用例,第一步 —— 业务分析

    Newbe.Claptrap 框架非常适合于解决具有并发问题的业务系统.火车票售票系统,就是一个非常典型的场景用例. 本系列我们将逐步从业务.代码.测试和部署多方面来介绍,如何使用 Newbe.Cla ...

  6. 二维KMP - 求字符矩阵的最小覆盖矩阵 - poj 2185

    Milking Grid Problem's Link:http://poj.org/problem?id=2185 Mean: 给你一个n*m的字符矩阵,让你求这个字符矩阵的最小覆盖矩阵,输出这个最 ...

  7. 一本通1641【例 1】矩阵 A×B

    1641: [例 1]矩阵 A×B sol:矩阵乘法模板.三个for循环 #include <bits/stdc++.h> using namespace std; typedef lon ...

  8. Task 4.2 求一个矩阵的最大子矩阵的和

    任务:输入一个二维整形数组,数组里有正数也有负数.二维数组中连续的一个子矩阵组成一个子数组,每个子数组都有一个和.求所有子数组的和的最大值.要求时间复杂度为O(n). (1)设计思想:把二维矩阵分解成 ...

  9. 解决使用elementUI框架el-upload跨域上传时session丢失问题

    解决方法一: 1.使用elementUI框架el-upload跨域上传时,后端获取不到cookie,后端接口显示未登录,在添加了 with-credentials="true"后依 ...

  10. [深度学习] pytorch学习笔记(1)(数据类型、基础使用、自动求导、矩阵操作、维度变换、广播、拼接拆分、基本运算、范数、argmax、矩阵比较、where、gather)

    一.Pytorch安装 安装cuda和cudnn,例如cuda10,cudnn7.5 官网下载torch:https://pytorch.org/ 选择下载相应版本的torch 和torchvisio ...

随机推荐

  1. kettle从入门到精通 第十九课 kettle 资源仓库

    1.kettle 里面的资源仓库的意思就是存放转换(.ktr)或者job(.kjb)文件的地方.通过spoon客户端右上角可以进行设置资源仓库. 2.kettle的资源仓库有三种方式 1)本地文件存储 ...

  2. SpringBoot系列(一)简介。

    概述: Spring Boot 可以简化spring的开发,可以快速创建独立的.产品级的应用程序. 特征: 快速创建独立的 Spring 应用程序 直接嵌入了Tomcat.Jetty或Undertow ...

  3. TRL(Transformer Reinforcement Learning) PPO Trainer 学习笔记

    (1)  PPO Trainer TRL支持PPO Trainer通过RL训练语言模型上的任何奖励信号.奖励信号可以来自手工制作的规则.指标或使用奖励模型的偏好数据.要获得完整的示例,请查看examp ...

  4. Coap 协议学习:1-有关概念

    COAP协议简介 不像人接入互联网的简单方便,由于物联网设备大多都是资源限制型的,有限的CPU.RAM.Flash.网络宽带等.对于这类设备来说,想要直接使用现有网络的TCP和HTTP来实现设备实现信 ...

  5. 机器学习(三)——K最临近方法构建分类模型(matlab)

    K最临近(K-Nearest Neighbors,KNN)方法是一种简单且直观的分类和回归算法,主要用于分类任务.其基本原理是用到表决的方法,找到距离其最近的K个样本,然后通过K个样本的标签进行表决, ...

  6. aach64架构 ubuntu20 桌面版 编译安装ffmpeg难点总结

    [编译安装x264] 这一步基本上没有难点 git clone https://gitee.com/mirrors/x264.git ./configure --enable-shared --ena ...

  7. Java BigDecimal 算术运算

    算术运算 BigDecimal bignum1 = new BigDecimal("10"); BigDecimal bignum2 = new BigDecimal(" ...

  8. Java 散列表HashTable

    什么是散列表hash table和使用场景 什么是散列表 散列表(Hash table,也叫哈希表),是根据关键码值(key value)而直接进行访问的数据结构.它通过把关键码值映射到表中一个位置来 ...

  9. 前端:如何让background背景图片进行CSS自适应

    在设置login背景时,找到了一张这样的图片: 但是设置成login背景时,如果没有做一些css适应设置,图片就变样了,变成了这样: 严重变形了,这就造成了一种理想与现实的差距. 若想解决这个自适应问 ...

  10. 面试官:Dubbo一次RPC调用会经过哪些环节?

    大家好,我是三友~~ 今天继续探秘系列,扒一扒一次RPC请求在Dubbo中经历的核心流程. 本文是基于Dubbo3.x版本进行讲解 一个简单的Demo 这里还是老样子,为了保证文章的完整性和连贯性,方 ...