在深度学习领域,Keras是一个高度封装的库并被广泛应用,可以通过调用其内置网络模块(各种网络层)实现针对性的模型结构;当所需要的网络层功能不被包含时,则需要通过自定义网络层或模型实现。

如何在keras框架下自定义层,基本“套路”如下。

一般地,keras中的网络层是一个类,所以自定义层即编写一个类,更为重要的是这个类(即自定义层)需要继承Layer父类,而且需要实现以下四种方法:

  1. __init __ (self, output_dim, **kwargs)

这个方法是用来初始化并自定义自定义层所需的属性,比如output_dim;

此外,该方法需要执行super().__init __(**kwargs),这行代码是执行Layer类中的初始化函数;

当执行上述代码就没有必要去管input_shape,weights,trainable等关键字参数,因为父类(Layer)的初始化函数实现了它们与layer实例的绑定。

  1. build(self, input_shape)

这个方法是用来创建层的权重;

在该方法中,根据之前的继承,通过Layer类的add_weight方法来自定义并添加一个权重矩阵,这个方法需要input_shape参数;

该方法必须设self.built = True,目的是为了保证这个层的权重定义函数build被执行过了;

在built函数中,需要说明这个权重各方面的属性,比如shape、初始化方式以及可训练性等信息。

  1. call(self, x)

这个方法是用来编写层的功能逻辑;

在该方法中,需要关注传入call的第一个参数:输入张量x;x只能是一种形式变量,不能是具体的变量,即它不能被定义;

这个call函数就是该层的计算逻辑,当创建好这个层实例后,该实例可以执行call函数;

可见,这个层的核心应该是一段符号式的输入张量到输出张量的计算过程。

  1. compute_output_shape(self, input_shape)

这个方法是用来保证输出shape是正确的;

这里重写compute_output_shape方法去覆盖父类中的同名方法,来保证输出的shape符合实际;

父类Layer中的compute_output_shape方法直接返回的是input_shape这明显是不对的,所以需要重写该方法。

示例

结合官方文档的例子,给出如下一个自定义层的代码:

使用自定义层,就如同使用keras内置网络层一样,如下图所示:(另外,本例使用kears内置的激活函数层ReLU承接自定义层的输出,从而避免将激活函数的功能加入到自定义层中)

keras自定义网络层的更多相关文章

  1. Keras自定义评估函数

    1. 比较一般的自定义函数: 需要注意的是,不能像sklearn那样直接定义,因为这里的y_true和y_pred是张量,不是numpy数组.示例如下: from keras import backe ...

  2. Keras 自定义层

    1.对于简单的定制操作,可以通过使用layers.core.Lambda层来完成.该方法的适用情况:仅对流经该层的数据做个变换,而这个变换本身没有需要学习的参数. # 切片后再分别进行embeddin ...

  3. keras 自定义 custom 函数

    转自: https://kexue.fm/archives/4493/,感谢分享! Keras是一个搭积木式的深度学习框架,用它可以很方便且直观地搭建一些常见的深度学习模型.在tensorflow出来 ...

  4. 『开发技巧』Keras自定义对象(层、评价函数与损失)

    1.自定义层 对于简单.无状态的自定义操作,你也许可以通过 layers.core.Lambda 层来实现.但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层. 这是一个 Keras2.0  ...

  5. pytorch自定义网络层以及损失函数

    转自:https://blog.csdn.net/dss_dssssd/article/details/82977170 https://blog.csdn.net/dss_dssssd/articl ...

  6. keras自定义padding大小

    1.keras卷积操作中border_mode的实现 def conv_output_length(input_length, filter_size, border_mode, stride): i ...

  7. 【TensorRT】自定义网络层的实现custom layers

    参考 1. Extending TensorRT With Custom Layers; 2. TensorRT Samples: MNIST(Plugin, add a custom layer); ...

  8. keras中保存自定义层和loss

    在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...

  9. [深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心)

    [深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心) 配合阅读: [深度概念]·Attention机制概念学习笔记 [TensorFlow深度学习深入]实战三·分别使用 ...

随机推荐

  1. Kubernetes 存储简介

    存储分类结构图 半持久化存储 1.EmptyDir EmptyDir是一个空目录,生命周期和所属的 Pod 是完全一致的,EmptyDir的用处是,可以在同一 Pod 内的不同容器之间共享工作过程中产 ...

  2. 从零搭建一个IdentityServer——集成Asp.net core Identity

    前面的文章使用Asp.net core 5.0以及IdentityServer4搭建了一个基础的验证服务器,并实现了基于客户端证书的Oauth2.0授权流程,以及通过access token访问被保护 ...

  3. HA工作机制

    HA工作机制 HA:高可用(7*24小时不中断服务) 主要的HA是针对集群的master节点的,即namenode和resourcemanager,毕竟DataNode挂掉之后影响 不是特别大,重启就 ...

  4. (转载)微软数据挖掘算法:Microsoft 时序算法(5)

    前言 本篇文章同样是继续微软系列挖掘算法总结,前几篇主要是基于状态离散值或连续值进行推测和预测,所用的算法主要是三种:Microsoft决策树分析算法.Microsoft聚类分析算法.Microsof ...

  5. 【PC Basic】CPU、核、多线程的那些事儿

    一.CPU与核的概念 1.半导体中名词[Wafer][Chip][Die]中文名字和用途 Wafer--晶圆 wafer 即为图片所示的晶圆,由纯硅(Si)构成.一般分为6英寸.8英寸.12英寸规格不 ...

  6. 深入浅出Java线程池:使用篇

    前言 很高兴遇见你~ 借助于很多强大的框架,现在我们已经很少直接去管理线程,框架的内部都会为我们自动维护一个线程池.例如我们使用最多的okHttp以及他的封装框架Retrofit,线程封装框架RxJa ...

  7. Kubernetes (yaml 文件详解)

    # yaml格式的pod定义文件完整内容:apiVersion: v1       #必选,版本号,例如v1kind: Pod       #必选,Podmetadata:       #必选,元数据 ...

  8. Jenkins安装部署项目

    Jenkins安装部署项目 配置JDK git maven 部署到服务器 一.新建任务 二.配置jenkins 三.添加构建信息 四.应用.保存 五.踩坑填坑记录 5.1没有jar包的情况 5.2无法 ...

  9. mysql int类型 int(11) 和int(2)区别

    CREATE TABLE `learn` ( `id` int(11) unsigned NOT NULL, `exp` int(2) DEFAULT 0, PRIMARY KEY (`id`)) E ...

  10. P1046 陶陶摘苹果 Python实现

    题目描述 陶陶家的院子里有一棵苹果树,每到秋天树上就会结出1010个苹果.苹果成熟的时候,陶陶就会跑去摘苹果.陶陶有个3030厘米高的板凳,当她不能直接用手摘到苹果的时候,就会踩到板凳上再试试. 现在 ...