设一共有\(K\)个客户机,

中心服务器初始化模型参数,执行若干轮(round),每轮选取至少1个至多\(K\)个客户机参与训练,接下来每个被选中的客户机同时在自己的本地根据服务器下发的本轮(\(t\)轮)模型\(w_t\)用自己的数据训练自己的模型\(w^k_{t+1}\),上传回服务器。服务器将收集来的各客户机的模型根据各方样本数量用加权平均的方式进行聚合,得到下一轮的模型\(w_{t+1}\):

\[\begin{aligned}
& \qquad w_{t+1} \leftarrow \sum^K_{k=1} \frac{n_k}{n} w^k_{t+1} \qquad\qquad //n_k为客户机k上的样本数量,n为所有被选中客户机的总样本数量\\
\end{aligned}
\]

【伪代码】

\[\begin{aligned}
& 算法1:Federated\ Averaging算法(FedAvg)。 \\
& K个客户端编号为k;B,E,\eta分别代表本地的minibatch\ size,epochs,学习率learning\ rate \\
& \\
& 服务器执行:\\
& \quad 初始化w_0 \\
& \quad for \ 每轮t=1,2,...,do \\
& \qquad m \leftarrow max(C \cdot K,1) \qquad\qquad //C为比例系数 \\
& \qquad S_t \leftarrow (随机选取m个客户端) \\
& \qquad for \ 每个客户端k \in S_t 同时\ do \\
& \qquad \qquad w^k_{t+1} \leftarrow 客户端更新(k,w_t) \\
& \qquad w_{t+1} \leftarrow \sum^K_{k=1} \frac{n_k}{n} w^k_{t+1} \qquad\qquad //n_k为客户机k上的样本数量,n为所有被选中客户机的总样本数量\\
& \\
& 客户端更新(k,w): \qquad \triangleright 在客户端k上运行 \\
& \quad \beta \leftarrow (将P_k分成若干大小为B的batch) \qquad\qquad //P_k为客户机k上数据点的索引集,P_k大小为n_k \\
& \quad for\ 每个本地的epoch\ i(1\sim E) \ do \\
& \qquad for\ batch\ b \in \beta \ do \\
& \qquad \qquad w \leftarrow w-\eta \triangledown l(w;b) \qquad\qquad //\triangledown 为计算梯度,l(w;b)为损失函数\\
& \quad 返回w给服务器
\end{aligned}
\]

为了增加客户机计算量,可以在中心服务器做聚合(加权平均)操作前在每个客户机上多迭代更新几次。计算量由三个参数决定:

  • \(C\),每一轮(round)参与计算的客户机比例。
  • \(E(epochs)\),每一轮每个客户机投入其全部本地数据训练一遍的次数。
  • \(B(batch size)\),用于客户机更新的batch大小。\(B=\infty\)表示batch为全部样本,此时就是full-batch梯度下降了。

当\(E=1\ B=\infty\)时,对应的就是FedSGD,即每一轮客户机一次性将所有本地数据投入训练,更新模型参数。

对于一个有着\(n_k\)个本地样本的客户机\(k\)来说,每轮的本地更新次数为\(u_k=E\cdot \frac{n_k}{B}\)。

参考文献:

  1. H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. Y. Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Proc. AISTATS, 2016, pp. 1273–1282.

