技术背景

在Jax的JIT即时编译中,会追踪每一个Tensor的Shape变化。如果在计算的过程中出现了一些动态Shape的Tensor(Shape大小跟输入的数据有关),那么就无法使用Jax的JIT进行编译优化。最常见的就是numpy.where这种操作,因为这个操作返回的是符合判定条件的Index序号,而不同输入对应的输出Index长度一般是不一致的,因此在Jax的JIT中无法对该操作进行编译。当然,需要特别说明的是,numpy.where这个操作有两种用法,一种是numpy.where(condition, 1, 0)直接给输入打上Mask。另一种用法是numpy.where(condition),这种用法返回的就是一个Index序列,也就是我们需要讨论的应用场景。

普通模式

我们考虑一个比较简单的Toy Model用于测试:

\[E=\sum_{|r_i-r_j|\leq \epsilon}q_iq_j
\]

在不采用即时编译的场景下,Jax的代码可以这么写:

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false' import numpy as np
np.random.seed(0)
import jax
from jax import numpy as jnp def func(r, q, cutoff=0.2):
dis = jnp.abs(r[:, None] - r[None])
maski, maskj = jnp.where(dis<=cutoff)
qi = q[maski]
qj = q[maskj]
return jnp.sum(qi*qj) N = 100
r = jnp.array(np.random.random(N), jnp.float32)
q = jnp.array(np.random.random(N), jnp.float32) print (func(r, q))
# 1035.7422

那么我们先记住这个输出的结果,因为采用的随机种子是一致的,一会儿可以直接跟JIT的输出结果进行对比。

JIT模式

Jax的JIT模式的使用方法常见的就是三种,一种是在函数头顶加一个装饰器,一种是在函数引用的时候使用jax.jit(function)来调用,最后一种是配合partial偏函数来使用,都不是很复杂。那么这里先用装饰器的形式演示一下Jax中即时编译的用法:

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false' import numpy as np
np.random.seed(0)
import jax
from jax import numpy as jnp @jax.jit
def func(r, q, cutoff=0.2):
dis = jnp.abs(r[:, None] - r[None])
maski, maskj = jnp.where(dis<=cutoff)
qi = q[maski]
qj = q[maskj]
return jnp.sum(qi*qj) N = 100
r = jnp.array(np.random.random(N), jnp.float32)
q = jnp.array(np.random.random(N), jnp.float32) print (func(r, q))

正如前面所说,因为numpy.where对应的输出是一个动态的Shape,那么在编译阶段就会报错。报错信息如下:

