Tiny Encoder Transformer:极简文本分类与注意力可视化实战

项目简介

本项目实现了一个极简版的 Transformer Encoder 文本分类器,并通过 Streamlit 提供了交互式可视化界面。用户可以输入任意文本,实时查看模型的分类结果及注意力权重热力图,直观理解 Transformer 的内部机制。项目采用 HuggingFace 的多语言 BERT 分词器,支持中英文等多种语言输入,适合教学、演示和轻量级 NLP 应用开发。


主要功能

  • 多语言支持:集成 HuggingFace bert-base-multilingual-cased 分词器,支持 100+ 语言。
  • 极简 Transformer 结构:自定义实现位置编码、单层/多层 Transformer Encoder、分类头,结构清晰,便于学习和扩展。
  • 注意力可视化:可实时展示输入文本的注意力热力图和每个 token 被关注的占比,帮助理解模型关注机制。
  • 高效演示:训练时仅用 AG News 数据集的前 200 条数据,并只训练 10 个 batch,保证页面加载和交互速度。

代码结构与核心实现

1. 数据加载与预处理

使用 HuggingFace datasets 库加载 AG News 数据集,并用 BERT 分词器对文本进行编码:

from datasets import load_dataset
from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
dataset = load_dataset("ag_news")
dataset["train"] = dataset["train"].select(range(200)) # 只用前200条数据 def encode(example):
tokens = tokenizer(
example["text"],
padding="max_length",
truncation=True,
max_length=64,
return_tensors="pt"
)
return {
"input_ids": tokens["input_ids"].squeeze(0),
"label": example["label"]
} encoded_train = dataset["train"].map(encode)

2. Tiny Encoder 模型结构

模型包含词嵌入层、位置编码、若干 Transformer Encoder 层和分类头,支持输出每层的注意力权重:

import torch.nn as nn

class PositionalEncoding(nn.Module):
# ... 位置编码实现,见下文详细代码 ... class TransformerEncoderLayerWithTrace(nn.Module):
# ... 支持 trace 的单层 Transformer Encoder,见下文详细代码 ... class TinyEncoderClassifier(nn.Module):
# ... 嵌入、位置编码、编码器堆叠、分类头,见下文详细代码 ...

3. 训练流程

采用交叉熵损失和 Adam 优化器,仅训练 10 个 batch,极大提升演示速度:

import torch.optim as optim
from torch.utils.data import DataLoader train_loader = DataLoader(encoded_train, batch_size=16, shuffle=True)
model = TinyEncoderClassifier(...)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3) model.train()
for i, batch in enumerate(train_loader):
if i >= 10: # 只训练10个batch
break
input_ids = batch["input_ids"]
labels = batch["label"]
logits, _ = model(input_ids)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

4. Streamlit 可视化界面

  • 提供文本输入框,用户可输入任意文本。
  • 实时推理并展示分类结果。
  • 可视化 Transformer 第一层各个注意力头的权重热力图和每个 token 被关注的占比(条形图)。
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt user_input = st.text_input("请输入文本:", "We all have a home called China.")
if user_input:
# ... 推理与注意力可视化代码,见下文详细代码 ...

训练与推理流程详解

  1. 数据加载与预处理

    • 加载 AG News 数据集,仅取前 200 条样本。
    • 用多语言 BERT 分词器编码文本,填充/截断到 64 长度。
  2. 模型结构

    • 词嵌入层将 token id 映射为向量。
    • 位置编码为每个 token 添加可区分的位置信息。
    • 堆叠若干 Transformer Encoder 层,支持输出注意力权重。
    • 分类头对第一个 token 的输出做分类(类似 BERT 的 [CLS])。
  3. 训练流程

    • 损失函数为交叉熵,优化器为 Adam。
    • 只训练 1 个 epoch,且只训练 10 个 batch,保证演示速度。
  4. 推理与可视化

    • 用户输入文本,模型输出预测类别编号。
    • 可视化注意力热力图和每个 token 被关注的占比,直观展示模型关注点。

适用场景

  • Transformer 原理教学与可视化演示
  • 注意力机制理解与分析
  • 多语言文本分类任务的快速原型开发
  • NLP 课程、讲座、实验室演示

完整案例说明:


Tiny Encoder

