[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

0x00 摘要

在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第二篇,重点关注的是如何启动弹性训练,并且可以对系统总体架构有所了解。

弹性训练系列文章如下:

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

0x01 重要概念

为了更好的说明(这个说明可能在后面文章也会出现,因为太重要了),我们先总述一下TE 最重要的 Agent 和 Rendezvous 两个概念。

  • Agent :Agent是运行在单节点上的独立后台进程,可以认为是 worker manager 或者 process supervisor,其负责启动worker,监控 worker 运行,捕获woker异常,通过 rendezvous 实现 worker 间的相互发现(比如把状态上报到KVStore),成员变动时候基于 rendezvous 进行变更同步等等。
  • Rendezvous :为了实现弹性训练,需要有一个节点/进程之间彼此发现的机制。Rendezvous就是这个发现机制或者说同步组件。当系统启动或者成员变更时候,所有worker会(重新)集合(rendezvous)以建立一个新的进程组。

我们从源码中取出示意图看看,大家先有一个总体概念。

0x02 分布式运行

2.1 方式改变

2.1.1 原有方式

我们知道,PET是从 PyTorch v1.9 合并进来的,因为合并了弹性训练,所以分布式启动的方式有了很大的改变。

V1.9 之前是使用 torch/distributed/launch.py 进行启动,比如:

python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
--master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
and all other arguments of your training script)

此处参数含义是:

  • nnodes :是参与训练的节点数目。
  • nproc_per_node :每个节点上运行的进程数目。
  • node_rank :当前节点标识符。
  • master_addrmaster_port 是 master 监听的地址和端口。

当运行时,torch.distributed.launch 会设置一些环境变量,包括 world_sizemaster_addrmaster_port 等等。然后在当前机器上创建 nproc_per_node 个进程,这些进程构成了一个本地组。如果一共有 NODE_SIZE 个机器参与训练,则一共有 NODE_SIZE * TRAINERS_PER_NODE 个进程。如果想启动一个分布式训练任务,则需要在所有的机器上执行相关命令。

2.1.2 目前方式

PyTorch 1.9 使用 torch/distributed/run.py 进行启动。如果依然采用 torch/distributed/launch.py,其实其内部已经透传给 run.py,具体参见代码:

def main(args=None):
logger.warn(
"The module torch.distributed.launch is deprecated "
"and going to be removed in future."
"Migrate to torch.distributed.run"
)
args = parse_args(args)
run(args)

torch.distributed.run是之前torch.distributed.launch的一个超集,提供如下新功能:

  • 容错:通过重新启动所有workers,可以优雅地处理worker故障。
  • 自动:Worker 的RANKWORLD_SIZE 是自动分配的
  • 弹性:允许在最小值和最大值(弹性)之间更改节点数。

为了使用弹性训练,用户代码也需要做一些修改,如果用户的训练脚本已经支持 torch.distributed.launch ,则只需要修改几处就可以使用torch.distributed.run

  • 无需手动传递RANK , WORLD_SIZE , MASTER_ADDR 和 MASTER_PORT。
  • 必须提供rdzv_backendrdzv_endpoint。对于大多数用户来说,这其实就是“c10d”(参见“rendezvous“)。其实这就替代了之前的MASTER_ADDR 和 MASTER_PORT。
  • use_env 参数已被删除。请从 LOCAL_RANK 环境变量中获取local_rank (例如,os.environ["LOCAL_RANK"])。
  • 用户需要确保脚本中有 load_checkpoint(path)save_checkpoint(path) 逻辑,即手动处理Checkpoint。因为当worker失败时,我们将使用最近的checkpoint来恢复现场,重启所有worker。

下面是一个训练脚本的示例,该脚本在每个epoch上设置检查点,因此在失败时最差也只是会丢失一个epoch的训练成果。

  def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
initialize(state) # torch.distributed.run ensure that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend) for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model) state.epoch += 1
save_checkpoint(state)

所以,我们接下来看看在新模式之下,如何分布式启动。

2.2 部署

