1804.03235-Large scale distributed neural network training through online distillation.md
现有分布式模型训练的模式
- 分布式SGD
- 并行SGD: 大规模训练中,一次的最长时间取决于最慢的机器
- 异步SGD: 不同步的数据,有可能导致权重更新向着未知方向
- 并行多模型 :多个集群训练不同的模型,再组合最终模型,但是会消耗inference运行时
- 蒸馏:流程复杂
- student训练数据集的选择
- unlabeled的数据
- 原始数据
- 留出来的数据
- student训练数据集的选择
协同蒸馏
- using the same architecture for all the models;
- using the same dataset to train all the models; and
- using the distillation loss during training before any model has fully converged.
特点
- 就算thacher和student是完全相同的模型设置,只要其内容足够不同,也是能够获得有效的提升的
- 即是模型未收敛,收益也是有的
- 丢掉teacher和student的区分,互相训练,也是有好处的
- 不是同步的模型也是可以的。
算法简单易懂,而且步骤看上去不是很复杂。
使用out of state模型权重的解释:
- every change in weights leads to a change in gradients, but as training progresses towards convergence, weight updates should substantially change only the predictions on a small subset of the training data;
- weights (and gradients) are not statistically identifiable as different copies of the weights might have arbitrary scaling differences, permuted hidden units, or otherwise rotated or transformed hidden layer feature space so that averaging gradients does not make sense unless models are extremely similar;
- sufficiently out-of-sync copies of the weights will have completely arbitrary differences that change the meaning of individual directions in feature space that are not distinguishable by measuring the loss on the training set;
- in contrast, output units have a clear and consistent meaning enforced by the loss function and the training data.
所以这里似乎是说,随机性的好处?
一种指导性的实用框架设计:
- Each worker trains an independent version of the model on a locally available subset of the training data.
- Occasionally, workers checkpoint their parameters.
- Once this happens, other workers can load the freshest available checkpoints into memory and perform codistillation.
再加上,可以在小一些的集群上使用分布式SGD。
另外论文中提到,这种方式,比起每次直接发送梯度和权重,只需要偶尔载入checkpoint,而且各个模型集群在运算上是完全相互独立的。这个倒是确实能减少一些问题。
但是,如果某个模型垮掉了,完全没收敛呢?
另外,没看出来这种框架哪里简单了,管理模型和checkpoint不是一个简单的事情。
实验结论
20TB的数据,有钱任性
论文中提到,并不是机器越多,最终模型效果越好,似乎32-128是比较合适的,更多了,模型收敛速度和性能不会更好,有时反而会有下降。
论文中的实验结果2a,最好的还是双模型并行,其次是协同蒸馏,最差的是unigram的smooth0.9,label smooth 0.99跟直接训练表现差不多,毕竟只是一个随机噪声。
另外,通过对比相同数据的协同蒸馏2b,和随机数据的协同整理,实验发现,随机数据实际上让模型有更好的表现
3在imagenet上的实验,出现了跟2a差不多的结果。
4中虽然不用非得用最新的模型,但是,协同蒸馏,使用太久远的checkpoint还是会显著降低训练效率的。
欠拟合的模型是有用的,但是过拟合的模型在蒸馏中可能不太有价值。
协同蒸馏比双步蒸馏能更快的收敛,而且更有效率。
3.5中介绍的,也是很多时候面临的问题,因为初始化,训练过程的参数不一样等问题,可能导致两次训练出来的模型的输出有很大区别。例如分类模型,可能上次训练的在某些分类上准确,而这次训练的,在这些分类上就不准确了。模型平均或者蒸馏法能有效避免这个问题。
总结
balabalabala
实验只尝试了两个模型,多个模型的多种拓扑结构也值得尝试。
很值得一读的一个论文。
1804.03235-Large scale distributed neural network training through online distillation.md的更多相关文章
- 论文笔记之:Large Scale Distributed Semi-Supervised Learning Using Streaming Approximation
Large Scale Distributed Semi-Supervised Learning Using Streaming Approximation Google 2016.10.06 官方 ...
- 1 - ImageNet Classification with Deep Convolutional Neural Network (阅读翻译)
ImageNet Classification with Deep Convolutional Neural Network 利用深度卷积神经网络进行ImageNet分类 Abstract We tr ...
- 【转】Principles of training multi-layer neural network using backpropagation
Principles of training multi-layer neural network using backpropagation http://galaxy.agh.edu.pl/~vl ...
- 用matlab训练数字分类的深度神经网络Training a Deep Neural Network for Digit Classification
This example shows how to use Neural Network Toolbox™ to train a deep neural network to classify ima ...
- CheeseZH: Stanford University: Machine Learning Ex4:Training Neural Network(Backpropagation Algorithm)
1. Feedforward and cost function; 2.Regularized cost function: 3.Sigmoid gradient The gradient for t ...
- 【论文考古】Training a 3-Node Neural Network is NP-Complete
今天看到一篇1988年的老文章谈到了训练一个简单网络是NPC问题[1].也就是下面的网络结构,在线性激活函数下,如果要找到参数使得输入数据的标签估计准确,这个问题是一个NPC问题.这个文章的意义在于宣 ...
- 大规模视觉识别挑战赛ILSVRC2015各团队结果和方法 Large Scale Visual Recognition Challenge 2015
Large Scale Visual Recognition Challenge 2015 (ILSVRC2015) Legend: Yellow background = winner in thi ...
- 通过Visualizing Representations来理解Deep Learning、Neural network、以及输入样本自身的高维空间结构
catalogue . 引言 . Neural Networks Transform Space - 神经网络内部的空间结构 . Understand the data itself by visua ...
- (zhuan) Recurrent Neural Network
Recurrent Neural Network 2016年07月01日 Deep learning Deep learning 字数:24235 this blog from: http:/ ...
随机推荐
- Java BitSet解决海量数据去重
先提一个问题,怎么在40亿个整数中找到那个唯一重复的数字? 第一想法就是Set的不可重复性,依次把每个数字放入HashSet中,当放不去进去的时候说明这就是重复的数字,输出这个数字. if(hs.co ...
- css 实现等分布局
目前移动版等分布局最常用的是 flex 等分,pc 端上用得更多则是 float. 假设父元素下有 3 个子元素,每个子元素相隔 24px,子元素等分父元素宽度 实现:float + margin ( ...
- pll时钟延迟为问题
pll时钟延迟为问题 这关系到pll的工作方式,如果pll内部使用的是鉴频器,则输入和输出将没有固定的相位差,就是每次锁定都锁定在某个相位,但每次都不一样.如果使用的是鉴相器,则输入和输出为0相位差. ...
- VS2017编译GDAL(64bit)+解决C#读取Shp数据中文路径的问题
编译GDAL过程比较繁琐,查阅了网上相关资料,同时通过实践,完成GDAL的编译,同时解决了SHP数据中文路径及中文字段乱码的问题,本文以“gdal-2.3.2”版本为例阐述整个编译过程. 一.编译准备 ...
- 海量数据中的TOPK问题小结
1.利用堆找出最大的K个数 首先,先理解下用堆找出最大的K个数的常用解法,例如问题是“从M(M <= 10000)个数中找出最大的K个数” (1)利用最大堆 建立一个N=M大小的大顶堆,然后输出 ...
- guava Lists.transform使用
作用:将一个List中的实体类转化为另一个List中的实体类. 稍微方便一点.例如:将List<Student>转化为List<StudentVo> Student: pack ...
- 廖雪峰Java8JUnit单元测试-2使用JUnit-4超时测试
1.超时测试 可以为JUnit的单个测试设置超时: 超时设置1秒:@Test(timeout=1000),单位为毫秒 2.示例 Leibniz定理:PI/4= 1 - 1/3 + 1/5 - 1/7 ...
- css定义好看的垂直滚动条
滚动条的css样式主要有三部分组成: 1.::-webkit-scrollbar 定义了滚动条整体的样式: 2.::-webkit-scrollbar-thumb 滑块部分: 3. ...
- docker 恶意镜像到容器逃逸影响本机
转载:http://521.li/post/122.html SUSE Linux GmbH高级软件工程师Aleksa Sarai公布了影响Docker, containerd, Podman, CR ...
- 填坑:Java 中的日期转换
我们之前讨论过时间,在Java 中有一些方法会出现横线?比如Date 过期方法. 参考文章:知识点:java一些方法会有横线?以Date 过期方法为例 Java中的日期和时间处理方法 Date类(官方 ...