刚从图像处理的hole中攀爬出来,刚走一步竟掉到了另一个hole(fire in the hole*▽*)

1.RNN中的attention
pytorch官方教程:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
首先,RNN的输入大小都是(1,1,hidden_size),即batch=1,seq_len=1,hidden_size=embed_size,相对于传统的encoder-decoder模型,attention机制仅在decoder处有所不同。下面具体看看:
1>保存了rnn每个词向量对应隐藏层的输出状态(encoder_outputs),用于decoder的attention机制

#train代码部分
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(
input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
1
2
3
4
5
2>AttnDecoderRNN的forward
1.输入的input经过embed

embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
1
2
2.获取关于输入的attention权重,这里的Q=decoder_rnn的input,K=decoder_rnn的隐藏元
2.1求Q和K相似度的方法有很多,这里让全连接层自己来学习,把embedded和hidden连接在一起经过fc层(部分修改了下)

similarity=self.attn(torch.cat((embedded[0], hidden[0]), 1))
1
2.2 经过softmax获得归一化的权重

attn_weights = F.softmax(similarity, dim=1)
1
3.权重应用于encoder输出的所有词对应的词向量上(对应相乘即可)->获得attention结果

attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))
1
4.把attention结果和decoder的输入cat在一起,使用1个全连接层来融合二者,最终生成带注意力机制的词向量

output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
1
2
5.根据decoder的上一个输出单词来预测下一个单词,这里多插一句,decoder的首个输入为起始标志符’sos’,其根据encode最后的隐藏元来预测第一个单词,后面依次类推。

output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
1
2
3
4
2.transformer中的attention
“Attention is All You Need”(霸气标题),pytorch代码推荐2篇:
哈佛大学NLP研究组:http://nlp.seas.harvard.edu/2018/04/03/attention.html
台湾小哥的代码(较通俗):https://github.com/jadore801120/attention-is-all-you-need-pytorch:

下面以soft_attention为例(*input和output的attention,仅和self_attention做下区分,第1篇代码标记src_attn,第2篇代码标记dec_enc_attn),soft_attention的目标:给定序列Q(query,长度记为lq,维度dk),键序列K(key,长度记为lk,维度dk),值序列V(value,长度记为lv,维度dv),计算Q和K的相似度权重,最后再乘上V。

下面直接贴上attention-is-all-you-need-pytorch中MultiHeadAttention代码

def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
#这里把batch和分块数放在一起,便于使用bmm
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)

output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
和RNN中的attention的不同,这里的batch_size和seq_len均不为1,其把序列视为一个整体,求Q和V的相似度可使用点乘(V可以视为上面提及的encoder_outputs),获得的是一个相似度矩阵,比如Q是一个长度为10的序列,K是一个长度为16的序列,其相似度矩阵就是一个10*16的矩阵,再如矩阵第一行表示Q的第一个单词和K序列所有单词的相似度。
similarity:=(lq,dk)∗(dk,lk)=(lq,lk) similarity:=(lq,dk)*(dk,lk)=(lq,lk)
similarity:=(lq,dk)∗(dk,lk)=(lq,lk)

然后,生成带注意力机制的词向量(通常K和V取相同的值,因而有lv=lk),另外上面整合attn_applied和input使用的是cat操作,而这里使用的是残差(类似于unet和resnet),最后使用PositionwiseFeedForward(2个fc层)来融合attn_applied和input,最终生成带注意力机制的词向量。
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv) attention\_applied=(lq,lk)*(lv,dv)=(lq,dv)
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv)

细节部分
在数据预处理部分,对序列s都进行了首尾标记,比如s=’’+ s + ‘’,刚看transform(之前跳过了seq2seq),对下面的代码甚是不解

decoder_input=target_seq[:, :-1] #这里不是去掉终止标记<eos>,去掉的可能是padding_0,只为兼容target_ground_y的序列长度?
encoder_input=input_seq[:, 1:] #encoder的输入序列去掉了起始标记<sos>
target_ground_y= target_seqtrg[:, 1:] #用于计算模型loss的target,去掉了起始标记<sos>
1
2
3
其实在pytorch官方教程中说的比较清楚,看下图

encoder的输入序列和ground_true只需要一个终止符即可,而decoder的输入序列开始必须指定一个起始符,让其根据context预测输出序列的第一个单词,后面根据前一个单词再预测下一个单词,依次类推直到当前预测的单词为终止标记’eos’,才计算loss.
---------------------
作者:PJ-Javis
来源:CSDN
原文:https://blog.csdn.net/jiangpeng59/article/details/84859640
版权声明:本文为博主原创文章,转载请附上博文链接!

