keras允许自定义Layer层, 大大方便了一些复杂操作的实现. 也方便了一些novel结构的复用, 提高搭建模型的效率.

实现方法

通过继承keras.engine.Layer类, 重写其中的部分方法, 实现层的自定义. 主要需要实现的方法及其意义有:

  • _ init _(self, **kwargs)

    作为类的初始化方法, 一般将需要传入的自定义参数存为对象的属性. 需要注意的有以下几点:

    • 由于继承Layer类, 所以在处理完自定义的参数之后, 仍可能还有参数需要父类处理, 所以需要调用父类的初始化方法, 将kwargs参数传入:

      class DecayingDropout(Layer):
      def __init__(self, initial_keep_rate=1., decay_interval=10000, decay_rate=0.977,
      noise_shape=None, seed=None, **kwargs):
      super(DecayingDropout, self).__init__(**kwargs)
    • 对象的self.supports_masking方法的作用是本层中是否涉及到使用mask或对mask矩阵进行计算. mask的作用是屏蔽传入Tensor的部分值, 常常在NLP问题中, 对句子padding之后, 不想让填补的0值对应的位置参与运算而使用. 这个参数默认为False, 如果有使用到, 需要将其值为True:

      self.supports_masking = True
  • build(self, input_shape, **kwargs)

    这里是定义权重的地方, 需要注意的有以下几点:

    • 通过self.add_weight方法定义权重, 且需要将权重存为类的属性, 例如:

      self.iterations = self.add_weight(name='iterations', shape=(1,), dtype=K.floatx(),
      initializer='zeros', trainable=False)

      其中self.iterations需要在初始化时设置为None, 符合类编程的习惯. self.add_weight方法有若干参数, 常用的即为上面几个.

    • 由于要求build方法必须设置self.built = True , 而这个方法在父类中实现, 因此, 在方法的最后需要调用:

      super(DecayingDropout, self).build(input_shape)
  • call(self, inputs, **kwargs)

    这里是编写层的功能逻辑的地方, 传入的第一个参数即输入张量, 即调用_ call _方法传入的张量. 除此之外, 需要注意的点有:

    • 如果需要在计算的过程中使用mask, 则需要传入mask参数:

      def call(self, x, mask=None):
      if mask is not None:
      mask = K.repeat(mask, x.shape[-1])
      mask = tf.transpose(mask, [0,2,1])
      mask = K.cast(mask, K.floatx())
      x = x * mask
      return K.sum(x, axis=self.axis) / K.sum(mask, axis=self.axis)
      else:
      return K.mean(x, axis=self.axis)
    • 如果该层在训练和预测时的行为不一样(如Dropout)函数, 需要传入指定参数training, 即使用布尔值指定调用的环境. 例如在Dropout层的源码中, call方法是这样实现的:

      def call(self, inputs, training=None):
      if 0. < self.rate < 1.:
      noise_shape = self._get_noise_shape(inputs) def dropped_inputs():
      return K.dropout(inputs, self.rate, noise_shape,
      seed=self.seed)
      return K.in_train_phase(dropped_inputs, inputs,
      training=training)

      K.in_train_phase() 方法就是用来区别在不同环境调用时, 返回的不同值的. 这个函数通过training参数区别调用环境, 如果是训练环境, 则返回第一个参数对应的结果, 预测环境则返回第二个参数对应的结果. 可以传入函数, 返回这个函数对应的返回结果.

    • 除了计算之外, 这个函数也是更新层内参数的地方, 即build方法中增加的参数. 通过self.add_update方法进行更新, 例如:

      def call(sekf, x):
      self.add_update([K.moving_average_update(self.moving_mean, mean,self.momentum),
      K.moving_average_update(self.moving_variance,variance,self.momentum)],
      inputs)

      或者:

      def call(self, inputs, training=None):
      self.add_update([K.update_add(self.iterations, [1])], inputs)

      可以看到, self.add_update方法传入一个列表, 包含一些列更新的动作. 这些更新的动作需要借助K的一些函数实现, 如K.moving_average_update, K.update_add等等.

      另外还可以传入inputs函数, 作为更新的前提条件.