1. 代码主要功能

该脚本实现了一个基于 Transformer Encoder 的文本分类模型,并通过 Streamlit 提供了可视化界面,

支持输入一句话并展示模型的分类结果及注意力权重热力图。

2. 主要模块说明

  • Tokenizer 初始化

    • 使用 HuggingFace 的多语言 BERT Tokenizer 对输入文本进行分词和编码。
  • 模型结构
    • 包含词嵌入层、位置编码、若干 Transformer Encoder 层(带注意力权重 trace)、分类器。
  • 数据处理与训练
    • 加载 AG News 数据集,编码文本,训练模型并保存。
    • 若已存在训练好的模型则直接加载。
  • Streamlit 可视化
    • 提供文本输入框,实时推理并展示分类结果。
    • 可视化 Transformer 第一层各个注意力头的权重热力图。

3. 数据流向说明

  1. 输入

    • 用户在 Streamlit 网页输入一句英文(或多语言)文本。
  2. 分词与编码
    • Tokenizer 将文本转为固定长度的 token id 序列(input_ids)。
  3. 模型推理
    • input_ids 输入 TinyEncoderClassifier,经过嵌入、位置编码、若干 Transformer 层,输出 logits(分类结果)和注意力权重(trace)。
  4. 分类输出
    • 取 logits 最大值作为类别预测,显示在网页上。
  5. 注意力可视化
    • 取第一层注意力权重,分别绘制每个 head 的热力图,帮助理解模型关注的 token 关系。

4. 适用场景

  • 适合教学、演示 Transformer 注意力机制和文本分类原理。
  • 可扩展用于多语言文本分类任务。

