最近在学习SSD的源码,其中有两个自定的层,特此学习一下并记录。

 import keras.backend as K
from keras.engine.topology import InputSpec
from keras.engine.topology import Layer
import numpy as np class L2Normalization(Layer):
'''
Performs L2 normalization on the input tensor with a learnable scaling parameter
as described in the paper "Parsenet: Looking Wider to See Better" (see references)
and as used in the original SSD model. Arguments:
gamma_init (int): The initial scaling parameter. Defaults to 20 following the
SSD paper. Input shape:
4D tensor of shape `(batch, channels, height, width)` if `dim_ordering = 'th'`
or `(batch, height, width, channels)` if `dim_ordering = 'tf'`. Returns:
The scaled tensor. Same shape as the input tensor.
''' def __init__(self, gamma_init=20, **kwargs):
if K.image_dim_ordering() == 'tf':
self.axis = 3
else:
self.axis = 1
self.gamma_init = gamma_init
super(L2Normalization, self).__init__(**kwargs) def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
gamma = self.gamma_init * np.ones((input_shape[self.axis],))
self.gamma = K.variable(gamma, name='{}_gamma'.format(self.name))
self.trainable_weights = [self.gamma]
super(L2Normalization, self).build(input_shape) def call(self, x, mask=None):
output = K.l2_normalize(x, self.axis)
output *= self.gamma
return output

首先说一下这个层是用来做什么的。就是对于每一个通道进行归一化,不过通道使用的是不同的归一化参数,也就是说这个参数是需要进行学习的,因此需要通过 自定义层来完成。

在keras中,每个层都是对象,真的,可以通过dir(Layer对象)来查看具有哪些属性。

具体说来:

__init__():用来进行初始化的(这不是废话么),gamma就是要学习的参数。

bulid():是用来创建这层的权重向量的,也就是要学习的参数“壳”。

33:设置该层的input_spec,这个是通过InputSpec函数来实现。

34:分配权重“壳”的实际空间大小

35,:由于底层使用的Tensorflow来进行实现的,因此这里使用Tensorflow中的variable来保存变量。

36:根据keras官网的要求,可训练的权重是要添加至trainable_weights列表中的

37:我不想说了,官网给的实例都是这么做的。

call():用来进行具体实现操作的。

40:沿着指定的轴对输入数据进行L2正则化

41:使用学习的gamma来对正则化后的数据进行加权

42:将最后的数据最为该层的返回值,这里由于是和输入形式相同的,因此就没有了compute_output_shape函数,如果输入和输出的形式不同,就需要进行输入的调整。

就这样子吧。

keras中自定义Layer的更多相关文章

  1. keras中保存自定义层和loss

    在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...

  2. keras中的loss、optimizer、metrics

    用keras搭好模型架构之后的下一步,就是执行编译操作.在编译时,经常需要指定三个参数 loss optimizer metrics 这三个参数有两类选择: 使用字符串 使用标识符,如keras.lo ...

  3. keras中的mask操作

    使用背景 最常见的一种情况, 在NLP问题的句子补全方法中, 按照一定的长度, 对句子进行填补和截取操作. 一般使用keras.preprocessing.sequence包中的pad_sequenc ...

  4. iOS开发UI篇—CAlayer(自定义layer)

    iOS开发UI篇—CAlayer(自定义layer) 一.第一种方式 1.简单说明 以前想要在view中画东西,需要自定义view,创建一个类与之关联,让这个类继承自UIView,然后重写它的Draw ...

  5. iOS 自定义layer的两种方式

    在iOS中,你能看得见摸得着的东西基本都是UIView,比如一个按钮,一个标签,一个文本输入框,这些都是UIView: 其实UIView之所以能显示在屏幕上,完全是因为它内部的一个图层 在创建UIVi ...

  6. Keras中RNN不定长输入的处理--padding and masking

    在使用RNN based model处理序列的应用中,如果使用并行运算batch sample,我们几乎一定会遇到变长序列的问题. 通常解决变长的方法主要是将过长的序列截断,将过短序列用0补齐到一个固 ...

  7. Keras网络层之“关于Keras的层(Layer)”

    关于Keras的“层”(Layer) 所有的Keras层对象都有如下方法: layer.get_weights():返回层的权重(numpy array) layer.set_weights(weig ...

  8. IOS 自定义Layer(图层)

    方式1: @interface NJViewController () @end @implementation NJViewController - (void)viewDidLoad { [sup ...

  9. iOS开发UI篇—自定义layer

    一.第一种方式 1.简单说明 以前想要在view中画东西,需要自定义view,创建一个类与之关联,让这个类继承自UIView,然后重写它的DrawRect:方法,然后在该方法中画图. 绘制图形的步骤: ...

随机推荐

  1. MAC mysql install

    # linux 查看是否安装mysql   rpm -qa |grep mysql    yum 安装mysql      yum -y install mysql-server 1 download ...

  2. HIVE大数据出现倾斜怎么办

      hive在跑数据时经常会出现数据倾斜的情况,使的作业经常reduce完成在99%后一直卡住,最后的1%花了几个小时都没跑完,通过YARN的管理界面配合日志,可以清楚其中的具体原因,这种情况就很可能 ...

  3. 菩提树下的杨过.Net 的《hadoop 2.6全分布安装》补充版

    对菩提树下的杨过.Net的这篇博客<hadoop 2.6全分布安装>,我真是佩服的五体投地,我第一次见过教程能写的这么言简意赅,但是又能比较准确表述每一步做法的,这篇博客主要就是在他的基础 ...

  4. Python单元测试框架——unittest

    测试的常用规则 一个测试单元必须关注一个很小的功能函数,证明它是正确的: 每个测试单元必须是完全独立的,必须能单独运行.这样意味着每一个测试方法必须重新加载数据,执行完毕后做一些清理工作.通常通过se ...

  5. sql中1=1和1=0的用处

    where 1=1 where 1=1有什么用?在SQL语言中,写这么一句话就跟没写一样. select * from table1 where 1=1与select * from table1完全没 ...

  6. VMWare 安装时报错 tools-windows.msi failed报错解决办法

    1.我用的是7.1.3版本的,到官方网站上下载这个版本的tools安装包 http://softwareupdate.vmware.com/cds/vmw-desktop/ws/7.1.3/32428 ...

  7. Nginx rewrite配置

    rewrite应用 Rewrite模块设置及Wordpress和Discuz的示例.Nginx的Rewrite规则比Apache的简单灵活多了,从下面介绍可见一斑. rewrite配置 Nginx可以 ...

  8. MySQL性能优化之max_connections参数

    很多开发人员都会遇见”MySQL: ERROR 1040: Too many connections”的异常情况,造成这种情况的一种原因是访问量过高,MySQL服务器抗不住,这个时候就要考虑增加从服务 ...

  9. 利用Metasploit进行Linux提权

    利用Metasploit进行Linux提权 Metasploit 拥有msfpayload 和msfencode 这两个工具,这两个工具不但可以生成exe 型后门,一可以生成网页脚本类型的webshe ...

  10. 采用DoGet方式提交中文,乱码产生原因分析及解决办法

    前段时间某功能在测试机器上出现乱码,情况如下:   现象:           调试搜索功能时,通过doGet方法提交到后台的中文参数在本地和开发测试机器上为乱码(Action层),在测试人员测试机器 ...