部署一般按照如下方式。

  1. (C10d后端不需要)启动 rendezvous 后端服务器,并获取端点(作为--rdzv_endpoint传递给启动程序脚本)
  2. 单节点多 worker:在主机上启动 launcher 以启动代理进程,代理会创建并监视本地工作组。
  3. 多节点多 worker:在所有节点上使用相同的参数启动 launcher 参加训练。

当使用作业/群集管理器时,多节点作业的入口点命令应为 launcher。

2.3 示例

我们首先通过几个例子来看看如何启动分布式训练。

2.3.1 单节点多worker启动

单节点多worker的启动方式如下,其实就是Standalone 模式,这是分布式模式的一种特例,具体就是针对单机多 Worker 提供了一些便利设置。

python -m torch.distributed.run
--standalone
--nnodes=1
--nproc_per_node=$NUM_TRAINERS
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

2.3.2 容错方式启动

如下是容错方式启动,固定数目workers,没有弹性训练。 --nproc_per_node=$NUM_TRAINERS 一般是 单节点上GPU 个数。

python -m torch.distributed.run
--nnodes=$NUM_NODES
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

HOST_NODE_ADDR, 的格式是: [:] ,指定了 C10d rendezvous 后端所运行的节点地址和端口,这个节点可以是训练集群中任意节点,但是最好找一个高带宽的节点。

2.3.3 弹性方式启动

下面是弹性训练,弹性区间为 (min=1, max=4)。通过指定rdzv参数,可以实现多机训练,具备容错与弹性能力

在多台机器上分别执行以下命令启动:最小节点数为MIN_SIZE,最大为MAX_SIZE,利用etcd服务实现一致性和信息同步。

python -m torch.distributed.run
--nnodes=1:4
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

HOST_NODE_ADDR, 的格式是: [:] ,指定了 C10d rendezvous 后端所运行的节点地址和端口,这个节点可以是训练集群中任意节点,但是最好找一个高带宽的节点。

关于 rendezvous backend,有几点说明:

对于多节点训练,需要指定:

  • --rdzv_id: 一个唯一的 job id,在参与job的所有节点之间共享。
  • --rdzv_backend: torch.distributed.elastic.rendezvous.RendezvousHandler 的一个实现。 (--rdzv_backend默认是static模式,不支持容错和弹性伸缩)
  • --rdzv_endpoint: rendezvous backend 所运行的 endpoint,通常格式为:host:port。就是取代了之前的 master address / port 设置。

目前,以下几种后端可以直接使用,c10d (推荐), etcd-v2, and etcd (legacy) 。为了使用 etcd-v2 或者 etcd,需要搭建一个 v2 api开启的 etcd server (即. --enable-v2)。

0x03 启动脚本

既然以上启动都是用 torch/distributed/run.py,所以我们仔细分析一下这个脚本,该脚本提供三个功能:

  • 依靠"重启所有 workers"来处理 worker 失败;

  • 自动分配 worker 的RANK and WORLD_SIZE

  • 弹性训练,即 node 数目允许在minimum和maximum之间改变;

3.1 参数定义