import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt # ============================
# 位置编码模块
# ============================
class PositionalEncoding(nn.Module):
"""
位置编码模块:为输入的 token 序列添加可区分位置信息。
使用正弦和余弦函数生成不同频率的编码。
"""
def __init__(self, d_model, max_len=512):
super().__init__()
# 创建一个 (max_len, d_model) 的全零张量,用于存储位置编码
pe = torch.zeros(max_len, d_model)
# 生成位置索引 (max_len, 1)
position = torch.arange(0, max_len).unsqueeze(1)
# 计算每个维度对应的分母项(不同频率)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
# 偶数位置用 sin,奇数位置用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 增加 batch 维度,形状变为 (1, max_len, d_model)
pe = pe.unsqueeze(0)
# 注册为 buffer,模型保存时一同保存,但不是参数
self.register_buffer('pe', pe) def forward(self, x):
"""
输入:x,形状为 (batch, seq_len, d_model)
输出:加上位置编码后的张量,形状同输入
"""
return x + self.pe[:, :x.size(1)] # ============================
# 单层 Transformer Encoder,支持输出注意力权重
# ============================
class TransformerEncoderLayerWithTrace(nn.Module):
"""
单层 Transformer Encoder,支持输出注意力权重。
包含多头自注意力、前馈网络、残差连接和层归一化。
"""
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
# 多头自注意力层
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
# 前馈网络第一层
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(0.1)
# 前馈网络第二层
self.linear2 = nn.Linear(dim_feedforward, d_model)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout 层
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1) def forward(self, src, trace=False):
"""
前向传播。
参数:
src: 输入序列,形状为 (batch, seq_len, d_model)
trace: 是否返回注意力权重
返回:
src: 输出序列
attn_weights: 注意力权重(如果 trace=True)
"""
# 多头自注意力,attn_weights 形状为 (batch, nhead, seq_len, seq_len)
attn_output, attn_weights = self.self_attn(src, src, src, need_weights=trace)
# 残差连接 + 层归一化
src2 = self.dropout1(attn_output)
src = self.norm1(src + src2)
# 前馈网络
src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
# 残差连接 + 层归一化
src = self.norm2(src + self.dropout2(src2))
# 返回输出和注意力权重(可选)
return src, attn_weights if trace else None # ============================
# Tiny Transformer 分类模型
# ============================
class TinyEncoderClassifier(nn.Module):
"""
Tiny Transformer 分类模型:
包含嵌入层、位置编码、若干 Transformer 编码器层和分类头。
支持输出每层的注意力权重。
"""
def __init__(self, vocab_size, d_model, n_heads, d_ff, num_layers, max_len, num_classes):
super().__init__()
# 词嵌入层,将 token id 映射为向量
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码模块
self.pos_encoder = PositionalEncoding(d_model, max_len)
# 堆叠多个 Transformer 编码器层
self.layers = nn.ModuleList([
TransformerEncoderLayerWithTrace(d_model, n_heads, d_ff) for _ in range(num_layers)
])
# 分类头,对第一个 token 的输出做分类
self.classifier = nn.Linear(d_model, num_classes) def forward(self, input_ids, trace=False):
"""
前向传播。
参数:
input_ids: 输入 token id,形状为 (batch, seq_len)
trace: 是否输出注意力权重
返回:
logits: 分类输出 (batch, num_classes)
traces: 每层的注意力权重(可选)
"""
# 词嵌入
x = self.embedding(input_ids)
# 加位置编码
x = self.pos_encoder(x)
traces = []
# 依次通过每一层 Transformer 编码器
for layer in self.layers:
x, attn = layer(x, trace=trace)
if trace:
traces.append({"attn_map": attn})
# 只取第一个 token 的输出做分类(类似 BERT 的 [CLS])
logits = self.classifier(x[:, 0])
return logits, traces if trace else None # ============================
# 模型构建与训练函数,显式使用CPU
# ============================
@st.cache_resource(show_spinner=False)
def build_and_train_model(d_model, n_heads, d_ff, num_layers):
device = torch.device('cpu') # 显式指定使用CPU tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
dataset = load_dataset("ag_news")
dataset["train"] = dataset["train"].select(range(200)) # 只用前200条数据
MAX_LEN = 64 def encode(example):
tokens = tokenizer(example["text"], padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt")
return {"input_ids": tokens["input_ids"].squeeze(0), "label": example["label"]} encoded_train = dataset["train"].map(encode)
encoded_train.set_format(type="torch")
train_loader = DataLoader(encoded_train, batch_size=16, shuffle=True) model = TinyEncoderClassifier(
vocab_size=tokenizer.vocab_size,
d_model=d_model,
n_heads=n_heads,
d_ff=d_ff,
num_layers=num_layers,
max_len=MAX_LEN,
num_classes=4
).to(device) # 模型放到CPU criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3) model.train()
for epoch in range(1): # 训练1个epoch
for i, batch in enumerate(train_loader):
if i >= 10: # 只训练10个batch
break
input_ids = batch["input_ids"].to(device) # 输入转到CPU
labels = batch["label"].to(device)
logits, _ = model(input_ids)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step() return model, tokenizer # ============================
# Streamlit 页面设置
# ============================
st.set_page_config(page_title="TinyEncoder")
st.title(" Tiny Encoder Transformer") # 固定模型参数
# d_model: 隐藏层维度,
# n_heads: 注意力头数,
# d_ff: 前馈层维度,
# num_layers: Transformer 层数
d_model = 64
n_heads = 2
d_ff = 128
num_layers = 1 # 构建并训练模型
with st.spinner("模型构建中..."):
model, tokenizer = build_and_train_model(d_model, n_heads, d_ff, num_layers) # ============================
# 推理与注意力权重可视化
# ============================
model.eval()
device = torch.device('cpu')
model.to(device) user_input = st.text_input("请输入文本:", "We all have a home called China.")
if user_input:
tokens = tokenizer(user_input, return_tensors="pt", max_length=64, padding="max_length", truncation=True)
input_ids = tokens["input_ids"].to(device) # 放CPU with torch.no_grad():
logits, traces = model(input_ids, trace=True) pred_class = torch.argmax(logits, dim=-1).item()
st.markdown(f"### 预测类别编号: `{pred_class}`")
if traces:
attn_map = traces[0]["attn_map"]
if attn_map is not None:
seq_len = input_ids.shape[1]
token_list = tokenizer.convert_ids_to_tokens(input_ids[0])
if '[PAD]' in token_list:
valid_len = token_list.index('[PAD]')
else:
valid_len = seq_len
token_list = token_list[:valid_len]
if attn_map.dim() == 4:
# [batch, heads, seq_len, seq_len]
heads = attn_map.size(1)
fig, axes = plt.subplots(1, heads, figsize=(5 * heads, 3))
if heads == 1:
axes = [axes]
for i in range(heads):
matrix = attn_map[0, i][:valid_len, :valid_len].cpu().detach().numpy()
sns.heatmap(matrix, ax=axes[i], cbar=False, xticklabels=token_list, yticklabels=token_list)
axes[i].set_title(f"Head {i}")
axes[i].tick_params(labelsize=6)
# 显示每个 token 被关注的占比
attn_sum = matrix.sum(axis=0)
attn_ratio = attn_sum / attn_sum.sum()
fig2, ax2 = plt.subplots(figsize=(5, 2))
ax2.bar(range(valid_len), attn_ratio)
ax2.set_xticks(range(valid_len))
ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
ax2.set_title(f"Head {i} Token Attention Ratio")
st.pyplot(fig2)
st.pyplot(fig)
elif attn_map.dim() == 3:
# [heads, seq_len, seq_len]
heads = attn_map.size(0)
fig, axes = plt.subplots(1, heads, figsize=(5 * heads, 3))
if heads == 1:
axes = [axes]
for i in range(heads):
matrix = attn_map[i][:valid_len, :valid_len].cpu().detach().numpy()
sns.heatmap(matrix, ax=axes[i], cbar=False, xticklabels=token_list, yticklabels=token_list)
axes[i].set_title(f"Head {i}")
axes[i].tick_params(labelsize=6)
# 显示每个 token 被关注的占比
attn_sum = matrix.sum(axis=0)
attn_ratio = attn_sum / attn_sum.sum()
fig2, ax2 = plt.subplots(figsize=(5, 2))
ax2.bar(range(valid_len), attn_ratio)
ax2.set_xticks(range(valid_len))
ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
ax2.set_title(f"Head {i} Token Attention Ratio")
st.pyplot(fig2)
st.pyplot(fig)
elif attn_map.dim() == 2:
# [seq_len, seq_len]
fig, ax = plt.subplots(figsize=(5, 3))
sns.heatmap(attn_map[:valid_len, :valid_len].cpu().detach().numpy(), ax=ax, cbar=False, xticklabels=token_list, yticklabels=token_list)
ax.set_title("Attention Map")
ax.tick_params(labelsize=6)
st.pyplot(fig)
# 显示每个 token 被关注的占比
matrix = attn_map[:valid_len, :valid_len].cpu().detach().numpy()
attn_sum = matrix.sum(axis=0)
attn_ratio = attn_sum / attn_sum.sum()
fig2, ax2 = plt.subplots(figsize=(5, 2))
ax2.bar(range(valid_len), attn_ratio)
ax2.set_xticks(range(valid_len))
ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
ax2.set_title("Token Attention Ratio")
st.pyplot(fig2)
else:
st.warning("注意力权重维度异常,无法可视化。")

第9.1讲、Tiny Encoder Transformer:极简文本分类与注意力可视化实战的更多相关文章

  1. Vim,极简使用教程,让你瞬间脱离键鼠切换的痛苦

    注:看大家对Vim仇恨极大,其实它只是一种文本操作方式,可以减少键鼠的切换,从而让编辑文本的操作更迅捷.并不等同于IDE,在我看来,它们是两个是包含关系,IDE可以有Vim编辑模式.Vim或许可以通过 ...

  2. 在Web应用中接入微信支付的流程之极简清晰版

    在Web应用中接入微信支付的流程之极简清晰版 背景: 在Web应用中接入微信支付,我以为只是调用几个API稍作调试即可. 没想到微信的API和官方文档里隐坑无数,致我抱着怀疑人生的心情悲愤踩遍了丫们布 ...

  3. 在Web应用中接入微信支付的流程之极简清晰版 (转)

    在Web应用中接入微信支付的流程之极简清晰版 背景: 在Web应用中接入微信支付,我以为只是调用几个API稍作调试即可. 没想到微信的API和官方文档里隐坑无数,致我抱着怀疑人生的心情悲愤踩遍了丫们布 ...

  4. 文本分类实战(八)—— Transformer模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

  5. Mongodb极简实践

    MongoDB 极简实践入门 1. 为什么用MongoDB? 传统的计算机应用大多使用关系型数据库来存储数据,比如大家可能熟悉的MySql, Sqlite等等,它的特点是数据以表格(table)的形式 ...

  6. MongoDB 极简实践入门

    原作者StevenSLXie; 原链接(https://github.com/StevenSLXie/Tutorials-for-Web-Developers/blob/master/MongoDB% ...

  7. 极简Node教程-七天从小白变大神(一:你需要Express)

    如果说用一句话来概括Node那就是:它开启了JavaScript服务器端语言. Node系列的文章并不会从一开始长篇概论的讲Node的历史,安装,以及其他很琐碎的事情.只会专门介绍关于Node或者准确 ...

  8. Spring Boot (七): Mybatis极简配置

    Spring Boot (七): Mybatis极简配置 1. 前言 ORM 框架的目的是简化编程中的数据库操作,经过这么多年的发展,基本上活到现在的就剩下两家了,一个是宣称可以不用写 SQL 的 H ...

  9. 一个基于protobuf的极简RPC

    前言 RPC采用客户机/服务器模式实现两个进程之间的相互通信,socket是RPC经常采用的通信手段之一.当然,除了socket,RPC还有其他的通信方法:http.管道...网络开源的RPC框架也比 ...

  10. 极简估值教程——第一篇 速判估值与PEG的推导

    来自盛京剑客的雪球原创专栏 一.极简速判估值怎么判? 很简单.简单到粗暴. 用PEG PEG=PE/(g*100)=1.0 什么意思? PE市盈率,g未来收益增长率,PEG为1.0合理估值,大于1.0 ...

随机推荐

  1. idea中启动web、jsp项目

    1. idea打开项目 选择要打开的项目的根目录 2. 项目配置 配置jdk modules配置 添加web 添加依赖 删除爆红的依赖 添加依赖目录或者jar 配置web.xml 配置lib 如果没有 ...

  2. 一文搞懂 MCP Servers

    一文搞懂 MCP Servers 什么是MCP MCP概念 MCP(Model Context Protocol,模型上下文协议)是由 Anthropic 提出并于 2024 年 11 月开源的一种通 ...

  3. MySQL2022.3.2

    创建库 CREATE DATABASE STUDENT; 创建表 CREATE TABLE STU(SNO INT PRIMARY KEY,//主键SNM CHAR(2) NOT NULL,//不能为 ...

  4. class com.ttsx.activity.item.dao.entity.RoleMenu ,Not found @TableId annotation, Cannot use Mybatis-Plus 'xxById' Method. 报错解决办法

    启动项目的时候,有几条WARN警告,如图: 引起原因: 是因为数据表实体类 没有di导致的. 例如: 解决办法: 增加实体id字段!或可以忽略!不影响程序!

  5. 迅速理解 LCS 最长公共子序列问题

    在算法与数据结构的经典问题中,最长公共子序列(Longest Common Subsequence,简称 LCS)问题占据着重要的地位.给定两个序列,我们需要找到它们最长的公共子序列,而子序列要求保持 ...

  6. DevEco Studio 常用设置【自用】

    设置为中文 API参考设置悬浮 始终定位打开的文件,单击预览免打开 保存时自动格式化和热更新 属性单独一行

  7. 1、从DeepSeek API调用到Semantic Kernel集成:深度解析聊天机器人开发全链路

    引言:AI时代下的聊天机器人开发范式演进 在生成式AI技术爆发的当下,基于大语言模型(LLM)的聊天机器人开发已形成标准化技术链路.本文将结合DeepSeek API与微软Semantic Kerne ...

  8. python正则表达式笔记2

    由 '\' 和一个字符组成的特殊序列在以下列出. 如果普通字符不是ASCII数位或者ASCII字母,那么正则样式将匹配第二个字符.比如,\$ 匹配字符 '$'. \number匹配数字代表的组合.每个 ...

  9. Docker top 命令

    Docker 命令大全Docker 命令大全docker top :查看容器中运行的进程信息,支持 ps 命令参数.语法docker top [OPTIONS] CONTAINER [ps OPTIO ...

  10. K8S 问题排查: cgroup 内存泄露问题

    Posted on 2019年12月6日Leave a comment Contents [hide] 1 前言 2 现象 3 原因 4 解决方案 4.1 方案一 4.2 方案二 4.3 方案三 5 ...