分布式机器学习中,参数服务器(Parameter Server)用于管理和共享模型参数,其基本思想是将模型参数存储在一个或多个中央服务器上,并通过网络将这些参数共享给参与训练的各个计算节点。每个计算节点可以从参数服务器中获取当前模型参数,并将计算结果返回给参数服务器进行更新。

为了保持模型一致性,通常采用下列两种方法:

  1. 将模型参数保存在一个集中的节点上,当一个计算节点要进行模型训练时,可从集中节点获取参数,进行模型训练,然后将更新后的模型推送回集中节点。由于所有计算节点都从同一个集中节点获取参数,因此可以保证模型一致性。
  2. 每个计算节点都保存模型参数的副本,因此要定期强制同步模型副本,每个计算节点使用自己的训练数据分区来训练本地模型副本。在每个训练迭代后,由于使用不同的输入数据进行训练,存储在不同计算节点上的模型副本可能会有所不同。因此,每一次训练迭代后插入一个全局同步的步骤,这将对不同计算节点上的参数进行平均,以便以完全分布式的方式保证模型的一致性,即All-Reduce范式

PS架构

在该架构中,包含两个角色:parameter server和worker

parameter server将被视为master节点在Master/Worker架构,而worker将充当计算节点负责模型训练

整个系统的工作流程分为4个阶段:

  1. Pull Weights: 所有worker从参数服务器获取权重参数
  2. Push Gradients: 每一个worker使用本地的训练数据训练本地模型,生成本地梯度,之后将梯度上传参数服务器
  3. Aggregate Gradients:收集到所有计算节点发送的梯度后,对梯度进行求和
  4. Model Update:计算出累加梯度,参数服务器使用这个累加梯度来更新位于集中服务器上的模型参数

可见,上述的Pull Weights和Push Gradients涉及到通信,首先对于Pull Weights来说,参数服务器同时向worker发送权重,这是一对多的通信模式,称为fan-out通信模式。假设每个节点(参数服务器和工作节点)的通信带宽都为1。假设在这个数据并行训练作业中有N个工作节点,由于集中式参数服务器需要同时将模型发送给N个工作节点,因此每个工作节点的发送带宽(BW)仅为1/N。另一方面,每个工作节点的接收带宽为1,远大于参数服务器的发送带宽1/N。因此,在拉取权重阶段,参数服务器端存在通信瓶颈。

对于Push Gradients来说,所有的worker并发地发送梯度给参数服务器,称为fan-in通信模式,参数服务器同样存在通信瓶颈。

基于上述讨论,通信瓶颈总是发生在参数服务器端,将通过负载均衡解决这个问题

将模型划分为N个参数服务器,每个参数服务器负责更新1/N的模型参数。实际上是将模型参数分片(sharded model)并存储在多个参数服务器上,可以缓解参数服务器一侧的网络瓶颈问题,使得参数服务器之间的通信负载减少,提高整体的通信效率。

代码实现

定义网络结构:

class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu") self.conv1 = nn.Conv2d(1,32,3,1).to(device)
self.dropout1 = nn.Dropout2d(0.5).to(device)
self.conv2 = nn.Conv2d(32,64,3,1).to(device)
self.dropout2 = nn.Dropout2d(0.75).to(device)
self.fc1 = nn.Linear(9216,128).to(device)
self.fc2 = nn.Linear(128,20).to(device)
self.fc3 = nn.Linear(20,10).to(device) def forward(self,x):
x = self.conv1(x)
x = self.dropout1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.dropout2(x)
x = F.max_pool2d(x,2)
x = torch.flatten(x,1) x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x) output = F.log_softmax(x,dim=1) return output

如上定义了一个简单的CNN

实现参数服务器:

class ParamServer(nn.Module):
def __init__(self):
super().__init__()
self.model = Net() if torch.cuda.is_available():
self.input_device = torch.device("cuda:0")
else:
self.input_device = torch.device("cpu") self.optimizer = optim.SGD(self.model.parameters(),lr=0.5) def get_weights(self):
return self.model.state_dict() def update_model(self,grads):
for para,grad in zip(self.model.parameters(),grads):
para.grad = grad self.optimizer.step()
self.optimizer.zero_grad()

get_weights获取权重参数,update_model更新模型,采用SGD优化器

实现worker:

class Worker(nn.Module):
def __init__(self):
super().__init__()
self.model = Net()
if torch.cuda.is_available():
self.input_device = torch.device("cuda:0")
else:
self.input_device = torch.device("cpu") def pull_weights(self,model_params):
self.model.load_state_dict(model_params) def push_gradients(self,batch_idx,data,target):
data,target = data.to(self.input_device),target.to(self.input_device)
output = self.model(data)
data.requires_grad = True
loss = F.nll_loss(output,target)
loss.backward()
grads = [] for layer in self.parameters():
grad = layer.grad
grads.append(grad) print(f"batch {batch_idx} training :: loss {loss.item()}") return grads

Pull_weights获取模型参数,push_gradients上传梯度

训练

训练数据集为MNIST

import torch
from torchvision import datasets,transforms from network import Net
from worker import *
from server import * train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True,
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=False,
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=128, shuffle=True) def main():
server = ParamServer()
worker = Worker() for batch_idx, (data,target) in enumerate(train_loader):
params = server.get_weights()
worker.pull_weights(params)
grads = worker.push_gradients(batch_idx,data,target)
server.update_model(grads) print("Done Training") if __name__ == "__main__":
main()

