官方文档参考:

https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names

本篇post的主要讲解的是:

jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))



jax.device_put(x, mesh_sharding(P(('b', 'a'), None)))

的不同:

主机的四个CPU情况:

代码:

import os

import functools
from typing import Optional import numpy as np import jax
import jax.numpy as jnp from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding # Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((4,))) # Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(2, 2))
jax.debug.visualize_array_sharding(y)

运行结果:


jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))

代码:(行优先的方式展开GPU)

点击查看代码
from typing import Optional
import jax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils P = PartitionSpec devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('a', 'b')) from jax.sharding import PositionalSharding sharding = PositionalSharding(devices) x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 1)) devices = mesh_utils.create_device_mesh((2, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
if mesh is None:
mesh = default_mesh
return NamedSharding(mesh, pspec) y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)

运行结果:

jax.device_put(x, mesh_sharding(P(('b', 'a'), None)))

代码:(列优先的方式展开GPU)

点击查看代码
from typing import Optional
import jax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils P = PartitionSpec devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('a', 'b')) from jax.sharding import PositionalSharding sharding = PositionalSharding(devices) x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 1)) devices = mesh_utils.create_device_mesh((2, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
if mesh is None:
mesh = default_mesh
return NamedSharding(mesh, pspec) y = jax.device_put(x, mesh_sharding(P(('b', 'a'), None)))
jax.debug.visualize_array_sharding(y)

运行结果:

Jax计算框架的NamedSharding的reshape —— namedsharding-gives-a-way-to-express-shardings-with-names的更多相关文章

  1. Storm分布式实时流计算框架相关技术总结

    Storm分布式实时流计算框架相关技术总结 Storm作为一个开源的分布式实时流计算框架,其内部实现使用了一些常用的技术,这里是对这些技术及其在Storm中作用的概括介绍.以此为基础,后续再深入了解S ...

  2. Spark Streaming实时计算框架介绍

    随着大数据的发展,人们对大数据的处理要求也越来越高,原有的批处理框架MapReduce适合离线计算,却无法满足实时性要求较高的业务,如实时推荐.用户行为分析等. Spark Streaming是建立在 ...

  3. Storm实时计算框架的编程模式

    storm分布式流式计算框架. nimbus:主进程服务(职责就是任务的分配的,程序的分发) supervisor:工作进程服务(职责就是启动线程池,接受任务,运行任务,报告任务的运行状态) 注意容错 ...

  4. 开源图计算框架GraphLab介绍

    GraphLab介绍 GraphLab 是由CMU(卡内基梅隆大学)的Select 实验室在2010 年提出的一个基于图像处理模型的开源图计算框架.框架使用C++语言开发实现. 该框架是面向机器学习( ...

  5. 大数据计算框架Hadoop, Spark和MPI

    转自:https://www.cnblogs.com/reed/p/7730338.html 今天做题,其中一道是 请简要描述一下Hadoop, Spark, MPI三种计算框架的特点以及分别适用于什 ...

  6. (第4篇)hadoop之魂--mapreduce计算框架,让收集的数据产生价值

    摘要: 通过前面的学习,大家已经了解了HDFS文件系统.有了数据,下一步就要分析计算这些数据,产生价值.接下来我们介绍Mapreduce计算框架,学习数据是怎样被利用的. 博主福利 给大家赠送一套ha ...

  7. Dream_Spark-----Spark 定制版:005~贯通Spark Streaming流计算框架的运行源码

    Spark 定制版:005~贯通Spark Streaming流计算框架的运行源码   本讲内容: a. 在线动态计算分类最热门商品案例回顾与演示 b. 基于案例贯通Spark Streaming的运 ...

  8. 【流处理】Kafka Stream-Spark Streaming-Storm流式计算框架比较选型

    Kafka Stream-Spark Streaming-Storm流式计算框架比较选型 elasticsearch-head Elasticsearch-sql client NLPchina/el ...

  9. 【codenet】代码相似度计算框架调研 -- 把内容与形式分开

    首发于我的gitpages博客 https://helenawang.github.io/2018/10/10/代码相似度计算框架调研 代码相似度计算框架调研 研究现状 代码相似度计算是一个已有40年 ...

  10. Storm:分布式流式计算框架

    Storm是一个分布式的.高容错的实时计算系统.Storm适用的场景: Storm可以用来用来处理源源不断的消息,并将处理之后的结果保存到持久化介质中. 由于Storm的处理组件都是分布式的,而且处理 ...

随机推荐

  1. ABC319题解

    直接从 D 开始了. D 可可爱爱的二分捏. check 就按照题目里写的就行了. 然后 \(l\) 的初值要注意一下,就是 \(\max^{i \le n}_{i=1}a_i\). 代码: #inc ...

  2. 大量索引场景下 Easysearch 和 Elasticsearch 的吞吐量差异

    最近有客户在使用 Elasticsearch 搜索服务时发现集群有掉节点,并且有 master 收集节点信息超时的日志,节点的负载也很高,不只是 data 节点,master 和协调节点的 cpu 使 ...

  3. 蓝屏rtux64w10.sys

    蓝屏rtux64w10.sys 环境: WIN10 +  Realtek USB RTL8156B 2.5G网卡 表现: 局域网复制时,间隔性速度变为0,多次后,最终蓝屏. 解决方法: 更新驱动. 地 ...

  4. 手把手教你免费用Flashduty做消息通知

    为什么需要消息通知? 如果有重要的情况发生,希望能通过各种媒介通知我们.可以举几个例子: 家里燃气费没有了,希望能有短信或者app通知 api频繁500报错,希望及时感知,及时修复 公司网站是http ...

  5. 使用Kimi+Markmap总结网页内容生成思维导图

    AI可以帮助我们更高效地阅读文章进行提炼总结,像上图这张思维导图,就是使用Kimi进行内容提炼,再使用markmap生成思维导图,下面讲解下详细实现步骤: 一.工具准备 Kimi,将文章或一篇网页投给 ...

  6. C++判断当前程序是否运行在Windows展台(Kiosk)模式下

    Windows有一个展台(Kiosk)模式.展台模式可以使Windows作为数字标牌进行使用.具体请参考Windows 展台 配置完展台模式,重启设备后,Windows会以全屏的方式运行展台应用,无法 ...

  7. Simple WPF: WPF 透明窗体和鼠标事件穿透

    一个自定义WPF窗体的解决方案,借鉴了吕毅老师的WPF制作高性能的透明背景的异形窗口一文,并在此基础上增加了鼠标穿透的功能.可以使得透明窗体的鼠标事件穿透到下层,在下层窗体中响应. 这个方法不一定是制 ...

  8. PromQL全方位解读:监控与性能分析的关键技术

    本文全面探索PromQL,从基础语法到高级操作,详细介绍了数据聚合.时间序列分析及内置函数应用,旨在提升用户构建复杂监控策略和性能分析的能力. 关注[TechLeadCloud],分享互联网架构.云服 ...

  9. aach64架构 ubuntu20 桌面版 编译安装ffmpeg难点总结

    [编译安装x264] 这一步基本上没有难点 git clone https://gitee.com/mirrors/x264.git ./configure --enable-shared --ena ...

  10. 洛谷P1747

    这个题被坑麻了,题目居然不给棋盘的范围,评论区居然有人说棋盘是无限大的,我想说的是如果真是这样那么第9个点答案应该是2而不是3,这个棋盘绝对是有大小的. #include<iostream> ...