pytorch固定BN层参数
背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。
原因:未固定主分支BN层中的running_mean和running_var。
解决方法:将需要固定的BN层状态设置为eval。
问题示例:
环境:torch:1.7.0
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 3)
self.bn2 = nn.BatchNorm2d(16)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 5)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
def print_parameter_grad_info(net):
print('-------parameters requires grad info--------')
for name, p in net.named_parameters():
print(f'{name}:\t{p.requires_grad}')
def print_net_state_dict(net):
for key, v in net.state_dict().items():
print(f'{key}')
if __name__ == "__main__":
net = Net()
print_parameter_grad_info(net)
net.requires_grad_(False)
print_parameter_grad_info(net)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
运行结果:
-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])
可以看到:
net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。
调用print_net_state_dict可以看到BN层中的参数running_mean和running_var并没在可优化参数net.parameters中
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default
momentumof 0.1. source
因此在training phase时对BN层显式设置eval状态:
if __name__ == "__main__":
net = Net()
net.requires_grad_(False)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
net.bn1.eval()
net.bn2.eval()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
可以看到结果正常了:
epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
交流基地:630390733
pytorch固定BN层参数的更多相关文章
- 【转载】 【caffe转向pytorch】caffe的BN层+scale层=pytorch的BN层
原文地址: https://blog.csdn.net/u011668104/article/details/81532592 ------------------------------------ ...
- 【转载】 Caffe BN+Scale层和Pytorch BN层的对比
原文地址: https://blog.csdn.net/elysion122/article/details/79628587 ------------------------------------ ...
- (原)torch中微调某层参数
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...
- Tensorflow训练和预测中的BN层的坑
以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...
- 【转载】 Pytorch(1) pytorch中的BN层的注意事项
原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...
- Batch Normalization的算法本质是在网络每一层的输入前增加一层BN层(也即归一化层),对数据进行归一化处理,然后再进入网络下一层,但是BN并不是简单的对数据进行求归一化,而是引入了两个参数λ和β去进行数据重构
Batch Normalization Batch Normalization是深度学习领域在2015年非常热门的一个算法,许多网络应用该方法进行训练,并且取得了非常好的效果. 众所周知,深度学习是应 ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- BN层
论文名字:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论 ...
- 【卷积神经网络】对BN层的解释
前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...
随机推荐
- get、post、
1.get请求 get请求会把参数放在url后面,中间用?隔开,也可以把参数放在请求body中,如果参数中有中文,http传的时候requests框架会将中文换成urlencode编码 2.get和p ...
- [题解] [USACO05JAN]Muddy Fields G
题目TP门 题目大意 在一个 \(R×C\) 的矩阵中,每个点有两个状态:草地和泥地.你需要在泥地里铺 \(1×k\) 木块, \(k\) 为任意整数,求最少要多少木块. 思路 两个横向木块不会互相干 ...
- jdk从1.8降到jdk1.7失败
1.将JAVA_HOME:的路径更改为1.7的相关路径,例如我的:C:\Java\jdk1.7.0_80 2.此时查看版本:Java -version,如果是1.8的版本,就把path路径下的%JAV ...
- macos brew zookeeper,安装后zookeeper启动失败?
一.Zookeeper安装流程 执行如下安装命令: brew install zookeeper 执行截图如下: 安装后查看 zookeeper 安装信息(默认拉取最新版本) brew info zo ...
- 基于CefSharp开发(二)自定义浏览器窗体
上一篇 https://www.cnblogs.com/mchao/p/13914726.html 简单了解了CefSharp引用配置但页面光秃秃的,这一篇着手开发简单浏览器窗体 一.Edge浏览器窗 ...
- scentos7安装redis,以及redis的主从配置
redis的安装 下载redis安装包 wget http://download.redis.io/releases/redis-4.0.6.tar.gz 解压压缩包 tar -zxvf redis- ...
- 记一次bug思考过程:HibernateException: Could not obtain transaction-synchronized Session for current thread
场景:把从客户端提交的任务放到线程池执行 异常:HibernateException: Could not obtain transaction-synchronized Session for cu ...
- moviepy音视频剪辑:使用fl_time进行时间特效处理报错ValueError: Attribute duration not set
专栏:Python基础教程目录 专栏:使用PyQt开发图形界面Python应用 专栏:PyQt+moviepy音视频剪辑实战 专栏:PyQt入门学习 老猿Python博文目录 老猿学5G博文目录 在使 ...
- 第8.15节 Python重写自定义类的__repr__方法
一. 引言 前面两节分别介绍了Python类中的__str__和__repr__方法的作用和语法,所有新式类都支持这两个方法,因为object类实现了这两个方法,但实际上各位开发者在自定义类的过程中, ...
- PostgreSQL 如何忽略事务中错误
在 PostgreSQL 的事务中:执行的SQL遇到错误(书写,约束限制):该事务的已经执行的SQL都会进行rollback.那如何忽略其中的错误.将SQL执行到底?在事务中设置 ON_ERROR_R ...