现在有四张卡,但是部署在windows10系统上,想尝试下在windows上使用单机多卡进行分布式训练,网上找了一圈硬是没找到相关的文章。以下是踩坑过程。

首先,pytorch的版本必须是大于1.7,这里使用的环境是:

pytorch==1.12+cu11.6
四张4090显卡
python==3.7.6

使用nn.DataParallel进行分布式训练

这一种方式较为简单:

首先我们要定义好使用的GPU的编号,GPU按顺序依次为0,1,2,3。gpu_ids可以通过命令行的形式传入:

gpu_ids = args.gpu_ids.split(',')
gpu_ids = [int(i) for i in gpu_ids]
torch.cuda.set_device('cuda:{}'.format(gpu_ids[0]))

创建模型后用nn.DataParallel进行处理,

 model.cuda()
r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])

对,没错,只需要这么两步就行了。需要注意的是保存模型后进行加载时,需要先用nn.DataParallel进行处理,再加载权重,不然参数名没对齐会报错。

checkpoint = torch.load(checkpoint_path)
model.cuda()
r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])
r_model.load_state_dict(checkpoint['state_dict'])

如果不使用分布式加载模型,你需要对权重进行映射:

new_start_dict = {}
for k, v in checkpoint['state_dict'].items():
new_start_dict["module." + k] = v
model.load_state_dict(new_start_dict)

使用Distributed进行分布式训练

首先了解一下概念:

node:主机数,单机多卡就一个主机,也就是1。

rank:当前进程的序号,用于进程之间的通讯,rank=0的主机为master节点。

local_rank:当前进程对应的GPU编号。

world_size:总的进程数。

在windows中,我们需要在py文件里面使用:

import os
os.environ["CUDA_VISIBLE_DEVICES]='0,1,3'

来指定使用的显卡。

假设现在我们使用上面的三张显卡,运行时显卡会重新按照0-N进行编号,有:

[38664] rank = 1, world_size = 3, n = 1, device_ids = [1]
[76032] rank = 0, world_size = 3, n = 1, device_ids = [0]
[23208] rank = 2, world_size = 3, n = 1, device_ids = [2]

也就是进程0使用第1张显卡,进行1使用第2张显卡,进程2使用第三张显卡。

有了上述的基本知识,再看看具体的实现。

使用torch.distributed.launch启动

使用torch.distributed.launch启动时,我们必须要在args里面添加一个local_rank参数,也就是:

parser.add_argument("--local_rank", type=int, default=0)

1、初始化:

import torch.distributed as dist

env_dict = {
key: os.environ[key]
for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
}
current_work_dir = os.getcwd()
init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]),
world_size=int(env_dict["WORLD_SIZE"]))

这里需要重点注意,这种启动方式在环境变量中会存在RANK和WORLD_SIZE,我们可以拿来用。backend必须指定为gloo,init_method必须是file:///,而且每次运行完一次,下一次再运行前都必须删除生成的ddp_example,不然会一直卡住。

2、构建模型并封装

local_rank会自己绑定值,不再是我们--local_rank指定的。

 model.cuda(args.local_rank)
r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)

3、构建数据集加载器并封装

  train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file))
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size,
collate_fn=collate.collate_fn, num_workers=4, sampler=train_sampler)

4、计算损失函数

我们把每一个GPU上的loss进行汇聚后计算。

def loss_reduce(self, loss):
rt = loss.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= self.args.local_world_size
return rt loss = self.criterion(outputs, labels)
torch.distributed.barrier()
loss = self.loss_reduce(loss)

注意打印相关信息和保存模型的时候我们通常只需要在local_rank=0时打印。同时,在需要将张量转换到GPU上时,我们需要指定使用的GPU,通过local_rank指定就行,即data.cuda(args.local_rank),保证数据在对应的GPU上进行处理。

5、启动

在windows下需要把换行符去掉,且只变为一行。

python -m torch.distributed.launch \
--nnode=1 \
--node_rank=0 \
--nproc_per_node=3 \
main_distributed.py \
--local_world_size=3 \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

nproc_per_node、local_world_size和GPU的数目保持一致。

使用torch.multiprocessing启动

使用torch.multiprocessing启动和使用torch.distributed.launch启动大体上是差不多的,有一些地方需要注意。

mp.spawn(main_worker, nprocs=args.nprocs, args=(args,))

main_worker是我们的主运行函数,dist.init_process_group要放在这里面,而且第一个参数必须为local_rank。即main_worker(local_rank, args)

nprocs是进程数,也就是使用的GPU数目。

args按顺序传入main_worker真正使用的参数。

其余的就差不多。

启动指令:

python main_mp_distributed.py \
--local_world_size=4 \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

最后需要说明的,假设我们设置的batch_size=64,那么实际上的batch_size = int(batch_size / GPU数目)。

附上完整的基于bert的中文文本分类单机多卡训练代码:https://github.com/taishan1994/pytorch_bert_chinese_text_classification

参考

https://github.com/tczhangzhi/pytorch-distributed

https://murphypei.github.io/blog/2020/09/pytorch-distributed

https://pytorch.org/docs/master/distributed.html?highlight=all_gather#torch.distributed.all_gather

https://github.com/lesliejackson/PyTorch-Distributed-Training

https://github.com/pytorch/examples/blob/ddp-tutorial-code/distributed/ddp/example.py

996黄金一代:[原创][深度][PyTorch] DDP系列第一篇:入门教程

「新生手册」:PyTorch分布式训练 - 知乎 (zhihu.com)

