1.对于简单的定制操作,可以通过使用layers.core.Lambda层来完成。该方法的适用情况:仅对流经该层的数据做个变换,而这个变换本身没有需要学习的参数.

# 切片后再分别进行embedding和average pooling
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation,Reshape
from keras.layers import merge
from keras.utils import plot_model
from keras.layers import *
from keras.models import Model def get_slice(x, index):
return x[:, index] keep_num = 3
field_lens = 90
input_field = Input(shape=(keep_num, field_lens))
avg_pools = []
for n in range(keep_num):
block = Lambda(get_slice,output_shape=(1,field_lens),arguments={'index':n})(input_field)
x_emb = Embedding(input_dim=100, output_dim=200, input_length=field_lens)(block)
x_avg = GlobalAveragePooling1D()(x_emb)
avg_pools.append(x_avg)
output = concatenate([p for p in avg_pools])
model = Model(input_field, output)
plot_model(model, to_file='model/lambda.png',show_shapes=True) plt.figure(figsize=(21, 12))
im = plt.imread('model/lambda.png')
plt.imshow(im)

这里用Lambda定义了一个对张量进行切片操作的层

2.对于具有可训练权重的定制层,需要自己来实现。

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np class MyLayer(Layer): def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs) def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.output_dim),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # Be sure to call this somewhere! def call(self, x):
return K.dot(x, self.kernel) def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)

参考:

Writing your own Keras layers Keras官方文档中文文档

keras Lambda自定义层实现数据的切片,Lambda传参数

keras中自定义Layer

如何利用Keras的扩展性

Keras 自定义层的更多相关文章

  1. Keras处理已保存模型中的自定义层(或其他自定义对象)

    如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects 参数将它们传递给加载机制: from keras.models import load_model # 假设 ...

  2. 『开发技巧』Keras自定义对象(层、评价函数与损失)

    1.自定义层 对于简单.无状态的自定义操作,你也许可以通过 layers.core.Lambda 层来实现.但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层. 这是一个 Keras2.0  ...

  3. keras Lambda 层

    Lambda层 keras.layers.core.Lambda(function, output_shape=None, mask=None, arguments=None) 本函数用以对上一层的输 ...

  4. keras自定义网络层

    在深度学习领域,Keras是一个高度封装的库并被广泛应用,可以通过调用其内置网络模块(各种网络层)实现针对性的模型结构:当所需要的网络层功能不被包含时,则需要通过自定义网络层或模型实现. 如何在ker ...

  5. MXNET:深度学习计算-自定义层

    虽然 Gluon 提供了大量常用的层,但有时候我们依然希望自定义层.本节将介绍如何使用 NDArray 来自定义一个 Gluon 的层,从而以后可以被重复调用. 不含模型参数的自定义层 我们先介绍如何 ...

  6. 『MXNet』第四弹_Gluon自定义层

    一.不含参数层 通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward函数里, from mxnet import nd, gluon from ...

  7. Keras网络层之“关于Keras的层(Layer)”

    关于Keras的“层”(Layer) 所有的Keras层对象都有如下方法: layer.get_weights():返回层的权重(numpy array) layer.set_weights(weig ...

  8. 从头学pytorch(十一):自定义层

    自定义layer https://www.cnblogs.com/sdu20112013/p/12132786.html一文里说了怎么写自定义的模型.本篇说怎么自定义层. 分两种: 不含模型参数的la ...

  9. keras中保存自定义层和loss

    在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...

随机推荐

  1. springmvc.xml和applicationContext.xml配置的特点

    1:springmvc.xml配置要点 一般它主要配置Controller的组件扫描器和视图解析器 下为:springmvc.xml文件 <?xml version="1.0" ...

  2. log4j2按日志级别输出到指定文件

    在项目中,可能会产生非常多的日志记录,为了方便日志分析,一般可以将日志按级别输出到指定文件,本次就先说说log4j2的实现吧: 1.先加入log4j2依赖包 2.写一个java类进行测试,类文件中仅仅 ...

  3. ArrayBlockingQueue,LinkedBlockingQueue分析

    BlockingQueue接口定义了一种阻塞的FIFO queue,每一个BlockingQueue都有一个容量,让容量满时往BlockingQueue中添加数据时会造成阻塞,当容量为空时取元素操作会 ...

  4. 【Network architecture】Rethinking the Inception Architecture for Computer Vision(inception-v3)论文解析

    目录 0. paper link 1. Overview 2. Four General Design Principles 3. Factorizing Convolutions with Larg ...

  5. Lucene 更新、删除、分页操作以及IndexWriter优化

    更新操作如下: 注意:通过lukeall-1.0.0.jar 查看软件,我们可以看到,更新其实是先删除在插入, 前面我们知道索引库中有两部分的内容组成,一个是索引文件,另一个是目录文件, 目前我们更新 ...

  6. 关于javascript以及jquery如何打开文件

    其实很简单, <input type="file" id="file" mce_style="display:none"> 这个 ...

  7. javaScript tips —— z-index 对事件机制的影响

    demo // DOM结构 class App extends React.Component { componentDidMount() { const div1 = document.getEle ...

  8. js删除数组中某一项或几项的几种方法

    1:js中的splice方法 splice(index,len,[item])    注释:该方法会改变原始数组. splice有3个参数,它也可以用来替换/删除/添加数组内某一个或者几个值 inde ...

  9. w3c标准盒模型与IE传统模型的区别

    一.盒子模型(box model) 在HTML文档中的每个元素被描绘为矩形盒子.确定其大小,属性——比如颜色.背景.边框,及其位置是渲染引擎的目标. CSS下这些矩形盒子由标准盒模型描述.这个模型描述 ...

  10. 初识HTML和CSS

    HTML 1.一套规则,浏览器认识的规则. 2.开发者: 学习Html规则 开发后台程序: - 写Html文件(充当模板的作用) ****** - 数据库获取数据,然后替换到html文件的指定位置(W ...