为什么要有 Multi-Head Attention?

单个 Attention 机制虽然可以捕捉句子中不同词之间的关系,但它只能关注一种角度或模式

Multi-Head 的作用是:

多个头 = 多个视角同时观察序列的不同关系

例如:

  • 一个头可能专注主语和动词的关系;
  • 另一个头可能专注宾语和介词;
  • 还有的可能学习句法结构或时态变化。

这些头的表示最终会被拼接(concatenate)后再线性变换整合成更丰富的上下文表示。

技术深入:Multi-Head Attention 计算过程

Multi-Head Attention 的计算过程如下:

  1. 对输入 X 进行线性变换得到 Q、K、V 矩阵
  2. 将 Q、K、V 分割成 h 个头
  3. 每个头独立计算 Attention
  4. 拼接所有头的输出
  5. 最后进行一次线性变换
# 伪代码实现
def multi_head_attention(X, h=8):
# 线性变换获得 Q, K, V
Q = X @ W_q # [batch_size, seq_len, d_model]
K = X @ W_k
V = X @ W_v # 分割成多头
Q_heads = split_heads(Q, h) # [batch_size, h, seq_len, d_k]
K_heads = split_heads(K, h)
V_heads = split_heads(V, h) # 每个头独立计算 attention
attn_outputs = []
for i in range(h):
attn_output = scaled_dot_product_attention(
Q_heads[:, i], K_heads[:, i], V_heads[:, i]
)
attn_outputs.append(attn_output) # 拼接所有头的输出
concat_output = concatenate(attn_outputs) # [batch_size, seq_len, d_model] # 最后的线性变换
output = concat_output @ W_o return output

如何判断多少个头(h)?

Transformer 默认将 d_model(模型维度)均分给每个头。

设:

  • d_model = 512:模型的总嵌入维度
  • h = 8:头数

那么每个头的维度为:

d_k = d_model // h = 512 // 8 = 64

一般要求:

d_model 必须能被 h 整除。

参数计算

Multi-Head Attention 中的参数量:

  • 输入投影矩阵:3 × (d_model × d_model) = 3d_model²
  • 输出投影矩阵:d_model × d_model = d_model²

总参数量:4 × d_model²

例如,当 d_model = 512 时,参数量约为 100 万。


头的数量怎么选?

头数 h 每头维度 d_k 适用情境
1 全部 基线,最弱(没多视角)
4 中等 小模型,如 tiny Transformer
8 64 标准配置,如原始 Transformer
16 更细粒度 大模型中常见,如 BERT-large

实际训练中:

  • 小任务(toy 或翻译教学):用 2 或 4 个头就够了。
  • 真实 NLP 任务:建议使用 8 个头(Transformer-base 规范)。
  • 太多头而模型参数不足时,效果可能反而下降(每头维度太小)。

头数与性能关系

研究表明,头数与模型性能并非简单的线性关系:

  • 头数过少:无法捕捉多种语言模式
  • 头数适中:性能最佳
  • 头数过多:每个头的维度变小,表达能力下降

实验发现

Michel et al. (2019) 的研究《Are Sixteen Heads Really Better than One?》发现:

  1. 在训练好的模型中,并非所有头都同等重要
  2. 大多数情况下,可以剪枝掉一部分头而不显著影响性能
  3. 不同层的头有不同的作用,底层头和顶层头往往更为重要

Multi-Head Attention 的优势

  1. 并行计算:所有头可以并行计算,提高训练效率
  2. 多角度表示:捕捉不同类型的依赖关系
  3. 信息冗余:多头提供冗余信息,增强模型鲁棒性
  4. 注意力分散:防止单一头过度关注某些模式

总结一句话

Multi-Head 的本质是多角度捕捉词与词的关系,提升模型对上下文的理解能力。头数越多,观察角度越多,但每个头的维度会减小,需注意平衡。


Attention 可视化

不同头学习到的注意力模式各不相同。以下是一个英语句子在 8 头注意力机制下的可视化示例:

可以看到:

  • 头1:关注相邻词的关系
  • 头2:捕捉主语-谓语关系
  • 头3:识别句法结构
  • 头4:连接相关实体
  • 其他头:各自专注于不同的语言特征

这种多角度的观察使得 Transformer 能够全面理解文本的语义和结构。


️ Streamlit 交互式可视化案例

想要直观地理解 Multi-Head Attention?以下是一个使用 Streamlit 构建的交互式可视化案例,让你可以实时探索不同头的注意力模式:

