PyTorch 中 weight decay 的设置
先介绍一下 Caffe 和 TensorFlow 中 weight decay 的设置:
- 在 Caffe 中,
SolverParameter.weight_decay可以作用于所有的可训练参数, 不妨称为 global weight decay, 另外还可以为各层中的每个可训练参数设置独立的decay_mult, global weight decay 和当前可训练参数的decay_mult共同决定了当前可训练参数的 weight decay. - 在 TensorFlow 中, 某些接口可以为其下创建的可训练参数设置独立的 weight decay (如
slim.conv2d可通过weights_regularizer,bias_regularizer分别为其下定义的 weight 和 bias 设置不同的 regularizer).
在 PyTorch 中, 模块 (nn.Module) 和参数 (nn.Parameter) 的定义没有暴露与 weight decay 设置相关的 argument, 它把 weight decay 的设置放到了 torch.optim.Optimizer (严格地说, 是 torch.optim.Optimizer 的子类, 下同) 中.
在 torch.optim.Optimizer 中直接设置 weight_decay, 其将作用于该 optimizer 负责优化的所有可训练参数 (和 Caffe 中 SolverParameter.weight_decay 的作用类似), 这往往不是所期望的: BatchNorm 层的 \(\gamma\) 和 \(\beta\) 就不应该添加正则化项, 卷积层和全连接层的 bias 也往往不用加正则化项. 幸运地是, torch.optim.Optimizer 支持为不同的可训练参数设置不同的 weight_decay (params 支持 dict 类型), 于是问题转化为如何将不期望添加正则化项的可训练参数 (如 BN 层的可训练参数及卷积层和全连接层的 bias) 从可训练参数列表中分离出来. 笔者借鉴网上的一些方法, 写了一个满足该功能的函数, 没有经过严格测试, 仅供参考.
"""
作者: 采石工
博客: http://www.cnblogs.com/quarryman/
发布时间: 2020年10月21日
版权声明: 自由分享, 保持署名-非商业用途-非衍生, 知识共享3.0协议.
"""
import torch
from torchvision import models
def split_parameters(module):
params_decay = []
params_no_decay = []
for m in module.modules():
if isinstance(m, torch.nn.Linear):
params_decay.append(m.weight)
if m.bias is not None:
params_no_decay.append(m.bias)
elif isinstance(m, torch.nn.modules.conv._ConvNd):
params_decay.append(m.weight)
if m.bias is not None:
params_no_decay.append(m.bias)
elif isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
params_no_decay.extend([*m.parameters()])
elif len(list(m.children())) == 0:
params_decay.extend([*m.parameters()])
assert len(list(module.parameters())) == len(params_decay) + len(params_no_decay)
return params_decay, params_no_decay
def print_parameters_info(parameters):
for k, param in enumerate(parameters):
print('[{}/{}] {}'.format(k+1, len(parameters), param.shape))
if __name__ == '__main__':
model = models.resnet18(pretrained=False)
params_decay, params_no_decay = split_parameters(model)
print_parameters_info(params_decay)
print_parameters_info(params_no_decay)
参考
- https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994
- https://discuss.pytorch.org/t/changing-the-weight-decay-on-bias-using-named-parameters/19132/4
- https://discuss.pytorch.org/t/how-to-set-different-learning-rate-for-weight-and-bias-in-one-layer/13450
- Allow to set 0 weight decay for biases and params in batch norm #1402
版权声明
版权声明:自由分享,保持署名-非商业用途-非衍生,知识共享3.0协议。
如果你对本文有疑问或建议,欢迎留言!转载请保留版权声明!
如果你觉得本文不错, 也可以用微信赞赏一下哈.

PyTorch 中 weight decay 的设置的更多相关文章
- 在神经网络中weight decay
weight decay(权值衰减)的最终目的是防止过拟合.在损失函数中,weight decay是放在正则项(regularization)前面的一个系数,正则项一般指示模型的复杂度,所以weigh ...
- pytorch中tensorboardX的用法
在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...
- 转pytorch中训练深度神经网络模型的关键知识点
版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...
- 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析
转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...
- PyTorch中使用深度学习(CNN和LSTM)的自动图像标题
介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...
- 第五章——Pytorch中常用的工具
2018年07月07日 17:30:40 __矮油不错哟 阅读数:221 1. 数据处理 数据加载 ImageFolder DataLoader加载数据 sampler:采样模块 1. 数据处理 ...
- Pytorch中RoI pooling layer的几种实现
Faster-RCNN论文中在RoI-Head网络中,将128个RoI区域对应的feature map进行截取,而后利用RoI pooling层输出7*7大小的feature map.在pytorch ...
- pytorch中如何使用DataLoader对数据集进行批处理
最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处 ...
- (原)CNN中的卷积、1x1卷积及在pytorch中的验证
转载请注明处处: http://www.cnblogs.com/darkknightzh/p/9017854.html 参考网址: https://pytorch.org/docs/stable/nn ...
随机推荐
- Java知识点JUC总结
JUC:java.util.concurrent (Java并发编程工具类) 一般面试提问:面向对象和高级语法.Java集合类.Java多线程.JUC 和高并发.Java IO和 NIO 获取多线程的 ...
- hystrix(2) metrics
上一节讲到了hystrix提供的五个功能,这一节我们首先来讲hystrix中提供实时执行metrics信息的实现.为什么先讲metrics,因为很多功能都是基于metrics的数据来实现的,它是很多功 ...
- 【Jenkins】三、设置定时任务
1.点击工程(Test1), 选择左侧的配置 2.选择"构建触发器"下面的"定时构建" 3.填写定时规则(这里设置每隔30分钟执行一次) 4.定时规则语法字段 ...
- php第四天-字符串
0x01 字符串 1.1 字符串的处理方式 在不同的语言中,字符串的处理方式不同:在C中字符串是作为字节数组处理的:在Java中字符串是作为对象处理的:而在php中则把字符串作为基本数据类型来处理. ...
- 机器学习-线性规划(LP)
线性规划问题 首先引入如下的问题: 假设食物的各种营养成分.价格如下表: Food Energy(能量) Protein(蛋白质) Calcium(钙) Price Oatmeal(燕麦) 110 4 ...
- 刷题[MRCTF2020]Ezpop
解题思路 打开一看直接是代码审计的题,就嗯审.最近可能都在搞反序列化,先把反序列化的题刷烂,理解理解 代码审计 Welcome to index.php <?php //flag is in f ...
- 有没有异常处理翻车过的,绩效还被打了C
絮叨 因为程序异常处理问题,就在前几天龙叔的服务挂了几秒钟. 完了,马上季度末打绩效,竟然在这里翻车了,心如刀绞啊. 虽然没有影响到用户体验,但是找到问题并解决掉问题是工程师日常追求之一. 作为一个优 ...
- Processing 状态量控制动画技巧
之前在CSDN上发表过: https://blog.csdn.net/fddxsyf123/article/details/62848357
- Oracle 11G RAC11.2.0.4 + Redhat7.3安装手册
安装思路: 1.安装两台redhat7 linux系统 2.网络配置(双网卡,public,vip,private,scan) 3.存储配置(内存配置,ASM共享存储:6块5G共享盘udev,根目录留 ...
- 神奇的字符串匹配:扩展KMP算法
引言 一个算是冷门的算法(在竞赛上),不过其算法思想值得深究. 前置知识 kmp的算法思想,具体可以参考 → Click here trie树(字典树). 正文 问题定义:给定两个字符串 S 和 T( ...