windows下使用pytorch进行单机多卡分布式训练的更多相关文章

  1. windows下安装pytorch

    安装: https://blog.csdn.net/xiangxianghehe/article/details/80103095 Windows下通过pip安装PyTorch 0.4.0 impor ...

  2. 云原生的弹性 AI 训练系列之二:PyTorch 1.9.0 弹性分布式训练的设计与实现

    背景 机器学习工作负载与传统的工作负载相比,一个比较显著的特点是对 GPU 的需求旺盛.在之前的文章中介绍过(https://mp.weixin.qq.com/s/Nasm-cXLtJObjLwLQH ...

  3. Pytorch使用分布式训练,单机多卡

    pytorch的并行分为模型并行.数据并行 左侧模型并行:是网络太大,一张卡存不了,那么拆分,然后进行模型并行训练. 右侧数据并行:多个显卡同时采用数据训练网络的副本. 一.模型并行 二.数据并行 数 ...

  4. Windows 下单机最大TCP连接数

    在做Socket 编程时,我们经常会要问,单机最多可以建立多少个 TCP 连接,本文将介绍如何调整系统参数来调整单机的最大TCP连接数. Windows 下单机的TCP连接数有多个参数共同决定,下面一 ...

  5. windows下配置cuda9.0和pytorch

    今天看了看pytorch官网竟然支持windows了,赶紧搞一个. 下载cuda 9.0  https://developer.nvidia.com/cuda-downloads 下载anaconda ...

  6. PyTorch在64位Windows下的Conda包(转载)

    PyTorch在64位Windows下的Conda包 昨天发了一篇PyTorch在64位Windows下的编译过程的文章,有朋友觉得能不能发个包,这样就不用折腾了.于是,这个包就诞生了.感谢@晴天14 ...

  7. redis在Windows下以后台服务一键搭建哨兵(主从复制)模式(单机)

    redis在Windows下以后台服务一键搭建哨兵(主从复制)模式(单机) 一.概述 此教程介绍如何在windows系统中单机布置redis哨兵模式(主从复制),同时要以后台服务的模式运行.布置以脚本 ...

  8. redis在Windows下以后台服务一键搭建集群(单机--伪集群)

    redis在Windows下以后台服务一键搭建集群(单机--伪集群) 一.概述 此教程介绍如何在windows系统中同一台机器上布置redis伪集群,同时要以后台服务的模式运行.布置以脚本的形式,一键 ...

  9. windows下PyTorch安装之路记录

    最近两天被windows下pytorch的安装给搞得很烦了,不过在今天终于安装成功了,如下图所示 下面详细说下此次安装的详细记录吧.我的电脑环境是Windows10+cuda9.0+cudnn7.1. ...

  10. Windows下nacos单机部分发现的坑

    一.下载nacos的地址: https://github.com/alibaba/nacos/releases 下载 nacos-server-1.3.2.tar.gz    就好 二.在Window ...

随机推荐

  1. JavaSSM

    Day1221 一.IT行业分类 前端 用户界面,眼睛能看到的,视觉效果比较. html5.css和css3.javascript.jquery.技术基础 bootstrap(css框架).vue.j ...

  2. swift 应用内切换语言

    1:在project info中的locations添加需要的语言 2:创建Localizable.strings文件 点击右边的localization勾选需要的语言 3:创建InfoPlist.s ...

  3. Abp Abp.AspNetZeroCore 2.0.0 2.1.1 Path

    纯手工修改,移除校验代码可调试. 将文件复制到 %userprofile%\.nuget\packages\abp.aspnetzerocore 目录中 替换对应的文件 Abp.AspNetZeroC ...

  4. Date 对象 定时器

    日期对象 Date 概述:date是表示日期时间的对象,主要的方法是获取时间和设置日期时间. date声明 使用new Date声明 有4种方式 1.不设参数 是获取当前的本地时间 var date ...

  5. 【python】第一模块 步骤五 第一课、内存管理机制

    第一课.内存管理机制 一.课程介绍 1.1 课程概要 课程概要 赋值语句的内存分析 垃圾回收机制 内存管理机制 课程目标 掌握赋值语句内存分析方法 掌握id()和is()的使用 了解python的垃圾 ...

  6. 2003031118—李伟—Python数据分析第七周作业—MySQL的安装以及使用

    项目    MySQL的安装以及使用 课程班级博客链接 20级数据班(本) 这个作业要求链接 作业要求 博客名称 2003031118-李伟-Python数据分析第七周作业-MySQL的安装以及使用 ...

  7. 通过富文本编辑器操作HTML页面

    <pre id="list_css" class="brush:css;toolbar:false">/*外部css,多个换行*/ https:// ...

  8. openvas漏洞扫描:使用openvas时扫描漏洞时,报告中显示的数据与数据库数据不同

    使用openvas设备进行漏洞扫描时,报告中的漏洞数量与readis数据库中查找到的漏洞数量不同 原因是,openvas的代码中默认在报告中显示的最小质量检测为70%.如图: 上图详细链接为:http ...

  9. OSIDP-内存管理-07

    专业术语 页框:内存中固定长度的块. 页:外存中固定长度的块. 段:外存中可变长度的块. 内存管理需求 重定位:程序从内存换出到外存后,再换回内存时,在内存空间中的位置和原先的位置有极大可能不相同.此 ...

  10. K8S中Pod概念

    一.资源限制 Pod 是 kubernetes 中最小的资源管理组件,Pod 也是最小化运行容器化应用的资源对象.一个 Pod 代表着集群中运行的一个进程.kubernetes 中其他大多数组件都是围 ...