分布式机器学习中,参数服务器(Parameter Server)用于管理和共享模型参数,其基本思想是将模型参数存储在一个或多个中央服务器上,并通过网络将这些参数共享给参与训练的各个计算节点。每个计算节点可以从参数服务器中获取当前模型参数,并将计算结果返回给参数服务器进行更新。

为了保持模型一致性,通常采用下列两种方法:

  1. 将模型参数保存在一个集中的节点上,当一个计算节点要进行模型训练时,可从集中节点获取参数,进行模型训练,然后将更新后的模型推送回集中节点。由于所有计算节点都从同一个集中节点获取参数,因此可以保证模型一致性。
  2. 每个计算节点都保存模型参数的副本,因此要定期强制同步模型副本,每个计算节点使用自己的训练数据分区来训练本地模型副本。在每个训练迭代后,由于使用不同的输入数据进行训练,存储在不同计算节点上的模型副本可能会有所不同。因此,每一次训练迭代后插入一个全局同步的步骤,这将对不同计算节点上的参数进行平均,以便以完全分布式的方式保证模型的一致性,即All-Reduce范式

PS架构

在该架构中,包含两个角色:parameter server和worker

parameter server将被视为master节点在Master/Worker架构,而worker将充当计算节点负责模型训练

整个系统的工作流程分为4个阶段:

  1. Pull Weights: 所有worker从参数服务器获取权重参数
  2. Push Gradients: 每一个worker使用本地的训练数据训练本地模型,生成本地梯度,之后将梯度上传参数服务器
  3. Aggregate Gradients:收集到所有计算节点发送的梯度后,对梯度进行求和
  4. Model Update:计算出累加梯度,参数服务器使用这个累加梯度来更新位于集中服务器上的模型参数

可见,上述的Pull Weights和Push Gradients涉及到通信,首先对于Pull Weights来说,参数服务器同时向worker发送权重,这是一对多的通信模式,称为fan-out通信模式。假设每个节点(参数服务器和工作节点)的通信带宽都为1。假设在这个数据并行训练作业中有N个工作节点,由于集中式参数服务器需要同时将模型发送给N个工作节点,因此每个工作节点的发送带宽(BW)仅为1/N。另一方面,每个工作节点的接收带宽为1,远大于参数服务器的发送带宽1/N。因此,在拉取权重阶段,参数服务器端存在通信瓶颈。

对于Push Gradients来说,所有的worker并发地发送梯度给参数服务器,称为fan-in通信模式,参数服务器同样存在通信瓶颈。

基于上述讨论,通信瓶颈总是发生在参数服务器端,将通过负载均衡解决这个问题

将模型划分为N个参数服务器,每个参数服务器负责更新1/N的模型参数。实际上是将模型参数分片(sharded model)并存储在多个参数服务器上,可以缓解参数服务器一侧的网络瓶颈问题,使得参数服务器之间的通信负载减少,提高整体的通信效率。

代码实现

定义网络结构:

class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu") self.conv1 = nn.Conv2d(1,32,3,1).to(device)
self.dropout1 = nn.Dropout2d(0.5).to(device)
self.conv2 = nn.Conv2d(32,64,3,1).to(device)
self.dropout2 = nn.Dropout2d(0.75).to(device)
self.fc1 = nn.Linear(9216,128).to(device)
self.fc2 = nn.Linear(128,20).to(device)
self.fc3 = nn.Linear(20,10).to(device) def forward(self,x):
x = self.conv1(x)
x = self.dropout1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.dropout2(x)
x = F.max_pool2d(x,2)
x = torch.flatten(x,1) x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x) output = F.log_softmax(x,dim=1) return output

如上定义了一个简单的CNN

实现参数服务器:

class ParamServer(nn.Module):
def __init__(self):
super().__init__()
self.model = Net() if torch.cuda.is_available():
self.input_device = torch.device("cuda:0")
else:
self.input_device = torch.device("cpu") self.optimizer = optim.SGD(self.model.parameters(),lr=0.5) def get_weights(self):
return self.model.state_dict() def update_model(self,grads):
for para,grad in zip(self.model.parameters(),grads):
para.grad = grad self.optimizer.step()
self.optimizer.zero_grad()

get_weights获取权重参数,update_model更新模型,采用SGD优化器

实现worker:

class Worker(nn.Module):
def __init__(self):
super().__init__()
self.model = Net()
if torch.cuda.is_available():
self.input_device = torch.device("cuda:0")
else:
self.input_device = torch.device("cpu") def pull_weights(self,model_params):
self.model.load_state_dict(model_params) def push_gradients(self,batch_idx,data,target):
data,target = data.to(self.input_device),target.to(self.input_device)
output = self.model(data)
data.requires_grad = True
loss = F.nll_loss(output,target)
loss.backward()
grads = [] for layer in self.parameters():
grad = layer.grad
grads.append(grad) print(f"batch {batch_idx} training :: loss {loss.item()}") return grads

Pull_weights获取模型参数,push_gradients上传梯度

训练

训练数据集为MNIST

import torch
from torchvision import datasets,transforms from network import Net
from worker import *
from server import * train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True,
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=False,
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=128, shuffle=True) def main():
server = ParamServer()
worker = Worker() for batch_idx, (data,target) in enumerate(train_loader):
params = server.get_weights()
worker.pull_weights(params)
grads = worker.push_gradients(batch_idx,data,target)
server.update_model(grads) print("Done Training") if __name__ == "__main__":
main()