启动脚本中,一些参数定义如下:

  • Node - 物理实例或容器;映射到与 job manager 所协调的单元。
  • Worker - 分布式训练环境中的worker。
  • WorkerGroup - 执行相同功能的一组worker(例如trainers)。
  • LocalWorkerGroup - 在同一节点上运行的工作组中的workers子集。
    • 一个节点运行 LOCAL_WORLD_SIZE个workers,这些 workers 组成LocalWorkerGroup
    • 节点上所有LocalWorkerGroups组成WorkerGroups
  • RANK - 工作组中worker的rank,是全局rank,可以认为是一个全局GPU资源列表。
    • Rank是不稳定的,在重启之间,本地Workers 会被分配到不同的ranks,所以不要在代码中对RANKLOCAL_RANK的稳定性做任何假设和依赖编码。
    • rendezvous完成后,其所有成员将对工作成员资格以及每个人在其中的角色(role)达成共识。此角色(role)使用一个介于 0 ~ world size 之间的整型来表示,被称之为rank。
  • LOCAL_RANK - 本地工作组中,某个worker 的 rank,可以认为是当前节点上的GPU资源列表。
  • GROUP_RANK - worker group的rank。介于0和“最大节点数”之间的数字。如果每个节点运行一个单一工作组,那GROUP_RANK就是这个节点的rank。
  • ROLE_RANK - 对于具有相同角色worker来说,他们之间共享的rank,角色在“WorkerSpec”中被指定。
  • WORLD_SIZE - 工作组中worker的总数。因为节点会加入/离开,所以WORLD_SIZE会变化,不能依赖 WORLD_SIZE的稳定性进行编码。
  • LOCAL_WORLD_SIZE - 本地工作组的大小,即本地运行的worker数目,等于在torch.distributed.run运行时候指定的--nproc_per_node。目前,torch/distributed/run.py 仅支持同构的 LOCAL_WORLD_SIZE。也就是说,假设所有节点运行相同数量的本地工作者(每个角色)。
  • ROLE_WORLD_SIZE - 具有同样角色的workers总数,在 WorkerSpec之中被指定。
  • rdzv_id - 用户定义的id,用于唯一标识作业的工作组。这个id在每个节点加入特定工作组时候使用。
  • rdzv_backend-rendezvous 的后端(例如“c10d”)。这通常是一个强一致性的键值存储。
  • rdzv_endpoint - rendezvous 后端端点;通常以“<host>:<port>”的形式出现。
  • run_id: 用户定义的id,它唯一地标识分布式应用程序的一个实例。它通常映射到作业id并用于允许节点加入正确的分布式应用程序。
  • TORCHELASTIC_RUN_ID - 与 rendezvous run_id 相等,即唯一的job id。
  • TORCHELASTIC_RESTART_COUNT - 迄今为止,工作组重启的次数。
  • TORCHELASTIC_MAX_RESTARTS - 配置的最大重启数目。

3.2 相关函数/变量

为了更好的理解上面的参数,我们选取部分相关函数/变量看看。

world_size,rank

这两个变量是动态生成的,所以从 state 之中取出。

rank, world_size = self._get_world()

def _get_world(self) -> Tuple[int, int]:
state = self._state_holder.state
return state.participants[self._this_node], len(state.participants)

_pg_group_ranks

该全局变量存储了每个 group 的 global rank 到 local rank 映射信息。

# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}

其赋值举例如下:

# Create the global rank to group rank mapping
_pg_group_ranks[pg] = {
global_rank: group_rank
for group_rank, global_rank in enumerate(ranks)
}

group_rank

我们可以利用 global rank 从 _pg_group_ranks 之中提取对应的 local rank。

def _get_group_rank(group: ProcessGroup, rank):
"""
Helper that gets a given group's local rank in the group from a given global
rank.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
"rank mapping")
if group not in _pg_group_ranks:
raise RuntimeError("The given group does not exist")
try:
group_rank = _pg_group_ranks[group][rank]
except KeyError:
raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
return group_rank

global_rank

我们可以利用一个 group 的 local rank 获取到其 gloabl rank。

def _get_global_rank(group, group_rank):
"""
Helper that gets a given group's global rank from a given local rank in the
group.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
"rank mapping")
group_rank_map = _pg_group_ranks[group]
for rank, grp_rank in group_rank_map.items():
if grp_rank == group_rank:
return rank
raise RuntimeError("The group rank is not part of the group")

group_size

我们可以 _get_group_size 获取到某一个group 的大小。

def _get_group_size(group):
"""
Helper that gets a given group's world size.
"""
if group is GroupMember.WORLD or group is None:
default_pg = _get_default_group()
return default_pg.size()
if group not in _pg_group_ranks:
raise RuntimeError("The given group does not exist")
return len(_pg_group_ranks[group])

nproc_per_node

这个变量可以得到每个node之上支持多少个进程。

