【机器学习的Tricks】随机权值平均优化器swa与pseudo-label伪标签
文章来自公众号【机器学习炼丹术】
1 stochastic weight averaging(swa)
- 随机权值平均
- 这是一种全新的优化器,目前常见的有SGB,ADAM,
【概述】:这是一种通过梯度下降改善深度学习泛化能力的方法,而且不会要求额外的计算量,可以用到Pytorch的优化器中。
随机权重平均和随机梯度下降SGD相似,所以我一般吧SWa看成SGD的进阶版本。
1.1 原理与算法
swa算法流程:

【怎么理解】:
- 对\(w_{swa}\)做了一个周期为c的滑动平均。每迭代c次,就会对这个\(w_{swa}\)做一次滑动平均。其他的时间使用SGD进行更新。
- 简单的说,整个流程是模型初始化参数之后,使用SGD进行梯度下降,迭代了c个epoch之后,将模型的参数用加权平均,得到\(w_{SWA}\),然后现在模型的参数就是\(w_{SWA}\),然后再用SGD去梯度下降c个epoch,然后再加权平均出来一个新的\(w_{SWA}\).
SWA加入了周期性滑动平均来限制权重的变化,解决了传统SGD在反向过程中的权重震荡问题。SGD是依靠当前的batch数据进行更新,寻找随机梯度下降随机寻找的样本的梯度下降方向很可能并不是我们想要的方向。
论文中给出了一个图片:

- 绿线是恒定学习率的SGD,效果并不好,直到SGD在训练的过程中所见了学习率,才可以得到一个收敛的结果;
- 而使用Stochastic weight averaging可以在学习率恒定的情况下,快速收敛,而且过程平稳。
1.2 python与实现
这里讲如何在pytorch深度学习框架中加入swa作为优化器:
from torchcontrib.optim import SWA
# training loop
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
opt = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
for _ in range(100):
opt.zero_grad()
loss_fn(model(input), target).backward()
opt.step()
opt.swap_swa_sgd()
如果使用了swa的话,那么lr_schedule这个方法就不需要在使用了,非常的方便。
【关于参数】:
使用swa的时候,就直接通过
torchcontrib.optim.SWA(base_opt,swa_start,swa_greq,swa_lr)
来封装原来的优化器。
- swa_start:是一个整数,表示经过swa_start个steps后,将学习率切换为固定值swa_lr。(在swa_start之前的step中,lr是0.1,在10个steps之后,lr变成0.05)
- swa_freq:在swa_freq个step优化之后,会将对应的权重加到swa滑动平均的结果上,相当于算法中的c;
- 使用opt.swap_swa_sgd()之后,可以将模型的权重替换为swa的滑动平均的权重。
1.3 关于BN
这里有一个问题就是在BatchNorm层训练的时候,BN层中也是有两个训练参数的,使用\(w_{swa}\)重置了模型参数,但是并没有更新BN层的参数,所以如果有bn层的话,还需要加上:
opt.bn_update(train_loader,model)
2 Pseudo-Label
- 伪标签
- 这是一种半监督的方法。其实非常简单,就是对于未标记的数据,许纳泽预测概率最大的标记作为该样本的pseudo-label,然后给未标记数据设置一个权重,在训练过程中慢慢增加未标记数据的权重。
这个方法的loss如下:

非常好理解了,前面一项就是训练集的loss,后面是测试集的loss,然后用一个\(\alpha(t)\)来做权重。
然后这个\(\alpha(t)\)就是随着训练的迭代次数增加而慢慢的线性增加(如果按照原来的论文中的描述):

