pytorch固定BN层参数
背景:基于PyTorch
的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch
相同的测试数据经过主分支输出的结果不同。
原因:未固定主分支BN
层中的running_mean
和running_var
。
解决方法:将需要固定的BN
层状态设置为eval
。
问题示例:
环境:torch
:1.7.0
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 3)
self.bn2 = nn.BatchNorm2d(16)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 5)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
def print_parameter_grad_info(net):
print('-------parameters requires grad info--------')
for name, p in net.named_parameters():
print(f'{name}:\t{p.requires_grad}')
def print_net_state_dict(net):
for key, v in net.state_dict().items():
print(f'{key}')
if __name__ == "__main__":
net = Net()
print_parameter_grad_info(net)
net.requires_grad_(False)
print_parameter_grad_info(net)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
运行结果:
-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])
可以看到:
net.requires_grad_(False)
已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data
在不同epoch
中前向之后出现了不同的结果。
调用print_net_state_dict
可以看到BN
层中的参数running_mean
和running_var
并没在可优化参数net.parameters
中
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
但在training pahse
的前向过程中,这两个参数被更新了。导致整个网络在freeze
的情况下,同样的测试数据出现了不同的结果
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default
momentum
of 0.1. source
因此在training phase时对BN层显式设置eval
状态:
if __name__ == "__main__":
net = Net()
net.requires_grad_(False)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
net.bn1.eval()
net.bn2.eval()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
可以看到结果正常了:
epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
交流基地:630390733
pytorch固定BN层参数的更多相关文章
- 【转载】 【caffe转向pytorch】caffe的BN层+scale层=pytorch的BN层
原文地址: https://blog.csdn.net/u011668104/article/details/81532592 ------------------------------------ ...
- 【转载】 Caffe BN+Scale层和Pytorch BN层的对比
原文地址: https://blog.csdn.net/elysion122/article/details/79628587 ------------------------------------ ...
- (原)torch中微调某层参数
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...
- Tensorflow训练和预测中的BN层的坑
以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...
- 【转载】 Pytorch(1) pytorch中的BN层的注意事项
原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...
- Batch Normalization的算法本质是在网络每一层的输入前增加一层BN层(也即归一化层),对数据进行归一化处理,然后再进入网络下一层,但是BN并不是简单的对数据进行求归一化,而是引入了两个参数λ和β去进行数据重构
Batch Normalization Batch Normalization是深度学习领域在2015年非常热门的一个算法,许多网络应用该方法进行训练,并且取得了非常好的效果. 众所周知,深度学习是应 ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- BN层
论文名字:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论 ...
- 【卷积神经网络】对BN层的解释
前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...
随机推荐
- 安装git和lsof
yum install git yum install lsof 查看80端口 lsof -i:80
- 【海思】Hi3531A SPI功能的详细配置以及使用
目录 一.前言 二.SPI管脚信息获取 2.1 SPI_SCLK.SPI_SDI.SPI_SDO管脚复用寄存器 2.2 片选SPI_CSN0-SPI_CSN3管脚寄存器 三.配置和使能与SPI相关的管 ...
- FL Studio新手入门:FL Studio五大常用按钮介绍
我们打开FL Studio编曲软件会发现界面中有好多的菜单和窗口,这些窗口每个都有其单独的功能.今天小编主要给大家详细讲解下FL Studio水果软件的五大常用按钮. 1.首先我,我们双击桌面的水果图 ...
- CentOS下如何用nmon收集系统实时运行状况
#赋予执行权限 chmod +x nmon 执行./nmon可以查看实时的系统状态有提示的,d看磁盘,n看网络,c看cpu #如果不想看实时的,想收集系统长时间运行情况然后分析,可用这个 nohup ...
- JAVA中删除文件夹下及其子文件夹下的某类文件
##定时删除拜访图片 ##cron表达式 秒 分 时 天 月 ? ##每月1日整点执行 CRON1=0 0 0 1 * ? scheduled.enable1=false ##图片路径 filePat ...
- springsecurity实现前后端分离之jwt-资料收集
https://www.jianshu.com/p/5b9f1f4de88d https://www.jianshu.com/p/725d32ab92f8 https://blog.csdn.net/ ...
- 第8.23节 Python中使用sort/sorted排序与“富比较”方法的关系分析
一. 引言 <第8.21节 Python中__lt__.gt__等 "富比较"("rich comparison")方法用途探究>和<第8.2 ...
- 第四章 、PyQt中的信号(signal)和槽(slot)机制以及Designer中的使用
老猿Python博文目录 专栏:使用PyQt开发图形界面Python应用 老猿Python博客地址 一.引言 前面章节其实已经在使用信号和槽了,但是作为Qt中最重要的机制也是Qt区别与其他开发平台的重 ...
- python中的Restful
哇,昨天组里进行总结的时候,小哥哥和小姐姐真是把我给秀到了,跟他们一比,我总结的太垃圾了,嘤嘤嘤.因为我平常不怎么总结,总结的话,有word还有纸质的,现在偏向于纸质,因为可以练练字.个人观点是,掌握 ...
- Java程序员普遍存在的面试问题以及应对之道(新书第一章节摘录)
其实大多数Java开发确实能胜任日常的开发工作,但不少候选人却无法在面试中打动面试官.因为要在短时间的面试中全面展示自己的实力,这很需要技巧,而从当前大多数Java开发的面试现状来看,会面试的候选人不 ...