第9.3讲、Tiny Transformer: 极简版Transformer
简介
极简版的 Transformer 编码器-解码器(Seq2Seq)结构,适合用于学习、实验和小型序列到序列(如翻译、摘要)任务。
该实现包含了位置编码、多层编码器、多层解码器、训练与推理流程,代码简洁易懂,便于理解 Transformer 的基本原理。
主要结构
- PositionalEncoding:正弦/余弦位置编码,为输入embedding添加位置信息。
- TransformerEncoderLayerWithTrace:单层编码器,含自注意力和前馈网络。
- TinyTransformer:多层堆叠的编码器。
- TransformerDecoderLayer:单层解码器,含自注意力、交叉注意力和前馈网络。
- TransformerDecoder:多层堆叠的解码器。
- TinyTransformerSeq2Seq:编码器-解码器整体结构。
- Seq2SeqDataset:简单的序列到序列数据集。
- train:训练循环。
- greedy_decode:贪婪解码推理。
- generate_subsequent_mask:生成自回归mask。
依赖环境
- Python 3.7+
- torch >= 1.10
安装 PyTorch(以 CPU 版本为例):
pip install torch
用法示例
1. 构建模型
from demo import TinyTransformerSeq2Seq
src_vocab_size = 1000
trg_vocab_size = 1000
model = TinyTransformerSeq2Seq(src_vocab_size, trg_vocab_size)
2. 构造数据集和 DataLoader
from demo import Seq2SeqDataset
from torch.utils.data import DataLoader
src_data = torch.randint(0, src_vocab_size, (100, 10)) # 100个样本,每个10个token
trg_data = torch.randint(0, trg_vocab_size, (100, 12)) # 100个样本,每个12个token
dataset = Seq2SeqDataset(src_data, trg_data)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
3. 训练模型
import torch.optim as optim
import torch.nn as nn
from demo import train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)
epoch_loss = train(model, dataloader, optimizer, criterion, tgt_pad_idx=0)
print(f"Train loss: {epoch_loss}")
4. 推理(贪婪解码)
from demo import greedy_decode
src = torch.randint(0, src_vocab_size, (1, 10)).to(device)
sos_idx = 1 # 假设1为<sos>
eos_idx = 2 # 假设2为<eos>
max_len = 12
output = greedy_decode(model, src, sos_idx, eos_idx, max_len)
print("Output token ids:", output)
注意事项
- 该实现为教学/实验用途,未包含完整的mask、权重初始化、分布式训练等工业级细节。
- 需要自行准备合适的训练数据和词表。
- 若需工业级NLP任务,建议使用 HuggingFace Transformers。
tiny Transformer案例代码1:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.utils.data import DataLoader, Dataset
# =========================
# 位置编码模块
# =========================
class PositionalEncoding(nn.Module):
"""
为输入的embedding添加位置信息,帮助模型捕捉序列顺序。
"""
def __init__(self, d_model, max_len=512):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
return x + self.pe[:, :x.size(1)]
# =========================
# 编码器层
# =========================
class TransformerEncoderLayerWithTrace(nn.Module):
"""
单层Transformer编码器,带自注意力和前馈网络。
"""
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None):
# src: (batch, seq_len, d_model)
src2, attn_weights = self.self_attn(src, src, src, attn_mask=src_mask)
src = self.norm1(src + self.dropout(src2))
src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
src = self.norm2(src + self.dropout(src2))
return src, attn_weights
# =========================
# 编码器
# =========================
class TinyTransformer(nn.Module):
"""
多层堆叠的Transformer编码器。
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, dim_feedforward, max_len, num_layers):
super().__init__()
self.embedding = nn.Embedding(src_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
TransformerEncoderLayerWithTrace(d_model, nhead, dim_feedforward)
for _ in range(num_layers)
])
def forward(self, src, trace=False):
# src: (batch, seq_len)
src = self.embedding(src)
src = self.pos_encoder(src)
attn_weights_all = []
for layer in self.layers:
src, attn_weights = layer(src)
if trace:
attn_weights_all.append(attn_weights)
return src, attn_weights_all
# =========================
# 解码器层
# =========================
class TransformerDecoderLayer(nn.Module):
"""
单层Transformer解码器,含自注意力、交叉注意力和前馈网络。
"""
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
# tgt: (batch, tgt_seq_len, d_model)
# memory: (batch, src_seq_len, d_model)
tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)
tgt = self.norm1(tgt + self.dropout(tgt2))
tgt2, _ = self.cross_attn(tgt, memory, memory, attn_mask=memory_mask)
tgt = self.norm2(tgt + self.dropout(tgt2))
tgt2 = self.linear2(self.dropout(torch.relu(self.linear1(tgt))))
tgt = self.norm3(tgt + self.dropout(tgt2))
return tgt
# =========================
# 解码器
# =========================
class TransformerDecoder(nn.Module):
"""
多层堆叠的Transformer解码器。
"""
def __init__(self, vocab_size, d_model, nhead, dim_feedforward, num_layers, max_len=512):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
TransformerDecoderLayer(d_model, nhead, dim_feedforward)
for _ in range(num_layers)
])
self.out_proj = nn.Linear(d_model, vocab_size)
def forward(self, tgt_ids, memory, tgt_mask=None):
# tgt_ids: (batch, tgt_seq_len)
x = self.embedding(tgt_ids)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x, memory, tgt_mask)
return self.out_proj(x)
# =========================
# Seq2Seq整体模型
# =========================
class TinyTransformerSeq2Seq(nn.Module):
"""
编码器-解码器结构的Transformer模型。
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=64, heads=4, d_ff=128, num_layers=2, max_len=64):
super().__init__()
self.encoder = TinyTransformer(src_vocab_size, tgt_vocab_size, d_model, heads, d_ff, max_len, num_layers)
self.decoder = TransformerDecoder(
vocab_size=tgt_vocab_size,
d_model=d_model,
nhead=heads,
dim_feedforward=d_ff,
num_layers=num_layers,
max_len=max_len
)
def forward(self, src_ids, tgt_input_ids, tgt_mask=None):
# src_ids: (batch, src_seq_len)
# tgt_input_ids: (batch, tgt_seq_len)
memory, _ = self.encoder(src_ids, trace=False)
logits = self.decoder(tgt_input_ids, memory, tgt_mask)
return logits
# =========================
# 工具函数: 生成自回归mask
# =========================
def generate_subsequent_mask(size):
"""
生成自回归mask,防止解码器看到未来的信息。
"""
return torch.triu(torch.full((size, size), float('-inf')), diagonal=1)
# =========================
# Toy数据集
# =========================
class Seq2SeqDataset(Dataset):
"""
简单的序列到序列数据集。
"""
def __init__(self, src_data, tgt_data):
self.src = src_data
self.tgt = tgt_data
def __len__(self):
return len(self.src)
def __getitem__(self, idx):
return self.src[idx], self.tgt[idx]
# =========================
# 训练循环
# =========================
def train(model, dataloader, optimizer, criterion, tgt_pad_idx):
"""
训练模型一个epoch。
"""
model.train()
total_loss = 0
for src, tgt in dataloader:
src, tgt = src.to(device), tgt.to(device)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
tgt_mask = generate_subsequent_mask(tgt_input.size(1)).to(device)
logits = model(src, tgt_input, tgt_mask)
loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# =========================
# 推理(贪婪解码)
# =========================
def greedy_decode(model, src, sos_idx, eos_idx, max_len):
"""
贪婪解码:每步选择概率最大的token,直到eos或最大长度。
"""
model.eval()
src = src.to(device)
memory, _ = model.encoder(src, trace=False)
tgt = torch.ones((src.size(0), 1), dtype=torch.long).fill_(sos_idx).to(device)
for _ in range(max_len - 1):
tgt_mask = generate_subsequent_mask(tgt.size(1)).to(device)
out = model.decoder(tgt, memory, tgt_mask)
next_token = out[:, -1, :].argmax(dim=-1).unsqueeze(1)
tgt = torch.cat([tgt, next_token], dim=1)
if (next_token == eos_idx).all():
break
return tgt
# =========================
# 设备选择
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