分布式机器学习(Parameter Server)的更多相关文章

  1. 分布式机器学习系统笔记(一)——模型并行,数据并行,参数平均,ASGD

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 文章索引::"机器学 ...

  2. 百度DMLC分布式深度机器学习开源项目(简称“深盟”)上线了如xgboost(速度快效果好的Boosting模型)、CXXNET(极致的C++深度学习库)、Minerva(高效灵活的并行深度学习引擎)以及Parameter Server(一小时训练600T数据)等产品,在语音识别、OCR识别、人脸识别以及计算效率提升上发布了多个成熟产品。

    百度为何开源深度机器学习平台?   有一系列领先优势的百度却选择开源其深度机器学习平台,为何交底自己的核心技术?深思之下,却是在面对业界无奈时的远见之举.   5月20日,百度在github上开源了其 ...

  3. 【分布式计算】MapReduce的替代者-Parameter Server

    原文:http://blog.csdn.net/buptgshengod/article/details/46819051 首先还是要声明一下,这个文章是我在入职阿里云1个月以来,对于分布式计算的一点 ...

  4. MXNet之ps-lite及parameter server原理

    MXNet之ps-lite及parameter server原理 ps-lite框架是DMLC组自行实现的parameter server通信框架,是DMLC其他项目的核心,例如其深度学习框架MXNE ...

  5. 转:Parameter Server 详解

    Parameter Server 详解   本博客仅为作者记录笔记之用,不免有很多细节不对之处. 还望各位看官能够见谅,欢迎批评指正. 更多相关博客请猛戳:http://blog.csdn.net/c ...

  6. Adam:大规模分布式机器学习框架

    引子 转载请注明:http://blog.csdn.net/stdcoutzyx/article/details/46676515 又是好久没写博客,记得有一次看Ng大神的訪谈录,假设每周读三篇论文, ...

  7. 分布式机器学习框架:MxNet 前言

           原文连接:MxNet和Caffe之间有什么优缺点一.前言: Minerva: 高效灵活的并行深度学习引擎 不同于cxxnet追求极致速度和易用性,Minerva则提供了一个高效灵活的平台 ...

  8. [Distributed ML] Parameter Server & Ring All-Reduce

    Resource ParameterServer入门和理解[较为详细,涉及到另一个框架:ps-lite] 一文读懂「Parameter Server」的分布式机器学习训练原理 并行计算与机器学习[很有 ...

  9. parameter server学习

    关于parameter server的学习: https://www.zybuluo.com/Dounm/note/517675 机器学习系统相比于其他系统而言,有一些自己的独特特点.例如: 迭代性: ...

  10. 分布式机器学习框架:CXXNet

    caffe是很优秀的dl平台.影响了后面很多相关框架.        cxxnet借鉴了很多caffe的思想.相比之下,cxxnet在实现上更加干净,例如依赖很少,通过mshadow的模板化使得gpu ...

随机推荐

  1. 【2019CCPC秦皇岛:A】Angle Beats 分类讨论 (unordered_map 加 hash)

    题意:n个给定点,q个询问点,每次询问给出一个坐标A,问从n中选定两个点B,C,有多少种方案使得ABC是个直角三角形. 思路:直角三角形能想的就那几个,枚举边,枚举顶点,这个题都行,写的枚举顶点的,A ...

  2. AD域安全攻防实践(附攻防矩阵图)

    以域控为基础架构,通过域控实现对用户和计算机资源的统一管理,带来便利的同时也成为了最受攻击者重点攻击的集权系统. 01.攻击篇 针对域控的攻击技术,在Windows通用攻击技术的基础上自成一套技术体系 ...

  3. P7213 [JOISC2020] 最古の遺跡 3 乱写

    不想写题解了,把写在草稿纸上的东西整理了一下 感谢 crashed 大佬的题解与对本人问题的回答,没有他我就不会搞懂这道神仙计数题.

  4. C_C++常用函数汇总

    1 string.h.cstring(C) (1)字符串连接函数 strcat.strncat strcat(char[ ], const char[ ]) strncat(char[ ], cons ...

  5. Centos 6 部署PPTP服务

    前言:PPTP使用一个TCP连接对隧道进行维护,使用通用路由封装(GRE)技术把数据封装成PPP数据桢通过隧道传送.可以对封装PPP桢中的负载数据进行加密或压缩. 注意:PPTP协议已经被IOS系统所 ...

  6. Vue+ElementUI动态显示el-table某列(值和颜色)的方法

    方法一:结合 template scope组件和 v-if 语法判断 例1:值 <el-table-column prop="status" label="车辆状态 ...

  7. 【牛客小白月赛69】题解与分析A-F【蛋挞】【玩具】【开题顺序】【旅游】【等腰三角形(easy)】【等腰三角形(hard)】

    比赛传送门:https://ac.nowcoder.com/acm/contest/52441 感觉整体难度有点偏大. 作者:Eriktse 简介:19岁,211计算机在读,现役ACM银牌选手力争以通 ...

  8. C++ 猜数字

    #include <iostream> #include <random> #include <limits> namespace random { std::ra ...

  9. python入门教程之十三错误和异常

    作为 Python 初学者,在刚学习 Python 编程时,经常会看到一些报错信息,在前面我们没有提及,这章节我们会专门介绍. Python 有两种错误很容易辨认:语法错误和异常. Python as ...

  10. MySQL数据库与Nacos搭建监控服务

    目录 Nacos部署 项目环境 快速开始 nacos2.2.0版本配置说明 MySQL部署 安装方式 Linux平台(CentOS-Stream-9)部署MySQL 调试防火墙管理工具 MySQL用户 ...