原文地址:

https://blog.csdn.net/weixin_40759186/article/details/87547795

---------------------------------------------------------------------------------------------------------------

用pytorch做dropout和BN时需要注意的地方

pytorch做dropout:

就是train的时候使用dropout,训练的时候不使用dropout,
pytorch里面是通过net.eval()固定整个网络参数,包括不会更新一些前向的参数,没有dropout,BN参数固定,理论上对所有的validation set都要使用net.eval()
net.train()表示会纳入梯度的计算。

net_dropped = torch.nn.Sequential(
torch.nn.Linear(1, N_HIDDEN),
torch.nn.Dropout(0.5), # drop 50% of the neuron
torch.nn.ReLU(),
torch.nn.Linear(N_HIDDEN, N_HIDDEN),
torch.nn.Dropout(0.5), # drop 50% of the neuron
torch.nn.ReLU(),
torch.nn.Linear(N_HIDDEN, 1),
)

for t in range(500):
pred_drop = net_dropped(x)
loss_drop = loss_func(pred_drop, y) optimizer_drop.zero_grad()
loss_drop.backward()
optimizer_drop.step() if t % 10 == 0:
# change to eval mode in order to fix drop out effect
net_dropped.eval() # parameters for dropout differ from train mode test_pred_drop = net_dropped(test_x) # change back to train mode
net_dropped.train()

pytorch做Batch Normalization:

net.eval()固定整个网络参数,固定BN的参数,moving_mean 和moving_var,不懂这个看下图:

            if self.do_bn:
bn = nn.BatchNorm1d(10, momentum=0.5)
setattr(self, 'bn%i' % i, bn) # IMPORTANT set layer to the Module
self.bns.append(bn) for epoch in range(EPOCH):
print('Epoch: ', epoch)
for net, l in zip(nets, losses):
net.eval() # set eval mode to fix moving_mean and moving_var
pred, layer_input, pre_act = net(test_x) net.train() # free moving_mean and moving_var
plot_histogram(*layer_inputs, *pre_acts)

moving_mean   和   moving_var

用tensorflow做dropout和BN时需要注意的地方

dropout和BN都有一个training的参数表明到底是train还是test, 表明test那dropout就是不dropout,BN就是固定住了BN的参数;

tf_is_training = tf.placeholder(tf.bool, None)  # to control dropout when training and testing

# dropout net
d1 = tf.layers.dense(tf_x, N_HIDDEN, tf.nn.relu)
d1 = tf.layers.dropout(d1, rate=0.5, training=tf_is_training) # drop out 50% of inputs

d2 = tf.layers.dense(d1, N_HIDDEN, tf.nn.relu)
d2 = tf.layers.dropout(d2, rate=0.5, training=tf_is_training) # drop out 50% of inputs

d_out = tf.layers.dense(d2, 1) for t in range(500):
sess.run([o_train, d_train], {tf_x: x, tf_y: y, tf_is_training: True}) # train, set is_training=True if t % 10 == 0:
# plotting
plt.cla()
o_loss_, d_loss_, o_out_, d_out_ = sess.run(
[o_loss, d_loss, o_out, d_out], {tf_x: test_x, tf_y: test_y, tf_is_training: False} # test, set is_training=False
)
    def add_layer(self, x, out_size, ac=None):
x = tf.layers.dense(x, out_size, kernel_initializer=self.w_init, bias_initializer=B_INIT)
self.pre_activation.append(x)
# the momentum plays important rule. the default 0.99 is too high in this case!
if self.is_bn: x = tf.layers.batch_normalization(x, momentum=0.4, training=tf_is_train) # when have BN
out = x if ac is None else ac(x)
return out
 

当BN的training的参数为train时,只是表示BN的参数是可变化的,并不是代表BN会自己更新moving_mean 和moving_var,因为这个操作是前向更新的op,在做train之前必须确保moving_mean 和moving_var更新了,更新moving_mean 和moving_var的操作在tf.GraphKeys.UPDATE_OPS

        # !! IMPORTANT !! the moving_mean and moving_variance need to be updated,
# pass the update_ops with control_dependencies to the train_op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train = tf.train.AdamOptimizer(LR).minimize(self.loss)

