jax框架:jax.grad
官方地址:
https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad
这里只给出几个样例代码:
- 设置 allow_int 参数,实现对整数类型求导:
未对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, ))
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
正常运行:
对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, 1))
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
报错:
通过设置 allow_int 实现对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, 1), allow_int=True)
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
未报错运行,但是没有获得争取结果:
应该这么说,在jax中不能对整数类型求导的,虽然这里设置了 allow_int 但是也不能得到正确的对整数类型的求导。
jax框架:jax.grad的更多相关文章
- 分子动力学模拟之SETTLE约束算法
技术背景 在上一篇文章中,我们讨论了在分子动力学里面使用LINCS约束算法及其在具备自动微分能力的Jax框架下的代码实现.约束算法,在分子动力学模拟的过程中时常会使用到,用于固定一些既定的成键关系.例 ...
- InfoQ一波文章:AdaSearch/JAX/TF_Serving/leon.bottou.org/Neural_ODE/NeurIPS_2018最佳论文
和 Nested Partition 有相通之处? 伯克利提出 AdaSearch:一种用于自适应搜索的逐步消除方法 在机器学习领域的诸多任务当中,我们通常希望能够立足预先给定的固定数据集找出问题的答 ...
- 使用jax加速Hamming Distance的计算
技术背景 一般认为Jax是谷歌为了取代TensorFlow而推出的一款全新的端到端可微的框架,但是Jax同时也集成了绝大部分的numpy函数,这就使得我们可以更加简便的从numpy的计算习惯中切换到G ...
- 使用JAX构建强化学习agent并借助TensorFlowLite将其部署到Android应用中
在之前发布文章<一个新 TensorFlow Lite 示例应用:棋盘游戏>中,展示了如何使用 TensorFlow 和 TensorFlow Agents 来训练强化学习 (RL) ag ...
- spring+jax 出现java.io.Serializable is an interface, and JAXB can't handle interfaces
spring+jax 出现java.io.Serializable is an interface, and JAXB can't handle interfaces 原因是我的webservice方 ...
- 活动预告 | Jax Diffusers 社区冲刺线上分享(还有北京线下活动)
我们的 Jax Diffuser 社区冲刺活动已经截止报名,全球有 200 多名参赛选手成功组成了约 70 支队伍共同参赛. 为了帮助参赛者更好的完成自己的项目,也为了与更多社区成员们分享扩散模型和生 ...
- Math Jax开源数学编辑器的使用
首先,这是一个开源免费,同时也可以支持扩展的软件. 使用API文档: 中文网站(http://mathjax-chinese-doc.readthedocs.io/en/latest/index.ht ...
- Jax
The scope of this project is to automate the current Credit Correction process of opening, editing, ...
- java各种框架的比较,分析
Spring 框架 优点 1.提供了一种管理对象的方法,可以把中间层的对象有效地组织起来 2.采用了分层结构,可以增量引入到项目中. 3.代码测试较容易 4.非侵入性,应用程序对Spring API的 ...
- Spring Framework(框架)整体架构 变迁
Spring Framework(框架)整体架构 2018年04月24日 11:16:41 阅读数:1444 标签: Spring框架架构 更多 个人分类: Spring框架 版权声明:本文为博主 ...
随机推荐
- .net6 .net core web api json 遇到 400 错误
环境: .net6 webapi 服务端模型声明 public class TongYiMinPgPayReq { public string mch_no { get; set; } public ...
- C#字符串截取分割方法
字符串操作:分割 Split.连接数组 Join.拼接 Format.截取 Substring.替换 Replace.左填充 PadLeft.右填充 PadRight.删除 Remove 1 //分割 ...
- CICD介绍
1.学习背景 当公司的服务器架构越来越复杂,需要频繁的发布新配置文件,以及新代码: 但是如果机器部署数量较多,发布的效率必然很低: 并且如果代码没有经过测试环境,预生产环境层层测试,最终才到生产环境, ...
- rsync备份任务练习
06-备份任务实战 今天的任务主要以实际备份任务入手,完成综合练习,完成对rsync的综合运用. 先看需求 再讲解 再次动手实践 客户端需求 客户端需求: 1.客户端每天凌晨1点在服务器本地打包备份( ...
- rsync备份
备份工具rsync 备份是太常见.且太重要的一个日常工作了. 备份源码.文档.数据库.等等. 类似cp命令拷贝,但是支持服务器之间的网络拷贝,且保证安全性. 学习背景 超哥游戏公司要每天都要对代码备份 ...
- 判断一个数n是不是快乐数
引言 题目:编写一个算法来判断一个数n是不是快乐数 来源:网友分享的面试算法题 题目描述 [快乐数定义] 对于一个正整数,每一次将该数替换为它每个位置上的数字的平方和. 然后重复这个过程直到这个数变为 ...
- 解锁LLMs的“思考”能力:Chain-of-Thought(CoT) 技术推动复杂推理的新发展
解锁LLMs的"思考"能力:Chain-of-Thought(CoT) 技术推动复杂推理的新发展 1.简介 Chain-of-Thought(CoT)是一种改进的Prompt技术, ...
- 接口签名规则和Java实现签名和验签代码
接口签名规则和Java实现签名和验签代码 签名规则 签名生成的通用步骤如下: 第一步,设所有发送或者接收到的数据为集合M,将集合M内非空参数值的参数按照参数名ASCII码从小到大排序(字典序),使用U ...
- TGI 基准测试
本文主要探讨 TGI 的小兄弟 - TGI 基准测试工具.它能帮助我们超越简单的吞吐量指标,对 TGI 进行更全面的性能剖析,以更好地了解如何根据实际需求对服务进行调优并按需作出最佳的权衡及决策.如果 ...
- C++中UNIX时间戳与日期互转
C++中UNIX时间戳与日期互转 使用time.h头文件 localtime 可以把时间戳转为 tm 结构体, tm结构体中可以格式化输出时间 mktime可以把tm结构体转为时间戳 tm 结构体中: ...