pytorch固定BN层参数
背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。
原因:未固定主分支BN层中的running_mean和running_var。
解决方法:将需要固定的BN层状态设置为eval。
问题示例:
环境:torch:1.7.0
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 3)
self.bn2 = nn.BatchNorm2d(16)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 5)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
def print_parameter_grad_info(net):
print('-------parameters requires grad info--------')
for name, p in net.named_parameters():
print(f'{name}:\t{p.requires_grad}')
def print_net_state_dict(net):
for key, v in net.state_dict().items():
print(f'{key}')
if __name__ == "__main__":
net = Net()
print_parameter_grad_info(net)
net.requires_grad_(False)
print_parameter_grad_info(net)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
运行结果:
-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])
可以看到:
net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。
调用print_net_state_dict可以看到BN层中的参数running_mean和running_var并没在可优化参数net.parameters中
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default
momentumof 0.1. source
因此在training phase时对BN层显式设置eval状态:
if __name__ == "__main__":
net = Net()
net.requires_grad_(False)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
net.bn1.eval()
net.bn2.eval()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
可以看到结果正常了:
epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
交流基地:630390733
pytorch固定BN层参数的更多相关文章
- 【转载】 【caffe转向pytorch】caffe的BN层+scale层=pytorch的BN层
原文地址: https://blog.csdn.net/u011668104/article/details/81532592 ------------------------------------ ...
- 【转载】 Caffe BN+Scale层和Pytorch BN层的对比
原文地址: https://blog.csdn.net/elysion122/article/details/79628587 ------------------------------------ ...
- (原)torch中微调某层参数
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...
- Tensorflow训练和预测中的BN层的坑
以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...
- 【转载】 Pytorch(1) pytorch中的BN层的注意事项
原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...
- Batch Normalization的算法本质是在网络每一层的输入前增加一层BN层(也即归一化层),对数据进行归一化处理,然后再进入网络下一层,但是BN并不是简单的对数据进行求归一化,而是引入了两个参数λ和β去进行数据重构
Batch Normalization Batch Normalization是深度学习领域在2015年非常热门的一个算法,许多网络应用该方法进行训练,并且取得了非常好的效果. 众所周知,深度学习是应 ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- BN层
论文名字:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论 ...
- 【卷积神经网络】对BN层的解释
前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...
随机推荐
- Contest 1428
A 移动次数是 \(\left|x_1-x_2\right|+\left|y_1-y_2\right|\). 如果 \(x_1\not=x_2\) 且 \(y_1\not=y_2\) 说明要换方向,两 ...
- canal部署
转载: https://blog.csdn.net/qq_30043755/article/details/83539116 最后的binlog最后被封装为这样的一个对象: com.alibaba.o ...
- Spring Framework 5.0简述
从Spring框架5.0开始,Spring需要JDK 8+ (Java SE 8+),并且已经为JDK 9提供了现成的支持. Spring框架还支持依赖注入(JSR 330)和通用注释(JSR 250 ...
- LeetCode 004 Median of Two Sorted Arrays
题目描述:Median of Two Sorted Arrays There are two sorted arrays A and B of size m and n respectively. F ...
- Springboot核心注解
运行文中的代码需要在项目构建中引入springboot 相关依赖. ① @configuration configuration,用来将bean加入到ioc容器.代替传统xml中的bean配置.代码示 ...
- 20200116_centos7.2 下 mysql_5.7修改root密码
1. 需改my.cnf文件 [root@rakinda-iot-platform ~]# vim /etc/my.cnf 2. 新增一行, 登录时跳过密码, 保存后退出, 重启mysql system ...
- vue--axios分装
封装: import axios from 'axios' axios.defaults.baseURL="http://127.0.0.1:8000/" axios.defaul ...
- Visual Studio 2012 Ultimate旗舰版下载地址与序列号
(为了方便个人使用转的的别的帖子的内容,原文链接http://wenku.baidu.com/link?url=acL08J8bTNQ4S5Sd3n3oLN5KJTtrfe8hHuP8aUrNscKN ...
- Python中函数的参数带星号是什么意思?
参数带星号表示支持可变不定数量的参数,这种方法叫参数收集. 星号又可以带1个或2个,带1个表示按位置来收集参数,带2个星号表示按关键字来收集参数. 1.带一个星号的参数收集模式: 这种模式是在函数定义 ...
- Python中的迭代是什么意思?
Python中的迭代是指通过重复执行的代码处理相似的数据集的过程,并且本次迭代的处理数据要依赖上一次的结果继续往下做,上一次产生的结果为下一次产生结果的初始状态,如果中途有任何停顿,都不能算是迭代. ...