1 介绍

1.1 背景

越来越多的手机和平板电脑成为许多人的主要计算设备。这些设备上强大的传感器(包括摄像头、麦克风和GPS),加上它们经常被携带的事实,意味着它们可以访问前所未有的大量数据,其中大部分本质上是私人的。根据这些数据学习的模型持有承诺通过支持更智能的应用程序来大大提高可用性,但数据的敏感性意味着将其存储在集中位置存在风险和责任。

1.2 本文贡献

本文的主要贡献是

  • 将来自移动设备的分散数据的训练问题(联邦学习)确定为一个重要的研究方向;

  • 选择可以应用于该设置的简单实用的算法FedAvg;

  • 对所提出的方法进行广泛的实证评估。

更具体地说,本文介绍了FedAvg算法,它将每个客户端上的局部随机梯度下降(SGD)与执行模型平均的服务器相结合。本文对该算法进行了广泛的实验,证明了它对不平衡非IID数据分布具有鲁棒性,并且可以将在分散数据上训练深度网络所需的通信轮次减少几个数量级。

1.3 联邦学习的理想问题

  • 对真实世界的移动设备上的数据进行训练 比 对数据中心可获得的代理数据进行训练,有明显的优势;
  • 数据是隐私的或数据量很大;
  • 对于监督任务,标签可以从用户交互中自然推断出来。

举例:

  • 图像分类任务。预测哪些照片最有可能在未来被多次查看或分享。用户拍摄的照片是隐私的,但对于本地,用户对照片的删除、共享等行为就是推断出来的标签。
  • 单词预测。用户在手机上输入时,输入法预测下一个单词。输入信息是隐私的,用户选择的下一个单词就是推断出来的标签。

1.4 联邦学习与分布式的对比

  • 非独立同分布:不同用户对移动设备的使用是不同的,因此数据非独立同分布。
  • 不平衡:一些用户会比其他人更频繁地使用服务或应用程序,从而导致本地训练数据的数量不同。
  • 大规模分布式:预计参与优化的客户端数量将远远大于每个客户端的平均实例数量。
  • 通信受限:移动设备有时候离线,或处于缓慢昂贵的连接中。

2 FedAvg

2.1 损失函数

对于机器学习问题,对于样本\((x_i,y_i)\)的损失为\(f_i(w)\),那么全局损失定义为:

\[f(w)\overset{\text{def}}{=}\frac{1}{n}\overset{n}{\sum}\limits_{i=1}f_i(w)
\]

在联邦学习问题中,假设有\(K\)个客户端,第\(k\)个客户端的数据集为\(P_k\),数据集大小\(n_k=|P_k|\)。那么对于客户端\(k\),该客户端数据的损失函数为:

\[F_k(w)=\frac{1}{n_k}\sum\limits_{i\in P_k}f_i(w)
\]

全局的损失函数定义为客户端损失的加权平均:

\[f(w)=\overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}F_k(w)
\]

2.2 通信成本与计算成本

对于数据集中到中心的情况,由于数据量较大,通信成本相对较小,计算成本较大。

通信成本指客户端与中央服务器之间传输数据所需的成本。联邦学习中,会受到移动设备带宽限制,同时客户端通常仅在有电源和有WiFi等情况下愿意参与优化,因此通信成本较大。而设备数据量小、手机有GPU等特性使得计算成本较小。

为了减小通信成本,方法:

  • 增加并行,每轮使用更多客户端(对应“客户端通常仅在有电源和有WiFi等情况下愿意参与优化”限制)。
  • 每个客户端在每个通信轮之间执行更复杂的计算,而不是执行像梯度计算这样的简单计算。

2.3 相关工作

以往工作没有考虑不平衡和非独立同分布数据,以及客户端数量少。

2.4 FedSGD

根据当前的模型\(w_t\)计算梯度\(g_k=\nabla F_k(w_t)\)。由于:

\[\nabla f(w_t)=\nabla[\overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}F_k(w_t)]=\overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}g_k
\]

那么中心服务器聚合梯度并进行更新的结果为:

\[w_{t+1}\leftarrow w_t-\eta\nabla f(w_t)=w_t-\eta\overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}g_k
\]

上式也等价于客户端先在本地做一次梯度更新,中心服务器再对模型进行加权平均:

\[w^k_{t+1}\leftarrow w_t-\eta g_k
\]
\[w_{t+1}\leftarrow \overset{K}{\sum}\limits_{k=1}\frac{n_k}{n}w^k_{t+1}
\]

2.5 FedAvg

写成上述第二种形式后,可以在做平均之前,多次迭代本地更新:

\[w^k\leftarrow w^k-\eta\nabla F_k(w^k)
\]

每个客户端可以多次计算上式得到本地在第\(t\)轮的最终模型,最后中心服务器将这些本地模型进行聚合得到\(w^{t+1}\)。

这就是FedAvg的思想,该算法主要有三个超参数:

  • \(C\):每次选择的客户端的比例
  • \(B\):本地训练时batchsize,当\(B=\infty\),即全批量
  • \(E\):本地训练轮数

