keras自定义网络层
在深度学习领域,Keras是一个高度封装的库并被广泛应用,可以通过调用其内置网络模块(各种网络层)实现针对性的模型结构;当所需要的网络层功能不被包含时,则需要通过自定义网络层或模型实现。
如何在keras框架下自定义层,基本“套路”如下。
一般地,keras中的网络层是一个类,所以自定义层即编写一个类,更为重要的是这个类(即自定义层)需要继承Layer父类,而且需要实现以下四种方法:
- __init __ (self, output_dim, **kwargs)
这个方法是用来初始化并自定义自定义层所需的属性,比如output_dim;
此外,该方法需要执行super().__init __(**kwargs),这行代码是执行Layer类中的初始化函数;
当执行上述代码就没有必要去管input_shape,weights,trainable等关键字参数,因为父类(Layer)的初始化函数实现了它们与layer实例的绑定。
- build(self, input_shape)
这个方法是用来创建层的权重;
在该方法中,根据之前的继承,通过Layer类的add_weight方法来自定义并添加一个权重矩阵,这个方法需要input_shape参数;
该方法必须设self.built = True,目的是为了保证这个层的权重定义函数build被执行过了;
在built函数中,需要说明这个权重各方面的属性,比如shape、初始化方式以及可训练性等信息。
- call(self, x)
这个方法是用来编写层的功能逻辑;
在该方法中,需要关注传入call的第一个参数:输入张量x;x只能是一种形式变量,不能是具体的变量,即它不能被定义;
这个call函数就是该层的计算逻辑,当创建好这个层实例后,该实例可以执行call函数;
可见,这个层的核心应该是一段符号式的输入张量到输出张量的计算过程。
- compute_output_shape(self, input_shape)
这个方法是用来保证输出shape是正确的;
这里重写compute_output_shape方法去覆盖父类中的同名方法,来保证输出的shape符合实际;
父类Layer中的compute_output_shape方法直接返回的是input_shape这明显是不对的,所以需要重写该方法。
示例
结合官方文档的例子,给出如下一个自定义层的代码:

使用自定义层,就如同使用keras内置网络层一样,如下图所示:(另外,本例使用kears内置的激活函数层ReLU承接自定义层的输出,从而避免将激活函数的功能加入到自定义层中)

keras自定义网络层的更多相关文章
- Keras自定义评估函数
1. 比较一般的自定义函数: 需要注意的是,不能像sklearn那样直接定义,因为这里的y_true和y_pred是张量,不是numpy数组.示例如下: from keras import backe ...
- Keras 自定义层
1.对于简单的定制操作,可以通过使用layers.core.Lambda层来完成.该方法的适用情况:仅对流经该层的数据做个变换,而这个变换本身没有需要学习的参数. # 切片后再分别进行embeddin ...
- keras 自定义 custom 函数
转自: https://kexue.fm/archives/4493/,感谢分享! Keras是一个搭积木式的深度学习框架,用它可以很方便且直观地搭建一些常见的深度学习模型.在tensorflow出来 ...
- 『开发技巧』Keras自定义对象(层、评价函数与损失)
1.自定义层 对于简单.无状态的自定义操作,你也许可以通过 layers.core.Lambda 层来实现.但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层. 这是一个 Keras2.0 ...
- pytorch自定义网络层以及损失函数
转自:https://blog.csdn.net/dss_dssssd/article/details/82977170 https://blog.csdn.net/dss_dssssd/articl ...
- keras自定义padding大小
1.keras卷积操作中border_mode的实现 def conv_output_length(input_length, filter_size, border_mode, stride): i ...
- 【TensorRT】自定义网络层的实现custom layers
参考 1. Extending TensorRT With Custom Layers; 2. TensorRT Samples: MNIST(Plugin, add a custom layer); ...
- keras中保存自定义层和loss
在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...
- [深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心)
[深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心) 配合阅读: [深度概念]·Attention机制概念学习笔记 [TensorFlow深度学习深入]实战三·分别使用 ...
随机推荐
- python的Counter类
python的Counter类 Counter 集成于 dict 类,因此也可以使用字典的方法,此类返回一个以元素为 key .元素个数为 value 的 Counter 对象集合 from coll ...
- MySQL调优之索引优化
一.索引基本知识 1.索引的优点 1.减少了服务器需要扫描的数据量 2.帮助服务器避免排序和临时表 例子: select * from emp orde by sal desc; 那么执行顺序: 所以 ...
- ES 2021 来了,详细解读5个新特性,附案例
ES 2021是世界上最受欢迎的编程语言的最新版本〜 本次迭代中包含了五个新特性,让我们来一睹为快. 1.全部替换replaceAll: js默认的replace 方法仅替换字符串中一个模式的第一个实 ...
- ProbabilityStatistics
class ProbabilityStatistics: @staticmethoddef simulation_of_probability(v, ratio=10000): assert v &g ...
- Centos7服务器安装Docker及Docker镜像加速,Docker删除
Centos7服务器安装Docker及Docker镜像加速,Docker删除 1.Centos7服务器安装Docker 1.1 root账户登录,查看内核版本如下 1.1.1 卸载服务器旧版本Dock ...
- Vue结合Element UI实战
创建工程 1. 创建一个名为hello-vue的工程 vue init webpack hello-vue 2. 安装依赖 需要安装 vue-router.element-ui.sass-loader ...
- DEDECMS:更改dede提示框标题
在include文件夹下,修改common.func.php文件里的 <title>DedeCMS提示信息</title> 修改成自己想要的标题. 将 <b>Ded ...
- 动态代理+静态代理+cglib代理 详解
代理定义:代理(Proxy):是一种设计模式,提供了对目标对象另外的访问方式;即通过代理对象访问目标对象.好处是:可以在目标对象实现的基础上,增强额外的功能操作,即扩展目标对象的功能. 动态代理+静态 ...
- ST在keil下开发时候文件options配置的一些小技巧
作者:良知犹存 转载授权以及围观:欢迎添加微信公众号:Conscience_Remains 总述 这是之前ST芯片载keil下开发时候总结的一些代码文件options配置小笔记,虽然不是很复杂 ...
- Kwp2000协议的应用(程序原理篇)
作者:良知犹存 转载授权以及围观:欢迎添加微信:becom_me 总述 接上篇文章Kwp2000协议的应用(硬件原理使用篇),本篇针对kwp2000协议标准的服务ID详细介绍,以及针对程序实现 ...