Scaled Dot-Product Attention是Transformer架构的核心组件,也是现代深度学习中最重要的注意力机制之一。本文将从原理、实现和应用三个方面深入剖析这一机制。

1. 基本原理

Scaled Dot-Product Attention的本质是一种加权求和机制,通过计算查询(Query)与键(Key)的相似度来确定对值(Value)的关注程度。其数学表达式为:

这个公式包含几个关键步骤:

  1. 计算相似度:通过点积(dot product)计算Query和Key的相似度,得到注意力分数(attention scores)
  2. 缩放(Scaling):将点积结果除以$\sqrt{d_k}$进行缩放,其中$d_k$是Key的维度
  3. 应用Mask(可选):在某些情况下(如自回归生成)需要遮盖未来信息
  4. Softmax归一化:将注意力分数通过softmax转换为概率分布
  5. 加权求和:用这些概率对Value进行加权求和

2. 为什么需要缩放(Scaling)?

缩放是Scaled Dot-Product Attention区别于普通Dot-Product Attention的关键。当输入的维度$d_k$较大时,点积的方差也会变大,导致softmax函数梯度变得极小(梯度消失问题)。通过除以$\sqrt{d_k}$,可以将方差控制在合理范围内。

假设Query和Key的各个分量是均值为0、方差为1的独立随机变量,则它们点积的方差为$d_k$。通过除以$\sqrt{d_k}$,可以将方差归一化为1。

3. 代码实现解析

让我们看看PyTorch中Scaled Dot-Product Attention的典型实现:

def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
# 获取key的维度
d_k = query.size(-1) # 计算注意力分数并缩放
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # 应用mask(如果提供)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf')) # 应用softmax得到注意力权重
attn = F.softmax(scores, dim=-1) # 应用dropout(如果提供)
if dropout is not None:
attn = dropout(attn) # 加权求和
return torch.matmul(attn, value), attn

这个函数接受query、key、value三个张量作为输入,可选的mask用于遮盖某些位置,dropout用于正则化。

4. 张量维度分析

假设输入的形状为:

  • Query: [batch_size, seq_len_q, d_k]
  • Key: [batch_size, seq_len_k, d_k]
  • Value: [batch_size, seq_len_k, d_v]

计算过程中各步骤的维度变化:

  1. Key转置后: [batch_size, d_k, seq_len_k]
  2. Query与Key的点积: [batch_size, seq_len_q, seq_len_k]
  3. Softmax后的注意力权重: [batch_size, seq_len_q, seq_len_k]
  4. 最终输出: [batch_size, seq_len_q, d_v]

5. 在Multi-Head Attention中的应用

Scaled Dot-Product Attention是Multi-Head Attention的基础。在Multi-Head Attention中,我们将输入投影到多个子空间,在每个子空间独立计算注意力,然后将结果合并:

class MultiHeadAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
super().__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None):
if mask is not None:
mask = mask.unsqueeze(1) nbatches = query.size(0) # 1) 投影并分割成多头
query, key, value = [
l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))
] # 2) 应用注意力机制
x, self.attn = scaled_dot_product_attention(query, key, value, mask, self.dropout) # 3) 合并多头结果
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)

6. 实际应用场景

Scaled Dot-Product Attention在多种场景下表现出色:

  1. 自然语言处理:捕捉句子中词与词之间的依赖关系
  2. 计算机视觉:关注图像中的重要区域
  3. 推荐系统:建模用户与物品之间的交互
  4. 语音处理:捕捉音频信号中的时序依赖

7. 优势与局限性

优势

  • 计算效率高(可以通过矩阵乘法并行计算)
  • 能够捕捉长距离依赖关系
  • 模型可解释性强(可以可视化注意力权重)

局限性

  • 计算复杂度为O(n²),对于长序列计算开销大
  • 没有考虑位置信息(需要额外的位置编码)
  • 对于某些任务,可能需要结合CNN等结构以捕捉局部特征

8. 总结

Scaled Dot-Product Attention是现代深度学习中的关键创新,通过简单而优雅的设计实现了强大的表达能力。它不仅是Transformer架构的核心,也启发了众多后续工作,如Performer、Linformer等对注意力机制的改进。理解这一机制对于掌握现代深度学习模型至关重要。

通过缩放点积、应用softmax和加权求和这三个简单步骤,Scaled Dot-Product Attention成功地让模型"关注"输入中的重要部分,这也是它能在各种任务中取得卓越表现的关键所在。

9、Scaled Dot-Product Attention应用案例

敬请关注下一篇