【一些关于pseudo-label的杂谈】
这个方法提出在2013年,然后再2015年作者用entropy信息熵来证明这个方法的有效性。但是证明过程较为牵强。这个伪标签我在2017年的一个项目中想到了,但是不知道可行不可行自己当时也无法进行证明,就作罢了,没想到现在看到同样的方法在2013年就提出来了。有点五味杂陈哈哈。
参考文献:
Izmailov, Pavel, et al. "Averaging weights leads to wider optima and better generalization." arXiv preprint arXiv:1803.05407 (2018).
Grandvalet, Yves, and Yoshua Bengio. "Semi-supervised learning by entropy minimization." Advances in neural information processing systems. 2005.
【机器学习的Tricks】随机权值平均优化器swa与pseudo-label伪标签的更多相关文章
- 区间第k大问题 权值线段树 hdu 5249
先说下权值线段树的概念吧 权值平均树 就是指区间维护值为这个区间内点出现次数和的线段树 用这个加权线段树 解决第k大问题就很方便了 int query(int l,int r,int rt,int k ...
- UVA 11090 Going in Cycle!! 环平均权值(bellman-ford,spfa,二分)
题意: 给定一个n个点m条边的带权有向图,求平均权值最小的回路的平均权值? 思路: 首先,图中得有环的存在才有解,其次再解决这个最小平均权值为多少.一般这种就是二分猜平均权值了,因为环在哪也难以找出来 ...
- GA:GA优化BP神经网络的初始权值、阈值,从而增强BP神经网络的鲁棒性—Jason niu
global p global t global R % 输入神经元个数,此处是6个 global S1 % 隐层神经元个数,此处是10个 global S2 % 输出神经元个数,此处是4个 glob ...
- The Minimum Cycle Mean in a Digraph 《有向图中的最小平均权值回路》 Karp
文件链接 Karp在1977年的论文,讲述了一种\(O(nm)\)的算法,用来求有向强连通图中最小平均权值回路(具体问题请参照这里) 本人翻译(有删改): 首先任取一个节点 \(s\) ,定义 \(F ...
- POJ 2018 Best Cow Fences (二分答案构造新权值 or 斜率优化)
$ POJ~2018~Best~Cow~ Fences $(二分答案构造新权值) $ solution: $ 题目大意: 给定正整数数列 $ A $ ,求一个平均数最大的长度不小于 $ L $ 的子段 ...
- 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)
1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...
- 机器学习:集成学习:随机森林.GBDT
集成学习(Ensemble Learning) 集成学习的思想是将若干个学习器(分类器&回归器)组合之后产生一个新学习器.弱分类器(weak learner)指那些分类准确率只稍微好于随机猜测 ...
- 【BZOJ-2892&1171】强袭作战&大sz的游戏 权值线段树+单调队列+标记永久化+DP
2892: 强袭作战 Time Limit: 50 Sec Memory Limit: 512 MBSubmit: 45 Solved: 30[Submit][Status][Discuss] D ...
- NOIp 2014 #2 联合权值 Label:图论 !!!未AC
题目描述 无向连通图G 有n 个点,n - 1 条边.点从1 到n 依次编号,编号为 i 的点的权值为W i ,每条边的长度均为1 .图上两点( u , v ) 的距离定义为u 点到v 点的最短距离. ...
随机推荐
- TCP Wrappers(简单防火墙)---限制IP登录ssh
1.TCP Wrappers 简介 TCP_ Wrappers是- 一个工作在第四层(传输层)的的安全工具,对有状态连接(TCP)的特定服务进行安全检测并实现访问控制,界定方式是凡是调用libwrap ...
- 代码注入——c++代码注入
代码注入之——c++代码注入 0x00 代码注入和DLL注入的区别 DLL注入后DLL会通过线程常驻在某个process中,而代码注入完成之后立即消失. 代码注入体积小,不占内存 0x01 通过c ...
- HashMap等集合初始化时应制定初始化大小
阿里巴巴开发规范中,推荐用户在初始化HashMap时,应指定集合初始值大小. 一.原因 这个不用多想,肯定是效率问题,那为什么会造成效率问题呢? 当我们new一个HashMap没有对其容量进行初始化的 ...
- SpringBoot执行定时任务@Scheduled
SpringBoot执行定时任务@Scheduled 在做项目时,需要一个定时任务来接收数据存入数据库,后端再写一个接口来提供该该数据的最新的那一条. 数据保持最新:设计字段sign的值(0,1)来设 ...
- python数据处理(五)之数据清洗:研究、匹配与格式化
1 前言 保持数据格式一致以及可读,否则数据不可能正确合并 清洗数据的过程中记下清洗过程的每一步,方便数据回溯以及过程复用 2 数据清洗基础知识 2.1 找出需要清洗的数据 仔细观察文件,观察数据字段 ...
- python 并发专题(十一):基础部分补充(三)线程
1. 背景 理论上来说:单个进程的多线程可以利用多核. 但是,开发Cpython解释器的程序员,给进入解释器的线程加了锁. 2. 加锁的原因: 当时都是单核时代,而且cpu价格非常贵. 如果不加全局解 ...
- python 装饰器(四):装饰器基础(三)叠放装饰器,参数化装饰器
叠放装饰器 示例 7-19 演示了叠放装饰器的方式:@lru_cache 应用到 @clock 装饰fibonacci 得到的结果上.在示例 7-21 中,模块中最后一个函数应用了两个 @htmliz ...
- Python函数04/生成器/推导式/内置函数
Python函数04/生成器/推导式/内置函数 目录 Python函数04/生成器/推导式/内置函数 内容大纲 1.生成器 2.推导式 3.内置函数(一) 4.今日总结 5.今日练习 内容大纲 1.生 ...
- 开源利器分享:BitBar 坐看今天你的项目涨了多少 star
今天开头我想叨叨几句,我个人最近的感受.在这个信息爆炸,互联网的时代里.我的周遭总是充斥者着各种让人能产生焦虑的信息, 我不知道有没有小伙伴和我一样,看到各种神通广大.游戏人生的大侠,低头看看自己当前 ...
- sql中in和exists的原理及使用场景。
在我们的工作中可能会遇到这样的情形: 我们需要查询a表里面的数据,但是要以b表作为约束. 举个例子,比如我们需要查询订单表中的数据,但是要以用户表为约束,也就是查询出来的订单的user_id要在用户表 ...