如何可视化深度学习网络中Attention层
前言
在训练深度学习模型时,常想一窥网络结构中的attention层权重分布,观察序列输入的哪些词或者词组合是网络比较care的。在小论文中主要研究了关于词性POS对输入序列的注意力机制。同时对比实验采取的是words的self-attention机制。

效果
下图主要包含两列:word_attention是self-attention机制的模型训练结果,POS_attention是词性模型的训练结果。
可以看出,相对于word_attention,POS的注意力机制不仅能够捕捉到评价的aspect,也能根据aspect关联的词借助情感语义表达的词性分布,care到相关词性的情感词。

核心代码
可视化样例
# coding: utf-8
def highlight(word, attn):
html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
return '<span style="background-color: {}">{}</span>'.format(html_color, word)
def mk_html(seq, attns):
html = ""
for ix, attn in zip(seq, attns):
html += ' ' + highlight(
ix,
attn
)
return html + "<br>"
from IPython.display import HTML, display
batch_size = 1
seqs = [["这", "是", "一个", "测试", "样例", "而已"]]
attns = [[0.01, 0.19, 0.12, 0.7, 0.2, 0.1]]
for i in range(batch_size):
text = mk_html(seqs[i], attns[i])
display(HTML(text))
接入model
需要在model的返回列表中,添加attention_weight的输出,理论上维度应该和输入序列的长度是一致的。
# load model
import torch
# if you train on gpu, you need to move onto cpu
model = torch.load("../docs/model_chk/2018-11-07-02:45:37", map_location=lambda storage, location: storage)
from torch.autograd import Variable
for batch_idx, samples in enumerate(test_loader, 0):
v_word = Variable(samples['word_vec'])
v_final_label = samples['top_label']
model.eval()
final_probs, att_weight = model(v_word, v_pos)
batch_words = toWords(samples["word_vec"].numpy(), idx_word) # id转化为word
batch_att = getAtten(batch_words, att_weight.data.numpy()) # 去除padding词,根据words的长度截取attention
labels = toLabel(samples['top_label'].numpy()) # 真实标签
pre_labels = toLabel(final_probs.data.numpy() >= 0.5) # 预测标签
for i in range(len(batch_words)):
text = mk_html(batch_words[i], batch_att[i])
print(labels[i], pre_labels[i])
display(HTML(text))
总结
- 建议把可视化独立出来,用jupyter-notebook编辑,方便分段调试和copy;同时因为是借助html渲染的,所以需要notebook
- 项目代码我后期后同步到github上,欢迎一起交流
如何可视化深度学习网络中Attention层的更多相关文章
- 深度学习网络中numpy多维数组的说明
目前在计算机视觉中应用的数组维度最多有四维,可以表示为 (Batch_size, Row, Column, Channel) 以下将要从二维数组到四维数组进行代码的简单说明: Tips: 1) 在nu ...
- 利用Tengine在树莓派上跑深度学习网络
树莓派是国内比较流行的一款卡片式计算机,但是受限于其硬件配置,用树莓派玩深度学习似乎有些艰难.最近OPENAI为嵌入式设备推出了一款AI框架Tengine,其对于配置的要求相比传统框架降低了很多,我尝 ...
- <深度学习优化策略-3> 深度学习网络加速器Weight Normalization_WN
前面我们学习过深度学习中用于加速网络训练.提升网络泛化能力的两种策略:Batch Normalization(Batch Normalization)和Layer Normalization(LN). ...
- 训练深度学习网络时候,出现Nan是什么原因,怎么才能避免?——我自己是因为data有nan的坏数据,clear下解决
from:https://www.zhihu.com/question/49346370 Harick 梯度爆炸了吧. 我的解决办法一般以下几条:1.数据归一化(减均值,除方差,或者加入n ...
- 【神经网络与深度学习】chainer边运行边定义的方法使构建深度学习网络变的灵活简单
Chainer是一个专门为高效研究和开发深度学习算法而设计的开源框架. 这篇博文会通过一些例子简要地介绍一下Chainer,同时把它与其他一些框架做比较,比如Caffe.Theano.Torch和Te ...
- 寻找下一款Prisma APP:深度学习在图像处理中的应用探讨(阅读小结)
原文链接:https://yq.aliyun.com/articles/61941?spm=5176.100239.bloglist.64.UPL8ec 某会议中的一篇演讲,主要讲述深度学习在图像领域 ...
- 自己动手实现深度学习框架-7 RNN层--GRU, LSTM
目标 这个阶段会给cute-dl添加循环层,使之能够支持RNN--循环神经网络. 具体目标包括: 添加激活函数sigmoid, tanh. 添加GRU(Gate Recurrent U ...
- caffe深度学习网络(.prototxt)在线可视化工具:Netscope Editor
http://ethereon.github.io/netscope/#/editor 网址:http://ethereon.github.io/netscope/#/editor 将.prototx ...
- 深度学习网络压缩模型方法总结(model compression)
两派 1. 新的卷机计算方法 这种是直接提出新的卷机计算方式,从而减少参数,达到压缩模型的效果,例如SqueezedNet,mobileNet SqueezeNet: AlexNet-level ac ...
随机推荐
- 【前端词典】这些功能其实不需要 JS,CSS 就能搞定
前言 今天我们大家介绍一些你可能乍一眼以为一定需要 JavaScript 才能完成的功能,其实 CSS 就能完成,甚至更加简单. 内容已经发布在 gitHub 了,欢迎围观 Star,更多文章都在 g ...
- effective-java学习笔记---使用限定通配符来增加 API 的灵活性31
在你的 API 中使用通配符类型,虽然棘手,但使得 API 更加灵活. 如果编写一个将被广泛使用的类库,正确使用通配符类型应该被认为是强制性的. 记住基本规则: producer-extends, c ...
- Python第十章-模块和包
模块和包 我们以前的代码都是写在一个文件中, 而且代码也比较短. 假设我们现在要写一个大的系统, 不可能把代码只写到一个文件中, 迫切想把代码写到不同的文件中, 并且能够在一个文件使用另一个文件中代码 ...
- iOS 缩小 ipa 大小
一.爱奇艺 爱奇艺移动应用优化之路:如何让崩溃率小于千分之二 iOS8 对于 App 的 text 段有 60MB 的限制: 超过 200MB 的 App 需要连接 WIFI 下载(之前是 150MB ...
- Reactor模式和Proactor模式
Reactor 主线程往epoll内核事件表中注册socket上的读就绪事件 主线程调用epoll_wait等待socket上有数据可读 当socket上有数据可读时,epoll_wait通知主线程, ...
- Jupyter修改主题,字体,字号-教程
cmd控制台安装主题工具包:主题更换工具详解 pip install --upgrade jupyterthemes 查看可用主题: jt -l 设定主题: jt -t 主题名称 恢复默认主题: jt ...
- Java 对象容器
一.ArrayList 容器 1.记事本 package booknote; import java.util.ArrayList; public class NoteBook { private A ...
- Spring-Cloud-Netflix-Eureka注册中心
TOC 概述 eureka是Netflix的子模块之一,也是一个核心的模块 eureka里有2个组件: 一个是EurekaServer(一个独立的项目) 这个是用于定位服务以实现中间层服务器的负载平衡 ...
- js之for与forEach循环的区别
回武汉打卡第四天,武汉加油,逆战必胜!今天咱们探讨一下for循环和forEach()循环的区别. 首先,for循环在最开始执行循环的时候,会建立一个循环变量i,之后每次循环都是操作这个变量,也就是说它 ...
- Pointer Lock API(3/3):一个Demo
简单的Demo演练 点击跳转至Code Pen以查看演示和源码 完整代码 <!DOCTYPE HTML> <html lang="en-US"> <h ...