keras BatchNormalization 之坑
任务简述:最近做一个图像分类的任务, 一开始拿vgg跑一个baseline,输出看起来很正常:

随后,我尝试其他的一些经典的模型架构,比如resnet50, xception,但训练输出显示明显异常:

val_loss 一直乱蹦,val_acc基本不发生变化。
检查了输入数据没发现问题,因此怀疑是网络构造有问题, 对比了vgg同xception, resnet在使用layer上的异同,认为问题可能出在BN层上,将vgg添加了BN层之后再训练果然翻车。

翻看keras BN 的源码, 原来keras 的BN层的call函数里面有个默认参数traing, 默认是None。此参数意义如下:
training=False/0, 训练时通过每个batch的移动平均的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化
training=True/1/None,训练时通过当前batch的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化
当training=None时,训练和测试的批归一化方式不一致,导致validation的输出指标翻车。
当training=True时,拿训练完的模型预测一个样本和预测一个batch的样本的差异非常大,也就是预测的结果根据batch的大小会不同!导致模型结果无法准确评估!也是个坑!
用keras的BN时切记要设置training=False!!!
def build_model():
Inputs = Input(shape=intput_shape, name='input')
x_tmp = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(Inputs)
x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
x_tmp = BatchNormalization(x_tmp, training=False)
x_tmp = MaxPooling2D(pool_size=(2, 2))(x_tmp) x_tmp = Flatten()(x_tmp)
x_tmp = Dense(128, activation='relu')(x_tmp)
outputs = Dense(10, activation='softmax')(x_tmp)
model = Model(Inputs, outputs)
return model
参考:
https://arxiv.org/pdf/1502.03167v3.pdf
https://github.com/keras-team/keras/blob/master/keras/layers/normalization.py#L16
keras BatchNormalization 之坑的更多相关文章
- win10+anaconda安装tensorflow和keras遇到的坑小结
win10下利用anaconda安装tensorflow和keras的教程都大同小异(针对CPU版本,我的gpu是1050TI的MAX-Q,不知为啥一直没安装成功),下面简单说下步骤. 一 Anaco ...
- tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑
自定义tf.keras.Model需要注意的点 model.save() subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_form ...
- tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`
经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...
- keras用法
关于Keras的“层”(Layer) 所有的Keras层对象都有如下方法: layer.get_weights():返回层的权重(numpy array) layer.set_weights(weig ...
- 『计算机视觉』Mask-RCNN_推断网络其二:基于ReNet101的FPN共享网络暨TensorFlow和Keras交互简介
零.参考资料 有关FPN的介绍见『计算机视觉』FPN特征金字塔网络. 网络构架部分代码见Mask_RCNN/mrcnn/model.py中class MaskRCNN的build方法的"in ...
- [Tensorflow] 使用 Mask_RCNN 完成目标检测与实例分割,同时输出每个区域的 Feature Map
Mask_RCNN-2.0 网页链接:https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 Mask_RCNN-master(matter ...
- Windows 下安装 tensorflow & keras & opencv 的避坑指南!
安装 Anaconda3 关键的一步: conda update pip 下面再去安装各种你需要的包,一般不会再报错. pip install -U tensorflow pip install -U ...
- Keras实现Hierarchical Attention Network时的一些坑
Reshape 对于的张量x,x.shape=(a, b, c, d)的情况 若调用keras.layer.Reshape(target_shape=(-1, c, d)), 处理后的张量形状为(?, ...
- Keras + Flask 提供接口服务的坑~~~
最近在搞Keras,训练完的模型要提供个预测服务出来.就想了个办法,通过Flask提供一个http服务,后来发现也能正常跑,但是每次预测都需要加载模型,效率非常低. 然后就把模型加载到全局,每次要用的 ...
随机推荐
- ctfhub技能树—文件上传—文件头检查
打开靶机 尝试上传一个php文件 抓包修改 放包 制作图片马 上传图片马,并修改文件类型为png 测试连接 查找flag 成功拿到flag
- python中hmac模块的使用
hmac(hex-based message authentication code)算法在计算哈希的过程中混入了key(实际上就是加盐),和hashlib模块中的普通加密算法相比,它能够防止密码被撞 ...
- windows_myql 安装与卸载详细讲解,
windows_myql 安装 注意: 安装前把 所有杀毒软件,安全卫士等关闭. 打开下载的mysql安装文件双击解压缩,运行"mysql-5.5.40-win64.msi". 注 ...
- InnoDB的主键选择与插入优化
索引的存放方式MyISAM和InnoDB存储引擎在MySQL中,不同存储引擎对索引的实现方式是不同的,总结下MyISAM和InnoDB两个存储引擎的索引实现方式.MyISAM引擎使用B+Tree作为索 ...
- html简单基础
标签语法 标签的语法: <标签名 属性1="属性值1" 属性2="属性值2"-->内容部分</标签名> <标签名 属性1=&quo ...
- centos下解压rar文件,Linux解压tar.gz和tar.bz2的命令
1.下载:根据主机系统下载合适的版本,当前64为centos系统演示下载: wget http://www.rarlab.com/rar/rarlinux-x64-5.3.0.tar.gz 2.解压安 ...
- 成为一名优秀的Java程序员9+难以置信的公式
成为一名优秀的Java程序员 成为一名优秀的Java程序员并不重要,但是首先您应该了解基本的编程语言. 好吧,你知道那太好了.我们应该一步一步地精通Java编程,并应遵循所有说明,改进Java的编程逻 ...
- 慕课网金职位 Java工程师2020 百度网盘下载
百度网盘链接:https://pan.baidu.com/s/1xshLRO3ru0LAsQQ0pE67Qg 提取码:bh9f 如果失效加我微信:610060008[视频不加密,资料代码齐全,超清一手 ...
- 《》——8幅图图解Java机制
String对象不可改变的特性 String s = "abcd"; s = s.concat"ef"; equals()与hashCode()方法协作约定 H ...
- linux下安装 zookeeper-3.4.9并搭建集群环境
本文主要记录作者在实践过程中实现在centos7环境下安装zookeeper并搭建集群的详细步骤,关于zookeeper本文将不做详细介绍,安装步骤详情如下: 前提准备:3台linux服务器(因为zo ...