keras允许自定义Layer层, 大大方便了一些复杂操作的实现. 也方便了一些novel结构的复用, 提高搭建模型的效率.

实现方法

通过继承keras.engine.Layer类, 重写其中的部分方法, 实现层的自定义. 主要需要实现的方法及其意义有:

  • _ init _(self, **kwargs)

    作为类的初始化方法, 一般将需要传入的自定义参数存为对象的属性. 需要注意的有以下几点:

    • 由于继承Layer类, 所以在处理完自定义的参数之后, 仍可能还有参数需要父类处理, 所以需要调用父类的初始化方法, 将kwargs参数传入:

      class DecayingDropout(Layer):
      def __init__(self, initial_keep_rate=1., decay_interval=10000, decay_rate=0.977,
      noise_shape=None, seed=None, **kwargs):
      super(DecayingDropout, self).__init__(**kwargs)
    • 对象的self.supports_masking方法的作用是本层中是否涉及到使用mask或对mask矩阵进行计算. mask的作用是屏蔽传入Tensor的部分值, 常常在NLP问题中, 对句子padding之后, 不想让填补的0值对应的位置参与运算而使用. 这个参数默认为False, 如果有使用到, 需要将其值为True:

      self.supports_masking = True
  • build(self, input_shape, **kwargs)

    这里是定义权重的地方, 需要注意的有以下几点:

    • 通过self.add_weight方法定义权重, 且需要将权重存为类的属性, 例如:

      self.iterations = self.add_weight(name='iterations', shape=(1,), dtype=K.floatx(),
      initializer='zeros', trainable=False)

      其中self.iterations需要在初始化时设置为None, 符合类编程的习惯. self.add_weight方法有若干参数, 常用的即为上面几个.

    • 由于要求build方法必须设置self.built = True , 而这个方法在父类中实现, 因此, 在方法的最后需要调用:

      super(DecayingDropout, self).build(input_shape)
  • call(self, inputs, **kwargs)

    这里是编写层的功能逻辑的地方, 传入的第一个参数即输入张量, 即调用_ call _方法传入的张量. 除此之外, 需要注意的点有:

    • 如果需要在计算的过程中使用mask, 则需要传入mask参数:

      def call(self, x, mask=None):
      if mask is not None:
      mask = K.repeat(mask, x.shape[-1])
      mask = tf.transpose(mask, [0,2,1])
      mask = K.cast(mask, K.floatx())
      x = x * mask
      return K.sum(x, axis=self.axis) / K.sum(mask, axis=self.axis)
      else:
      return K.mean(x, axis=self.axis)
    • 如果该层在训练和预测时的行为不一样(如Dropout)函数, 需要传入指定参数training, 即使用布尔值指定调用的环境. 例如在Dropout层的源码中, call方法是这样实现的:

      def call(self, inputs, training=None):
      if 0. < self.rate < 1.:
      noise_shape = self._get_noise_shape(inputs) def dropped_inputs():
      return K.dropout(inputs, self.rate, noise_shape,
      seed=self.seed)
      return K.in_train_phase(dropped_inputs, inputs,
      training=training)

      K.in_train_phase() 方法就是用来区别在不同环境调用时, 返回的不同值的. 这个函数通过training参数区别调用环境, 如果是训练环境, 则返回第一个参数对应的结果, 预测环境则返回第二个参数对应的结果. 可以传入函数, 返回这个函数对应的返回结果.

    • 除了计算之外, 这个函数也是更新层内参数的地方, 即build方法中增加的参数. 通过self.add_update方法进行更新, 例如:

      def call(sekf, x):
      self.add_update([K.moving_average_update(self.moving_mean, mean,self.momentum),
      K.moving_average_update(self.moving_variance,variance,self.momentum)],
      inputs)

      或者:

      def call(self, inputs, training=None):
      self.add_update([K.update_add(self.iterations, [1])], inputs)

      可以看到, self.add_update方法传入一个列表, 包含一些列更新的动作. 这些更新的动作需要借助K的一些函数实现, 如K.moving_average_update, K.update_add等等.

      另外还可以传入inputs函数, 作为更新的前提条件.


