keras BatchNormalization 之坑
任务简述:最近做一个图像分类的任务, 一开始拿vgg跑一个baseline,输出看起来很正常:

随后,我尝试其他的一些经典的模型架构,比如resnet50, xception,但训练输出显示明显异常:

val_loss 一直乱蹦,val_acc基本不发生变化。
检查了输入数据没发现问题,因此怀疑是网络构造有问题, 对比了vgg同xception, resnet在使用layer上的异同,认为问题可能出在BN层上,将vgg添加了BN层之后再训练果然翻车。

翻看keras BN 的源码, 原来keras 的BN层的call函数里面有个默认参数traing, 默认是None。此参数意义如下:
training=False/0, 训练时通过每个batch的移动平均的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化
training=True/1/None,训练时通过当前batch的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化
当training=None时,训练和测试的批归一化方式不一致,导致validation的输出指标翻车。
当training=True时,拿训练完的模型预测一个样本和预测一个batch的样本的差异非常大,也就是预测的结果根据batch的大小会不同!导致模型结果无法准确评估!也是个坑!
用keras的BN时切记要设置training=False!!!
def build_model():
Inputs = Input(shape=intput_shape, name='input')
x_tmp = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(Inputs)
x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
x_tmp = BatchNormalization(x_tmp, training=False)
x_tmp = MaxPooling2D(pool_size=(2, 2))(x_tmp) x_tmp = Flatten()(x_tmp)
x_tmp = Dense(128, activation='relu')(x_tmp)
outputs = Dense(10, activation='softmax')(x_tmp)
model = Model(Inputs, outputs)
return model
参考:
https://arxiv.org/pdf/1502.03167v3.pdf
https://github.com/keras-team/keras/blob/master/keras/layers/normalization.py#L16
keras BatchNormalization 之坑的更多相关文章
- win10+anaconda安装tensorflow和keras遇到的坑小结
win10下利用anaconda安装tensorflow和keras的教程都大同小异(针对CPU版本,我的gpu是1050TI的MAX-Q,不知为啥一直没安装成功),下面简单说下步骤. 一 Anaco ...
- tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑
自定义tf.keras.Model需要注意的点 model.save() subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_form ...
- tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`
经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...
- keras用法
关于Keras的“层”(Layer) 所有的Keras层对象都有如下方法: layer.get_weights():返回层的权重(numpy array) layer.set_weights(weig ...
- 『计算机视觉』Mask-RCNN_推断网络其二:基于ReNet101的FPN共享网络暨TensorFlow和Keras交互简介
零.参考资料 有关FPN的介绍见『计算机视觉』FPN特征金字塔网络. 网络构架部分代码见Mask_RCNN/mrcnn/model.py中class MaskRCNN的build方法的"in ...
- [Tensorflow] 使用 Mask_RCNN 完成目标检测与实例分割,同时输出每个区域的 Feature Map
Mask_RCNN-2.0 网页链接:https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 Mask_RCNN-master(matter ...
- Windows 下安装 tensorflow & keras & opencv 的避坑指南!
安装 Anaconda3 关键的一步: conda update pip 下面再去安装各种你需要的包,一般不会再报错. pip install -U tensorflow pip install -U ...
- Keras实现Hierarchical Attention Network时的一些坑
Reshape 对于的张量x,x.shape=(a, b, c, d)的情况 若调用keras.layer.Reshape(target_shape=(-1, c, d)), 处理后的张量形状为(?, ...
- Keras + Flask 提供接口服务的坑~~~
最近在搞Keras,训练完的模型要提供个预测服务出来.就想了个办法,通过Flask提供一个http服务,后来发现也能正常跑,但是每次预测都需要加载模型,效率非常低. 然后就把模型加载到全局,每次要用的 ...
随机推荐
- 【Oracle】to_data() to_char()用法解析
1.转换函数 与date操作关系最大的就是两个转换函数:to_date(),to_char() to_date() 作用将字符类型按一定格式转化为日期类型: 具体用法:to_dat ...
- C语言中左值和右值的区别(C语言学习笔记)
重要的内容要重复强调: C语言的术语Ivalue指用于识别或定位一个存储位置的标识符.( 注意:左值同时还必须是可改变的) 其实rvalue的发明完全是为了搭配lvalue , rvalue你可以理解 ...
- 一. SpringCloud简介与微服务架构
1. 微服务架构 1.1 微服务架构理解 微服务架构(Microservice Architecture)是一种架构概念,旨在通过将功能分解到各个离散的服务中以实现对解决方案的解耦.你可以将其看作是在 ...
- SAP系统跨平台字符编码转换
SAP系统在进行了夸平台的迁移,可能会遇到操作系统层文件编码不同,导致SAP系统无法识别或者乱码的问题.例如SAP系统从AIX平台迁移到linux平台,SAP应用服务器的编码会发生变化,从4102变化 ...
- InnoDB事务篇
1.解决数据更新丢失的问题 1)LBCC:基于锁的并发控制.让操作串行化执行.效率低. 2)MVCC:基于版本的并发控制.使用快照形式.效率高.读写不冲突.主流数据库都是使用的MVCC. 2.Inno ...
- ES 2021 来了,详细解读5个新特性,附案例
ES 2021是世界上最受欢迎的编程语言的最新版本〜 本次迭代中包含了五个新特性,让我们来一睹为快. 1.全部替换replaceAll: js默认的replace 方法仅替换字符串中一个模式的第一个实 ...
- 研发过程及工具支撑 DevOps 工具链集成
https://mp.weixin.qq.com/s/NYm63nkCymIV3DbL4O01dg 腾讯重新定义敏捷 |Q推荐 小智 InfoQ 2020-09-03 敏捷开发奠基人 Robert C ...
- 获取控制台的错误信息 onerror
js 获取控制台的错误信息 https://www.bbsmax.com/A/Vx5ML2NmJN/ <!DOCTYPE html> <html lang="en" ...
- goroutine 分析 协程的调度和执行顺序 并发写 run in the same address space 内存地址 闭包 存在两种并发 确定性 非确定性的 Go 的协程和通道理所当然的支持确定性的并发方式(
package main import ( "fmt" "runtime" "sync" ) const N = 26 func main( ...
- Understanding go.sum and go.mod file in Go
https://golangbyexample.com/go-mod-sum-module/ Understanding go.sum and go.mod file in Go (Golang) – ...