在深度学习领域,Keras是一个高度封装的库并被广泛应用,可以通过调用其内置网络模块(各种网络层)实现针对性的模型结构;当所需要的网络层功能不被包含时,则需要通过自定义网络层或模型实现。

如何在keras框架下自定义层,基本“套路”如下。

一般地,keras中的网络层是一个类,所以自定义层即编写一个类,更为重要的是这个类(即自定义层)需要继承Layer父类,而且需要实现以下四种方法:

  1. __init __ (self, output_dim, **kwargs)

这个方法是用来初始化并自定义自定义层所需的属性,比如output_dim;

此外,该方法需要执行super().__init __(**kwargs),这行代码是执行Layer类中的初始化函数;

当执行上述代码就没有必要去管input_shape,weights,trainable等关键字参数,因为父类(Layer)的初始化函数实现了它们与layer实例的绑定。

  1. build(self, input_shape)

这个方法是用来创建层的权重;

在该方法中,根据之前的继承,通过Layer类的add_weight方法来自定义并添加一个权重矩阵,这个方法需要input_shape参数;

该方法必须设self.built = True,目的是为了保证这个层的权重定义函数build被执行过了;

在built函数中,需要说明这个权重各方面的属性,比如shape、初始化方式以及可训练性等信息。

  1. call(self, x)

这个方法是用来编写层的功能逻辑;

在该方法中,需要关注传入call的第一个参数:输入张量x;x只能是一种形式变量,不能是具体的变量,即它不能被定义;

这个call函数就是该层的计算逻辑,当创建好这个层实例后,该实例可以执行call函数;

可见,这个层的核心应该是一段符号式的输入张量到输出张量的计算过程。

  1. compute_output_shape(self, input_shape)

这个方法是用来保证输出shape是正确的;

这里重写compute_output_shape方法去覆盖父类中的同名方法,来保证输出的shape符合实际;

父类Layer中的compute_output_shape方法直接返回的是input_shape这明显是不对的,所以需要重写该方法。

示例

结合官方文档的例子,给出如下一个自定义层的代码:

使用自定义层,就如同使用keras内置网络层一样,如下图所示:(另外,本例使用kears内置的激活函数层ReLU承接自定义层的输出,从而避免将激活函数的功能加入到自定义层中)

keras自定义网络层的更多相关文章

  1. Keras自定义评估函数

    1. 比较一般的自定义函数: 需要注意的是,不能像sklearn那样直接定义,因为这里的y_true和y_pred是张量,不是numpy数组.示例如下: from keras import backe ...

  2. Keras 自定义层

    1.对于简单的定制操作,可以通过使用layers.core.Lambda层来完成.该方法的适用情况:仅对流经该层的数据做个变换,而这个变换本身没有需要学习的参数. # 切片后再分别进行embeddin ...

  3. keras 自定义 custom 函数

    转自: https://kexue.fm/archives/4493/,感谢分享! Keras是一个搭积木式的深度学习框架,用它可以很方便且直观地搭建一些常见的深度学习模型.在tensorflow出来 ...

  4. 『开发技巧』Keras自定义对象(层、评价函数与损失)

    1.自定义层 对于简单.无状态的自定义操作,你也许可以通过 layers.core.Lambda 层来实现.但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层. 这是一个 Keras2.0  ...

  5. pytorch自定义网络层以及损失函数

    转自:https://blog.csdn.net/dss_dssssd/article/details/82977170 https://blog.csdn.net/dss_dssssd/articl ...

  6. keras自定义padding大小

    1.keras卷积操作中border_mode的实现 def conv_output_length(input_length, filter_size, border_mode, stride): i ...

  7. 【TensorRT】自定义网络层的实现custom layers

    参考 1. Extending TensorRT With Custom Layers; 2. TensorRT Samples: MNIST(Plugin, add a custom layer); ...

  8. keras中保存自定义层和loss

    在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...

  9. [深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心)

    [深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心) 配合阅读: [深度概念]·Attention机制概念学习笔记 [TensorFlow深度学习深入]实战三·分别使用 ...

随机推荐

  1. mysql半同步复制跟无损半同步区别

    mysql半同步复制跟无损半同步复制的区别: 无损复制其实就是对semi sync增加了rpl_semi_sync_master_wait_point参数,来控制半同步模式下主库在返回给会话事务成功之 ...

  2. 使用 .NETCore自带框架快速实现依赖注入

    Startup 在Startup的ConfigureServices()中配置DI的接口与其实现 public void ConfigureServices(IServiceCollection se ...

  3. Redisson 分布式锁实战与 watch dog 机制解读

    Redisson 分布式锁实战与 watch dog 机制解读 目录 Redisson 分布式锁实战与 watch dog 机制解读 背景 普通的 Redis 分布式锁的缺陷 Redisson 提供的 ...

  4. eNSP启动设备AR1失败记一次解决步骤

    eNSP稳定版本下载:   微信搜索公众号"疯刘小三" 关注后回复ensp即可获得下载链接地址 eNSP V100R002C00B510 Setup.exe 最近在用eNSp的时候 ...

  5. net.core.somaxconn net.ipv4.tcp_max_syn_backlog

    Linux参数-net.core.somaxconn与net.ipv4.tcp_max_syn_backlog_梁海江的博客-CSDN博客_net.ipv4.tcp_max_syn_backlog h ...

  6. (Oracle)DDL及其数据泵导入导出(impdp/expdp)

    create tablespace ybp_dev datafile 'G:\app\Administrator\oradata\health\ybp_dev1.dbf' size 10m autoe ...

  7. centos7-docker的安装过程

    一.卸载旧版本以及依赖(第一次安装忽略) sudo yum remove docker \ docker-client \ docker-client-latest \ docker-common \ ...

  8. (三)SpringBoot停止服务的方法

    SpringBoot停止服务的方法 第一种:actuator 第二种:context 第三种:进程号 第四种:SpringApplication.exit() 第五种:自定义Controller Sp ...

  9. Spark高级数据分析——纽约出租车轨迹的空间和时间数据分析

    Spark高级数据分析--纽约出租车轨迹的空间和时间数据分析 一.地理空间分析: 二.pom.xml 原文地址:https://www.jianshu.com/p/eb6f3e0c09b5 作者:II ...

  10. hibernate+spring+tomcat启动报错Not supported by BasicDataSource

    最近使用hibernate+spring+jsp的小项目制作过程中出现一些错误,例如: java.lang.UnsupportedOperationException: Not supported b ...