def determine_local_world_size(nproc_per_node: str):
try:
logging.info(f"Using nproc_per_node={nproc_per_node}.")
return int(nproc_per_node)
except ValueError:
if nproc_per_node == "cpu":
num_proc = os.cpu_count()
device_type = "cpu"
elif nproc_per_node == "gpu":
if not torch.cuda.is_available():
raise ValueError("Cuda is not available.")
device_type = "gpu"
num_proc = torch.cuda.device_count()
elif nproc_per_node == "auto":
if torch.cuda.is_available():
num_proc = torch.cuda.device_count()
device_type = "gpu"
else:
num_proc = os.cpu_count()
device_type = "cpu"
else:
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}")
)
return num_proc

3.3 脚本入口

脚本入口主要代码如下,可以看到,其调用到了 elastic_launch 来完成功能,所以我们下一节就要顺藤摸瓜来看看这个函数。

from torch.distributed.launcher.api import LaunchConfig, elastic_launch

def run(args):
if args.standalone: # 有两种模式:Standalone 模式和分布式模式,这里要判断一下
args.rdzv_backend = "c10d"
args.rdzv_endpoint = "localhost:29400"
args.rdzv_id = str(uuid.uuid4())
log.info(
f"\n**************************************\n"
f"Rendezvous info:\n"
f"--rdzv_backend={args.rdzv_backend} "
f"--rdzv_endpoint={args.rdzv_endpoint} "
f"--rdzv_id={args.rdzv_id}\n"
f"**************************************\n"
) config, cmd, cmd_args = config_from_args(args)
elastic_launch(
config=config,
entrypoint=cmd,
)(*cmd_args) def main(args=None):
args = parse_args(args)
run(args) if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
)
main()

0x04 单体总体流程

我们下面就从 elastic_launch 开始,看看在单节点上如何启动运行。我们首先给出一个总体示意图,图上是两个节点,每个节点有一个 agent,agent下面是一个 worker group,组下面是4个worker。

4.1 小例子

我们再从源码中找一个例子来看看,这里只是设置了两个workers。

import uuid
import torch
from torch.distributed.launcher.api import LaunchConfig, elastic_launch def worker_fn(t1, t2):
return torch.add(t1, t2) def main():
t1 = torch.rand((3,3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True) config = LaunchConfig(
min_nodes=2,
max_nodes=4,
nproc_per_node=1,
run_id=str(uuid.uuid4()),
role="trainer",
rdzv_endpoint="localhost:29400",
rdzv_backend="c10d",
max_restarts=1,
monitor_interval=1,
start_method="spawn",
) outputs = elastic_launch(config, worker_fn)(t1, t2) if __name__ == '__main__':
main()

输出如下,可以看到有两个 worker 进程 和一个 agent 进程。

{"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 0, "group_rank": 0, "worker_id": "12172", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [0], \"role_rank\": [0], \"role_world_size\": [2]}", "agent_restarts": 0}}

{"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 1, "group_rank": 0, "worker_id": "3276", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [1], \"role_rank\": [1], \"role_world_size\": [2]}", "agent_restarts": 0}}

{"name": "torchelastic.worker.status.SUCCEEDED", "source": "AGENT", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": null, "group_rank": 0, "worker_id": null, "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\"}", "agent_restarts": 0}}

4.2 入口

顺着代码我们深入挖掘一下。elastic_launch 的作用就是启动一个 torchelastic agent,然后通过这个 agent来调用用户程序入口,agent 会启动 worker 进行训练,并且管理 worker 生命周期

class elastic_launch:
"""
Launches an torchelastic agent on the container that invoked the entrypoint. 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
``entrypoint`` can be a function or a command.
2. The return value is a map of each worker's output mapped
by their respective global rank.
""" def __init__(
self,
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
):
self._config = config
self._entrypoint = entrypoint def __call__(self, *args, **kwargs):
return launch_agent(self._config, self._entrypoint, list(args)) # 内部会调用用户程序

4.3 启动代理

launch_agent 启动了一个 LocalElasticAgent,调用了其 run 方法。

