联邦平均算法(Federated Averaging Algorithm,FedAvg)
设一共有\(K\)个客户机,
中心服务器初始化模型参数,执行若干轮(round),每轮选取至少1个至多\(K\)个客户机参与训练,接下来每个被选中的客户机同时在自己的本地根据服务器下发的本轮(\(t\)轮)模型\(w_t\)用自己的数据训练自己的模型\(w^k_{t+1}\),上传回服务器。服务器将收集来的各客户机的模型根据各方样本数量用加权平均的方式进行聚合,得到下一轮的模型\(w_{t+1}\):
& \qquad w_{t+1} \leftarrow \sum^K_{k=1} \frac{n_k}{n} w^k_{t+1} \qquad\qquad //n_k为客户机k上的样本数量,n为所有被选中客户机的总样本数量\\
\end{aligned}
\]
【伪代码】
& 算法1:Federated\ Averaging算法(FedAvg)。 \\
& K个客户端编号为k;B,E,\eta分别代表本地的minibatch\ size,epochs,学习率learning\ rate \\
& \\
& 服务器执行:\\
& \quad 初始化w_0 \\
& \quad for \ 每轮t=1,2,...,do \\
& \qquad m \leftarrow max(C \cdot K,1) \qquad\qquad //C为比例系数 \\
& \qquad S_t \leftarrow (随机选取m个客户端) \\
& \qquad for \ 每个客户端k \in S_t 同时\ do \\
& \qquad \qquad w^k_{t+1} \leftarrow 客户端更新(k,w_t) \\
& \qquad w_{t+1} \leftarrow \sum^K_{k=1} \frac{n_k}{n} w^k_{t+1} \qquad\qquad //n_k为客户机k上的样本数量,n为所有被选中客户机的总样本数量\\
& \\
& 客户端更新(k,w): \qquad \triangleright 在客户端k上运行 \\
& \quad \beta \leftarrow (将P_k分成若干大小为B的batch) \qquad\qquad //P_k为客户机k上数据点的索引集,P_k大小为n_k \\
& \quad for\ 每个本地的epoch\ i(1\sim E) \ do \\
& \qquad for\ batch\ b \in \beta \ do \\
& \qquad \qquad w \leftarrow w-\eta \triangledown l(w;b) \qquad\qquad //\triangledown 为计算梯度,l(w;b)为损失函数\\
& \quad 返回w给服务器
\end{aligned}
\]
为了增加客户机计算量,可以在中心服务器做聚合(加权平均)操作前在每个客户机上多迭代更新几次。计算量由三个参数决定:
- \(C\),每一轮(round)参与计算的客户机比例。
- \(E(epochs)\),每一轮每个客户机投入其全部本地数据训练一遍的次数。
- \(B(batch size)\),用于客户机更新的batch大小。\(B=\infty\)表示batch为全部样本,此时就是full-batch梯度下降了。
当\(E=1\ B=\infty\)时,对应的就是FedSGD,即每一轮客户机一次性将所有本地数据投入训练,更新模型参数。
对于一个有着\(n_k\)个本地样本的客户机\(k\)来说,每轮的本地更新次数为\(u_k=E\cdot \frac{n_k}{B}\)。
参考文献:
- H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. Y. Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Proc. AISTATS, 2016, pp. 1273–1282.
联邦平均算法(Federated Averaging Algorithm,FedAvg)的更多相关文章
- 谷歌的网页排序算法(PageRank Algorithm)
本文将介绍谷歌的网页排序算法(PageRank Algorithm),以及它如何从250亿份网页中捞到与你的搜索条件匹配的结果.它的匹配效果如此之好,以至于“谷歌”(google)今天已经成为一个被广 ...
- 联邦学习(Federated Learning)
联邦学习简介 联邦学习(Federated Learning)是一种新兴的人工智能基础技术,在 2016 年由谷歌最先提出,原本用于解决安卓手机终端用户在本地更新模型的问题,其设计目标是 ...
- 维特比算法(Viterbi Algorithm)
寻找最可能的隐藏状态序列(Finding most probable sequence of hidden states) 对于一个特殊的隐马尔科夫模型(HMM)及一个相应的观察序列,我们常常希望 ...
- 图像处理之泛洪填充算法(Flood Fill Algorithm)
泛洪填充算法(Flood Fill Algorithm) 泛洪填充算法又称洪水填充算法是在很多图形绘制软件中常用的填充算法,最熟悉不过就是 windows paint的油漆桶功能.算法的原理很简单,就 ...
- 隐马尔科夫模型,第三种问题解法,维比特算法(biterbi) algorithm python代码
上篇介绍了隐马尔科夫模型 本文给出关于问题3解决方法,并给出一个例子的python代码 回顾上文,问题3是什么, 下面给出,维比特算法(biterbi) algorithm 下面通过一个具体例子,来说 ...
- 图像处理------泛洪填充算法(Flood Fill Algorithm) 油漆桶功能
泛洪填充算法(Flood Fill Algorithm) 泛洪填充算法又称洪水填充算法是在很多图形绘制软件中常用的填充算法,最熟悉不过就是 windows paint的油漆桶功能.算法的原理很简单,就 ...
- HMM隐马尔科夫算法(Hidden Markov Algorithm)初探
1. HMM背景 0x1:概率模型 - 用概率分布的方式抽象事物的规律 机器学习最重要的任务,是根据一些已观察到的证据(例如训练样本)来对感兴趣的未知变量(例如类别标记)进行估计和推测. 概率模型(p ...
- 一致性哈希算法(Consistent Hashing Algorithm)
一致性哈希算法(Consistent Hashing Algorithm) 浅谈一致性Hash原理及应用 在讲一致性Hash之前我们先来讨论一个问题. 问题:现在有亿级用户,每日产生千万级订单,如 ...
- EM算法(Expectation Maximization Algorithm)
EM算法(Expectation Maximization Algorithm) 1. 前言 这是本人写的第一篇博客(2013年4月5日发在cnblogs上,现在迁移过来),是学习李航老师的< ...
随机推荐
- Cobalt Strike之LINK木马
在同一目录下 新建一个exp.ps1 一个test.txt exp.ps1代码 $file = Get-Content "test.txt" $WshShell = New-Obj ...
- loj536「LibreOJ Round #6」花札(二分图博弈)
loj536「LibreOJ Round #6」花札(二分图博弈) loj 题解时间 很明显是二分图博弈. 以某个点为起点,先手必胜的充要条件是起点一定在最大匹配中. 判断方法是看起点到该点的边有流量 ...
- DLink 815路由器栈溢出漏洞分析与复现
DLink 815路由器栈溢出漏洞分析与复现 qemu模拟环境搭建 固件下载地址 File DIR-815_FIRMWARE_1.01.ZIP - Firmware for D-link DIR-81 ...
- 手撕代码:leetcode 309最佳买卖股票时机含冷冻期
转载于:https://segmentfault.com/a/1190000014746613 给定一个整数数组,其中第i个元素代表了第i天的股票价格. 设计一个算法计算出最大利润.在满足以下约束条件 ...
- Java并发机制(7)--线程池ThreadPoolExecutor的使用
Java并发编程:线程池的使用整理自:博客园-海子-http://www.cnblogs.com/dolphin0520/p/3932921.html 1.什么是线程池,为什么要使用线程池: 1.1. ...
- Oracle入门基础(五)一一多表查询
SQL> --等值连接 SQL> --查询员工信息:员工号 姓名 月薪 部门名称 SQL> set linesize 80 SQL> desc dept 名称 是否为空? 类型 ...
- 两个链表有一个交点,如何在时间复杂度 O(n) 和 空间复杂度 O(1) 的条件下实现?_字节跳动面试题
输入两个链表,找出它们的第一个公共结点 我们可以首先遍历两个链表得到它们的长度,就能知道哪个链表比较长, 我们可以首先遍历两个链表得到它们的长度,就能知道哪个链表比较长,以及长的链表比短的链表多几个结 ...
- java-网络通信--socket实现多人聊天(基于命令行)
先编写最简答的服务器 思路 1编写一个实现Runnable接口的静态内部类 ServerC,便于区分每个客户端 1.1 获取客户端数据函数 public String remsg() 1.2 转发消息 ...
- elasticsearch 5.6.7在线安装ik分词,亲测有效
官网的在线安装命令 ./bin/elasticsearch-plugin install https://github.com/medcl/elasticsearch-analysis-ik/rele ...
- 《每周一点canvas动画》——3D点线与水波动画
<每周一点canvas动画>--差分函数的妙用 每周一点canvas动画代码文件 好像上次更新还是十一前,这唰唰唰的就过去大半个月了,现在才更新实在不好意思.这次我们不涉及canvas 3 ...