pytorch笔记:09)Attention机制
刚从图像处理的hole中攀爬出来,刚走一步竟掉到了另一个hole(fire in the hole*▽*)
1.RNN中的attention
pytorch官方教程:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
首先,RNN的输入大小都是(1,1,hidden_size),即batch=1,seq_len=1,hidden_size=embed_size,相对于传统的encoder-decoder模型,attention机制仅在decoder处有所不同。下面具体看看:
1>保存了rnn每个词向量对应隐藏层的输出状态(encoder_outputs),用于decoder的attention机制
#train代码部分
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(
input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
1
2
3
4
5
2>AttnDecoderRNN的forward
1.输入的input经过embed
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
1
2
2.获取关于输入的attention权重,这里的Q=decoder_rnn的input,K=decoder_rnn的隐藏元
2.1求Q和K相似度的方法有很多,这里让全连接层自己来学习,把embedded和hidden连接在一起经过fc层(部分修改了下)
similarity=self.attn(torch.cat((embedded[0], hidden[0]), 1))
1
2.2 经过softmax获得归一化的权重
attn_weights = F.softmax(similarity, dim=1)
1
3.权重应用于encoder输出的所有词对应的词向量上(对应相乘即可)->获得attention结果
attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))
1
4.把attention结果和decoder的输入cat在一起,使用1个全连接层来融合二者,最终生成带注意力机制的词向量
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
1
2
5.根据decoder的上一个输出单词来预测下一个单词,这里多插一句,decoder的首个输入为起始标志符’sos’,其根据encode最后的隐藏元来预测第一个单词,后面依次类推。
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
1
2
3
4
2.transformer中的attention
“Attention is All You Need”(霸气标题),pytorch代码推荐2篇:
哈佛大学NLP研究组:http://nlp.seas.harvard.edu/2018/04/03/attention.html
台湾小哥的代码(较通俗):https://github.com/jadore801120/attention-is-all-you-need-pytorch:
下面以soft_attention为例(*input和output的attention,仅和self_attention做下区分,第1篇代码标记src_attn,第2篇代码标记dec_enc_attn),soft_attention的目标:给定序列Q(query,长度记为lq,维度dk),键序列K(key,长度记为lk,维度dk),值序列V(value,长度记为lv,维度dv),计算Q和K的相似度权重,最后再乘上V。
下面直接贴上attention-is-all-you-need-pytorch中MultiHeadAttention代码
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
#这里把batch和分块数放在一起,便于使用bmm
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
和RNN中的attention的不同,这里的batch_size和seq_len均不为1,其把序列视为一个整体,求Q和V的相似度可使用点乘(V可以视为上面提及的encoder_outputs),获得的是一个相似度矩阵,比如Q是一个长度为10的序列,K是一个长度为16的序列,其相似度矩阵就是一个10*16的矩阵,再如矩阵第一行表示Q的第一个单词和K序列所有单词的相似度。
similarity:=(lq,dk)∗(dk,lk)=(lq,lk) similarity:=(lq,dk)*(dk,lk)=(lq,lk)
similarity:=(lq,dk)∗(dk,lk)=(lq,lk)
然后,生成带注意力机制的词向量(通常K和V取相同的值,因而有lv=lk),另外上面整合attn_applied和input使用的是cat操作,而这里使用的是残差(类似于unet和resnet),最后使用PositionwiseFeedForward(2个fc层)来融合attn_applied和input,最终生成带注意力机制的词向量。
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv) attention\_applied=(lq,lk)*(lv,dv)=(lq,dv)
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv)
细节部分
在数据预处理部分,对序列s都进行了首尾标记,比如s=’’+ s + ‘’,刚看transform(之前跳过了seq2seq),对下面的代码甚是不解
decoder_input=target_seq[:, :-1] #这里不是去掉终止标记<eos>,去掉的可能是padding_0,只为兼容target_ground_y的序列长度?
encoder_input=input_seq[:, 1:] #encoder的输入序列去掉了起始标记<sos>
target_ground_y= target_seqtrg[:, 1:] #用于计算模型loss的target,去掉了起始标记<sos>
1
2
3
其实在pytorch官方教程中说的比较清楚,看下图
encoder的输入序列和ground_true只需要一个终止符即可,而decoder的输入序列开始必须指定一个起始符,让其根据context预测输出序列的第一个单词,后面根据前一个单词再预测下一个单词,依次类推直到当前预测的单词为终止标记’eos’,才计算loss.
---------------------
作者:PJ-Javis
来源:CSDN
原文:https://blog.csdn.net/jiangpeng59/article/details/84859640
版权声明:本文为博主原创文章,转载请附上博文链接!
pytorch笔记:09)Attention机制的更多相关文章
- Multimodal —— 看图说话(Image Caption)任务的论文笔记(三)引入视觉哨兵的自适应attention机制
在此前的两篇博客中所介绍的两个论文,分别介绍了encoder-decoder框架以及引入attention之后在Image Caption任务上的应用. 这篇博客所介绍的文章所考虑的是生成captio ...
- Multimodal —— 看图说话(Image Caption)任务的论文笔记(二)引入attention机制
在上一篇博客中介绍的论文"Show and tell"所提出的NIC模型采用的是最"简单"的encoder-decoder框架,模型上没有什么新花样,使用CNN ...
- 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...
- 【学习笔记】注意力机制(Attention)
前言 这一章看啥视频都不好使,啃书就完事儿了,当然了我也没有感觉自己学的特别扎实,不过好歹是有一定的了解了 注意力机制 由于之前的卷积之类的神经网络,选取卷积中最大的那个数,实际上这种行为是没有目的的 ...
- DL4NLP —— seq2seq+attention机制的应用:文档自动摘要(Automatic Text Summarization)
两周以前读了些文档自动摘要的论文,并针对其中两篇( [2] 和 [3] )做了presentation.下面把相关内容简单整理一下. 文本自动摘要(Automatic Text Summarizati ...
- 论文笔记:Attention Is All You Need
Attention Is All You Need 2018-04-17 10:35:25 Paper:http://papers.nips.cc/paper/7181-attention-is-a ...
- [NLP/Attention]关于attention机制在nlp中的应用总结
原文链接: https://blog.csdn.net/qq_41058526/article/details/80578932 attention 总结 参考:注意力机制(Attention Mec ...
- Java:并发笔记-09
Java:并发笔记-09 说明:这是看了 bilibili 上 黑马程序员 的课程 java并发编程 后做的笔记 7. 共享模型之工具-2 原理:AQS 原理 对于 AQS 的原理这部分内容,没很好的 ...
- Mongodb源代码阅读笔记:Journal机制
Mongodb源代码阅读笔记:Journal机制 Mongodb源代码阅读笔记:Journal机制 涉及的文件 一些说明 PREPLOGBUFFER WRITETOJOURNAL WRITETODAT ...
随机推荐
- 【转载】HTML5自定义data属性
可能大家在使用jquery mobile时,经常会看到data-role.data-theme等的使用,比如:通过如下代码即可实现页眉的效果: [html] <div data-role= ...
- windows server 2003 修改远程链接端口
服务器默认的远程链接的端口是3389,只能内网访问,外网不能访问,现映射了8400端口给服务器,内外网都可以访问,因此需要修改服务器的远程链接的端口. 运行中 输入:regedit 选择十进制,将33 ...
- jsp请求转发小例子(转载)
在服务器端对客户端请求时行转发对其它的对象,如果jsp网页或Servlet 用三个 jsp网页来演示转发: forword1.jsp, 用来提交表单, 将表单内容提交给 forwrod2.jsp, ...
- JAVA 添加、修改和删除PDF书签
当阅读篇幅较长的PDF文档时,为方便我们再次阅读时快速定位到上一次的阅读位置,可以插入一个书签进行标记:此外,对于文档中已有的书签,我们也可以根据需要进行修改或者删除等操作.本篇文章将通过Java编程 ...
- JavaScript 数组相关基础方法
文章来源于:https://www.cnblogs.com/dolphinX/p/3353590.html 创建数组 构造函数 1.无参构造函数,创建一空数组 var a1=new Array(); ...
- linux下的日志压缩脚本
linux下的日志压缩脚本: #!/bin/bash #第一步:先定义项目列表如下: projects="project-a project-b project-c project-d&qu ...
- TIME-April
一转眼四月份又过了三分之一,现在才开始计划自己的四月还真是对自己太过放松了呀!不过前一段时间都在搞学生会的五四评优答辩,索然不是我喜欢的过程,但是结果还比较令人欢喜.翻掉过去的篇章,展开新的一页. 四 ...
- 题解报告:poj 1195 Mobile phones(二维BIT裸题)
Description Suppose that the fourth generation mobile phone base stations in the Tampere area operat ...
- C#中的预编译指令介绍[转]
原文链接 1.#define和#undef 用法: #define DEBUG #undef DEBUG #define告诉编译器,我定义了一个DEBUG的一个符号,他类似一个变量,但是它没有具体的值 ...
- chromedriver与chrome版本对应
今天把手头有的一些关于selenium测试的资源整理了一下,分享出来. 1. 所有版本chrome下载 是不是很难找到老版本的chrome?博主收集了几个下载chrome老版本的网站,其中哪个下载的是 ...