@record
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
if not config.run_id:
run_id = str(uuid.uuid4().int)
config.run_id = run_id entrypoint_name = _get_entrypoint_name(entrypoint, args) rdzv_parameters = RendezvousParameters(
backend=config.rdzv_backend,
endpoint=config.rdzv_endpoint,
run_id=config.run_id,
min_nodes=config.min_nodes,
max_nodes=config.max_nodes,
**config.rdzv_configs,
) agent = None
rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
try:
spec = WorkerSpec( # 1. 得到spec
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_handler, # RendezvousHandler
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
redirects=config.redirects,
tee=config.tee,
master_addr=master_addr,
master_port=master_port,
) cfg = metrics.MetricsConfig(config.metrics_cfg) if config.metrics_cfg else None
metrics.initialize_metrics(cfg) agent = LocalElasticAgent( # 2. 构建代理
spec=spec, start_method=config.start_method, log_dir=config.log_dir
) result = agent.run() # 3. 启动代理
events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
if result.is_failed():
# ChildFailedError is treated specially by @record
# if the error files for the failed children exist
# @record will copy the first error (root cause)
# to the error file of the launcher process.
raise ChildFailedError(
name=entrypoint_name,
failures=result.failures,
)
else:
return result.return_values
except ChildFailedError:
raise
except Exception:
if agent:
events.record(agent.get_agent_status_event(WorkerState.FAILED))
else:
events.record(_construct_event(config))
raise
finally:
rdzv_handler.shutdown()

这里有几个关键点:

4.3.1 WorkerSpec

WorkerSpec :这是配置信息,里面包含了代理所需要的某些全局信息,比如 RendezvousHandler,role,entry(用户函数)。

spec = {WorkerSpec}
args = {tuple: 2} (tensor, tensor)
fn = {NoneType} None
local_world_size = {int} 1
master_addr = {NoneType} None
master_port = {NoneType} None
max_restarts = {int} 1
monitor_interval = {int} 1
rdzv_handler = {DynamicRendezvousHandler}
redirects = {Std} Std.NONE
role = {str} 'trainer'
tee = {Std} Std.NONE
entry = worker_fn

代理会从这里提取各种所需信息。比如_start_workers 会从中获取 store。

use_agent_store = spec.rdzv_handler.get_backend() == "static"

此时逻辑为:

+--------------------------+      +---------------------------------------------------+
|LocalElasticAgent | | WorkerSpec |
| | | |
| WorkerSpec +--------------> | rdzv_handler = {DynamicRendezvousHandler} --------+
| | | | |
| rdzv_run_id | | entry = worker_fn | |
| | | | |
| store | | role = {str} 'trainer' | |
| | | | |
| | +---------------------------------------------------+ |
| | |
| | |
| | |
| | |
| | +-----------------------------------------+ |
+--------------------------+ |DynamicRendezvousHandler | |
| | |
| | |
| _settings: RendezvousSettings | <---+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+

4.3.2 WorkerGroup

WorkerGroup 代表了一个工作组。WorkerGroup 作为一个整体来管理多个 workers,进行批量处理。

class WorkerGroup:
"""
Represents the set of ``Worker`` instances for the given ``WorkerSpec``
managed by ``ElasticAgent``. Whether the worker group contains cross
instance workers or not depends on the implementation of the agent.
""" __slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state"] def __init__(self, spec: WorkerSpec):
self.spec = spec
self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)] # assigned after rdzv
self.store = None
self.group_rank = None
self.group_world_size = None self.state = WorkerState.INIT

在SimpleElasticAgent 初始化之中,会建立一个 WorkerGroup。

class SimpleElasticAgent(ElasticAgent):
"""
An ``ElasticAgent`` that manages workers (``WorkerGroup``)
for a single ``WorkerSpec`` (e.g. one particular type of worker role).
""" def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
self._worker_group = WorkerGroup(spec)
self._remaining_restarts = self._worker_group.spec.max_restarts
self._store = None
self._exit_barrier_timeout = exit_barrier_timeout
self._total_execution_time = 0

具体如下:

+-----------------------------+      +------------------------------------------------+
| LocalElasticAgent | | WorkerSpec |
| | | |
| +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} -------+
| |WorkerGroup | | | | |
| | spec +--------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} 'trainer' | |
| | group_rank | | | | |
| | group_world_size | | +------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| | |
| rdzv_run_id | |
| store | +-----------------------------------------+ |
| | |DynamicRendezvousHandler | |
+-----------------------------+ | | |
| | |
| _settings: RendezvousSettings | <--+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+

