自己搭建一个 Tiny Decoder(带 Mask),参考 Transformer Encoder 的结构,并添加 Masked Multi-Head Self-Attention,它是 Decoder 的核心特征之一。


1. 背景与动机

Transformer 架构已成为自然语言处理(NLP)领域的主流。其 Encoder-Decoder 结构广泛应用于机器翻译、文本生成等任务。Decoder 的核心特征是 Masked Multi-Head Self-Attention,它保证了自回归生成时不会"偷看"未来信息。本文将带你从零实现一个最小可运行的 Tiny Decoder,并深入理解其原理。


2. Tiny Decoder 架构简述

一个标准 Transformer Decoder Layer 包括:

  1. Masked Multi-Head Self-Attention
  2. Encoder-Decoder Attention(跨注意力)
  3. Feed Forward Network (FFN)
  4. LayerNorm + Residual Connection

为了简化,我们暂时不引入 Encoder-Decoder Attention,只聚焦于:

Masked Self-Attention + FFN


3. 什么是 Masked Attention?

Masked Attention 的作用是在 Decoder 生成序列时,禁止看到"未来"的 token,防止信息泄露。

用一个 Mask 矩阵来实现,例如:

Mask for length 4:
[[0, -inf, -inf, -inf],
[0, 0, -inf, -inf],
[0, 0, 0, -inf],
[0, 0, 0, 0]]

这个 Mask 会加在 Attention 的 logits 上(即 QKᵗ / sqrt(dk)),将不允许的位置置为 -inf,softmax 之后就是 0。


4. Tiny Decoder 核心代码(简化 PyTorch 实现)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math # 带掩码的多头自注意力机制
class MaskedSelfAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0 # 保证可以均分到每个头 self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度 # 用一个线性层同时生成 Q、K、V
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
# 输出投影
self.out_proj = nn.Linear(d_model, d_model) def forward(self, x):
# x: (batch, seq_len, d_model)
B, T, C = x.size()
# 生成 Q、K、V,并分头
qkv = self.qkv_proj(x).reshape(B, T, 3, self.num_heads, self.d_k).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, heads, T, d_k) # 计算注意力分数 (QK^T / sqrt(d_k))
attn_logits = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k) # (B, heads, T, T) # 构造下三角 Mask,防止看到未来信息
mask = torch.tril(torch.ones(T, T)).to(x.device)
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf')) # softmax 得到注意力权重
attn = F.softmax(attn_logits, dim=-1)
# 加权求和得到输出
out = attn @ v # (B, heads, T, d_k) # 合并多头
out = out.transpose(1, 2).contiguous().reshape(B, T, C)
# 输出投影
return self.out_proj(out) # 前馈神经网络
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
# 两层全连接+ReLU
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
) def forward(self, x):
# 前馈变换
return self.ff(x) # Tiny Decoder 层,包含 Masked Self-Attention 和前馈网络
class TinyDecoderLayer(nn.Module):
def __init__(self, d_model=128, num_heads=4, d_ff=512):
super().__init__()
self.self_attn = MaskedSelfAttention(d_model, num_heads) # 掩码自注意力
self.ff = FeedForward(d_model, d_ff) # 前馈网络
self.norm1 = nn.LayerNorm(d_model) # 层归一化1
self.norm2 = nn.LayerNorm(d_model) # 层归一化2 def forward(self, x):
# x: (batch, seq_len, d_model)
# 先归一化,再做自注意力,并加残差
x = x + self.self_attn(self.norm1(x))
# 再归一化,前馈网络,并加残差
x = x + self.ff(self.norm2(x))
return x

5. 使用示例

x = torch.randn(2, 10, 128)        # Decoder输入
context = torch.randn(2, 15, 128) # Encoder输出
decoder = TinyDecoderLayer()
y = decoder(x, context) # output shape: (2, 10, 128)

6. 进阶扩展

6.1 添加 Encoder-Decoder Attention

Encoder-Decoder Attention 允许 Decoder 在生成时参考 Encoder 的输出(即源语言信息),是机器翻译等任务的关键。其实现方式与 Self-Attention 类似,只是 Q 来自 Decoder,K/V 来自 Encoder。

伪代码:

class CrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
# ...同 MaskedSelfAttention ...
def forward(self, x, context):
# x: (B, T_dec, d_model), context: (B, T_enc, d_model)
# Q from x, K/V from context
# ...实现...

在 Decoder Layer 中插入:

self.cross_attn = CrossAttention(d_model, num_heads)
# forward:
x = x + self.cross_attn(self.norm_cross(x), context)

6.2 多层 Decoder 堆叠

实际应用中,Decoder 通常由多层堆叠而成:

class TinyDecoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff):
super().__init__()
self.layers = nn.ModuleList([
TinyDecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x

6.3 加入 Positional Encoding

Transformer 不具备序列顺序感知能力,需加上 Positional Encoding:

class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(1)]

7. 完整训练例子(伪代码)

# 假设有输入数据 input_seq, target_seq
x = embedding(input_seq)
x = pos_encoding(x)
decoder = TinyDecoder(num_layers=2, d_model=128, num_heads=4, d_ff=512)
output = decoder(x)
# 计算 loss, 反向传播

8. 小结

  • Decoder 的关键是 Masked Self-Attention,通过 tril 的下三角掩码防止泄漏未来信息。
  • 可以用 torch.tril 快速构造下三角 Mask。
  • Decoder 层和 Encoder 类似,但注意力机制加了 Mask,而且通常会多出 Encoder-Decoder Attention。
  • 可扩展为多层、加入位置编码、跨注意力等,逐步构建完整的 Transformer Decoder。

    *如果不加 Mask,允许 Decoder 看到未来 token,会导致模型训练"作弊",推理时表现极差,生成文本质量低下,模型失去实际应用价值。因此,Masked Self-Attention 是保证自回归生成和模型泛化能力的关键机制。

9. 参考资料

第9.2讲、Tiny Decoder(带 Mask)详解与实战的更多相关文章

  1. 第三节:带你详解Java的操作符,控制流程以及数组

    前言 大家好,给大家带来带你详解Java的操作符,控制流程以及数组的概述,希望你们喜欢 操作符 算数操作符 一般的 +,-,*,/,还有两个自增 自减 ,以及一个取模 % 操作符. 这里的操作算法,一 ...

  2. MySQL5.6的4个自带库详解

    MySQL5.6的4个自带库详解 1.information_schema详细介绍: information_schema数据库是MySQL自带的,它提供了访问数据库元数据的方式.什么是元数据呢?元数 ...

  3. IntelliJ IDEA 快捷键说明大全(中英对照、带图示详解)

    因为觉得网络上的 idea 快捷键不够详尽,所以特别编写了此篇文章,方便大家使用 idea O(∩_∩)O~ 其中的英文说明来自于 idea 的官网资料,中文说明主要来自于自己的领会和理解,英文说明只 ...

  4. Java线程池带图详解

    线程池作为Java中一个重要的知识点,看了很多文章,在此以Java自带的线程池为例,记录分析一下.本文参考了Java并发编程:线程池的使用.Java线程池---addWorker方法解析.线程池.Th ...

  5. Java自带命令详解

    1. 背景 给一个系统定位问题的时候,知识.经验是关键基础,数据(运行日志.异常堆栈.GC日志.线程快照[threaddump / javacore文件].堆转储快照[heapdump / hprof ...

  6. OPENGL_三角形带GL_TRIANGLE_STRIP详解

    使用三角形带原因:减少顶点传递,渲染时api向显卡传输数据量是瓶颈,用较好的传递方法传递一个三角形最少可以少于一个点. 点的顺序根据奇数,偶数不一样的原因:保持所有三角形法线在同一方向. 原文:htt ...

  7. L2-014. 列车调度(带图详解)

    L2-014. 列车调度   火车站的列车调度铁轨的结构如下图所示. Figure 两端分别是一条入口(Entrance)轨道和一条出口(Exit)轨道,它们之间有N条平行的轨道.每趟列车从入口可以选 ...

  8. Sprite Atlas与Sprite Mask详解

    https://www.sohu.com/a/169409304_280780 Unity 2017.1正式发布后,带来了一批能帮助大家更加简化工作流的新功能.今天这篇文章,将由Unity技术经理成亮 ...

  9. java中带继承类的加载顺序详解及实战

    一.背景: 在面试中,在java基础方面,类的加载顺序经常被问及,很多时候我们是搞不清楚到底类的加载顺序是怎么样的,那么今天我们就来看看带有继承的类的加载顺序到底是怎么一回事?在此记下也方便以后复习巩 ...

  10. Android 自带 camera 详解

    在本文中 需要考虑的问题 概述 Manifest声明 使用内置的摄像头应用程序 捕获图像的intent 捕获视频的intent 接收摄像头intent的结果 创建摄像头应用程序 检测摄像头硬件 访问摄 ...

随机推荐

  1. rabbitmq的基本使用

    使用MQ的三大作用:1.同步变异步2.流量削峰3.解耦降低服务间的耦合性要不要使用MQ,需不需要使用MQ依据项目的需要做选择. 使用场景: 例如:注册用户时候,发送激活邮件.监控应用中抛出的异常,邮件 ...

  2. Redis集群(cluster模式)搭建(三主三从)

    上一篇搭建了一主二从,并加入了哨兵,任何一个节点挂掉都不影响正常使用,实现了高可用.仍然存在一个问题,一主二从每个节点都存储着全部数据,随着业务庞大,数据量会超过节点容量,即便是redis可以配置清理 ...

  3. git clone加速

    使用github的镜像网站进行访问,github.com.cnpmjs.org,我们将原本的网站中的github.com 进行替换.

  4. Content-Encoding:br 是一种什么编码格式?

    一.前言 在之前测试HTTP应答的压缩过程中无意间发现在Google浏览器下出现了 Content-Encoding:br 这种的编码格式,当时我就纳闷了,前面不是一直在研究GZip压缩吗?br压缩又 ...

  5. Nginx 配置 HTTPS 完整过程

    配置站点使用 https,并且将 http 重定向至 https. 1. nginx 的 ssl 模块安装 查看 nginx 是否安装 http_ssl_module 模块. $ /usr/local ...

  6. BundleFusion+WIN11+VS2019 + CUDA11.7环境配置

    BundleFusion+WIN11+VS2019环境配置 Step1 一开始会提示你重定解决方案,点是即可,如果点错了,也可以在这里再点一次: 简要记录一下环境的配置过程,刚下载下来BundleFu ...

  7. 【python-数据分析】pandas数据提取

    import pandas as pd 1. 直接索引 df = pd.DataFrame({'AdmissionDate': ['2021-01-25','2021-01-22','2021-01- ...

  8. Microsoft.NETCore.App 版本不一致导致的运行失败

    场景重现 今天新建了一个 ASP.NET Core 的项目, 通过 Web Deploy 顺利发布到IIS上后, 但访问时出现如下异常: 异常原因 通过手动执行dotnet命令发现运行框架版本不一致? ...

  9. Linux 关机与重启命令

    关机命令 我们可以使用以下三种命令来关机 Linux : 1.立刻关机(需要root用户) shutdown -h now 10 分钟后自动关机 shutdown -h 10 2.立刻关机 halt ...

  10. springboot将vo生成文件到目录

    依赖 org.springframework spring-mock 2.0.8 com.alibaba fastjson 1.2.62 service实现 public RestResponseBo ...