jax框架:jax.grad
官方地址:
https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad

这里只给出几个样例代码:
- 设置 allow_int 参数,实现对整数类型求导:
未对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, ))
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
正常运行:

对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, 1))
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
报错:

通过设置 allow_int 实现对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, 1), allow_int=True)
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
未报错运行,但是没有获得争取结果:

应该这么说,在jax中不能对整数类型求导的,虽然这里设置了 allow_int 但是也不能得到正确的对整数类型的求导。
jax框架:jax.grad的更多相关文章
- 分子动力学模拟之SETTLE约束算法
技术背景 在上一篇文章中,我们讨论了在分子动力学里面使用LINCS约束算法及其在具备自动微分能力的Jax框架下的代码实现.约束算法,在分子动力学模拟的过程中时常会使用到,用于固定一些既定的成键关系.例 ...
- InfoQ一波文章:AdaSearch/JAX/TF_Serving/leon.bottou.org/Neural_ODE/NeurIPS_2018最佳论文
和 Nested Partition 有相通之处? 伯克利提出 AdaSearch:一种用于自适应搜索的逐步消除方法 在机器学习领域的诸多任务当中,我们通常希望能够立足预先给定的固定数据集找出问题的答 ...
- 使用jax加速Hamming Distance的计算
技术背景 一般认为Jax是谷歌为了取代TensorFlow而推出的一款全新的端到端可微的框架,但是Jax同时也集成了绝大部分的numpy函数,这就使得我们可以更加简便的从numpy的计算习惯中切换到G ...
- 使用JAX构建强化学习agent并借助TensorFlowLite将其部署到Android应用中
在之前发布文章<一个新 TensorFlow Lite 示例应用:棋盘游戏>中,展示了如何使用 TensorFlow 和 TensorFlow Agents 来训练强化学习 (RL) ag ...
- spring+jax 出现java.io.Serializable is an interface, and JAXB can't handle interfaces
spring+jax 出现java.io.Serializable is an interface, and JAXB can't handle interfaces 原因是我的webservice方 ...
- 活动预告 | Jax Diffusers 社区冲刺线上分享(还有北京线下活动)
我们的 Jax Diffuser 社区冲刺活动已经截止报名,全球有 200 多名参赛选手成功组成了约 70 支队伍共同参赛. 为了帮助参赛者更好的完成自己的项目,也为了与更多社区成员们分享扩散模型和生 ...
- Math Jax开源数学编辑器的使用
首先,这是一个开源免费,同时也可以支持扩展的软件. 使用API文档: 中文网站(http://mathjax-chinese-doc.readthedocs.io/en/latest/index.ht ...
- Jax
The scope of this project is to automate the current Credit Correction process of opening, editing, ...
- java各种框架的比较,分析
Spring 框架 优点 1.提供了一种管理对象的方法,可以把中间层的对象有效地组织起来 2.采用了分层结构,可以增量引入到项目中. 3.代码测试较容易 4.非侵入性,应用程序对Spring API的 ...
- Spring Framework(框架)整体架构 变迁
Spring Framework(框架)整体架构 2018年04月24日 11:16:41 阅读数:1444 标签: Spring框架架构 更多 个人分类: Spring框架 版权声明:本文为博主 ...
随机推荐
- 算法金 | AI 基石,无处不在的朴素贝叶斯算法
大侠幸会,在下全网同名「算法金」 0 基础转 AI 上岸,多个算法赛 Top 「日更万日,让更多人享受智能乐趣」 历史上,许多杰出人才在他们有生之年默默无闻, 却在逝世后被人们广泛追忆和崇拜. 18世 ...
- INFINI Easysearch 与兆芯完成产品兼容互认证
近日,极限科技旗下软件产品 INFINI Easysearch 搜索引擎软件 V1.0 与兆芯完成兼容性测试,功能与稳定性良好,并获得兆芯产品兼容互认证书. 此次兼容适配基于银河麒麟高级服务器操作系统 ...
- python-API开发zk客户端
前面于超老师讲完了,zk运维的基本命令行玩法,更多的还是开发需要通过代码和zk结合处理. 大多数场景是java后端去操作. 这里我们以运维更友好的python来学习. 1.kazoo模块 zookee ...
- Spring扩展———自定义bean组件注解
引言 Java 注解(Annotation)又称 Java 标注,是 JDK5.0 引入的一种注释机制. Java 语言中的类.方法.变量.参数和包等都可以被标注.和 Javadoc 不同,Java ...
- 图片预加载需要token认证的地址处理
1.添加函数修改img的属性: /** * * @param {*} idName 传入的id,获取改img的dom,添加相应的数学 */ export const proxyImg = (idNam ...
- mysql 判断字符串结尾
mysql 判断字符串结尾 CREATE TABLE `tbl_str` ( `id` INT DEFAULT NULL, `Str` VARCHAR(30) DEFAULT NULL) INSERT ...
- Lru-k在Rust中的实现及源码解析
LRU-K 是一种缓存淘汰算法,旨在改进传统的LRU(Least Recently Used,最近最少使用)算法的性能.将其中高频的数据达到K次访问移入到另一个队列进行保护. 算法思想 LRU-K中的 ...
- uboot 修改代码 增加 环境变量
--- title: uboot修改代码增加环境变量 date: 2019-12-27 21:26:39 categories: tags: - uboot --- 以"tftp下载kern ...
- 网络OSI七层模型及各层作用 tcp-ip
背景 虽然说以前学习计算机网络的时候,学过了,但为了更好地学习一些物联网协议(MQTT.CoAP.LWM2M.OPC),需要重新复习一下. OSI七层模型 七层模型,亦称OSI(Open System ...
- Vue2 整理(三):高级篇
前言 基础篇链接:https://www.cnblogs.com/xiegongzi/p/15782921.html 组件化开发篇链接:https://www.cnblogs.com/xiegongz ...