第9.3讲、Tiny Transformer: 极简版Transformer
简介
极简版的 Transformer 编码器-解码器(Seq2Seq)结构,适合用于学习、实验和小型序列到序列(如翻译、摘要)任务。
该实现包含了位置编码、多层编码器、多层解码器、训练与推理流程,代码简洁易懂,便于理解 Transformer 的基本原理。
主要结构
- PositionalEncoding:正弦/余弦位置编码,为输入embedding添加位置信息。
- TransformerEncoderLayerWithTrace:单层编码器,含自注意力和前馈网络。
- TinyTransformer:多层堆叠的编码器。
- TransformerDecoderLayer:单层解码器,含自注意力、交叉注意力和前馈网络。
- TransformerDecoder:多层堆叠的解码器。
- TinyTransformerSeq2Seq:编码器-解码器整体结构。
- Seq2SeqDataset:简单的序列到序列数据集。
- train:训练循环。
- greedy_decode:贪婪解码推理。
- generate_subsequent_mask:生成自回归mask。
依赖环境
- Python 3.7+
- torch >= 1.10
安装 PyTorch(以 CPU 版本为例):
pip install torch
用法示例
1. 构建模型
from demo import TinyTransformerSeq2Seq
src_vocab_size = 1000
trg_vocab_size = 1000
model = TinyTransformerSeq2Seq(src_vocab_size, trg_vocab_size)
2. 构造数据集和 DataLoader
from demo import Seq2SeqDataset
from torch.utils.data import DataLoader
src_data = torch.randint(0, src_vocab_size, (100, 10)) # 100个样本,每个10个token
trg_data = torch.randint(0, trg_vocab_size, (100, 12)) # 100个样本,每个12个token
dataset = Seq2SeqDataset(src_data, trg_data)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
3. 训练模型
import torch.optim as optim
import torch.nn as nn
from demo import train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)
epoch_loss = train(model, dataloader, optimizer, criterion, tgt_pad_idx=0)
print(f"Train loss: {epoch_loss}")
4. 推理(贪婪解码)
from demo import greedy_decode
src = torch.randint(0, src_vocab_size, (1, 10)).to(device)
sos_idx = 1 # 假设1为<sos>
eos_idx = 2 # 假设2为<eos>
max_len = 12
output = greedy_decode(model, src, sos_idx, eos_idx, max_len)
print("Output token ids:", output)
注意事项
- 该实现为教学/实验用途,未包含完整的mask、权重初始化、分布式训练等工业级细节。
- 需要自行准备合适的训练数据和词表。
- 若需工业级NLP任务,建议使用 HuggingFace Transformers。
tiny Transformer案例代码1:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.utils.data import DataLoader, Dataset
# =========================
# 位置编码模块
# =========================
class PositionalEncoding(nn.Module):
"""
为输入的embedding添加位置信息,帮助模型捕捉序列顺序。
"""
def __init__(self, d_model, max_len=512):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
return x + self.pe[:, :x.size(1)]
# =========================
# 编码器层
# =========================
class TransformerEncoderLayerWithTrace(nn.Module):
"""
单层Transformer编码器,带自注意力和前馈网络。
"""
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None):
# src: (batch, seq_len, d_model)
src2, attn_weights = self.self_attn(src, src, src, attn_mask=src_mask)
src = self.norm1(src + self.dropout(src2))
src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
src = self.norm2(src + self.dropout(src2))
return src, attn_weights
# =========================
# 编码器
# =========================
class TinyTransformer(nn.Module):
"""
多层堆叠的Transformer编码器。
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, dim_feedforward, max_len, num_layers):
super().__init__()
self.embedding = nn.Embedding(src_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
TransformerEncoderLayerWithTrace(d_model, nhead, dim_feedforward)
for _ in range(num_layers)
])
def forward(self, src, trace=False):
# src: (batch, seq_len)
src = self.embedding(src)
src = self.pos_encoder(src)
attn_weights_all = []
for layer in self.layers:
src, attn_weights = layer(src)
if trace:
attn_weights_all.append(attn_weights)
return src, attn_weights_all
# =========================
# 解码器层
# =========================
class TransformerDecoderLayer(nn.Module):
"""
单层Transformer解码器,含自注意力、交叉注意力和前馈网络。
"""
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
# tgt: (batch, tgt_seq_len, d_model)
# memory: (batch, src_seq_len, d_model)
tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)
tgt = self.norm1(tgt + self.dropout(tgt2))
tgt2, _ = self.cross_attn(tgt, memory, memory, attn_mask=memory_mask)
tgt = self.norm2(tgt + self.dropout(tgt2))
tgt2 = self.linear2(self.dropout(torch.relu(self.linear1(tgt))))
tgt = self.norm3(tgt + self.dropout(tgt2))
return tgt
# =========================
# 解码器
# =========================
class TransformerDecoder(nn.Module):
"""
多层堆叠的Transformer解码器。
"""
def __init__(self, vocab_size, d_model, nhead, dim_feedforward, num_layers, max_len=512):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
TransformerDecoderLayer(d_model, nhead, dim_feedforward)
for _ in range(num_layers)
])
self.out_proj = nn.Linear(d_model, vocab_size)
def forward(self, tgt_ids, memory, tgt_mask=None):
# tgt_ids: (batch, tgt_seq_len)
x = self.embedding(tgt_ids)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x, memory, tgt_mask)
return self.out_proj(x)
# =========================
# Seq2Seq整体模型
# =========================
class TinyTransformerSeq2Seq(nn.Module):
"""
编码器-解码器结构的Transformer模型。
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=64, heads=4, d_ff=128, num_layers=2, max_len=64):
super().__init__()
self.encoder = TinyTransformer(src_vocab_size, tgt_vocab_size, d_model, heads, d_ff, max_len, num_layers)
self.decoder = TransformerDecoder(
vocab_size=tgt_vocab_size,
d_model=d_model,
nhead=heads,
dim_feedforward=d_ff,
num_layers=num_layers,
max_len=max_len
)
def forward(self, src_ids, tgt_input_ids, tgt_mask=None):
# src_ids: (batch, src_seq_len)
# tgt_input_ids: (batch, tgt_seq_len)
memory, _ = self.encoder(src_ids, trace=False)
logits = self.decoder(tgt_input_ids, memory, tgt_mask)
return logits
# =========================
# 工具函数: 生成自回归mask
# =========================
def generate_subsequent_mask(size):
"""
生成自回归mask,防止解码器看到未来的信息。
"""
return torch.triu(torch.full((size, size), float('-inf')), diagonal=1)
# =========================
# Toy数据集
# =========================
class Seq2SeqDataset(Dataset):
"""
简单的序列到序列数据集。
"""
def __init__(self, src_data, tgt_data):
self.src = src_data
self.tgt = tgt_data
def __len__(self):
return len(self.src)
def __getitem__(self, idx):
return self.src[idx], self.tgt[idx]
# =========================
# 训练循环
# =========================
def train(model, dataloader, optimizer, criterion, tgt_pad_idx):
"""
训练模型一个epoch。
"""
model.train()
total_loss = 0
for src, tgt in dataloader:
src, tgt = src.to(device), tgt.to(device)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
tgt_mask = generate_subsequent_mask(tgt_input.size(1)).to(device)
logits = model(src, tgt_input, tgt_mask)
loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# =========================
# 推理(贪婪解码)
# =========================
def greedy_decode(model, src, sos_idx, eos_idx, max_len):
"""
贪婪解码:每步选择概率最大的token,直到eos或最大长度。
"""
model.eval()
src = src.to(device)
memory, _ = model.encoder(src, trace=False)
tgt = torch.ones((src.size(0), 1), dtype=torch.long).fill_(sos_idx).to(device)
for _ in range(max_len - 1):
tgt_mask = generate_subsequent_mask(tgt.size(1)).to(device)
out = model.decoder(tgt, memory, tgt_mask)
next_token = out[:, -1, :].argmax(dim=-1).unsqueeze(1)
tgt = torch.cat([tgt, next_token], dim=1)
if (next_token == eos_idx).all():
break
return tgt
# =========================
# 设备选择
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
第9.3讲、Tiny Transformer: 极简版Transformer的更多相关文章
- Underscore源码阅读极简版入门
看了网上的一些资料,发现大家都写得太复杂,让新手难以入门.于是写了这个极简版的Underscore源码阅读. 源码: https://github.com/hanzichi/underscore-an ...
- js消除小游戏(极简版)
js小游戏极简版 (1) 基础布局 <div class = "box"> <p></p> <div class="div&qu ...
- SimpleThreadPool极简版
package com.dwz.concurrency.chapter13; import java.util.ArrayList; import java.util.LinkedList; impo ...
- 极简版ASP.NET Core学习路径及教程
绝承认这是一个七天速成教程,即使有这个效果,我也不愿意接受这个名字.嗯. 这个路径分为两块: 实践入门 理论延伸 有了ASP.NET以及C#的知识以及项目经验,我们几乎可以不再需要了解任何新的知识就开 ...
- Vue Virtual Dom 和 Diff原理(面试必备) 极简版
我又来了,这是Vue面试三板斧的最后一招,当然也是极其简单了,先说Virtual Dom,来一句概念: 用js来模拟DOM中的节点.传说中的虚拟DOM. 再来一张图: 是不是一下子秒懂 没懂再来一张 ...
- Vue数据双向绑定(面试必备) 极简版
我又来吹牛逼了,这次我们简单说一下vue的数据双向绑定,我们这次不背题,而是要你理解这个流程,保证读完就懂,逢人能讲,面试必过,如果没做到,请再来看一遍,走起: 介绍双向数据之前,我们先解释几个名词: ...
- 极简版 react+webpack 脚手架
目录结构 asset/ css/ img/ src/ entry.js ------------------------ 入口文件 .babelrc index.html package.json w ...
- 【极简版】SpringBoot+SpringData JPA 管理系统
前言 只有光头才能变强. 文本已收录至我的GitHub仓库,欢迎Star:https://github.com/ZhongFuCheng3y/3y 在上一篇中已经讲解了如何从零搭建一个SpringBo ...
- cookie——登录注册极简版
本实例旨在最直观地说明如何利用cookie完成登录注册功能,忽略正则验证. index.html <!doctype html> <html lang="en"& ...
- 极简版 卸载 home 扩充 根分区--centos7 xfs 文件格式
1. 查看文件系统 df -Th 2. 关闭正常连接 /home的用户 fuser /home 3. 卸载 /home的挂载点 umount /home 4.删除home的lv 注意 lv的名称的写法 ...
随机推荐
- php-fpm自动重启 解决方案
环境:Mac.php7.1.nginx 现象:killall php-fpm,php-fpm自动重启 共有如下几种解决方案: 1.检查php-fpm.conf的deamonize模式是否开启 2.查找 ...
- LCP 11. 期望个数统计
地址:https://leetcode-cn.com/problems/qi-wang-ge-shu-tong-ji/ <?php /** 某互联网公司一年一度的春招开始了,一共有 n 名面试者 ...
- php站点导入大mysql文件(linux系统)
问题描述:站点数据多,mysql导出后大于1G,使用phpmyadmin,导入一半报错,白白浪费等待时间,使用navicat 导入,执行时间过长提示错误 解决思路:1.拆分mysql文件,分批次导入, ...
- pandas 操作excel
一 Series 什么是series 相当于表格中的行和列,不同的设置可以按行或列排序 2.series 创建 空的series import pandas as pd s2=pd.Series() ...
- Web前端入门第 17 问:前端开发编辑器及插件推荐
HELLO,这里是大熊学习前端开发的入门笔记. 本系列笔记基于 windows 系统. 虽然说 Web 前端开发用记事本也能玩,但正常的开发者绝不用记事本玩(大佬除外). 想想要用记事本扣一个淘宝.京 ...
- 工作面试必备:SQL 中的各种连接 JOIN 的区别总结!
前言 尽管大多数开发者在日常工作中经常用到Join操作,如Inner Join.Left Join.Right Join等,但在面对特定查询需求时,选择哪种Join类型以及如何使用On和Where子句 ...
- NumPy学习4
今天学习NumPy相关数组操作 NumPy 中包含了一些处理数组的常用方法,大致可分为以下几类:(1)数组变维操作(2)数组转置操作(3)修改数组维度操作(4)连接与分割数组操作 numpy_test ...
- 通过局域网访问连接 vite 或 Django 之类的项目
博客地址:https://www.cnblogs.com/zylyehuo/ step1 将 vite 或 Django 类的项目启动 ip 设置为 0.0.0.0:端口 step2 查询本机电脑在当 ...
- MySQL-InnoDB行锁
InnoDB的锁类型 InnoDB存储引擎支持行锁,锁类型有两种: 共享锁(S锁) 排他锁(X锁) S和S不互斥,其他均互斥. 除了这两种锁以外,innodb还支持一种锁,叫做意向锁. 那么什么是意向 ...
- soapUI参数化总结
1.新建项目目录 以获取用户贡献等级为例,目录如下: 2.添加DataSource和DataSource Loop 选中Test Step右键分别新建DataSource和DataSource Loo ...