分布式机器学习(Parameter Server)的更多相关文章

  1. 分布式机器学习系统笔记(一)——模型并行,数据并行,参数平均,ASGD

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 文章索引::"机器学 ...

  2. 百度DMLC分布式深度机器学习开源项目(简称“深盟”)上线了如xgboost(速度快效果好的Boosting模型)、CXXNET(极致的C++深度学习库)、Minerva(高效灵活的并行深度学习引擎)以及Parameter Server(一小时训练600T数据)等产品,在语音识别、OCR识别、人脸识别以及计算效率提升上发布了多个成熟产品。

    百度为何开源深度机器学习平台?   有一系列领先优势的百度却选择开源其深度机器学习平台,为何交底自己的核心技术?深思之下,却是在面对业界无奈时的远见之举.   5月20日,百度在github上开源了其 ...

  3. 【分布式计算】MapReduce的替代者-Parameter Server

    原文:http://blog.csdn.net/buptgshengod/article/details/46819051 首先还是要声明一下,这个文章是我在入职阿里云1个月以来,对于分布式计算的一点 ...

  4. MXNet之ps-lite及parameter server原理

    MXNet之ps-lite及parameter server原理 ps-lite框架是DMLC组自行实现的parameter server通信框架,是DMLC其他项目的核心,例如其深度学习框架MXNE ...

  5. 转:Parameter Server 详解

    Parameter Server 详解   本博客仅为作者记录笔记之用,不免有很多细节不对之处. 还望各位看官能够见谅,欢迎批评指正. 更多相关博客请猛戳:http://blog.csdn.net/c ...

  6. Adam:大规模分布式机器学习框架

    引子 转载请注明:http://blog.csdn.net/stdcoutzyx/article/details/46676515 又是好久没写博客,记得有一次看Ng大神的訪谈录,假设每周读三篇论文, ...

  7. 分布式机器学习框架:MxNet 前言

           原文连接:MxNet和Caffe之间有什么优缺点一.前言: Minerva: 高效灵活的并行深度学习引擎 不同于cxxnet追求极致速度和易用性,Minerva则提供了一个高效灵活的平台 ...

  8. [Distributed ML] Parameter Server & Ring All-Reduce

    Resource ParameterServer入门和理解[较为详细,涉及到另一个框架:ps-lite] 一文读懂「Parameter Server」的分布式机器学习训练原理 并行计算与机器学习[很有 ...

  9. parameter server学习

    关于parameter server的学习: https://www.zybuluo.com/Dounm/note/517675 机器学习系统相比于其他系统而言,有一些自己的独特特点.例如: 迭代性: ...

  10. 分布式机器学习框架:CXXNet

    caffe是很优秀的dl平台.影响了后面很多相关框架.        cxxnet借鉴了很多caffe的思想.相比之下,cxxnet在实现上更加干净,例如依赖很少,通过mshadow的模板化使得gpu ...

随机推荐

  1. Jira使用浅谈篇二

    本篇参考:https://university.atlassian.com/student/collection/850385/path/1083901 本篇接上文,上文已经对项目设置了一个基础的配置 ...

  2. Kafka 管理【主题、分区、消费者组】

    更多内容,前往 IT-BLOG 主题操作 使用 kafka-topics.sh 工具可以执行主题的大部分操作.可以用它创建.修改.删除和查看集群里的主题.要使用该工具的全部功能,需要通过 --zook ...

  3. 使用 ApplicationContextAware 定义 SpringContextHolder 类

    需求:使用 @autowired注入一些对象,但发现不可以直接使用@Autowired,因为方法是static的,要使用该方法当前对象也必须是static,正常情况下@Autowired无法注入静态的 ...

  4. Windows10绿色植物主题Kemikal

    给大家分享一个Windows10的主题,Kemikal主题,内置8张绿色植物树木的壁纸.使用这个主题前需要破解系统主题文件. 想要完整的使用这个主题,需要下载安装下方的三个文件: Windows10主 ...

  5. Alchemy Nft黑客松任务(第一周)

    Alchemy是什么项目? 2019年12月,Alchemy完成1500万美元A轮融资,资方为Pantera Capital,斯坦福大学,Coinbase,三星等. 2021年4月,Alchemy以5 ...

  6. 有一个公网IP地址

    这几天在家里拉了一条300M+的宽带,但是遇到了一些坑,本文就简单说明一下如下: 突发此次需求是这样的:阿里云有台服务器公网带宽是1M的,虽说带宽小,但是数据中心的服务器显然是稳定的,只是带宽太小,有 ...

  7. [Linux]Linux大文件已删除,但df查看已使用的空间并未减少解决【待续】

    1 问题描述 X 参考文献 Linux大文件已删除,但df查看已使用的空间并未减少解决 - ChinaUnix linux磁盘空间未及时释放 - 博客园 linux磁盘目录占用空间分析工具之ncdu ...

  8. 【JSOI2008】最大值

    [JSOI2008]最大值 线段树裸题!动态RMQ. 这道题的操作是直接在序列末尾添加数值,所以连\(push_{down}\),以及建树什么的都不用了.. 这真是写过的最简短的一道\(seg_{tr ...

  9. subprocess,哈希,日志模块

    hashlib模块: # 1. 先确定你要使用的加密方式: md系列,sha系列 md5 = hashlib.md5() # 指定加密方式 # 2. 进行明文数据的加密 data = 'hello12 ...

  10. 【观察者设计模式详解】C/Java/JS/Go/Python/TS不同语言实现

    简介 观察者模式(Observer Pattern)是一种行为型模式.它定义对象间的一种一对多的依赖关系,当一个对象的状态发生改变时,所有依赖于它的对象都得到通知并被自动更新. 观察者模式使用三个类S ...