背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。

原因:未固定主分支BN层中的running_meanrunning_var

解决方法:将需要固定的BN层状态设置为eval

问题示例

环境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F class Net(nn.Module): def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 3)
self.bn2 = nn.BatchNorm2d(16)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 5) def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features def print_parameter_grad_info(net):
print('-------parameters requires grad info--------')
for name, p in net.named_parameters():
print(f'{name}:\t{p.requires_grad}') def print_net_state_dict(net):
for key, v in net.state_dict().items():
print(f'{key}') if __name__ == "__main__":
net = Net() print_parameter_grad_info(net)
net.requires_grad_(False)
print_parameter_grad_info(net) torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32) # print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
pre = net(train_data)
# 计算损失和参数更新等
# .... # test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)

运行结果:

-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。

调用print_net_state_dict可以看到BN层中的参数running_meanrunning_var并没在可优化参数net.parameters

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase时对BN层显式设置eval状态:

if __name__ == "__main__":
net = Net()
net.requires_grad_(False) torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32) # print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
net.bn1.eval()
net.bn2.eval()
pre = net(train_data)
# 计算损失和参数更新等
# .... # test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)

可以看到结果正常了:

epoch:0 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])

交流基地:630390733

pytorch固定BN层参数的更多相关文章

  1. 【转载】 【caffe转向pytorch】caffe的BN层+scale层=pytorch的BN层

    原文地址: https://blog.csdn.net/u011668104/article/details/81532592 ------------------------------------ ...

  2. 【转载】 Caffe BN+Scale层和Pytorch BN层的对比

    原文地址: https://blog.csdn.net/elysion122/article/details/79628587 ------------------------------------ ...

  3. (原)torch中微调某层参数

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...

  4. Tensorflow训练和预测中的BN层的坑

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  5. 【转载】 Pytorch(1) pytorch中的BN层的注意事项

    原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...

  6. Batch Normalization的算法本质是在网络每一层的输入前增加一层BN层(也即归一化层),对数据进行归一化处理,然后再进入网络下一层,但是BN并不是简单的对数据进行求归一化,而是引入了两个参数λ和β去进行数据重构

    Batch Normalization Batch Normalization是深度学习领域在2015年非常热门的一个算法,许多网络应用该方法进行训练,并且取得了非常好的效果. 众所周知,深度学习是应 ...

  7. PyTorch模型读写、参数初始化、Finetune

    使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...

  8. BN层

    论文名字:Batch Normalization: Accelerating Deep Network Training by  Reducing Internal Covariate Shift 论 ...

  9. 【卷积神经网络】对BN层的解释

    前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...

随机推荐

  1. 牛逼哄哄的PageHelper分页插件到底是怎么实现的?网友:给我10分钟,给你写一个~

    Hi,各位读者们 PageHelper是一款好用的开源免费的Mybatis第三方物理分页插件,其实我并不想加上好用两个字,但是为了表扬插件作者开源免费的崇高精神,我毫不犹豫的加上了好用一词作为赞美. ...

  2. leetcode 108 和leetcode 109

    //感想:有时候啊,对于一道题目,如果知道那个点在哪,就会非常简单,比如说这两题,将有序的数组转换为二叉搜索树, 有几个点: 1.二叉搜索树:对于某个节点,它的左节点小于它,它的右节点大于它,这是二叉 ...

  3. php bypass disable_function 命令执行 方法汇总简述

    1.使用未被禁用的其他函数 exec,shell_exec,system,popen,proc_open,passthru (python_eval?perl_system ? weevely3 wi ...

  4. redlock分布式锁真的安全吗

    此文是对http://zhangtielei.com/posts/blog-redlock-reasoning-part2.html文章的个人归纳,如有问题请联系删除 什么是redlock redlo ...

  5. python画猫并打包成EXE文件

    因python自带有海龟画图库,尝试给爱猫的小仙女来画个猫咪. 1.代码如下 from turtle import * #两个函数用于画心 def curvemove(): for i in rang ...

  6. Meetings S 题解

    题目描述 题目链接 有两个牛棚位于一维数轴上的点 \(0\) 和 \(L\) 处.同时有 \(N\) 头奶牛位于数轴上不同的位置(将牛棚和奶牛看作点).每头奶牛 \(i\) 初始时位于某个位置 \(x ...

  7. 蓝桥杯——复数幂 (2018JavaAB组第3题)

    18年Java蓝桥杯A组第3题和B组是一样的. 第三题往往比较难. 复数幂 (18JavaAB3) (A.B两卷第三题一样) 设i为虚数单位.对于任意正整数n,(2+3i)^n 的实部和虚部都是整数. ...

  8. 【mq读书笔记】Index索引文件

    1.IndexHeader头部,40字节,记录IndexFile的统计信息: begainTimestamp:该索引文件中包含消息的最小存储时间 endTimestamp:该索引文件中包含消息的最大存 ...

  9. Alpha冲刺——代码规范,冲刺计划

    这个作业属于哪个课程 https://edu.cnblogs.com/campus/fzzcxy/2018SE2/ 这个作业要求在哪里 https://edu.cnblogs.com/campus/f ...

  10. dubbo源码学习(二)dubbo容器启动流程简略分析

    dubbo版本2.6.3 继续之前的dubbo源码阅读,从com.alibaba.dubbo.container.Main.main(String[] args)作为入口 简单的数据一下启动的流程 1 ...