先介绍一下 Caffe 和 TensorFlow 中 weight decay 的设置:

  • Caffe 中, SolverParameter.weight_decay 可以作用于所有的可训练参数, 不妨称为 global weight decay, 另外还可以为各层中的每个可训练参数设置独立的 decay_mult, global weight decay 和当前可训练参数的 decay_mult 共同决定了当前可训练参数的 weight decay.
  • TensorFlow 中, 某些接口可以为其下创建的可训练参数设置独立的 weight decay (如 slim.conv2d 可通过 weights_regularizer, bias_regularizer 分别为其下定义的 weight 和 bias 设置不同的 regularizer).

在 PyTorch 中, 模块 (nn.Module) 和参数 (nn.Parameter) 的定义没有暴露与 weight decay 设置相关的 argument, 它把 weight decay 的设置放到了 torch.optim.Optimizer (严格地说, 是 torch.optim.Optimizer 的子类, 下同) 中.

torch.optim.Optimizer 中直接设置 weight_decay, 其将作用于该 optimizer 负责优化的所有可训练参数 (和 Caffe 中 SolverParameter.weight_decay 的作用类似), 这往往不是所期望的: BatchNorm 层的 \(\gamma\) 和 \(\beta\) 就不应该添加正则化项, 卷积层和全连接层的 bias 也往往不用加正则化项. 幸运地是, torch.optim.Optimizer 支持为不同的可训练参数设置不同的 weight_decay (params 支持 dict 类型), 于是问题转化为如何将不期望添加正则化项的可训练参数 (如 BN 层的可训练参数及卷积层和全连接层的 bias) 从可训练参数列表中分离出来. 笔者借鉴网上的一些方法, 写了一个满足该功能的函数, 没有经过严格测试, 仅供参考.

"""
作者: 采石工
博客: http://www.cnblogs.com/quarryman/
发布时间: 2020年10月21日
版权声明: 自由分享, 保持署名-非商业用途-非衍生, 知识共享3.0协议.
"""
import torch
from torchvision import models def split_parameters(module):
params_decay = []
params_no_decay = []
for m in module.modules():
if isinstance(m, torch.nn.Linear):
params_decay.append(m.weight)
if m.bias is not None:
params_no_decay.append(m.bias)
elif isinstance(m, torch.nn.modules.conv._ConvNd):
params_decay.append(m.weight)
if m.bias is not None:
params_no_decay.append(m.bias)
elif isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
params_no_decay.extend([*m.parameters()])
elif len(list(m.children())) == 0:
params_decay.extend([*m.parameters()])
assert len(list(module.parameters())) == len(params_decay) + len(params_no_decay)
return params_decay, params_no_decay def print_parameters_info(parameters):
for k, param in enumerate(parameters):
print('[{}/{}] {}'.format(k+1, len(parameters), param.shape)) if __name__ == '__main__':
model = models.resnet18(pretrained=False)
params_decay, params_no_decay = split_parameters(model)
print_parameters_info(params_decay)
print_parameters_info(params_no_decay)

参考

版权声明

版权声明:自由分享,保持署名-非商业用途-非衍生,知识共享3.0协议。

如果你对本文有疑问或建议,欢迎留言!转载请保留版权声明!

如果你觉得本文不错, 也可以用微信赞赏一下哈.

PyTorch 中 weight decay 的设置的更多相关文章

  1. 在神经网络中weight decay

    weight decay(权值衰减)的最终目的是防止过拟合.在损失函数中,weight decay是放在正则项(regularization)前面的一个系数,正则项一般指示模型的复杂度,所以weigh ...

  2. pytorch中tensorboardX的用法

    在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...

  3. 转pytorch中训练深度神经网络模型的关键知识点

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...

  4. 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析

    转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...

  5. PyTorch中使用深度学习(CNN和LSTM)的自动图像标题

    介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...

  6. 第五章——Pytorch中常用的工具

    2018年07月07日 17:30:40 __矮油不错哟 阅读数:221   1. 数据处理 数据加载 ImageFolder DataLoader加载数据 sampler:采样模块 1. 数据处理 ...

  7. Pytorch中RoI pooling layer的几种实现

    Faster-RCNN论文中在RoI-Head网络中,将128个RoI区域对应的feature map进行截取,而后利用RoI pooling层输出7*7大小的feature map.在pytorch ...

  8. pytorch中如何使用DataLoader对数据集进行批处理

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处 ...

  9. (原)CNN中的卷积、1x1卷积及在pytorch中的验证

    转载请注明处处: http://www.cnblogs.com/darkknightzh/p/9017854.html 参考网址: https://pytorch.org/docs/stable/nn ...

随机推荐

  1. Nginx反代MogileFS集群

    上一篇博文我们主要聊了下分布式文件系统MogileFS的组件以及部署使用,回顾请参考https://www.cnblogs.com/qiuhom-1874/p/13677279.html:今天我们主要 ...

  2. MySQL的共享锁阻塞会话案例浅析输入日志标题

        这是问题是一个网友遇到的问题:一个UPDATE语句产生的共享锁阻塞了其他会话的案例,对于这个案例,我进一步分析.总结和衍化了相关问题.下面分析如有不对的地方,敬请指正.下面是初始化环境和数据的 ...

  3. MyBatis学习(一)初识MyBatis

    一.MyBatis简介 MyBatis 本是apache的一个开源项目iBatis, 2010年这个项目由apache software foundation 迁移到了google code,并且改名 ...

  4. Codeforces Round #673 (Div. 2)

    [Codeforces Round #673 (Div. 2) ] 题目链接# A. Copy-paste 思路: 贪心的策略.每次只加上最小的就可以了 #include<bits/stdc++ ...

  5. git push 提交时出错 the remote end hung up unexpectedly

    错误原因 与远程服务的连接中断,但是检查发现origin还在,可能是文件太大,缓存不够,增加缓存大小 解决方案 专案目录 >.git >config 在末尾增加如下代码 [http] po ...

  6. RabbitMQ小记(二)

    1.RabbitMQ相关介绍 (1)RabbitMQ整体上是一个生产者和消费者模型,主要负责接收.存储.转发消息.RabbitMQ整体结构图如下: (2)生产者:发送消息的一方,生产者创建一条消息,发 ...

  7. sipp3.6对freeswitch进行压力测试

    一.安装sipp 1.下载地址: https://github-production-release-asset-2e65be.s3.amazonaws.com/13161657/99df6100-9 ...

  8. Hyper-V Server + Windows Admin Center

    2020年的十一黄金周是双节,偶然间得知再出现双节可能要几十年之后了,很可惜我并没有出去游玩的打算.所以假期没什么事,就来研究下Hyper Server + Windows Admin Center. ...

  9. 051 01 Android 零基础入门 01 Java基础语法 05 Java流程控制之循环结构 13 Eclipse下程序调试——debug入门1

    051 01 Android 零基础入门 01 Java基础语法 05 Java流程控制之循环结构 13 Eclipse下程序调试--debug入门1 本文知识点: 程序调试--debug入门1 程序 ...

  10. Python数据结构与算法之图的广度优先与深度优先搜索算法示例

    本文实例讲述了Python数据结构与算法之图的广度优先与深度优先搜索算法.分享给大家供大家参考,具体如下: 根据维基百科的伪代码实现: 广度优先BFS: 使用队列,集合 标记初始结点已被发现,放入队列 ...