4.4 代理运行

SimpleElasticAgent 是 LocalElasticAgent 的基类,所以会先运行到WorkerSpec.run 方法这里,run方法则调用了 _invoke_run。

    @prof
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
start_time = time.monotonic()
try:
result = self._invoke_run(role) # 调用
self._total_execution_time = int(time.monotonic() - start_time)
self._record_metrics(result)
self._record_worker_events(result)
return result
finally:
# record the execution time in case there were any exceptions during run.
self._total_execution_time = int(time.monotonic() - start_time)
self._shutdown()

4.5 代理主循环

代理在 invoke_run 之中做如下操作:

  • 启动 _initialize_workers,这里会使用 _rendezvous 构建一个 rendezvous,然后调用 _start_workers 启动 workers。
  • 进入 while True 循环,在循环之中:
    • 通过 _monitor_workers 定期轮训用户程序运行情况,得到客户进程运行结果,然后依据情况作出判断。

      • 如果程序正常结束,则返回。
      • 如果程序出错,则重试,即重启所有 workers,如果重试次数达到依然有问题,就结束所有workers。
      • 如果节点成员关系有变化,比如scale up就会有新的节点在waiting,这时候就重启所有workers。
    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# NOTE: currently only works for a single role spec = self._worker_group.spec
role = spec.role self._initialize_workers(self._worker_group) # 启动worker
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler while True:
assert self._worker_group.state != WorkerState.INIT
# 定期监控
time.sleep(monitor_interval)
# 监控客户程序运行情况
run_result = self._monitor_workers(self._worker_group) # 得到进程运行结果
state = run_result.state
self._worker_group.state = state put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1) if state == WorkerState.SUCCEEDED:
# 程序正常结束
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
# 程序出错
if self._remaining_restarts > 0: # 重试
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group) # 重试次数达到,结束workers
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# 节点成员关系有变化,比如scale up,就会有新节点waiting
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
# 如果有新的节点在waiting,就重启所有workers
if num_nodes_waiting > 0:
self._restart_workers(self._worker_group)
else:
raise Exception(f"[{role}] Worker group in {state.name} state")

于是最终逻辑如下:

+----------------------------------------------+
| LocalElasticAgent |
| | +---------------------------------------------------+
| rdzv_run_id | | WorkerSpec |
| | | |
| store +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} +-------+
| |WorkerGroup | | | | |
| _pcontext | spec +------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} 'trainer' | |
| | group_rank | | | | |
| | group_world_size | | +---------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| +----------------------------------------+ | |
| | _invoke_run | | |
| | | | +-----------------------------------------+ |
| | _initialize_workers +------------------------+ |DynamicRendezvousHandler | |
| | | | | | | |
| | | | | | | |
| | while True: | | | | _settings: RendezvousSettings | <---+
| | _monitor_workers(_worker_group) | | | | |
| | + | | | | _store: Store |
| | | _pcontext.wait | | | | |
| | | | | | | _state_holder: _RendezvousStateHolder |
| +----------------------------------------+ | | | |
| | | | | _op_executor: _RendezvousOpExecutor |
+----------------------------------------------+ | | |
| | +-----------------------------------------+
| |
v v
+-------------------------------------------------+
| +------------+ +------------+ +------------+ |
| |Process | |Process | |Process | |
| | | | | | | |
| | work_fn | | work_fn | | work_fn | |
| | | | | | | |
| +------------+ +------------+ +------------+ |
+-------------------------------------------------+

手机如下:

至此,脚本如何启动和单体流程我们分析完毕,下一篇我们来具体分析代理。

0xFF 参考