除此之外, 还有一些常常需要重新定义的方法:

  • get_config(self):

    返回层的一些参数. 对于自定义的参数, 需要在此指定返回:

    def get_config(self):
    config = {'initial_keep_rate': self.initial_keep_rate,
    'decay_interval': self.decay_interval,
    'decay_rate': self.decay_rate,
    'noise_shape': self.noise_shape,
    'seed': self.seed}
    base_config = super(DecayingDropout, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
  • compute_output_shape(input_shape) :

    计算输出shape. input_shape是输入数据的shape.

  • compute_mask(self, input, input_mask=None):

    计算输出的mask, 其中input_mask为输入的mask. 需要注意的有:

    • 如果input_mask为None, 说明上一层没有mask. 可以在本层创建一个新的mask矩阵.

    • 如果以后的层不需要使用mask, 返回None即可, 之后就不存在mask矩阵了

      def compute_mask(self, input, input_mask=None):
      # need not to pass the mask to next layers
      return None
    • 如果经过本层, mask矩阵没有变化, 不用实现该函数, 只需要在初始化时, 指定self.supports_masking = True即可.

参考资料

编写你自己的Keras层

Keras编写自定义层--以GroupNormalization为例

Keras自定义实现带masking的meanpooling层

Keras实现支持masking的Flatten层

Layer层自定义的更多相关文章

  1. 如果layer层在iframe下不居中滚动

    需要在layer前面加上parent.layer. 2.运用layer层的步骤: 1.引入1.8版本以上的jquery文件 <script type="text/javascript& ...

  2. 当music-list向上滑动的时候,设置layer层,随其滚动,覆盖图片,往下滚动时候,图片随着展现出来

    1.layer层代码: <div class="bg-layer" ref="layer"></div> 2.在mounted()的时候 ...

  3. 3.4 常用的两种 layer 层 3.7 字体与文本

    3.4 常用的两种 layer 层  //在cocos2d-x中,经常使用到的两种 layer 层 : CCLayer 和 CCLayerColor //CCLayer 的创建 CCLayer* la ...

  4. caffe layer层cpp、cu调试经验和相互关系

    对于layer层的cpp文件,你可以用LOG和printf.cout进行调试,cu文件不能使用LOG,可以使用cout,printf. 对于softmaxloss的layer层,既有cpp文件又有cu ...

  5. ZBrush中Layer层笔刷介绍

    本文我们来介绍ZBrush®中的Layer层笔刷,该笔刷是一种类似梯田效果的笔刷,常用来制作鳞甲和花纹图腾.他还可以用一个固定的数值抬高或降低模型的表面,当笔刷在重合时,笔画重叠部分不会再次位移,这使 ...

  6. layer层、modal模拟窗 单独测试页面

    layer_test.jsp <%@ page language="java" import="java.util.*" pageEncoding=&qu ...

  7. layer.alert自定义关闭回调事件

    在项目应用中,遇到自定义关闭layer.alert弹出层,即在关闭layer.alert时,可以自动触发关闭时的事件, 具体方法为: layer.alert('爱心提示!', function(){ ...

  8. [Cocos2d-x For WP8]Layer 层

        层(CCLayer) 从概念上说,层就是场景里的背景. CCLayer同样是CCNode的子类,通常用addChild方法添加子节点.CCLayer对象定义了可描绘的区域,定义了描绘的规则.C ...

  9. 非常好的分页组建layPage和 layer层特效

    http://layer.layui.com/ http://sentsin.com/layui/laypage/

随机推荐

  1. PowerJob 在线日志饱受好评的秘诀:小但实用的分布式日志系统

    本文适合有 Java 基础知识的人群 作者:HelloGitHub-Salieri HelloGitHub 推出的<讲解开源项目>系列. 项目地址: https://github.com/ ...

  2. e3mall商城的归纳总结3之后台商品节点、认识nginx

    一  后台商品节点 大家都知道后台创建商品的时候需要选择商品的分类,而这个商品的分类就就像一棵树一样,一层包含一层又包含一层.因此这里用的框架是easyUiTree.该分类前端使用的是异步加载模式(指 ...

  3. 你可能不了解的java枚举

    枚举在java里也算个老生长谈的内容了,每当遇到一组需要类举的数据时我们都会自然而然地使用枚举类型: public enum Color { RED, GREEN, BLUE, YELLOW; pub ...

  4. Nodejs模块:fs

    /** * @description fs模块常用api */ // fs所有的文件操作都是异步IO,如果要以同步的方式去调用,都会加一个在原同步api的基础上加Sync // 同步的方式会在最后传入 ...

  5. Content Security Policy (CSP)内容安全策略总结

    跨域脚本攻击 XSS 是最常见.危害最大的网页安全漏洞. 为了防止它们,要采取很多编程措施,非常麻烦.很多人提出,能不能根本上解决问题,浏览器自动禁止外部注入恶意脚本?这就是"网页安全政策& ...

  6. 从String中移除空白字符的多种方式!?

    字符串,是Java中最常用的一个数据类型了.我们在日常开发时候会经常使用字符串做很多的操作.比如字符串的拼接.截断.替换等. 这一篇文章,我们介绍一个比较常见又容易被忽略的一个操作,那就是移除字符串中 ...

  7. Q200510-01: 求部门工资最高的员工

    问题: 求部门工资最高的员工 Employee 表包含所有员工信息,每个员工有其对应的 Id, salary 和 department Id. +----+-------+--------+----- ...

  8. 20190923-04Linux用户管理命令 000 012

     useradd 添加新用户 1.基本语法 useradd 用户名 (功能描述:添加新用户) useradd -g 组名 用户名 (功能描述:添加新用户到某个组) 2.案例实操 (1)添加一个用户 [ ...

  9. Ajax跨域解决方案大全

    题纲 关于跨域,有N种类型,本文只专注于ajax请求跨域(,ajax跨域只是属于浏览器"同源策略"中的一部分,其它的还有Cookie跨域iframe跨域,LocalStorage跨 ...

  10. RabbitMQ和Kafka的高可用集群原理

    前言 小伙伴们,通过前边文章的阅读,相信大家已经对RocketMQ的基本原理有了一个比较深入的了解,那么大家对当前比较常用的RabbitMQ和Kafka是不是也有兴趣了解一些呢,了解的多一些也不是坏事 ...