Traceback (most recent call last):
File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 21, in <module>
print (func(r, q))
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/api.py", line 622, in cache_miss
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 303, in memoized_fun
ans = call(fun, *args)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 12, in func
maski, maskj = jnp.where(dis<=cutoff)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1077, in where
return nonzero(condition, size=size, fill_value=fill_value)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1332, in nonzero
size = core.concrete_or_error(operator.index, size,
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/core.py", line 1278, in concrete_or_error
raise ConcretizationTypeError(val, context)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function func at /home/dechin/projects/gitee/dechin/tests/jax_mask.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument 'r'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last):
File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 21, in <module>
print (func(r, q))
File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 12, in func
maski, maskj = jnp.where(dis<=cutoff)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1077, in where
return nonzero(condition, size=size, fill_value=fill_value)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1332, in nonzero
size = core.concrete_or_error(operator.index, size,
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function func at /home/dechin/projects/gitee/dechin/tests/jax_mask.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument 'r'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

想避免这个报错,要么就是对该函数不做编译(牺牲性能),要么就是自己写一个CUDA算子(增加工作量),再就是我们这里用到的NonZero定长输出的方法(预置条件)。

NonZero的使用

使用Jax的NonZero函数时,也有一点需要注意:虽然NonZero可以做到固定长度的输出,但是这个固定的长度本身也是一个名为size的传入参数。也就是说,NonZero的输出Shape也是要取决于输入参数的。Jax开发时也考虑到了这一点,所以在编译时提供了一个功能可以设置静态参量:static_argnames,例如我们的案例中,将size这个名称的传参设置为静态参量,这样就可以使用Jax的即时编译了:

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false' import numpy as np
np.random.seed(0)
import jax
from jax import numpy as jnp
from functools import partial @partial(jax.jit, static_argnames='size')
def func(r, q, cutoff=0.2, size=5000):
if q.shape[0] != r.shape[0]+1:
raise ValueError("The q.shape[0] should be equal to r.shape[0]+1") dis = jnp.abs(r[:, None] - r[None])
maski, maskj = jnp.nonzero(jnp.where(dis<=cutoff, 1, 0), size=size, fill_value=-1)
qi = q[maski]
qj = q[maskj]
return jnp.sum(qi*qj) N = 100
r = jnp.array(np.random.random(N), jnp.float32)
q = jnp.array(np.random.random(N), jnp.float32)
pader = jnp.array([0.], jnp.float32)
q = jnp.append(q, pader) print (func(r, q))
# 1035.7422

可以看到,函数用Jax的JIT成功编译,并且输出结果跟前面未编译时候是一致的。当然,这里还用到了一个小技巧,就是NonZero函数输出结果时,不到长度的输出结果会被自动Pad到给定的长度,这里Pad的值使用的是我们给出的fill_value。因为NonZero输出的也是索引,这样我们可以把Pad的这些索引设置为-1,然后在构建参数\(q\)的时候事先在末尾append一个0,这样就可以确保计算的输出结果直接就是正确的。

总结概要

在使用Jax的过程中,有时候会遇到函数输出是一个动态的Shape,这种情况下我们很难利用到Jax的即时编译的功能,不能使得性能最大化。这也是使用Tensor数据结构来计算的一个特点,有好有坏。本文介绍了Jax的另外一个函数NonZero,可以使得我们能够编译那些动态Shape输出的函数。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/nonzero.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

Jax中关于NonZero的使用的更多相关文章

  1. Python中Numpy.nonzero()函数

    Numpy.nonzero()返回的是数组中,非零元素的位置.如果是二维数组就是描述非零元素在几行几列,三维数组则是描述非零元素在第几组中的第几行第几列. 举例如下: 二维数组: a = np.arr ...

  2. JAX-MD在近邻表的计算中,使用了什么奇技淫巧?(一)

    技术背景 JAX-MD是一款基于JAX的纯Python高性能分子动力学模拟软件,应该说在纯Python的软件中很难超越其性能.当然,比一部分直接基于CUDA的分子动力学模拟软件性能还是有些差距.而在计 ...

  3. K-均值聚类算法(K-means)

        K-means是一种无监督的学习,将相似的对象归到同一个簇中.可以将一批数据分为K个不同的簇,并且每个簇的中心采用簇中所含样本的均值计算而成.     K-means算法的K值需要由用户指定, ...

  4. Python杂记

    一.函数 1.numpy 模块中的nonzero函数 nonzero返回的数非零元素的下标. 如果输入是单维度的时候它的返回值只有一个:如果输入是多个维度的话,那么它的返回值也是多个维度的.并且的它的 ...

  5. Python 工匠:编写条件分支代码的技巧

    欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~ 本文由鹅厂优文发表于云+社区专栏 作者:朱雷 | 腾讯IEG高级工程师 『Python 工匠』是什么? 我一直觉得编程某种意义是一门『手艺』 ...

  6. Python 编程语言要掌握的技能之一:编写条件分支代码的技巧

    Python 里的分支代码 Python 支持最为常见的 if/else 条件分支语句,不过它缺少在其他编程语言中常见的 switch/case 语句. 除此之外,Python 还为 for/whil ...

  7. GPU随机采样速度比较

    技术背景 随机采样问题,不仅仅只是一个统计学/离散数学上的概念,其实在工业领域也都有非常重要的应用价值/潜在应用价值,具体应用场景我们这里就不做赘述.本文重点在于在不同平台上的采样速率,至于另外一个重 ...

  8. 分子动力学模拟之基于自动微分的LINCS约束

    技术背景 在分子动力学模拟的过程中,考虑到运动过程实际上是遵守牛顿第二定律的.而牛顿第二定律告诉我们,粒子的动力学过程仅跟受到的力场有关系,但是在模拟的过程中,有一些参量我们是不希望他们被更新或者改变 ...

  9. MindSpore多元自动微分

    技术背景 当前主流的深度学习框架,除了能够便捷高效的搭建机器学习的模型之外,其自动并行和自动微分等功能还为其他领域的科学计算带来了模式的变革.本文我们将探索如何用MindSpore去实现一个多维的自动 ...

  10. MindSpore尝鲜之Vmap功能

    技术背景 Vmap是一种在python里面经常提到的向量化运算的功能,比如之前大家常用的就是numba和jax中的向量化运算的接口.虽然numpy中也使用到了向量化的运算,比如计算两个numpy数组的 ...

随机推荐

  1. 每天那么多工作,我为什么能做到 "不忘事" ?

    大家好,我是程序员鱼皮. 我相信很多朋友都遇到过丢失工作.或者忘记事情的情况,尤其是事情一多,就更容易遗漏:而如果在工作中你漏掉了某项任务,需要上级或同事重复提醒你,是很影响别人对你的印象的. 那么如 ...

  2. Kubernetes 初体验

    在 DigitalOcean 创建一个 Kubernetes 集群 下载集群 Config 文件到 ~/.kube 目录 通过环境变量 KUBECONFIG 设置本地 kubectl 工具使用下载的配 ...

  3. uni-app 小程序 前置摄像头

    在小程序拍照的话,uni.chooseImage()可以直接调取摄像头拍照,而如果要调用前置摄像头,这个api就没有提供了. 在查找官方文档发现,可以通过camera有提供这么一个组件,页面内嵌的区域 ...

  4. Java并发编程学习前期知识上篇

    Java并发编程学习前期知识上篇 我们先来看看几个大厂真实的面试题: 从上面几个真实的面试问题来看,我们可以看到大厂的面试都会问到并发相关的问题.所以 Java并发,这个无论是面试还是在工作中,并发都 ...

  5. Unity 中 Color 与 Color32 的区别

    1. 存储方式 Color用四个浮点数(float)来表示RGBA,取值范围均是0到1 举例: var orange = new Color(1f, 0.5f, 0f, 1f); 而Color32使用 ...

  6. Vue SPA项目如何修改网站标题

    直接贴 门户项目代码 // 全局router 直接挂载路由导航守卫 router.beforeEach((to, from, next) => { if (to.meta.title) { va ...

  7. 前端使用 Konva 实现可视化设计器(23)- 绘制曲线、属性面板

    本章分享一下如何使用 Konva 绘制基础图形:曲线,以及属性面板的基本实现思路,希望大家继续关注和支持哈(多求 5 个 Stars 谢谢)! 请大家动动小手,给我一个免费的 Star 吧~ 大家如果 ...

  8. C++的并发编程历史

    多线程环境 并非所有的语言都提供了多线程的环境.即便是C++语言,直到C++11标准之前,也是没有多线程支持的. 在这种情况下,Linux/Unix平台下的开发者通常会使用POSIX Threads, ...

  9. eBPF 概述:第 1 部分:介绍

    1. 前言 有兴趣了解更多关于 eBPF 技术的底层细节?那么请继续移步,我们将深入研究 eBPF 的底层细节,从其虚拟机机制和工具,到在远程资源受限的嵌入式设备上运行跟踪. 注意:本系列博客文章将集 ...

  10. 2023年5月中国数据库排行榜:OTO组合回归育新机,华为高斯蓄势待发展雄心

    路漫漫其修远兮,吾将上下而求索. 2023年5月的 墨天轮中国数据库流行度排行 火热出炉,本月共有262个数据库参与排名.本月排行榜前十变动较大,可以用一句话概括为:openGauss 立足创新夺探花 ...