除此之外, 还有一些常常需要重新定义的方法:

  • get_config(self):

    返回层的一些参数. 对于自定义的参数, 需要在此指定返回:

    def get_config(self):
    config = {'initial_keep_rate': self.initial_keep_rate,
    'decay_interval': self.decay_interval,
    'decay_rate': self.decay_rate,
    'noise_shape': self.noise_shape,
    'seed': self.seed}
    base_config = super(DecayingDropout, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
  • compute_output_shape(input_shape) :

    计算输出shape. input_shape是输入数据的shape.

  • compute_mask(self, input, input_mask=None):

    计算输出的mask, 其中input_mask为输入的mask. 需要注意的有:

    • 如果input_mask为None, 说明上一层没有mask. 可以在本层创建一个新的mask矩阵.

    • 如果以后的层不需要使用mask, 返回None即可, 之后就不存在mask矩阵了

      def compute_mask(self, input, input_mask=None):
      # need not to pass the mask to next layers
      return None
    • 如果经过本层, mask矩阵没有变化, 不用实现该函数, 只需要在初始化时, 指定self.supports_masking = True即可.

参考资料

编写你自己的Keras层

Keras编写自定义层--以GroupNormalization为例

Keras自定义实现带masking的meanpooling层

Keras实现支持masking的Flatten层

Layer层自定义的更多相关文章

  1. 如果layer层在iframe下不居中滚动

    需要在layer前面加上parent.layer. 2.运用layer层的步骤: 1.引入1.8版本以上的jquery文件 <script type="text/javascript& ...

  2. 当music-list向上滑动的时候,设置layer层,随其滚动,覆盖图片,往下滚动时候,图片随着展现出来

    1.layer层代码: <div class="bg-layer" ref="layer"></div> 2.在mounted()的时候 ...

  3. 3.4 常用的两种 layer 层 3.7 字体与文本

    3.4 常用的两种 layer 层  //在cocos2d-x中,经常使用到的两种 layer 层 : CCLayer 和 CCLayerColor //CCLayer 的创建 CCLayer* la ...

  4. caffe layer层cpp、cu调试经验和相互关系

    对于layer层的cpp文件,你可以用LOG和printf.cout进行调试,cu文件不能使用LOG,可以使用cout,printf. 对于softmaxloss的layer层,既有cpp文件又有cu ...

  5. ZBrush中Layer层笔刷介绍

    本文我们来介绍ZBrush®中的Layer层笔刷,该笔刷是一种类似梯田效果的笔刷,常用来制作鳞甲和花纹图腾.他还可以用一个固定的数值抬高或降低模型的表面,当笔刷在重合时,笔画重叠部分不会再次位移,这使 ...

  6. layer层、modal模拟窗 单独测试页面

    layer_test.jsp <%@ page language="java" import="java.util.*" pageEncoding=&qu ...

  7. layer.alert自定义关闭回调事件

    在项目应用中,遇到自定义关闭layer.alert弹出层,即在关闭layer.alert时,可以自动触发关闭时的事件, 具体方法为: layer.alert('爱心提示!', function(){ ...

  8. [Cocos2d-x For WP8]Layer 层

        层(CCLayer) 从概念上说,层就是场景里的背景. CCLayer同样是CCNode的子类,通常用addChild方法添加子节点.CCLayer对象定义了可描绘的区域,定义了描绘的规则.C ...

  9. 非常好的分页组建layPage和 layer层特效

    http://layer.layui.com/ http://sentsin.com/layui/laypage/

随机推荐

  1. paramiko 模块 ---- python2.7

    模拟远程执行命令: ? 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import paramiko   #设置日志记录 paramiko ...

  2. 按钮改变和控制div的形状的html,JavaScript代码

    <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  3. 第一篇 Scrum冲刺博客

    一.Alpha任务认领 冯荣新 任务 预计时间 搜索框 0.5h 首页轮播图 0.5h 分类导航 2h 商品列表 2h 商品详情轮播图 0.5h 商品底部工具栏 1h 购物车列表 1.5h 购物车工具 ...

  4. Linux教学资源服务器构建

    1. 需求分析 1.1 课题简介 随着计算机互联网的迅速发展,大多数学校已经实现教学的信息化,从传统的黑板教学方式转变为现阶段的多媒体教学,教学的资源,素材课件,甚至学生的作业也都实现数字化,为了实现 ...

  5. SpringBoot启动注解源码流程学习总结

  6. [CSP-S2019]括号树 题解

    CSP-S2 2019 D1T2 刚开考的时候先大概浏览了一遍题目,闻到一股浓浓的stack气息 调了差不多1h才调完,加上T1用了1.5h+ 然而T3还是没写出来,滚粗 思路分析 很容易想到的常规操 ...

  7. Java 的开发效率究竟比 C++ 高在哪里?

    有几个原因     大师助手解决你的烦恼1. 语言上,Java是一个比C++更容易parse得多的语言,所以相应的工具链IDE会更容易做,无论多大的Java的项目,就是新手写完都不会有编译错误.但是写 ...

  8. Ellxir

    API: elixir https://hexdocs.pm/elixir/Module.html#content API: erlang http://www.cnerlang.com/api.ht ...

  9. TinkPHP5.1开发注意事项

    1.新下载的框架文件,开发前先开启调试配置 config目录下app.php文件 // 应用调试模式 'app_debug'              => true, 2.每新建一个方法,都要 ...

  10. 剑指 Offer 56 - I. 数组中数字出现的次数

    题目描述 一个整型数组 nums 里除两个数字之外,其他数字都出现了两次.请写程序找出这两个只出现一次的数字.要求时间复杂度是\(O(n)\),空间复杂度是\(O(1)\). 示例1: 输入:nums ...