前言

在训练深度学习模型时,常想一窥网络结构中的attention层权重分布,观察序列输入的哪些词或者词组合是网络比较care的。在小论文中主要研究了关于词性POS对输入序列的注意力机制。同时对比实验采取的是words的self-attention机制。

效果

下图主要包含两列:word_attention是self-attention机制的模型训练结果,POS_attention是词性模型的训练结果。

可以看出,相对于word_attention,POS的注意力机制不仅能够捕捉到评价的aspect,也能根据aspect关联的词借助情感语义表达的词性分布,care到相关词性的情感词。

核心代码

可视化样例

# coding: utf-8
def highlight(word, attn):
html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
return '<span style="background-color: {}">{}</span>'.format(html_color, word) def mk_html(seq, attns):
html = ""
for ix, attn in zip(seq, attns):
html += ' ' + highlight(
ix,
attn
)
return html + "<br>" from IPython.display import HTML, display
batch_size = 1
seqs = [["这", "是", "一个", "测试", "样例", "而已"]]
attns = [[0.01, 0.19, 0.12, 0.7, 0.2, 0.1]] for i in range(batch_size):
text = mk_html(seqs[i], attns[i])
display(HTML(text))

接入model

需要在model的返回列表中,添加attention_weight的输出,理论上维度应该和输入序列的长度是一致的。

# load model
import torch
# if you train on gpu, you need to move onto cpu
model = torch.load("../docs/model_chk/2018-11-07-02:45:37", map_location=lambda storage, location: storage) from torch.autograd import Variable
for batch_idx, samples in enumerate(test_loader, 0):
v_word = Variable(samples['word_vec'])
v_final_label = samples['top_label'] model.eval()
final_probs, att_weight = model(v_word, v_pos) batch_words = toWords(samples["word_vec"].numpy(), idx_word) # id转化为word
batch_att = getAtten(batch_words, att_weight.data.numpy()) # 去除padding词,根据words的长度截取attention
labels = toLabel(samples['top_label'].numpy()) # 真实标签
pre_labels = toLabel(final_probs.data.numpy() >= 0.5) # 预测标签 for i in range(len(batch_words)):
text = mk_html(batch_words[i], batch_att[i])
print(labels[i], pre_labels[i])
display(HTML(text))

总结

  • 建议把可视化独立出来,用jupyter-notebook编辑,方便分段调试和copy;同时因为是借助html渲染的,所以需要notebook
  • 项目代码我后期后同步到github上,欢迎一起交流

