03 Transformer 中的多头注意力(Multi-Head Attention)Pytorch代码实现
3:20 来个赞
24:43 弹幕,是否懂了

QKV 相乘(QKV 同源),QK 相乘得到相似度A,AV 相乘得到注意力值 Z
- 第一步实现一个自注意力机制

自注意力计算
def self_attention(query, key, value, dropout=None, mask=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# mask的操作在QK之后,softmax之前
if mask is not None:
mask.cuda()
scores = scores.masked_fill(mask == 0, -1e9)
self_attn = F.softmax(scores, dim=-1)
if dropout is not None:
self_attn = dropout(self_attn)
return torch.matmul(self_attn, value), self_attn
多头注意力
# PYthon/PYtorch/你看的这个模型的理论
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
def forward(self, head, d_model, query, key, value, dropout=0.1,mask=None):
"""
:param head: 头数,默认 8
:param d_model: 输入的维度 512
:param query: Q
:param key: K
:param value: V
:param dropout:
:param mask:
:return:
"""
assert (d_model % head == 0)
self.d_k = d_model // head
self.head = head
self.d_model = d_model
self.linear_query = nn.Linear(d_model, d_model)
self.linear_key = nn.Linear(d_model, d_model)
self.linear_value = nn.Linear(d_model, d_model)
# 自注意力机制的 QKV 同源,线性变换
self.linear_out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(p=dropout)
self.attn = None
# if mask is not None:
# # 多头注意力机制的线性变换层是4维,是把query[batch, frame_num, d_model]变成[batch, -1, head, d_k]
# # 再1,2维交换变成[batch, head, -1, d_k], 所以mask要在第一维添加一维,与后面的self attention计算维度一样
# mask = mask.unsqueeze(1)
n_batch = query.size(0)
# 多头需要对这个 X 切分成多头
# query==key==value
# [b,1,512]
# [b,8,1,64]
# [b,32,512]
# [b,8,32,64]
query = self.linear_query(query).view(n_batch, -1, self.head, self.d_k).transpose(1, 2) # [b, 8, 32, 64]
key = self.linear_key(key).view(n_batch, -1, self.head, self.d_k).transpose(1, 2) # [b, 8, 32, 64]
value = self.linear_value(value).view(n_batch, -1, self.head, self.d_k).transpose(1, 2) # [b, 8, 32, 64]
x, self.attn = self_attention(query, key, value, dropout=self.dropout, mask=mask)
# [b,8,32,64]
# [b,32,512]
# 变为三维, 或者说是concat head
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.head * self.d_k)
return self.linear_out(x)
03 Transformer 中的多头注意力(Multi-Head Attention)Pytorch代码实现的更多相关文章
- 【译】在Transformer中加入相对位置信息
目录 引言 动机 解决方案 概览 注释 实现 高效实现 结果 结论 参考文献 本文翻译自How Self-Attention with Relative Position Representation ...
- 【注意力机制】Attention Augmented Convolutional Networks
注意力机制之Attention Augmented Convolutional Networks 原始链接:https://www.yuque.com/lart/papers/aaconv 核心内容 ...
- 【学习笔记】注意力机制(Attention)
前言 这一章看啥视频都不好使,啃书就完事儿了,当然了我也没有感觉自己学的特别扎实,不过好歹是有一定的了解了 注意力机制 由于之前的卷积之类的神经网络,选取卷积中最大的那个数,实际上这种行为是没有目的的 ...
- Transformer可解释性:注意力机制注意到了什么?
原创作者 | FLPPED 论文: Self-Attention Attribution: Interpreting Information Interactions Inside Transform ...
- 第五课第四周笔记3:Multi-Head Attention多头注意力
Multi-Head Attention多头注意力 让我们进入并了解多头注意力机制. 符号变得有点复杂,但要记住的事情基本上只是你在上一个视频中学到的自我注意机制的四个大循环. 让我们看一下每次计算自 ...
- AAAI2018中的自注意力机制(Self-attention Mechanism)
近年来,注意力(Attention)机制被广泛应用到基于深度学习的自然语言处理(NLP)各个任务中.随着注意力机制的深入研究,各式各样的attention被研究者们提出,如单个.多个.交互式等等.去年 ...
- 深入理解BERT Transformer ,不仅仅是注意力机制
来源商业新知网,原标题:深入理解BERT Transformer ,不仅仅是注意力机制 BERT是google最近提出的一个自然语言处理模型,它在许多任务 检测上表现非常好. 如:问答.自然语言推断和 ...
- Transformer中引用iqd作为数据源的时候数据预览出现乱码
在cognos开发利用transform建模的过程中导入iqd数据源预览乱码问题,下面先描述一下环境 操作系统版本: [root@enfo212 ~]# cat /proc/version Linux ...
- ICCV2021 | Vision Transformer中相对位置编码的反思与改进
前言 在计算机视觉中,相对位置编码的有效性还没有得到很好的研究,甚至仍然存在争议,本文分析了相对位置编码中的几个关键因素,提出了一种新的针对2D图像的相对位置编码方法,称为图像RPE(IRPE). ...
- 如何诊断RAC系统中的'gc cr multi block request'?
'gc cr multi block request' 是RAC数据库上比较常见的一种等待事件,在RAC 上进行全表扫描(Full Table Scan)或者全索引扫描(Index Fast Full ...
随机推荐
- 【MySQL】二级MySQL考试 救场帮助表
周六去考二级,应用第一题就是添加外键约束 草,写了半天老说语法不对,然后急中生智,觉得默认的库里应该有文档说明表 以下是SQL查询过程: -- 猜测是在mysql库里面 mysql> USE m ...
- 人形机器人的AI技术 —— 将一个大问题拆解为若干个小问题
前文: 人形机器人 -- Figure 01机器人亮相 | OpenAI多模态能力加持 | 与人类流畅对话交互 | 具身智能的GPT-4时刻 所需的AI技术: 人形机器人的软件层面其实有: 视觉模块/ ...
- [工具分享]ClipX超级粘贴板,超级好用
1.背景 话说粘贴.复制是码农们的必备核心技能, 普通码农们当然已经熟练的掌握了普通的粘贴复制.... 但是,你不知道的是,牛逼的架构师已经会使用超级粘贴板了,功能非常强大 ............. ...
- mybatis-plus系统化学习之配置精讲
1.背景 mybatis-plus给出了很多配置, 大部分的配置使用默认的就可以了, 但是还是有很多需要的配置比如: # mybatis-plus相关配置 mybatis-plus: # xml扫描, ...
- 使用 Apache DolphinScheduler 进行 EMR 任务调度
By AWS Team 前言 随着企业规模的扩大,业务数据的激增,我们会使用 Hadoop/Spark 框架来处理大量数据的 ETL/聚合分析作业,⽽这些作业将需要由统一的作业调度平台去定时调度. 在 ...
- QT中TreeWidget树控件的使用
关于Item Widgets中Tree Widget的使用方法! TreeWidget树控件的使用 创建列表头, 该控件有什么属性 QStringList header_list; header_li ...
- 金融、支付行业的开发者不得不知道的float、double计算误差问题
为什么浮点数 float 或 double 运算的时候会有精度丢失的风险呢? <阿里巴巴 Java 开发手册>中提到:"浮点数之间的等值判断,基本数据类型不能用 == 来比较,包 ...
- 线性dp:LeetCode674. 最长连续递增序列
LeetCode674. 最长连续递增序列 阅读本文之前,需要先了解"动态规划方法论",这在我的文章以前有讲过 链接:动态规划方法论 本文之前也讲过一篇文章:最长递增子序列,这道题 ...
- (八)Redis 主从复制、切片集群
一.主从复制 1.主从关系 都说的 Redis 具有高可靠性,这里有两层含义:一是数据尽量少丢失,二是服务尽量少中断.AOF 和 RDB 保证了前者,而对于后者,Redis 的做法就是将一份数据同时保 ...
- 千万级别mysql 分库分表后表分页查询优化方案初探
随着使用的用户群体越来越多,表数据也会随着时间的推移,单表的数据量会越来越大. 以订单表为例,假如每天的订单量在 4 万左右,那么一个月的订单量就是 120 多万,一年就是 1400 多万,随着年数的 ...