第9讲、深入理解Scaled Dot-Product Attention的更多相关文章

  1. [UCSD白板题] Minimum Dot Product

    Problem Introduction The dot product of two sequences \(a_1,a_2,\cdots,a_n\) and \(b_1,b_2,\cdots,b_ ...

  2. Dot Product

    These are vectors: They can be multiplied using the "Dot Product" (also see Cross Product) ...

  3. FB面经Prepare: Dot Product

    Conduct Dot Product of two large Vectors 1. two pointers 2. hashmap 3. 如果没有额外空间,如果一个很大,一个很小,适合scan小的 ...

  4. CUDA Samples: dot product(使用零拷贝内存)

    以下CUDA sample是分别用C++和CUDA实现的点积运算code,CUDA包括普通实现和采用零拷贝内存实现两种,并对其中使用到的CUDA函数进行了解说,code参考了<GPU高性能编程C ...

  5. 向量点积(Dot Product),向量叉积(Cross Product)

    参考的是<游戏和图形学的3D数学入门教程>,非常不错的书,推荐阅读,老外很喜欢把一个东西解释的很详细. 1.向量点积(Dot Product) 向量点积的结果有什么意义?事实上,向量的点积 ...

  6. 理解numpy dot函数

    python代码 x = np.array([[1,3],[1,4]]) y = np.array([[2,2],[3,1]]) print np.dot(x,y) 结果 [[11 5] [14 6] ...

  7. 理解numpy.dot()

    import numpy.matlib import numpy as np a = np.array([[1,2],[3,4]]) b = np.array([[11,12],[13,14]]) p ...

  8. CUDA Samples: Dot Product

    以下CUDA sample是分别用C++和CUDA实现的两个非常大的向量实现点积操作,并对其中使用到的CUDA函数进行了解说,各个文件内容如下: common.hpp: #ifndef FBC_CUD ...

  9. 视觉slam十四讲个人理解(ch7视觉里程计1)

    参考博文::https://blog.csdn.net/david_han008/article/details/53560736 https://blog.csdn.net/n66040927/ar ...

  10. [论文理解] CBAM: Convolutional Block Attention Module

    CBAM: Convolutional Block Attention Module 简介 本文利用attention机制,使得针对网络有了更好的特征表示,这种结构通过支路学习到通道间关系的权重和像素 ...

随机推荐

  1. 记一次.NET内存居高不下排查解决与启示

    前情 我们有个海外的项目,一共70个服务,前前后后花了超过一年时间完成了云服务迁移和架构调整.主要是架构调整了,原来的docker swarm托管服务,新架构改为Kubernetes托管.几台云服务器 ...

  2. 提供一个纯C语言的图像压缩程序,但是要达到将6MB的图片压缩到100KB以内的要求,有损压缩肯定是必须的。同时,要在速度上有所提升,我可以为您提供一个基于多线程的图像压缩程序。

    提供一个纯C语言的图像压缩程序,但是要达到将6MB的图片压缩到100KB以内的要求,有损压缩肯定是必须的.同时,要在速度上有所提升,我可以为您提供一个基于多线程的图像压缩程序. 首先,我们需要了解一下 ...

  3. 分布式锁—6.Redisson的同步器组件

    大纲 1.Redisson的分布式锁简单总结 2.Redisson的Semaphore简介 3.Redisson的Semaphore源码剖析 4.Redisson的CountDownLatch简介 5 ...

  4. 【数值计算方法】线性方程组迭代算法的Python实现

    线性方程组迭代算法的Python实现 jacobi,GS,SOR迭代法 def JacobiIter(A:np.ndarray, b:np.ndarray, tol:float=1e-5, maxIt ...

  5. js 时间转时间戳

    前言 有时候我们用时间插件,选择好时间后,需要把日期格式转化为时间戳,再传到后台 时间转时间戳 let time = Math.floor(new Date("2014-04-23 18:5 ...

  6. PHP的curl获取header信息

    PHP的curl功能十分强大,简单点说,就是一个PHP实现浏览器的基础. 最常用的可能就是抓取远程数据或者向远程POST数据.但是在这个过程中,调试时,可能会有查看header的必要. echo ge ...

  7. [每日算法 - 华为机试] leetcode690. 员工的重要性

    入口 力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台备战技术面试?力扣提供海量技术面试资源,帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer.https://le ...

  8. javascript for...in

    在JS中我们最常见的循环语句是for循环语句,一个简单的for循环语句如下: for(var i = 0, n = 100; i < n; i++){ // to do somethings . ...

  9. 11. RabbitMQ 消息队列 Federation (Exchange 交换机和 Queue队列) + Shovel 同步的搭建配置

    11. RabbitMQ 消息队列 Federation (Exchange 交换机和 Queue队列) + Shovel 同步的搭建配置 @ 目录 11. RabbitMQ 消息队列 Federat ...

  10. 使用Python建模量子隧穿

    引言 量子隧穿是量子力学中的一个非常有趣且令人神往的现象.在经典物理学中,我们通常认为粒子必须克服一个势垒才能通过它.但是,在量子力学中,粒子有时可以"穿越"一个势垒,即使它的能量 ...