Keras实现Hierarchical Attention Network时的一些坑
Reshape
对于的张量x,x.shape=(a, b, c, d)的情况
若调用keras.layer.Reshape(target_shape=(-1, c, d)),
处理后的张量形状为(?, ?, c, d)
若调用tf.reshape(x, shape=[-1, c, d])
处理后的张量形状为(a*b, c, d)
为了在keras代码中实现tf.reshape的效果,用lambda层做,
调用Lambda(lambda x: tf.reshape(x, shape=[-1, c, d]))(x)
nice and cool.
输出Attention的打分
这里,我们希望attention层能够输出attention的score,而不只是计算weighted sum。
在使用时
score = Attention()(x)
weighted_sum = MyMerge()([score, x])
class Attention(Layer):
def __init__(self, **kwargs):
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 3
self.w = self.add_weight(name="attention_weight",
shape=(input_shape[-1],
input_shape[-1]),
initializer='uniform',
trainable=True
)
self.b = self.add_weight(name="attention_bias",
shape=(input_shape[-1],),
initializer='uniform',
trainable=True
)
self.v = self.add_weight(name="attention_v",
shape=(input_shape[-1], 1),
initializer='uniform',
trainable=True
)
super(Attention, self).build(input_shape)
def call(self, inputs):
x = inputs
att = K.tanh(K.dot(x, self.w) + self.b)
att = K.softmax(K.dot(att, self.v))
print(att.shape)
return att
def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[1], 1
class MyMerge(Layer):
def __init__(self, **kwargs):
super(MyMerge, self).__init__(**kwargs)
def call(self, inputs):
att = inputs[0]
x = inputs[1]
att = tf.tile(att, [1, 1, x.shape[-1]])
outputs = tf.multiply(att, x)
outputs = K.sum(outputs, axis=1)
return outputs
def compute_output_shape(self, input_shape):
return input_shape[1][0], input_shape[1][2]
keras中Model的嵌套
这边是转载自https://github.com/uhauha2929/examples/blob/master/Hierarchical%20Attention%20Networks%20.ipynb
可以看到,sentEncoder是Model类型,在后面的时候通过TimeDistributed(sentEncoder),当成一个层那样被调用。
embedding_layer = Embedding(len(word_index) + 1,
EMBEDDING_DIM,
input_length=MAX_SENT_LENGTH)
sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sentence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
sentEncoder = Model(sentence_input, l_lstm)
review_input = Input(shape=(MAX_SENTS,MAX_SENT_LENGTH), dtype='int32')
review_encoder = TimeDistributed(sentEncoder)(review_input)
l_lstm_sent = Bidirectional(LSTM(100))(review_encoder)
preds = Dense(2, activation='softmax')(l_lstm_sent)
model = Model(review_input, preds)
Keras实现Hierarchical Attention Network时的一些坑的更多相关文章
- Hierarchical Attention Based Semi-supervised Network Representation Learning
Hierarchical Attention Based Semi-supervised Network Representation Learning 1. 任务 给定:节点信息网络 目标:为每个节 ...
- Dual Attention Network for Scene Segmentation
Dual Attention Network for Scene Segmentation 原始文档 https://www.yuque.com/lart/papers/onk4sn 在本文中,我们通 ...
- 语义分割之Dual Attention Network for Scene Segmentation
Dual Attention Network for Scene Segmentation 在本文中,我们通过 基于自我约束机制捕获丰富的上下文依赖关系来解决场景分割任务. 与之前通过多尺 ...
- Paper | Residual Attention Network for Image Classification
目录 1. 相关工作 2. Residual Attention Network 2.1 Attention残差学习 2.2 自上而下和自下而上 2.3 正则化Attention 最近看了些关于att ...
- A Survey of Model Compression and Acceleration for Deep Neural Network时s
A Survey of Model Compression and Acceleration for Deep Neural Network时s 本文全面概述了深度神经网络的压缩方法,主要可分为参数修 ...
- 注意力机制---Attention、local Attention、self Attention、Hierarchical attention
一.编码-解码架构 目的:解决语音识别.机器翻译.知识问答等输出输入序列长度不相等的任务. C是输入的一个表达(representation),包含了输入序列的有效信息. 它可能是一个向量,也可能是一 ...
- Residual Attention Network for Image Classification(CVPR 2017)详解
一.Residual Attention Network 简介 这是CVPR2017的一篇paper,是商汤.清华.香港中文和北邮合作的文章.它在图像分类问题上,首次成功将极深卷积神经网络与人类视觉注 ...
- 论文解读(FedGAT)《Federated Graph Attention Network for Rumor Detection》
论文信息 论文标题:Federated Graph Attention Network for Rumor Detection论文作者:Huidong Wang, Chuanzheng Bai, Ji ...
- 关于pyinstaller打包程序时设置icon时的一个坑
关于pyinstaller打包程序时设置icon时的一个坑 之前在用pyinstaller打包程序的时候遇到了关于设置图标的一点小问题,无论在后面加--icon 或是-i都出现报错.查了下st ...
随机推荐
- 如何在Windows Server 2008 R2中更改桌面图标
如何在Windows Server 2008 R2中更改桌面图标 Windows Server 2008 R2 已经在 MSDN 和 TechNet Plus 订阅上公布,gOxiA 在第一时间下载并 ...
- vue-cli3 按需加载loading,服务的方式调用
安装 babel-plugin-component npm install babel-plugin-component -S 安装element-ui npm install element-ui ...
- SQL命令如何分发到集群的各节点
有些数据库集群的规模是很大的,有上百个节点,那么维护SQL命令如何快速分发给各个节点,例如:要加个字段,逐个节点操作那是十分低效,枯燥的. TreeSoft增加了[SQL分发]功能,简单配置,可以快速 ...
- windows server 2012 r2 无法安装 .net 3.5
服务器需安装SQL 2012 ,因需安装.net3.5,没有想到2012出于安全竟然不让手动安装,对于源文件也是把控比较严,折腾了好一会儿才解决问题 有参才一下powershell等安装命令,均失败. ...
- 《ucore lab1 exercise6》实验报告
资源 ucore在线实验指导书 我的ucore实验代码 题目:完善中断初始化和处理 请完成编码工作和回答如下问题: 中断描述符表(也可简称为保护模式下的中断向量表)中一个表项占多少字节?其中哪几位代表 ...
- centos 用户组操作
adduser testuser //新建testuser 用户 passwd testuser //给testuser 用户设置密码 useradd -g testgroup testuser // ...
- [转帖]C语言计算时间函数 & 理解linux time命令的输出中“real”“user”“sys”的真正含义
C语言计算时间函数 & 理解linux time命令的输出中“real”“user”“sys”的真正含义 https://blog.csdn.net/willyang519/article/d ...
- maven基础和基本使用
maven介绍 Maven是基于项目对象模型(POM project object model)实现的,可以通过一小段描述信息(配置)来管理项目的构建,报告和文档的软件项目管理工具. 具体作用: 项目 ...
- Python调用API接口的几种方式
Python调用API接口的几种方式 相信做过自动化运维的同学都用过API接口来完成某些动作.API是一套成熟系统所必需的接口,可以被其他系统或脚本来调用,这也是自动化运维的必修课. 本文主要介绍py ...
- Python【常用的数据类型】
int, float, string整数,浮点数,字符串----------------------------------------字符串(string)用引号括起来的文本 >>& ...