pytorch笔记:09)Attention机制的更多相关文章

  1. Multimodal —— 看图说话(Image Caption)任务的论文笔记(三)引入视觉哨兵的自适应attention机制

    在此前的两篇博客中所介绍的两个论文,分别介绍了encoder-decoder框架以及引入attention之后在Image Caption任务上的应用. 这篇博客所介绍的文章所考虑的是生成captio ...

  2. Multimodal —— 看图说话(Image Caption)任务的论文笔记(二)引入attention机制

    在上一篇博客中介绍的论文"Show and tell"所提出的NIC模型采用的是最"简单"的encoder-decoder框架,模型上没有什么新花样,使用CNN ...

  3. 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...

  4. 【学习笔记】注意力机制(Attention)

    前言 这一章看啥视频都不好使,啃书就完事儿了,当然了我也没有感觉自己学的特别扎实,不过好歹是有一定的了解了 注意力机制 由于之前的卷积之类的神经网络,选取卷积中最大的那个数,实际上这种行为是没有目的的 ...

  5. DL4NLP —— seq2seq+attention机制的应用:文档自动摘要(Automatic Text Summarization)

    两周以前读了些文档自动摘要的论文,并针对其中两篇( [2] 和 [3] )做了presentation.下面把相关内容简单整理一下. 文本自动摘要(Automatic Text Summarizati ...

  6. 论文笔记:Attention Is All You Need

    Attention Is All You Need 2018-04-17 10:35:25  Paper:http://papers.nips.cc/paper/7181-attention-is-a ...

  7. [NLP/Attention]关于attention机制在nlp中的应用总结

    原文链接: https://blog.csdn.net/qq_41058526/article/details/80578932 attention 总结 参考:注意力机制(Attention Mec ...

  8. Java:并发笔记-09

    Java:并发笔记-09 说明:这是看了 bilibili 上 黑马程序员 的课程 java并发编程 后做的笔记 7. 共享模型之工具-2 原理:AQS 原理 对于 AQS 的原理这部分内容,没很好的 ...

  9. Mongodb源代码阅读笔记:Journal机制

    Mongodb源代码阅读笔记:Journal机制 Mongodb源代码阅读笔记:Journal机制 涉及的文件 一些说明 PREPLOGBUFFER WRITETOJOURNAL WRITETODAT ...

随机推荐

  1. 淘淘商城项目_同步索引库问题分析 + ActiveMQ介绍/安装/使用 + ActiveMQ整合spring + 使用ActiveMQ实现添加商品后同步索引库_匠心笔记

    文章目录 1.同步索引库问题分析 2.ActiveM的介绍 2.1.什么是ActiveMQ 2.2.ActiveMQ的消息形式 3.ActiveMQ的安装 3.1.安装环境 3.2.安装步骤 4.Ac ...

  2. 关于eclipse部署项目后,在tomcat中的webapps文件夹下没有项目

    转自:https://blog.csdn.net/yang505581644/article/details/78802316 一.发现问题 在eclipse中新建Dynamic Web Projec ...

  3. 图的遍历---------开始开始-------o(∩_∩)o 哈哈

    图的遍历 深度优先搜索(Depth First Search , DFS) --深度优先搜索--我的理解就是分身术的另一种实现方法---用分身术将所有能看到的路都走一遍----这就是深度搜索--- 下 ...

  4. springMVC RedirectAttributes

    @Controller public class TestController { @RequestMapping("/redirectDemo") public String r ...

  5. javascript---DOM大编程2

    编程挑战 现在利用之前我们学过的JavaScript知识,实现选项卡切换的效果. 效果图: 文字素材: 房产: 275万购昌平邻铁三居 总价20万买一居    200万内购五环三居 140万安家东三环 ...

  6. Python3中使用urllib的方法详解(header,代理,超时,认证,异常处理)_python

    我们可以利用urllib来抓取远程的数据进行保存哦,以下是python3 抓取网页资源的多种方法,有需要的可以参考借鉴. 1.最简单 import urllib.request response = ...

  7. 2015湘潭市第七届大学生程序设计竞赛 —— Fraction

    题目大意: 小数化分数,但是分母限制在[1,1000],很明显的枚举,但是在赛场上的时候傻逼了,无论怎么枚举,怎么二分就是wa,wa到死···········. (ps:我要给出题人寄刀片~~~~), ...

  8. MyEclipse常用设置记录

    MyEclipse版本:MyEclipse 2014 Blue版本. 设置内容: 1.内存优化 <MyEclipse_ROOT>/myeclipse-blue.ini文件 主要修改-vma ...

  9. Jmeter+Jenkins+Ant自动化集成环境搭建

    搭建环境: JDK:jdk1.8.0_92 Ant:apache-ant-1.9.7 Jmeter: apache-jmeter-3.0 Jenkins:jenkins-2.19.3 具体环境配置 1 ...

  10. ssm(Spring、Springmvc、Mybatis)实战之淘淘商城-第五天(非原创)

    文章大纲 一.课程介绍二.前台系统(门户系统)搭建介绍三.前台系统(门户系统)搭建实战四.js请求跨域解决五.项目源码与资料下载六.参考文章   一.课程介绍 一共14天课程(1)第一天:电商行业的背 ...