如何可视化深度学习网络中Attention层的更多相关文章

  1. 深度学习网络中numpy多维数组的说明

    目前在计算机视觉中应用的数组维度最多有四维,可以表示为 (Batch_size, Row, Column, Channel) 以下将要从二维数组到四维数组进行代码的简单说明: Tips: 1) 在nu ...

  2. 利用Tengine在树莓派上跑深度学习网络

    树莓派是国内比较流行的一款卡片式计算机,但是受限于其硬件配置,用树莓派玩深度学习似乎有些艰难.最近OPENAI为嵌入式设备推出了一款AI框架Tengine,其对于配置的要求相比传统框架降低了很多,我尝 ...

  3. <深度学习优化策略-3> 深度学习网络加速器Weight Normalization_WN

    前面我们学习过深度学习中用于加速网络训练.提升网络泛化能力的两种策略:Batch Normalization(Batch Normalization)和Layer Normalization(LN). ...

  4. 训练深度学习网络时候,出现Nan是什么原因,怎么才能避免?——我自己是因为data有nan的坏数据,clear下解决

    from:https://www.zhihu.com/question/49346370   Harick     梯度爆炸了吧. 我的解决办法一般以下几条:1.数据归一化(减均值,除方差,或者加入n ...

  5. 【神经网络与深度学习】chainer边运行边定义的方法使构建深度学习网络变的灵活简单

    Chainer是一个专门为高效研究和开发深度学习算法而设计的开源框架. 这篇博文会通过一些例子简要地介绍一下Chainer,同时把它与其他一些框架做比较,比如Caffe.Theano.Torch和Te ...

  6. 寻找下一款Prisma APP:深度学习在图像处理中的应用探讨(阅读小结)

    原文链接:https://yq.aliyun.com/articles/61941?spm=5176.100239.bloglist.64.UPL8ec 某会议中的一篇演讲,主要讲述深度学习在图像领域 ...

  7. 自己动手实现深度学习框架-7 RNN层--GRU, LSTM

    目标         这个阶段会给cute-dl添加循环层,使之能够支持RNN--循环神经网络. 具体目标包括: 添加激活函数sigmoid, tanh. 添加GRU(Gate Recurrent U ...

  8. caffe深度学习网络(.prototxt)在线可视化工具:Netscope Editor

    http://ethereon.github.io/netscope/#/editor 网址:http://ethereon.github.io/netscope/#/editor 将.prototx ...

  9. 深度学习网络压缩模型方法总结(model compression)

    两派 1. 新的卷机计算方法 这种是直接提出新的卷机计算方式,从而减少参数,达到压缩模型的效果,例如SqueezedNet,mobileNet SqueezeNet: AlexNet-level ac ...

随机推荐

  1. python学习之BeautifulSoup模块爬图

    BeautifulSoup模块爬图学习HTML文本解析标签定位网上教程多是爬mzitu,此网站反爬限制多了.随意找了个网址,解析速度有些慢.脚本流程:首页获取总页数-->拼接每页URL--> ...

  2. BIT-Count of Range Sum

    2019-12-17 18:56:56 问题描述: 问题求解: 本题个人感觉还是很有难度的,主要的难点在于如何将题目转化为bit计数问题. 首先构建一个presum数组,这个没有问题. 需要对于任意一 ...

  3. Worktile正式入驻飞书,助力企业轻松实现敏捷开发与协作

    企业在敏捷研发中时常面临着交付延期.需求不匹配等问题,如何更高效地完成敏捷研发? Worktile携手飞书,为企业用户提供敏捷开发服务,帮助企业实现软件项目的需求管理.缺陷追踪.迭代规划与推进以及效能 ...

  4. SpringBoot 集成Web

    1,静态资源访问: 在我们开发Web应用的时候,需要引用大量的js.css.图片等静态资源. 默认配置 Spring Boot默认提供静态资源目录位置需置于classpath下,目录名需符合如下规则: ...

  5. thinkphp5源码剖析系列1-类的自动加载机制

    前言 tp5想必大家都不陌生,但是大部分人都停留在应用的层面,我将开启系列随笔,深入剖析tp5源码,以供大家顺利进阶.本章将从类的自动加载讲起,自动加载是tp框架的灵魂所在,也是成熟php框架的必备功 ...

  6. 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST

    目录 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST 下载数据集 加载数据集 构建神经网络 反向传播(BP)算法 进行预测 F1验证 总结 参考 数据挖掘入门系 ...

  7. while实现2-3+4-5+6...+100 的和

    while实现2-3+4-5+6...+100 的和 可以看到规律为2-100内所有奇数都为减法,偶数为加法 设定变量 total=0: count=2 当count为偶数时与total相加,反则相减 ...

  8. HIT软件构造课程3.4总结(Object-Oriented Programming )

    上一节学习了ADT理论,这一节学习ADT的具体实现:OOP 1.基本概念:对象,类,属性,方法 对象 对象是状态和行为的捆绑.java中,状态=成员变量,行为=方法. 类 每个对象都定义了一个类,类定 ...

  9. Python 【基础面试题】

    前言 面试题仅做学习参考,学习者阅后也要用心钻研其中的原理,重要知识需要系统学习.透彻学习,形成自己的知识链.以下五点建议希望对您有帮助,早日拿到一份心仪的offer. 做好细节工作,细致的人运气不会 ...

  10. IOS部分APP使用burpsuite抓不到包原因

    曾经在ios12的时候,iphone通过安装burpsuite的ca证书并开启授权,还可以抓到包,升级到ios13后部分app又回到以前连上代理就断网的情况. 分析:ios(13)+burpsuite ...