第9.3讲、Tiny Transformer: 极简版Transformer的更多相关文章
- Underscore源码阅读极简版入门
看了网上的一些资料,发现大家都写得太复杂,让新手难以入门.于是写了这个极简版的Underscore源码阅读. 源码: https://github.com/hanzichi/underscore-an ...
- js消除小游戏(极简版)
js小游戏极简版 (1) 基础布局 <div class = "box"> <p></p> <div class="div&qu ...
- SimpleThreadPool极简版
package com.dwz.concurrency.chapter13; import java.util.ArrayList; import java.util.LinkedList; impo ...
- 极简版ASP.NET Core学习路径及教程
绝承认这是一个七天速成教程,即使有这个效果,我也不愿意接受这个名字.嗯. 这个路径分为两块: 实践入门 理论延伸 有了ASP.NET以及C#的知识以及项目经验,我们几乎可以不再需要了解任何新的知识就开 ...
- Vue Virtual Dom 和 Diff原理(面试必备) 极简版
我又来了,这是Vue面试三板斧的最后一招,当然也是极其简单了,先说Virtual Dom,来一句概念: 用js来模拟DOM中的节点.传说中的虚拟DOM. 再来一张图: 是不是一下子秒懂 没懂再来一张 ...
- Vue数据双向绑定(面试必备) 极简版
我又来吹牛逼了,这次我们简单说一下vue的数据双向绑定,我们这次不背题,而是要你理解这个流程,保证读完就懂,逢人能讲,面试必过,如果没做到,请再来看一遍,走起: 介绍双向数据之前,我们先解释几个名词: ...
- 极简版 react+webpack 脚手架
目录结构 asset/ css/ img/ src/ entry.js ------------------------ 入口文件 .babelrc index.html package.json w ...
- 【极简版】SpringBoot+SpringData JPA 管理系统
前言 只有光头才能变强. 文本已收录至我的GitHub仓库,欢迎Star:https://github.com/ZhongFuCheng3y/3y 在上一篇中已经讲解了如何从零搭建一个SpringBo ...
- cookie——登录注册极简版
本实例旨在最直观地说明如何利用cookie完成登录注册功能,忽略正则验证. index.html <!doctype html> <html lang="en"& ...
- 极简版 卸载 home 扩充 根分区--centos7 xfs 文件格式
1. 查看文件系统 df -Th 2. 关闭正常连接 /home的用户 fuser /home 3. 卸载 /home的挂载点 umount /home 4.删除home的lv 注意 lv的名称的写法 ...
随机推荐
- 机器学习 | 强化学习(2) | 动态规划求解(Planning by Dynamic Programming)
动态规划求解(Planning by Dynamic Programming) 动态规划概论 动态(Dynamic):序列性又或是时序性的问题部分 规划(Programming):最优化一个程序(Pr ...
- 赶快检查,木马可能已经植入服务器,Redis未授权访问漏洞记录,redis的key值出现backup要谨慎
问题描述:为图省事,很多时候我们在使用redis的时候会使用默认空密码,这就增加了安全隐患,如果有下属情况,那赶快去检查下redis,木马或许已经植入服务器,应尽快处理: 1.redis绑定在 0.0 ...
- ABAQUS 中的一些约定
目录 自由度notation Axisymmetric elements Activation of degrees of freedom Internal variables in Abaqus/S ...
- [第一章]ABAQUS CM插件中文手册
ABAQUS Composite Modeler User Manual(zh-CN) Dassault Systèmes, 2018 注: 源文档的交叉引用链接,本文无效 有些语句英文表达更易理解, ...
- 震惊!C++程序真的从main开始吗?99%的程序员都答错了
嘿,朋友们好啊!我是小康.今天咱们来聊一个看似简单,但实际上99%的C++程序员都答错的问题:C++程序真的是从main函数开始执行的吗? 如果你毫不犹豫地回答"是",那恭喜你,你 ...
- go declared and not used
Go语言在代码规范中定义未使用的变量会报"declared and not used"错误 package main import "fmt" func mai ...
- 基础指令:mkdir、ls、cd、pwd、touch、rm、mv、cp、echo、cat、关机与重启
目录 1. 创建目录 2. 查看目录内容 3. 进入指定目录(传送) 4. 显示当前所在位置 5. 创建文件 6. 删除文件或目录 7. 移动文件 8. 复制文件或目录 9. echo输出信息到屏幕 ...
- thinkphp6 使用自定义命令,生成数据库视图
在 ThinkPHP 命令行工具中,你可以为选项设置 别名,通过为选项指定一个简短的别名来简化命令输入.例如,如果你希望 --force-recreate 选项有一个简短的别名 -f,你可以通过在 a ...
- 事件监听、焦点--java进阶day03
1.事件 按钮是组件,点击后就会重新游戏 对于这种点击了组件之后,有逻辑触发的操作,就是事件 2.事件中的专有名词 绑定监听也就是绑定监视,是真正组织代码逻辑的地方 要有绑定监听就需要监听器,今天学习 ...
- 【Java】内部类详解
说起内部类这个词,想必很多人都不陌生,但是又会觉得不熟悉.原因是平时编写代码时可能用到的场景不多,用得最多的是在有事件监听的情况下,并且即使用到也很少去总结内部类的用法.今天我们就来一探究竟. 一.内 ...