Reshape

对于的张量x,x.shape=(a, b, c, d)的情况

若调用keras.layer.Reshape(target_shape=(-1, c, d)),

处理后的张量形状为(?, ?, c, d)

若调用tf.reshape(x, shape=[-1, c, d])

处理后的张量形状为(a*b, c, d)

为了在keras代码中实现tf.reshape的效果,用lambda层做,

调用Lambda(lambda x: tf.reshape(x, shape=[-1, c, d]))(x)

nice and cool.

输出Attention的打分

这里,我们希望attention层能够输出attention的score,而不只是计算weighted sum。

在使用时

score = Attention()(x)

weighted_sum = MyMerge()([score, x])

class Attention(Layer):
def __init__(self, **kwargs):
super(Attention, self).__init__(**kwargs) def build(self, input_shape):
assert len(input_shape) == 3
self.w = self.add_weight(name="attention_weight",
shape=(input_shape[-1],
input_shape[-1]),
initializer='uniform',
trainable=True
) self.b = self.add_weight(name="attention_bias",
shape=(input_shape[-1],),
initializer='uniform',
trainable=True
) self.v = self.add_weight(name="attention_v",
shape=(input_shape[-1], 1),
initializer='uniform',
trainable=True
)
super(Attention, self).build(input_shape) def call(self, inputs):
x = inputs
att = K.tanh(K.dot(x, self.w) + self.b)
att = K.softmax(K.dot(att, self.v))
print(att.shape)
return att def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[1], 1 class MyMerge(Layer):
def __init__(self, **kwargs):
super(MyMerge, self).__init__(**kwargs) def call(self, inputs):
att = inputs[0]
x = inputs[1]
att = tf.tile(att, [1, 1, x.shape[-1]])
outputs = tf.multiply(att, x)
outputs = K.sum(outputs, axis=1)
return outputs def compute_output_shape(self, input_shape):
return input_shape[1][0], input_shape[1][2]

keras中Model的嵌套

这边是转载自https://github.com/uhauha2929/examples/blob/master/Hierarchical%20Attention%20Networks%20.ipynb

可以看到,sentEncoder是Model类型,在后面的时候通过TimeDistributed(sentEncoder),当成一个层那样被调用。

embedding_layer = Embedding(len(word_index) + 1,
EMBEDDING_DIM,
input_length=MAX_SENT_LENGTH) sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sentence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
sentEncoder = Model(sentence_input, l_lstm) review_input = Input(shape=(MAX_SENTS,MAX_SENT_LENGTH), dtype='int32')
review_encoder = TimeDistributed(sentEncoder)(review_input)
l_lstm_sent = Bidirectional(LSTM(100))(review_encoder)
preds = Dense(2, activation='softmax')(l_lstm_sent)
model = Model(review_input, preds)

Keras实现Hierarchical Attention Network时的一些坑的更多相关文章

  1. Hierarchical Attention Based Semi-supervised Network Representation Learning

    Hierarchical Attention Based Semi-supervised Network Representation Learning 1. 任务 给定:节点信息网络 目标:为每个节 ...

  2. Dual Attention Network for Scene Segmentation

    Dual Attention Network for Scene Segmentation 原始文档 https://www.yuque.com/lart/papers/onk4sn 在本文中,我们通 ...

  3. 语义分割之Dual Attention Network for Scene Segmentation

    Dual Attention Network for Scene Segmentation 在本文中,我们通过 基于自我约束机制捕获丰富的上下文依赖关系来解决场景分割任务.       与之前通过多尺 ...

  4. Paper | Residual Attention Network for Image Classification

    目录 1. 相关工作 2. Residual Attention Network 2.1 Attention残差学习 2.2 自上而下和自下而上 2.3 正则化Attention 最近看了些关于att ...

  5. A Survey of Model Compression and Acceleration for Deep Neural Network时s

    A Survey of Model Compression and Acceleration for Deep Neural Network时s 本文全面概述了深度神经网络的压缩方法,主要可分为参数修 ...

  6. 注意力机制---Attention、local Attention、self Attention、Hierarchical attention

    一.编码-解码架构 目的:解决语音识别.机器翻译.知识问答等输出输入序列长度不相等的任务. C是输入的一个表达(representation),包含了输入序列的有效信息. 它可能是一个向量,也可能是一 ...

  7. Residual Attention Network for Image Classification(CVPR 2017)详解

    一.Residual Attention Network 简介 这是CVPR2017的一篇paper,是商汤.清华.香港中文和北邮合作的文章.它在图像分类问题上,首次成功将极深卷积神经网络与人类视觉注 ...

  8. 论文解读(FedGAT)《Federated Graph Attention Network for Rumor Detection》

    论文信息 论文标题:Federated Graph Attention Network for Rumor Detection论文作者:Huidong Wang, Chuanzheng Bai, Ji ...

  9. 关于pyinstaller打包程序时设置icon时的一个坑

    关于pyinstaller打包程序时设置icon时的一个坑     之前在用pyinstaller打包程序的时候遇到了关于设置图标的一点小问题,无论在后面加--icon 或是-i都出现报错.查了下st ...

随机推荐

  1. mysql quick query row count using sql

    1. command show table status like '{table-name}'; 2. sample mysql> use inventory; Database change ...

  2. 【CSS3练习】transform 2d变形实例练习

    transform 2d变形实例练习:练习了旋转 倾斜 缩放的功能 <!DOCTYPE html> <html lang="en"> <head> ...

  3. .net core 使用SignalR实现实时通信

    这几天在研究SignalR,网上大部分的例子都是聊天室,我的需求是把服务端的信息发送给前端展示.并且需要实现单个用户推送. 用户登录我用的是ClaimsIdentity,这里就不多解释,如果不是很了解 ...

  4. 安装Windows和Ubuntu双系统

    发现关注消息 安装Windows和Ubuntu双系统     安装Windows和Ubuntu双系统 0.552016.12.10 15:54:41字数 2101阅读 6644 这几天开始动手做毕设啦 ...

  5. RabbitMQ官方教程四 Routing(GOLANG语言实现)

    在上一教程中,我们构建了一个简单的日志记录系统. 我们能够向许多消费者广播日志消息. 在本教程中,我们将向其中添加功能-我们将使仅订阅消息的子集成为可能. 例如,我们将只能将严重错误消息定向到日志文件 ...

  6. python解包

    概念 python的解包可以这样来理解:把元素给拆分并把其赋值给自己所需要的变量,因此元素应该是一个可迭代对象. 形式 简单版本 下面展示的是解包的基本形式,根据长度赋值给对应多的变量. name_l ...

  7. Appium移动自动化测试-----(十)appium API 之上下文操作

    其实上下文的操作主要针对于混合应用.啥是混合应用,简单来说就是APP用里面嵌入网页.Android上的浏览器就属于混合应用. 1.获取当前上下文 方法: getContext() 获取当前所有的可用的 ...

  8. 使用Wifi pineapple(菠萝派)进行Wi-Fi钓鱼攻击

    简介: WiFi Pineapple 是由国外无线安全审计公司Hak5开发并售卖的一款无线安全测试神器. 特性: 用作 Wi-Fi 中间人攻击测试平台 一整套的针对 AP 的渗透测试套件 基于 WEB ...

  9. websockify文档

    一.官网地址 地址:https://github.com/novnc/websockify 二.开启代理 1.单台服务器 python /usr/local/websockify/websockify ...

  10. 记一次EFCore类型转换错误及解决方案

    一  背景 今天在使用EntityFrameworkCore 查询的时候在调试的时候总是提示如下错误:Unable to cast object of type 'System.Data.SqlTyp ...