[PyTorch Elastic源码阅读](

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程的更多相关文章

  1. [源码解析] PyTorch 分布式之弹性训练(3)---代理

    [源码解析] PyTorch 分布式之弹性训练(3)---代理 目录 [源码解析] PyTorch 分布式之弹性训练(3)---代理 0x00 摘要 0x01 总体背景 1.1 功能分离 1.2 Re ...

  2. [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

    [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑 目录 [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑 0x00 ...

  3. [源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

    [源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎 目录 [源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎 0x00 摘要 0x0 ...

  4. [源码解析] PyTorch 分布式之弹性训练(6)---监控/容错

    [源码解析] PyTorch 分布式之弹性训练(6)---监控/容错 目录 [源码解析] PyTorch 分布式之弹性训练(6)---监控/容错 0x00 摘要 0x01 总体逻辑 1.1 Node集 ...

  5. [源码解析] PyTorch 分布式之弹性训练(7)---节点变化

    [源码解析] PyTorch 分布式之弹性训练(7)---节点变化 目录 [源码解析] PyTorch 分布式之弹性训练(7)---节点变化 0x00 摘要 0x01 变化方式 1.1 Scale-d ...

  6. [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

    [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路 目录 [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路 0x00 摘要 0x01 痛点 0x02 难点 0 ...

  7. [源码解析] PyTorch 分布式(1)------历史和概述

    [源码解析] PyTorch 分布式(1)------历史和概述 目录 [源码解析] PyTorch 分布式(1)------历史和概述 0x00 摘要 0x01 PyTorch分布式的历史 1.1 ...

  8. [源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用

    [源码解析] PyTorch 分布式(5) ------ DistributedDataParallel 总述&如何使用 目录 [源码解析] PyTorch 分布式(5) ------ Dis ...

  9. [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行

    [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行 目录 [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行 0x00 摘要 0x0 ...

随机推荐

  1. 【豆科基因组】豇豆Cowpea,Vigna unguiculata [L.] Walp.基因组2019PJ

    目录 来源 结果 基因组大小估计 采用stitching方法组装 修改豇豆染色体编号 基因注释和重复DNA 豇豆遗传多样性 SNP和INDEL Vu03 上 4.2-Mb 染色体倒位的鉴定 与其他暖季 ...

  2. 59. Divide Two Integers

    Divide Two Integers My Submissions QuestionEditorial Solution Total Accepted: 66073 Total Submission ...

  3. 完美png图片添加水印类

    完美png图片添加水印类 被添加水印图片和水印图片都可以是png,保证透明无色背景,可调节透明度 <?phpclass Imgshuiyin{ /* 缩略图相关常量定义 */ const THU ...

  4. 什么是DDL,DML,DCL

    转载自  https://www.2cto.com/database/201610/555167.html DML.DDL.DCL区别 . 总体解释: DML(data manipulation la ...

  5. Shell $()、${}、$[]、$(())

    目录 Shell中的 $().${}.$[].$(()) $().${} 替换 ${} 变量内容的替换.删除.取代 数组 $[].$(()) 运算符 Shell中的 $().${}.$[].$(()) ...

  6. 浅讲.Net 6 之 WebApplicationBuilder

    介绍 .Net 6为我们带来的一种全新的引导程序启动的方式.与之前的拆分成Program.cs和Startup不同,整个引导启动代码都在Program.cs中. WebApplicationBuild ...

  7. A Child's History of England.21

    There was one tall Norman Knight who rode before the Norman army on a prancing horse, throwing up hi ...

  8. [云原生]Docker - 镜像

    目录 Docker镜像 获取镜像 列出本地镜像 创建镜像 方法一:修改已有镜像 方法二:通过Dockerfile构建镜像 方法三:从本地文件系统导入 上传镜像 保存和载入镜像 移除本地镜像 镜像的实现 ...

  9. 容器的分类与各种测试(二)——vector部分用法

    向量 vector 是一种对象实体, 能够容纳许多其他类型相同的元素, 因此又被称为容器. 与string相同, vector 同属于STL(Standard Template Library, 标准 ...

  10. 【swift】CoreData Crash(崩溃)(Failed to call designated initializer on NSManagedObject class)

    感谢另一篇博客:https://blog.csdn.net/devday/article/details/6577985 里面的图片和介绍,发现问题如他描述的一样,没有bundle 我的Xcode版本 ...