【转载】 深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方,的更多相关文章

  1. 深度学习面试题21:批量归一化(Batch Normalization,BN)

    目录 BN的由来 BN的作用 BN的操作阶段 BN的操作流程 BN可以防止梯度消失吗 为什么归一化后还要放缩和平移 BN在GoogLeNet中的应用 参考资料 BN的由来 BN是由Google于201 ...

  2. Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时

    Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时 在论文中,Capsule被Hinton大神定义为这样一组神经元:其活动向量所表示的是特定实体类型 ...

  3. 深度学习基础系列(九)| Dropout VS Batch Normalization? 是时候放弃Dropout了

    Dropout是过去几年非常流行的正则化技术,可有效防止过拟合的发生.但从深度学习的发展趋势看,Batch Normalizaton(简称BN)正在逐步取代Dropout技术,特别是在卷积层.本文将首 ...

  4. windows10环境下安装深度学习环境anaconda+pytorch+CUDA+cuDDN

    步骤零:安装anaconda.opencv.pytorch(这些不详细说明).复制运行代码,如果没有报错,说明已经可以了.不过大概率不行,我的会报错提示AssertionError: Torch no ...

  5. 常用深度学习框架(keras,pytorch.cntk,theano)conda 安装--未整理

    版本查询 cpu tensorflow conda env list source activate tensorflow python import tensorflow as tf 和 tf.__ ...

  6. 深度学习之入门Pytorch(1)------基础

    目录: Pytorch数据类型:Tensor与Storage 创建张量 tensor与numpy数组之间的转换 索引.连接.切片等 Tensor操作[add,数学运算,转置等] GPU加速 自动求导: ...

  7. 【深度学习】基于Pytorch的ResNet实现

    目录 1. ResNet理论 2. pytorch实现 2.1 基础卷积 2.2 模块 2.3 使用ResNet模块进行迁移学习 1. ResNet理论 论文:https://arxiv.org/pd ...

  8. 动手学深度学习11- 多层感知机pytorch简洁实现

    多层感知机的简洁实现 定义模型 读取数据并训练数据 损失函数 定义优化算法 小结 多层感知机的简洁实现 import torch from torch import nn from torch.nn ...

  9. 动手学深度学习8-softmax分类pytorch简洁实现

    定义和初始化模型 softamx和交叉熵损失函数 定义优化算法 训练模型 import torch from torch import nn from torch.nn import init imp ...

随机推荐

  1. [LeetCode] 113. Path Sum II ☆☆☆(二叉树所有路径和等于给定的数)

    LeetCode 二叉树路径问题 Path SUM(①②③)总结 Path Sum II leetcode java 描述 Given a binary tree and a sum, find al ...

  2. Linux查看用户属于哪些组/查看用户组下有哪些用户

    一.关于/etc/group格式的讨论 在说/etc/group格式的时候,网上很多文章都会说是“组名:组密码:组ID:组下用户列表”,这说法对了解/etc/group格式是没问题的,但如果碰到“查看 ...

  3. Mybatis排序无效问题解决

    在 mybatis 的 xml中,为一个SQL语句配置order by 子句时,需要这个排序的字段是前端传递过来的,而且排序的顺序(升序 OR 降序)也是由前端传递过来的.对于这种需求,我起初写成了下 ...

  4. Mysql索引引起的死锁

    提到索引,首先想到的是效率提高,查询速度提升,不知不觉都会有一种心理趋向,管它三七二十一,先上个索引提高一下效率..但是索引其实也是暗藏杀机的... 今天压测带优化项目,开着Jmeter高并发访问项目 ...

  5. 变量和基本类型——复合类型,const限定符,处理类型

    一.复合类型 复合类型是指基于其他类型定义的类型.C++语言有几种复合类型,包括引用和指针. 1.引用 引用并非对象,它只是为一个已存在的对象所起的另外一个名字. 除了以下2种情况,其他所有引用的类型 ...

  6. Spring Boot + Spring Cloud 实现权限管理系统 (集成 Shiro 框架)

    Apache Shiro 优势特点 它是一个功能强大.灵活的,优秀开源的安全框架. 它可以处理身份验证.授权.企业会话管理和加密. 它易于使用和理解,相比Spring Security入门门槛低. 主 ...

  7. React中禁止chrome填充密码表单

    当 input 的 type="password" 时,chrome浏览器会以 type="password" 为标识记住输入的用户名和密码, 如果chrome ...

  8. 银联支付 Asp.Net 对接开发内容简介

    银联对接开发主要包含测试环境以及生产环境两部分. 其中程序开发部分测试以及生产是相同的. 不同的是,测试环境与生产环境请求支付的Url地址,以及分别使用的证书不同. 一.配置部分 1,测试环境证书获取 ...

  9. python3 线程池-threadpool模块与concurrent.futures模块

    多种方法实现 python 线程池 一. 既然多线程可以缩短程序运行时间,那么,是不是线程数量越多越好呢? 显然,并不是,每一个线程的从生成到消亡也是需要时间和资源的,太多的线程会占用过多的系统资源( ...

  10. HTML(五)选择器--伪类选择器

    HTML代码 <body> <a href="www.baidu.com">www.baidu.com</a> </body> CS ...