联邦学习PySyft
Steps involved in the Federated Learning Approach
The mobile devices download the global ML model
Data is being generated while the user is using application linked with the ML model
As the user starts to interact with the application more, the user gets much better predictions according to his usage
Once the model is ready for the scheduled sync with the server. The personalised model that was getting trained with the on device capability is sent to the server.
Models from all the devices are collected and a Federated average function is used to generate a much imporved version of the model than the previous one
Once trained the improved version is sent to all the devices where the user gets the experience based on the usage by all the devices arround the globe.
Installing PySyft
In order to install PySyft, it is recommended that you set up a conda environment first
conda create -n pysyft python=3
conda activate pysyft
conda install jupyter notebook
You then need to install the package
pip install syft
Step by Step guide to develop the neural network using federated learning approach
Importing the libraries:
Numpy
PyTorch
PySyft
Pickle
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import time
import copy
import numpy as np
import syft as sy
from syft.frameworks.torch.federated import utils
from syft.workers.websocket_client import WebsocketClientWorker
Initializing the training parameters
Learning rate 0.001
Neural network 100 epoches
total batches 8
class Parser:
def __init__(self):
self.epoches = 100
self.lr = 0.001
self.test_batch_size = 8
self.batch_size =8
self.log_interval = 10
self.seed = 1
args = Parser()
torch.manual_seed(args.seed)
Dataset Preprocessing
with open('boston_housing.pickle','rb') as f:
((x,y),(x_test,y_test)) = pickle.load(f)
x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float()
x_test = torch.from_numpy(x_test).float()
y_test = torch.from_numpy(y_test).float()
mean = x.mean(0,keepdim=True)
dev = x.std(0,keepdim=True)
mean[:,3] = 0.
dev[:,3] = 1.
x = (x-mean)/dev
x_test = (x_test - mean)/dev
train = TensorDataset(x,y)
test = TensorDataset(x_test,y_test)
train_loader = DataLoader(train,batch_size = args.batch_size,shuffle=True)
train_loader = DataLoader(test,batch_size=args.test_batch_size,shuffle=True)
Creating Neural Network with PyTorch
Creating the architecture of the neural network model
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1 = nn.Linear(13,32)
self.fc2 = nn.Linear(32,24)
self.fc3 = nn.Linear(24,16)
self.fc4 = nn.Linear(16,1)
def __init__(self):
x = x.view(-1,13)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc4(x))
x = self.fc(x)
return x
Connecting the data with the remote mobile devices
Though data will be available offline for federated learning with the workers but here we are sending the data over to the workers for training with ondevice capability
remote_dataset = (list(), list())
train_distributed_dataset = []
for batch_idx, (data,target) in enumerate(train_loader):
data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
remote_dataset[batch_idx % len(compute_nodes)].append((data, target))
bobs_model = Net()
alices_model = Net()
bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)
alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)
models = [bobs_model,alices_model]
optimizers = [bobs_optimizer,alices_optimizer]
model = Net()
Connect to the workers or the devices for training
hook = sy.TorchHook(torch)
bob_worker = sy.VirtualWorker(hook, id="bob")
alice_worker = sy.VirtualWorker(hook, id="alice")
compute_nodes = [bob_worker, alice_worker]
Training the Neural Network
def update(data, target, model, optimizer):
model.send(data.location)
optimizer.zero_grad()
prediction = model(data)
loss = F.mse_loss(prediction.view(-1), target)
loss.backward()
optimizer.step()
return model
def train():
for data_index in range(len(remote_dataset[0])-1):
for remote_index in range(len(compute_nodes)):
data, target = remote_dataset[remote_index][data_index]
models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])
for model in models:
model.get()
return utils.federated_avg({
"bob": models[0],
"alice": models[1]
})
def test(federated_model):
federated_model.eval()
test_loss = 0
for data, target in test_loader:
output = federated_model(data)
test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item()
predection = output.data.max(1, keepdim=True)[1]
test_loss /= len(test_loader.dataset)
print('Test set: Average loss: {:.4f}'.format(test_loss))
for epoch in range(args.epochs):
start_time = time.time()
print(f"Epoch Number {epoch + 1}")
federated_model = train()
model = federated_model
test(federated_model)
total_time = time.time() - start_time
print('Communication time over the network', round(total_time, 2), 's\n')
References:
Federated Learning with PySyft
联邦学习PySyft的更多相关文章
- 联邦学习开源框架FATE助力腾讯神盾沙箱,携手打造数据安全合作生态
近日,微众银行联邦学习FATE开源社区迎来了两位新贡献者——来自腾讯的刘洋及秦姝琦,作为云计算安全领域的专家,两位为FATE构造了新的功能点,并在Github上提交修复了相关漏洞.(Github项目地 ...
- 联邦学习(Federated Learning)
联邦学习简介 联邦学习(Federated Learning)是一种新兴的人工智能基础技术,在 2016 年由谷歌最先提出,原本用于解决安卓手机终端用户在本地更新模型的问题,其设计目标是 ...
- 联邦学习 Federated Learning 相关资料整理
本文链接:https://blog.csdn.net/Sinsa110/article/details/90697728代码微众银行+杨强教授团队的联邦学习FATE框架代码:https://githu ...
- 腾讯数据安全专家谈联邦学习开源项目FATE:通往隐私保护理想未来的桥梁
数据孤岛.数据隐私以及数据安全,是目前人工智能和云计算在大规模产业化应用过程中绕不开的“三座大山”. “联邦学习”作为新一代的人工智能算法,能在数据不出本地的情况下,实现共同建模,提升AI模型的效果, ...
- Federal Learning(联邦学习)认知
本人是学生党,同时也是小菜鸡一枚,撞运气有机会能够给老师当项目助理,在这个过程中肯定会学到一些有趣的知识,就在此平台上记录一下,在知识点方面有不对的还请各位指正. What(什么是联邦学习?) 联邦学 ...
- Apache Pulsar 在腾讯 Angel PowerFL 联邦学习平台上的实践
腾讯 Angel PowerFL 联邦学习平台 联邦学习作为新一代人工智能基础技术,通过解决数据隐私与数据孤岛问题,重塑金融.医疗.城市安防等领域. 腾讯 Angel PowerFL 联邦学习平台构建 ...
- MindSpore联邦学习框架解决行业级难题
内容来源:华为开发者大会2021 HMS Core 6 AI技术论坛,主题演讲<MindSpore联邦学习框架解决隐私合规下的数据孤岛问题>. 演讲嘉宾:华为MindSpore联邦学习工程 ...
- 联邦学习:按Dirichlet分布划分Non-IID样本
我们在<Python中的随机采样和概率分布(二)>介绍了如何用Python现有的库对一个概率分布进行采样,其中的dirichlet分布大家一定不会感到陌生.该分布的概率密度函数为 \[P( ...
- 【流行前沿】联邦学习 Federated Learning with Only Positive Labels
核心问题:如果每个用户只有一类数据,如何进行联邦学习? Felix X. Yu, , Ankit Singh Rawat, Aditya Krishna Menon, and Sanjiv Kumar ...
随机推荐
- 2019 年 GitHub 上最热门的 Java 开源项目
1.JavaGuide https://github.com/Snailclimb/JavaGuide Star 22668 [Java 学习 + 面试指南] 一份涵盖大部分 Java 程序员所需要掌 ...
- 世界GDP数据可视化
各国GDP数据可视化 数据来自世界银行 导入资源包,如下: Pandas, numpy, seaborn 和 matplotlib import pandas as pd import numpy a ...
- Tkinter最佳实践(半小时)
概述: 简介 Tkinter模块("Tk 接口")是Python的标准Tk GUI工具包的接口.Tk和Tkinter可以在大多数的Unix平台下使用,同样可以应用在Windows和 ...
- Java设计模式:Prototype(原型)模式
概念定义 使用原型实例指定待创建对象的种类,并通过拷贝该原型来创建新的对象.Prototype模式允许一个原型对象克隆(复制)出多个与其相同的对象,而无需知道任何如何创建的细节. 应用场景 对象的创建 ...
- 什么是SFP光模块?
什么是光模块? 光模块(optical module)由光电子器件.功能电路和光接口等组成,光电子器件包括发射和接收两部分.简单的说,光模块的作用就是光电转换,发送端把电信号转换成光信号,通过光纤传送 ...
- C# - VS2019页面布局容器splitContainer和groupBox小结
前言 在WinFrm应用程序中,产品的外观.布局将直接影响用户第一体验,所以对于开发者来说,在没有美工支持的前提下,应当注意系统页面的布局,本章主要讲解splitContainer和groupBox的 ...
- 史上最详细JVM,Java内存区域讲解
本人免费整理了Java高级资料,一共30G,需要自己领取:传送门:https://mp.weixin.qq.com/s/JzddfH-7yNudmkjT0IRL8Q 运行时数据区域 JVM载执行Jav ...
- FCC---Use CSS Animation to Change the Hover State of a Button---鼠标移过,背景色变色,用0.5s的动画制作
You can use CSS @keyframes to change the color of a button in its hover state. Here's an example of ...
- 2019年上半年收集到的人工智能LSTM干货文章
2019年上半年收集到的人工智能LSTM干货文章 门控神经网络:LSTM 和 GRU 简要说明 LSTM-CNN-Attention算法系列之一:LSTM提取时间特征 对时间序列分类的LSTM全卷积网 ...
- Model赋值返回json
DataTable resultList = bll.GetResultListByCondition(bureauCode, deptCode, fileTitle); IList<GanBu ...