import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel # 页面设置
st.set_page_config(page_title="Multi-Head Attention 可视化", layout="wide")
st.title("Multi-Head Attention 可视化工具") # 加载预训练模型
@st.cache_resource
def load_model():
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)
return tokenizer, model tokenizer, model = load_model() # 用户输入
user_input = st.text_area("请输入一段文本进行分析:",
"Transformer是一种强大的神经网络架构,它使用了Multi-Head Attention机制。",
height=100) # 处理文本
if user_input:
# 分词并获取注意力权重
inputs = tokenizer(user_input, return_tensors="pt")
outputs = model(**inputs) # 获取所有层的注意力权重
attentions = outputs.attentions # tuple of tensors, one per layer # 选择层
layer_idx = st.slider("选择Transformer层:", 0, len(attentions)-1, 0) # 获取选定层的注意力权重
layer_attentions = attentions[layer_idx].detach().numpy() # 获取头数
num_heads = layer_attentions.shape[1] # 选择头
head_idx = st.slider("选择注意力头:", 0, num_heads-1, 0) # 获取选定头的注意力权重
head_attention = layer_attentions[0, head_idx] # 获取标记
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # 可视化
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(head_attention,
xticklabels=tokens,
yticklabels=tokens,
cmap="YlGnBu",
ax=ax)
plt.title(f"第 {layer_idx+1} 层,第 {head_idx+1} 个头的注意力权重")
st.pyplot(fig) # 显示注意力模式分析
st.subheader("注意力模式分析") # 计算每个词的平均注意力
avg_attention = head_attention.mean(axis=0)
top_indices = np.argsort(avg_attention)[-3:][::-1] st.write("这个注意力头主要关注的词:")
for idx in top_indices:
st.write(f"- {tokens[idx]}: {avg_attention[idx]:.4f}") # 添加交互式功能
if st.checkbox("显示所有头的对比"):
st.subheader("所有头的注意力对比") # 为每个头创建一个小型热力图
# 计算行列数以适应任意数量的头
num_cols = 4
num_rows = (num_heads + num_cols - 1) // num_cols # 向上取整
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3*num_rows))
axes = axes.flatten() for h in range(num_heads):
sns.heatmap(layer_attentions[0, h],
xticklabels=[] if h < (num_heads-num_cols) else tokens,
yticklabels=[] if h % num_cols != 0 else tokens,
cmap="YlGnBu",
ax=axes[h])
axes[h].set_title(f"头 {h+1}") # 隐藏未使用的子图
for h in range(num_heads, len(axes)):
axes[h].axis('off') plt.tight_layout()
st.pyplot(fig) # 添加解释
st.markdown("""
### 如何解读这个可视化: - 颜色越深表示注意力权重越高
- 纵轴代表查询词(当前词)
- 横轴代表键词(被关注的词)
- 每个头学习不同的关注模式 通过调整滑块,你可以探索不同层和不同头的注意力模式,观察模型如何理解文本中的关系。
""") # 运行说明
st.sidebar.markdown("""
## 使用说明 1. 在文本框中输入你想分析的文本
2. 使用滑块选择要查看的层和注意力头
3. 查看热力图了解词与词之间的注意力关系
4. 勾选"显示所有头的对比"可以同时查看所有头的模式 这个工具帮助你直观理解 Multi-Head Attention 的工作原理和不同头的功能分工。
""") # 代码说明
with st.expander("查看完整代码实现"):
st.code("""
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel # 页面设置
st.set_page_config(page_title="Multi-Head Attention 可视化", layout="wide")
st.title("Multi-Head Attention 可视化工具") # 加载预训练模型
@st.cache_resource
def load_model():
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)
return tokenizer, model tokenizer, model = load_model() # 用户输入和可视化逻辑
# ...此处省略,与上面代码相同
""") ### 如何运行这个可视化工具 1. 安装必要的依赖:
```bash
pip install streamlit torch transformers matplotlib seaborn
  1. 将上面的代码保存为 attention_viz.py

  2. 运行 Streamlit 应用:

streamlit run attention_viz.py



这个交互式工具让你可以:

  • 输入任意文本并查看注意力分布
  • 选择不同的 Transformer 层和注意力头
  • 直观对比不同头学习到的不同模式
  • 分析哪些词获得了最高的注意力权重

通过这个可视化工具,你可以亲自探索 Multi-Head Attention 的工作原理,加深对这一机制的理解。

第8讲、Multi-Head Attention 的核心机制与实现细节的更多相关文章

  1. MFC六大核心机制之二:运行时类型识别(RTTI)

    上一节讲的是MFC六大核心机制之一:MFC程序的初始化,本节继续讲解MFC六大核心机制之二:运行时类型识别(RTTI). typeid运算子 运行时类型识别(RTTI)即是程序执行过程中知道某个对象属 ...

  2. multi lstm attention 坑一个

    multi lstm attention时序之间,inputs维度是1024,加上attention之后维度是2018,输出1024,时序之间下次再转成2048的inputs 但是如果使用multi ...

  3. MFC六大核心机制之一:MFC程序的初始化

    很多做软件开发的人都有一种对事情刨根问底的精神,例如我们一直在用的MFC,很方便,不用学太多原理性的知识就可以做出各种窗口程序,但喜欢钻研的朋友肯定想知道,到底微软帮我们做了些什么,让我们在它的框架下 ...

  4. Qt核心机制与原理

    转:  https://blog.csdn.net/light_in_dark/article/details/64125085 ★了解Qt和C++的关系 ★掌握Qt的信号/槽机制的原理和使用方法 ★ ...

  5. Spark大数据处理 之 从WordCount看Spark大数据处理的核心机制(1)

    大数据处理肯定是分布式的了,那就面临着几个核心问题:可扩展性,负载均衡,容错处理.Spark是如何处理这些问题的呢?接着上一篇的"动手写WordCount",今天要做的就是透过这个 ...

  6. Qt核心机制和原理

    转:http://blog.csdn.net/light_in_dark/article/details/64125085 ★了解Qt和C++的关系 ★掌握Qt的信号/槽机制的原理和使用方法 ★了解Q ...

  7. Spark大数据处理 之 从WordCount看Spark大数据处理的核心机制(2)

    在上一篇文章中,我们讲了Spark大数据处理的可扩展性和负载均衡,今天要讲的是更为重点的容错处理,这涉及到Spark的应用场景和RDD的设计来源. Spark的应用场景 Spark主要针对两种场景: ...

  8. 如何优雅的写UI——(1)MFC六大核心机制-程序初始化

    很多做软件开发的人都有一种对事情刨根问底的精神,例如我们一直在用的MFC,很方便,不用学太多原理性的知识就可以做出各种窗口程序,但喜欢钻研的朋友肯定想知道,到底微软帮我们做了些什么,让我们在它的框架下 ...

  9. MFC六大核心机制

    MFC六大核心机制概述 我们选择了C++,主要是因为它够艺术.够自由,使用它我们可以实现各种想法,而MFC将多种可灵活使用的功能封装起来,我们岂能忍受这种“黑盒”操作?于是研究分析MFC的核心机制成为 ...

  10. JAVA基础之两种核心机制

    突然之间需要学习Java,学校里学的东西早就忘记了,得用最短的时间把Java知识理顺,重点还是J2EE,毕竟所有的ava项目中95%都是J2EE,还是先从基础的J2SE学起吧....... 首先是了解 ...

随机推荐

  1. windows 配置java发布环境

    一.jdk安装 1.下载jdk安装文件 2.在"系统变量"下"新建"选项"JAVA_HOME"值为:"jdk"文件夹路径 ...

  2. php全文搜索代码

    在PHP中实现全文搜索,你可以使用多种方法,具体取决于你的数据存储方式和需求.如果你的数据存储在MySQL数据库中,你可以利用MySQL的全文搜索功能(FULLTEXT).如果你需要更复杂的搜索功能, ...

  3. Docker镜像的内部机制

    Docker镜像的内部机制 镜像就是一个打包文件,里面包含了应用程序还有它运行所依赖的环境,例如文件系统.环境变量.配置参数等等. 环境变量.配置参数这些东西还是比较简单的,随便用一个 manifes ...

  4. Go new函数 例子解析答疑

    package main import "fmt" func main() { p1 :=new(int) *p1 =1 fmt.Println("p1",p1 ...

  5. linux安装protoc

    protobuf 是做什么的? 专业的解答: Protocol Buffers 是一种轻便高效的结构化数据存储格式,可用于结构化数据串行化,很适合做数据存储或 RPC 数据交换格式.它可用于通讯协议. ...

  6. 基础命令:dd、tar、ln、find、逻辑符号、alisa别名、md5sun校验、lrzsz文件上传下载、wget

    目录 3.0 dd读取.转换并输出数据 3.1 压缩 (tar.zip).解压缩(tar xf.unzip) 3.2 ln软硬链接 3.2.1 软链接: 3.2.2 硬链接: 3.3 find文件查找 ...

  7. 区块链特辑——solidity语言基础(二)

    Solidity语法基础学习 四.函数类型: 函数 Function function FnName [V] [SM] [return (--)] {} ·[V]:Visibility,可见性: ·[ ...

  8. 【Linux】3.7 定时任务调度

    3.7定时任务调度 1. 任务调度原理 crond任务调度:crontab进行定时任务调度 使用方法:crontab [选项] crontab [选项] -e:编辑crontab定时任务 -i:查询c ...

  9. 【SpringMVC】使用 @RequestMapping 映射请求

    使用 @RequestMapping 映射请求 Spring MVC 使用 @RequestMapping 注解为控制器指定可以处理哪些 URL 请求 在控制器的类定义及方法定义处都可标注 @Requ ...

  10. nodejs中使用websockets

    websockets介绍 websockets这个新协议为客户端提供了一个更快.更有效的通信线路.像HTTP一样,websockets运行在TCP连接之上,但是它们更快,因为我们不必每次都打开一个新的 ...