当\(B=\infty,E=1\)时,FedAvg和FedSGD等价

这里还定义了每轮的本地更新次数:\(u_k=E\frac{n_k}{B}\),由该公式也可以算出,FedSGD每轮本地更新次数为1。

完整的伪代码:

至此我们可以简单比较FedSGD和FedAvg:

.center { width: auto; display: table; margin-left: auto; margin-right: auto }

算法 local server
FedSGD 计算本轮梯度 收集local的梯度,加权平均后作为server要下降的梯度
FedAvg 多次梯度下降,得到本轮的本地模型 收集local的模型,加权平均后作为本轮得到的模型

3 实验

3.1 模型初始化

聚合参数\(\theta\):以\(\theta w+(1-\theta)w^{'}\)对两个模型进行聚合,得到最终模型。

左图是使用两个初始模型\(w,w^{'}\)训练不同数据得到的损失,右图是两模型使用同一个\(w\)初始化训练不同数据,可以看出右边损失较小,且当\(\theta=0.5\)效果最好。因此在联邦学习实验中,每个客户端需要共享相同的初始化模型。

3.2 数据集和训练任务

选取大小适中的数据集,以便研究超参数

第一个任务是MNIST数字识别,使用两个模型:

  • 多层感知机。2个隐藏层,每个隐藏层有200个单元,使用ReLU激活。

    199210个参数:图像为\(28\times 28\),转为一维后是784。第一层\(784*200+偏置200\),第二层\(200*200+偏置200\),第三层\(200*10+偏置10\)

  • \(32*5*5\)卷积+\(2*2\)最大池化+\(64*5*5\)卷积+\(2*2\)最大池化+512单元全连接+ReLU+Softmax

数据集划分:

  • iid:划分100个客户端,每个客户端接收600张图。
  • 非iid:先按数字对图片进行排序,并划分成200个大小为300的碎片,给100个客户端每个分2个碎片,即每个客户端分到的数据只包含两个数字。

分出来的数据集有iid和非iid,但都是平衡的。

第二个任务是字符预测,使用LSTM,读取一行字符预测下一个字符。

数据集是莎士比亚全集,每个说话角色为一个客户端,共1146个。每个客户端,前80%的行是训练集,后20%行是测试集。

数据集划分:

  • iid:将所有文字平均划分给每个客户端。
  • 非iid:每个客户端仅有该角色说的话。

学习率设置在\(10^{\frac{1}{3}}\)到\(10^{\frac{1}{6}}\)区间。

3.3 增加并行性

\(C\)控制并行量,因此先改变\(C\)。

实验记录了MLP达到97测试集准确率和CNN达到99测试准确率所需要的通信轮数。

使用小批量,当\(C=0.1\)时效果就已经较好。为平衡计算效率和收敛速度,之后实验固定\(C=0.1\)。

3.4 增加每个客户端的计算量

在FedAvg算法部分,我们已经指出,每轮本地更新次数为\(u_k=E\frac{n_k}{B}\)。在实验中设置独立同分布的更新次数为期望更新次数,即\(u=E\frac{n}{kB}\)。

首先,对于两种任务,增加\(B\)都减少了通信轮数。

对于MNIST任务,iid效果比非iid更显著。实际生活我们设备上的数字也不会是规律性的,因此这种情况是该方法鲁棒的论证。

对于莎士比亚数据集,非iid效果很好,而这代表了我们在现实生活的数据分布(不同的人说话数量相差很大)。推测是某些客户端有较大的数据集,使本地训练更具有价值。

3.5 FedSGD vs FedAvg

可以看出,FedAvg不仅减少通信轮数,还提高了测试精度(蓝色实线是FedSGD)。推测是模型平均会产生类似dropout正则化的收益。

3.6 是否能过度优化客户端

对于非常大的本地迭代次数,FedAvg可能会停滞或发散。这一结果表明,对于某些模型,尤其是在收敛的后期阶段,减少每轮的本地计算量(即减小E或增大B)可能是有益的,就像衰减学习率一样。

3.7 CIFAR实验

数据集包含50000个训练数据和10000个测试数据,将其平均划分给100个客户端,每个客户端包含500个训练数据和100个测试数据。

使用的模型为两个卷积层+两个全连接层+一个线性变换层。

图像会经过裁剪为\(24*24\)、左右反转、调整对比度、亮度等预处理。

单机的SGD对比10个客户端的FedSGD和FedAvg:

现有的模型,对CIFAR分类任务的测试精度已经很高,但这里只要达到80%左右即可,原因是本文的目标是评估FedAvg方法,而非提高CIFAR测试精度。

不同学习率的影响:

3.8 大规模LSTM实验

为了证明方法在现实世界问题上有效,还在大规模的预测下一个单词任务上进行了实验。

训练数据集由来自大型社交网络的1000万个公开帖子组成。按作者对帖子进行了分组,总共有超过500,000名客户。文中将每个客户端的数据集限制为最多5000个单词,并对10000个作者的数据进行了测试。

原文链接:https://arxiv.org/abs/1602.05629

联邦学习开山之作Communication-Efficient Learning of Deep Networks from Decentralized Data的更多相关文章

  1. 【论文考古】联邦学习开山之作 Communication-Efficient Learning of Deep Networks from Decentralized Data

    B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, "Communication-Efficient Learni ...

  2. Communication-Efficient Learning of Deep Networks from Decentralized Data

    郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! Proceedings of the 20th International Conference on Artificial Intell ...

  3. Deep Learning 8_深度学习UFLDL教程:Stacked Autocoders and Implement deep networks for digit classification_Exercise(斯坦福大学深度学习教程)

    前言 1.理论知识:UFLDL教程.Deep learning:十六(deep networks) 2.实验环境:win7, matlab2015b,16G内存,2T硬盘 3.实验内容:Exercis ...

  4. [译]深度神经网络的多任务学习概览(An Overview of Multi-task Learning in Deep Neural Networks)

    译自:http://sebastianruder.com/multi-task/ 1. 前言 在机器学习中,我们通常关心优化某一特定指标,不管这个指标是一个标准值,还是企业KPI.为了达到这个目标,我 ...

  5. 【流行前沿】联邦学习 Partial Model Averaging in Federated Learning: Performance Guarantees and Benefits

    Sunwoo Lee, , Anit Kumar Sahu, Chaoyang He, and Salman Avestimehr. "Partial Model Averaging in ...

  6. Self-Taught Learning to Deep Networks

    In this section, we describe how you can fine-tune and further improve the learned features using la ...

  7. 联邦学习:按混合分布划分Non-IID样本

    我们在博文<联邦学习:按病态独立同分布划分Non-IID样本>中学习了联邦学习开山论文[1]中按照病态独立同分布(Pathological Non-IID)划分样本. 在上一篇博文< ...

  8. 联邦学习 Federated Learning 相关资料整理

    本文链接:https://blog.csdn.net/Sinsa110/article/details/90697728代码微众银行+杨强教授团队的联邦学习FATE框架代码:https://githu ...

  9. 联邦学习(Federated Learning)

    联邦学习简介        联邦学习(Federated Learning)是一种新兴的人工智能基础技术,在 2016 年由谷歌最先提出,原本用于解决安卓手机终端用户在本地更新模型的问题,其设计目标是 ...

  10. Federal Learning(联邦学习)认知

    本人是学生党,同时也是小菜鸡一枚,撞运气有机会能够给老师当项目助理,在这个过程中肯定会学到一些有趣的知识,就在此平台上记录一下,在知识点方面有不对的还请各位指正. What(什么是联邦学习?) 联邦学 ...

随机推荐

  1. 在 Mac 上使用 X11

    有时我们需要在服务器上运行一个 GUI 程序,然而我们是通过 SSH 连接到服务器的,看不到图形界面,怎么办呢?我们可以通过 X11 将 GUI 程序的界面转发到本地. 在 Mac 上使用 X11 需 ...

  2. LaTeX 书写函数

    \[\text{text 模式} \] \[\mathrm{mathrm 模式} \] \[\textit{textit 模式} \] \[\operatorname{operatorname 模式} ...

  3. 【Python自动化】之特殊的自动化定位操作

    今天有时间了,想好好的把之前遇到过的自动化问题总结一下,以后有新的总结再更新 目录: 一.上传文件(4.11) 二.下拉框选择(4.11) 1.Select下拉框 2.非Select下拉框 三.下拉框 ...

  4. Allen基因图谱:python Aabgen的安装

    1. abagen 使用教程的官方链接:abagen: A toolbox for the Allen Brain Atlas genetics data - abagen 0.1.3-doc+0.g ...

  5. WebShell流量特征检测_中国菜刀篇

    80后用菜刀,90后用蚁剑,95后用冰蝎和哥斯拉,以phpshell连接为例,本文主要是对这四款经典的webshell管理工具进行流量分析和检测. 什么是一句话木马? 1.定义 顾名思义就是执行恶意指 ...

  6. 【YashanDB数据库】yasql登录有特殊字符@导致无法登录

    问题备机 Linux bash shell环境下,使用yasql登录数据库没有使用转义导致登录失败.报错信息如下 问题分析 linux特殊字符转义问题,多加几层转义可以解决问题. 解决办法 su - ...

  7. Maven高级——分模块开发与设计

    分模块开发的意义 将原始模块按照功能拆分成若干个子模块,方便模块间的相互调用,接口共享 分模块开发 创建Maven工程 书写模块代码 注意:分模块开发需要先针对模块功能进行设计,再进行编码.不会先将工 ...

  8. MyBatis——案例——查询-查询详情

      查询-查询详情 (根据id获取商品全部信息(即商品对象))          1.编写Mapper接口方法:Brand selectById(int id);            2.编写SQL ...

  9. Passwords

    详见 此处 Header File 0*)1190*+0**0).0970)/0)/111105000

  10. Flutter TextField 的高度问题

    示例 先来看一个例子:假设我们要做一个表单,左边是提示文字,右边是输入框 给出代码: Row( crossAxisAlignment: CrossAxisAlignment.center, child ...