第9.2讲、Tiny Decoder(带 Mask)详解与实战
自己搭建一个 Tiny Decoder(带 Mask),参考 Transformer Encoder 的结构,并添加 Masked Multi-Head Self-Attention,它是 Decoder 的核心特征之一。
1. 背景与动机
Transformer 架构已成为自然语言处理(NLP)领域的主流。其 Encoder-Decoder 结构广泛应用于机器翻译、文本生成等任务。Decoder 的核心特征是 Masked Multi-Head Self-Attention,它保证了自回归生成时不会"偷看"未来信息。本文将带你从零实现一个最小可运行的 Tiny Decoder,并深入理解其原理。
2. Tiny Decoder 架构简述
一个标准 Transformer Decoder Layer 包括:
- Masked Multi-Head Self-Attention
- Encoder-Decoder Attention(跨注意力)
- Feed Forward Network (FFN)
- LayerNorm + Residual Connection

为了简化,我们暂时不引入 Encoder-Decoder Attention,只聚焦于:
Masked Self-Attention + FFN
3. 什么是 Masked Attention?
Masked Attention 的作用是在 Decoder 生成序列时,禁止看到"未来"的 token,防止信息泄露。
用一个 Mask 矩阵来实现,例如:
Mask for length 4:
[[0, -inf, -inf, -inf],
[0, 0, -inf, -inf],
[0, 0, 0, -inf],
[0, 0, 0, 0]]
这个 Mask 会加在 Attention 的 logits 上(即 QKᵗ / sqrt(dk)),将不允许的位置置为 -inf,softmax 之后就是 0。
4. Tiny Decoder 核心代码(简化 PyTorch 实现)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# 带掩码的多头自注意力机制
class MaskedSelfAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0 # 保证可以均分到每个头
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 用一个线性层同时生成 Q、K、V
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
# 输出投影
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
# x: (batch, seq_len, d_model)
B, T, C = x.size()
# 生成 Q、K、V,并分头
qkv = self.qkv_proj(x).reshape(B, T, 3, self.num_heads, self.d_k).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, heads, T, d_k)
# 计算注意力分数 (QK^T / sqrt(d_k))
attn_logits = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k) # (B, heads, T, T)
# 构造下三角 Mask,防止看到未来信息
mask = torch.tril(torch.ones(T, T)).to(x.device)
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
# softmax 得到注意力权重
attn = F.softmax(attn_logits, dim=-1)
# 加权求和得到输出
out = attn @ v # (B, heads, T, d_k)
# 合并多头
out = out.transpose(1, 2).contiguous().reshape(B, T, C)
# 输出投影
return self.out_proj(out)
# 前馈神经网络
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
# 两层全连接+ReLU
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
def forward(self, x):
# 前馈变换
return self.ff(x)
# Tiny Decoder 层,包含 Masked Self-Attention 和前馈网络
class TinyDecoderLayer(nn.Module):
def __init__(self, d_model=128, num_heads=4, d_ff=512):
super().__init__()
self.self_attn = MaskedSelfAttention(d_model, num_heads) # 掩码自注意力
self.ff = FeedForward(d_model, d_ff) # 前馈网络
self.norm1 = nn.LayerNorm(d_model) # 层归一化1
self.norm2 = nn.LayerNorm(d_model) # 层归一化2
def forward(self, x):
# x: (batch, seq_len, d_model)
# 先归一化,再做自注意力,并加残差
x = x + self.self_attn(self.norm1(x))
# 再归一化,前馈网络,并加残差
x = x + self.ff(self.norm2(x))
return x
5. 使用示例
x = torch.randn(2, 10, 128) # Decoder输入
context = torch.randn(2, 15, 128) # Encoder输出
decoder = TinyDecoderLayer()
y = decoder(x, context) # output shape: (2, 10, 128)
6. 进阶扩展
6.1 添加 Encoder-Decoder Attention
Encoder-Decoder Attention 允许 Decoder 在生成时参考 Encoder 的输出(即源语言信息),是机器翻译等任务的关键。其实现方式与 Self-Attention 类似,只是 Q 来自 Decoder,K/V 来自 Encoder。
伪代码:
class CrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
# ...同 MaskedSelfAttention ...
def forward(self, x, context):
# x: (B, T_dec, d_model), context: (B, T_enc, d_model)
# Q from x, K/V from context
# ...实现...
在 Decoder Layer 中插入:
self.cross_attn = CrossAttention(d_model, num_heads)
# forward:
x = x + self.cross_attn(self.norm_cross(x), context)
6.2 多层 Decoder 堆叠
实际应用中,Decoder 通常由多层堆叠而成:
class TinyDecoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff):
super().__init__()
self.layers = nn.ModuleList([
TinyDecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
6.3 加入 Positional Encoding
Transformer 不具备序列顺序感知能力,需加上 Positional Encoding:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(1)]
7. 完整训练例子(伪代码)
# 假设有输入数据 input_seq, target_seq
x = embedding(input_seq)
x = pos_encoding(x)
decoder = TinyDecoder(num_layers=2, d_model=128, num_heads=4, d_ff=512)
output = decoder(x)
# 计算 loss, 反向传播
8. 小结
- Decoder 的关键是 Masked Self-Attention,通过
tril的下三角掩码防止泄漏未来信息。 - 可以用
torch.tril快速构造下三角 Mask。 - Decoder 层和 Encoder 类似,但注意力机制加了 Mask,而且通常会多出 Encoder-Decoder Attention。
- 可扩展为多层、加入位置编码、跨注意力等,逐步构建完整的 Transformer Decoder。
*如果不加 Mask,允许 Decoder 看到未来 token,会导致模型训练"作弊",推理时表现极差,生成文本质量低下,模型失去实际应用价值。因此,Masked Self-Attention 是保证自回归生成和模型泛化能力的关键机制。
9. 参考资料
第9.2讲、Tiny Decoder(带 Mask)详解与实战的更多相关文章
- 第三节:带你详解Java的操作符,控制流程以及数组
前言 大家好,给大家带来带你详解Java的操作符,控制流程以及数组的概述,希望你们喜欢 操作符 算数操作符 一般的 +,-,*,/,还有两个自增 自减 ,以及一个取模 % 操作符. 这里的操作算法,一 ...
- MySQL5.6的4个自带库详解
MySQL5.6的4个自带库详解 1.information_schema详细介绍: information_schema数据库是MySQL自带的,它提供了访问数据库元数据的方式.什么是元数据呢?元数 ...
- IntelliJ IDEA 快捷键说明大全(中英对照、带图示详解)
因为觉得网络上的 idea 快捷键不够详尽,所以特别编写了此篇文章,方便大家使用 idea O(∩_∩)O~ 其中的英文说明来自于 idea 的官网资料,中文说明主要来自于自己的领会和理解,英文说明只 ...
- Java线程池带图详解
线程池作为Java中一个重要的知识点,看了很多文章,在此以Java自带的线程池为例,记录分析一下.本文参考了Java并发编程:线程池的使用.Java线程池---addWorker方法解析.线程池.Th ...
- Java自带命令详解
1. 背景 给一个系统定位问题的时候,知识.经验是关键基础,数据(运行日志.异常堆栈.GC日志.线程快照[threaddump / javacore文件].堆转储快照[heapdump / hprof ...
- OPENGL_三角形带GL_TRIANGLE_STRIP详解
使用三角形带原因:减少顶点传递,渲染时api向显卡传输数据量是瓶颈,用较好的传递方法传递一个三角形最少可以少于一个点. 点的顺序根据奇数,偶数不一样的原因:保持所有三角形法线在同一方向. 原文:htt ...
- L2-014. 列车调度(带图详解)
L2-014. 列车调度 火车站的列车调度铁轨的结构如下图所示. Figure 两端分别是一条入口(Entrance)轨道和一条出口(Exit)轨道,它们之间有N条平行的轨道.每趟列车从入口可以选 ...
- Sprite Atlas与Sprite Mask详解
https://www.sohu.com/a/169409304_280780 Unity 2017.1正式发布后,带来了一批能帮助大家更加简化工作流的新功能.今天这篇文章,将由Unity技术经理成亮 ...
- java中带继承类的加载顺序详解及实战
一.背景: 在面试中,在java基础方面,类的加载顺序经常被问及,很多时候我们是搞不清楚到底类的加载顺序是怎么样的,那么今天我们就来看看带有继承的类的加载顺序到底是怎么一回事?在此记下也方便以后复习巩 ...
- Android 自带 camera 详解
在本文中 需要考虑的问题 概述 Manifest声明 使用内置的摄像头应用程序 捕获图像的intent 捕获视频的intent 接收摄像头intent的结果 创建摄像头应用程序 检测摄像头硬件 访问摄 ...
随机推荐
- 三分钟掌握音视频处理 | 在 Rust 中优雅地使用 FFmpeg
前言 音视频处理看似高深莫测,但在开发中,我们或多或少都会遇到相关需求,比如视频格式转换.剪辑.添加水印.音频提取等. FFmpeg 作为行业标准,几乎无所不能,很多流行的软件(如 VLC.YouTu ...
- 【vscode】vscode配置Java
[vscode]vscode配置Java 前言 配环境,需要记录,避免反复踩坑. 步骤 step1:官网走 配环境为什么不直接上官网教程,Visual Studio Code - Co ...
- linux clickhouse 密码设置
默认密码 clickhouse 安装好之后,系统默认的登录账号密码是 /etc/clickhouse-server/users.d/default-password.xml 文件中配置的,默认密码是 ...
- go 判断数组下标是否存在
举例 现在需要判断命令行是否传了参数,即 os.Args[1] 是否存在 如果使用下述的判断: func main() { fmt.Println(os.Args[1]) } 会报错:index ou ...
- Golang 入门 : 转换
Go中数学运算和比较运算要求包含的值具有相同的类型.如果不是的话,则在尝试运行代码时会报错. 为变量分配新值也是如此.如果所赋值的类型与变量的声明类型不匹配,也会报错. 解决方法是使用转换,它允许你将 ...
- harbor
一篇带你了解私有仓库 Harbor 的搭建 一.Harbor简介 虽然Docker官方提供了公共的镜像仓库,但是从安全和效率等方面考虑,部署我们私有环境内的Registry也是非常必要的. Harbo ...
- Momentum Contrast for Unsupervised Visual Representation Learning论文精读
目录 Birth of MoCo Supervised Learning Contrastive Learning MoCo Dictionary Limits of the early learni ...
- 比较 HashSet、LinkedHashSet 和 TreeSet 三者的异同
比较 HashSet.LinkedHashSet 和 TreeSet 三者的异同HashSet.LinkedHashSet 和 TreeSet 都是 Set 接口的实现类,都能保证元素唯一,并且都不是 ...
- static修饰成员变量的特点及static修饰成员变量内存图解-java se进阶 day01
1.static介绍 static是静态的意思,它可以用于修饰成员变量和成员方法 2.static的特点 1.被static修饰了的成员变量,可以被类中的所有对象所共享 虽然stu02没有给schoo ...
- 【Linux】5.4 Shell数组
Shell数组 数组中可以存放多个值.Bash Shell 只支持一维数组(不支持多维数组),初始化时不需要定义数组大小(与 PHP 类似). 1. 数组赋值 与大部分编程语言类似,数组元素的下标由0 ...