联邦平均算法(Federated Averaging Algorithm,FedAvg)的更多相关文章

  1. 谷歌的网页排序算法(PageRank Algorithm)

    本文将介绍谷歌的网页排序算法(PageRank Algorithm),以及它如何从250亿份网页中捞到与你的搜索条件匹配的结果.它的匹配效果如此之好,以至于“谷歌”(google)今天已经成为一个被广 ...

  2. 联邦学习(Federated Learning)

    联邦学习简介        联邦学习(Federated Learning)是一种新兴的人工智能基础技术,在 2016 年由谷歌最先提出,原本用于解决安卓手机终端用户在本地更新模型的问题,其设计目标是 ...

  3. 维特比算法(Viterbi Algorithm)

      寻找最可能的隐藏状态序列(Finding most probable sequence of hidden states) 对于一个特殊的隐马尔科夫模型(HMM)及一个相应的观察序列,我们常常希望 ...

  4. 图像处理之泛洪填充算法(Flood Fill Algorithm)

    泛洪填充算法(Flood Fill Algorithm) 泛洪填充算法又称洪水填充算法是在很多图形绘制软件中常用的填充算法,最熟悉不过就是 windows paint的油漆桶功能.算法的原理很简单,就 ...

  5. 隐马尔科夫模型,第三种问题解法,维比特算法(biterbi) algorithm python代码

    上篇介绍了隐马尔科夫模型 本文给出关于问题3解决方法,并给出一个例子的python代码 回顾上文,问题3是什么, 下面给出,维比特算法(biterbi) algorithm 下面通过一个具体例子,来说 ...

  6. 图像处理------泛洪填充算法(Flood Fill Algorithm) 油漆桶功能

    泛洪填充算法(Flood Fill Algorithm) 泛洪填充算法又称洪水填充算法是在很多图形绘制软件中常用的填充算法,最熟悉不过就是 windows paint的油漆桶功能.算法的原理很简单,就 ...

  7. HMM隐马尔科夫算法(Hidden Markov Algorithm)初探

    1. HMM背景 0x1:概率模型 - 用概率分布的方式抽象事物的规律 机器学习最重要的任务,是根据一些已观察到的证据(例如训练样本)来对感兴趣的未知变量(例如类别标记)进行估计和推测. 概率模型(p ...

  8. 一致性哈希算法(Consistent Hashing Algorithm)

    一致性哈希算法(Consistent Hashing Algorithm) 浅谈一致性Hash原理及应用   在讲一致性Hash之前我们先来讨论一个问题. 问题:现在有亿级用户,每日产生千万级订单,如 ...

  9. EM算法(Expectation Maximization Algorithm)

    EM算法(Expectation Maximization Algorithm) 1. 前言   这是本人写的第一篇博客(2013年4月5日发在cnblogs上,现在迁移过来),是学习李航老师的< ...

随机推荐

  1. 手把手带你使用Paint in 3D和Photon撸一个在线涂鸦画板

    Paint in 3D Paint in 3D用于在游戏内和编辑器里绘制所有物体.所有功能已经过深度优化,在WebGL.移动端.VR 以及更多平台用起来都非常好用! 它支持标准管线,以及 LWRP.H ...

  2. 99%的人都搞错了的java方法区存储内容,通过可视化工具HSDB和代码示例一次就弄明白了

    https://zhuanlan.zhihu.com/p/269134063  番茄番茄我是西瓜 那是我日夜思念深深爱着的人啊~ 已关注   6 人赞同了该文章 前言 本篇是java内存区域管理系列教 ...

  3. win10 doskey宏命令定义,类似于Linux的alias别名命令

    doskey 命令别名=命令 例如:doskey echo2 = echo $1 这里的$1是占位符. 如果想删除,直接赋予空值即可:例如:doskey echo2= 总的来说把 https://do ...

  4. maven常用命令含义

    今天在开发过程中,对一个mapper.xml文件的sql进行了改动,重启tomcat后发现没有生效,首先考虑是不是远程服务开启着,导致代码没有走本地,确认远程服务是关闭的,的确是本地修改没有生效,于是 ...

  5. 为什么 Thread 类的 sleep()和 yield ()方法是静态的?

    Thread 类的 sleep()和 yield()方法将在当前正在执行的线程上运行.所以在其他处于等待状态的线程上调用这些方法是没有意义的.这就是为什么这些方法是静态的.它们可以在当前正在执行的线程 ...

  6. LIKE 声明中的%和_是什么意思?

    %对应于 0 个或更多字符,_只是 LIKE 语句中的一个字符. 如何在 Unix 和 MySQL 时间戳之间进行转换? UNIX_TIMESTAMP 是从 MySQL 时间戳转换为 Unix 时间戳 ...

  7. SVG是什么?

    SVG表示(scalable vector graphics)可缩放矢量图形.这是一个基于文本的图形语言,它可以绘制使用文本.线.点等的图形,因此可以轻巧又快速地渲染.

  8. ubuntu 安装 mysql mariadb

    本教程面向Ubuntu服务器,适用于Ubuntu的任何LTS版本,包括Ubuntu 14.04,Ubuntu 16.04,Ubuntu 18.04,甚至非LTS版本(如Ubuntu 17.10和其他基 ...

  9. Python的数据基础库Numpy怎样对数组进行排序

    Numpy怎样对数组排序 Numpy给数组排序的三个方法: numpy.sort:返回排序后数组的拷贝 array.sort:原地排序数组而不是返回拷贝 numpy.argsort:间接排序,返回的是 ...

  10. w3schools网站的HTML教程之HTML编辑器

    使用记事本或文本编辑器编写 HTML HTML 可以使用如下专业的 HTML 编辑器进行编辑: Microsoft WebMatrix Sublime Text 然而,我们推荐使用记事本(PC)或文本 ...