Transformer架构详解

1. 架构概述

Transformer是一种基于自注意力机制的神经网络架构,由Vaswani等人在2017年的论文《Attention Is All You Need》中首次提出。它彻底改变了自然语言处理领域,逐步取代了传统的RNN和CNN架构。

主要特点

  • 完全基于注意力机制:摒弃了传统的循环和卷积结构
  • 并行计算能力强:克服了RNN的顺序计算限制
  • 长距离依赖捕捉能力:通过自注意力机制有效捕捉任意距离的依赖关系
  • 模块化设计:编码器-解码器结构清晰,易于扩展

2. 核心组件详解

2.1 编码器(Encoder)

结构组成

  1. 输入嵌入层(Input Embedding)

    • 将输入的token转换为稠密向量表示
    • 通常维度为512(d_model)或更大
    • 包含可学习的参数矩阵
  2. 位置编码(Positional Encoding)

    • 解决Transformer缺少位置信息的问题
    • 使用正弦和余弦函数的固定模式:
      PE(pos,2i) = sin(pos/10000^(2i/d_model))
      PE(pos,2i+1) = cos(pos/10000^(2i/d_model))
    • 与输入嵌入相加而非拼接
  3. 编码器层堆叠(Encoder Layers)

    • 通常6层或更多(原始论文使用6层)
    • 每层包含两个子层:
      • 多头自注意力机制(Multi-Head Self-Attention)
      • 前馈神经网络(Feed Forward Network)
    • 每个子层都有残差连接和层归一化

自注意力机制细节

  • 计算过程

    Attention(Q,K,V) = softmax(QK^T/√d_k)V
    • Q(Query), K(Key), V(Value)来自同一输入
    • √d_k缩放防止点积过大导致梯度消失
  • 多头注意力

    • 将Q,K,V投影到h个不同子空间(h通常为8)
    • 并行计算h个注意力头
    • 拼接所有头的结果并通过线性变换

2.2 解码器(Decoder)

结构组成

  1. 输出嵌入层(Output Embedding)

    • 与编码器嵌入层类似但独立
    • 通常共享编码器嵌入层的权重(可选)
  2. 位置编码

    • 与编码器使用相同的实现方式
  3. 解码器层堆叠(Decoder Layers)

    • 通常6层(与编码器层数相同)
    • 每层包含三个子层:
      • 掩码多头自注意力(Masked Multi-Head Self-Attention)
      • 编码器-解码器注意力(Encoder-Decoder Attention)
      • 前馈神经网络
    • 每个子层都有残差连接和层归一化

关键机制

  • 掩码自注意力

    • 防止解码器"偷看"未来信息
    • 通过添加负无穷掩码(-∞)实现
    • 确保位置i只能关注位置≤i的token
  • 编码器-解码器注意力

    • Q来自解码器前一层的输出
    • K,V来自编码器的最终输出
    • 允许解码器关注输入序列的相关部分

2.3 词表与分词器(Tokenizer)

主流分词方法

  1. Byte Pair Encoding (BPE)

    • 通过合并频繁出现的字节对逐步构建词表
    • 平衡词表大小与token序列长度
    • 被GPT系列模型采用
  2. WordPiece

    • 类似BPE但基于概率合并
    • 被BERT等模型采用
  3. SentencePiece

    • 直接从原始文本训练
    • 支持Unicode字符级处理
    • 不依赖预处理分词

词表设计考量

  • 大小通常在30k-100k之间
  • 包含特殊token([CLS],[SEP],[MASK]等)
  • 处理未知token的方式(如[UNK]或字节级回退)

2.4 位置编码的演进

  1. 原始正弦编码

    • 固定模式,无需学习
    • 理论上可外推到任意长度
  2. 可学习位置编码

    • 将位置视为可学习的嵌入
    • 被BERT等模型采用
    • 但受限于训练时见过的最大长度
  3. 相对位置编码

    • 关注token之间的相对距离而非绝对位置
    • 多种变体(T5, Transformer-XL等)
    • 更好的长度外推能力
  4. 旋转位置编码(RoPE)

    • 通过旋转矩阵注入位置信息
    • 被LLaMA等最新模型采用
    • 保持相对位置关系的线性特性

3. 训练机制

3.1 优化目标

  1. 语言建模目标

    • 自回归模型(GPT风格):最大化下一个token的似然
    • 自编码模型(BERT风格):预测被掩码的token
  2. 损失函数

    • 交叉熵损失(Cross-Entropy Loss)
    • 可选项:标签平滑(Label Smoothing)

3.2 优化策略

  1. Adam优化器变种

    • 原始Transformer使用Adam(β1=0.9, β2=0.98, ε=1e-9)
    • 现代变种:AdamW(解耦权重衰减)
  2. 学习率调度

    • 预热学习率(Warmup):
      lrate = d_model^-0.5 * min(step_num^-0.5, step_num*warmup_steps^-1.5)
    • 余弦衰减
    • 线性衰减
  3. 正则化技术

    • Dropout(原始论文使用0.1)
    • 权重衰减(通常0.01)
    • 梯度裁剪(通常1.0)

4. 架构演进与变体

4.1 经典变体

  1. BERT (2018)

    • 仅使用编码器
    • 双向自注意力
    • 掩码语言建模目标
  2. GPT (2018)

    • 仅使用解码器(带掩码)
    • 自回归语言建模
  3. T5 (2019)

    • 编码器-解码器完整结构
    • 将所有任务统一为文本到文本格式

4.2 效率优化方向

  1. 稀疏注意力

    • Longformer(2020):局部+全局注意力
    • BigBird(2020):随机+局部+全局注意力
  2. 混合架构

    • FNet(2021):用傅里叶变换替代注意力
    • Hyena(2023):结合CNN与注意力
  3. 状态空间模型

    • Mamba(2023):选择性状态空间

4.3 现代大语言模型架构

  1. GPT-3/4

    • 纯解码器架构
    • 旋转位置编码
    • 极深(96层以上)和极宽(12288维度)
  2. LLaMA

    • RMSNorm替代LayerNorm
    • SwiGLU激活函数
    • 旋转位置编码
  3. PaLM

    • 并行注意力与前馈计算
    • 共享QKV投影

5. 实践建议

5.1 超参数选择

  1. 维度关系

    • d_model(模型维度):通常512-12288
    • d_ff(前馈层维度):通常4*d_model
    • h(注意力头数):通常d_model/64
  2. 深度选择

    • 基础模型:6-12层
    • 大型模型:24-96层
    • 极深模型:100+层

5.2 训练技巧

  1. 初始化策略

    • 注意力层:Xavier/Glorot初始化
    • 残差连接:初始化为1/√N(N为层数)
  2. 精度选择

    • FP32:传统选择
    • FP16/BF16:现代标准
    • 混合精度训练:常用实践
  3. 批处理策略

    • 动态批处理
    • 梯度累积

6. 关键数学原理

6.1 注意力机制数学

  1. 缩放点积注意力

    Attention(Q,K,V) = softmax(QK^T/√d_k)V
    • Q ∈ ℝ^{n×d_k}, K ∈ ℝ^{m×d_k}, V ∈ ℝ^
    • 计算复杂度:O(nmd_k + nmd_v)
  2. 多头注意力

    MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O

    其中:

    head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

6.2 前馈网络

FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
  • 通常d_ff = 4*d_model
  • 现代变体使用GELU或SwiGLU

6.3 层归一化

LayerNorm(x) = γ ⊙ (x - μ)/σ + β
  • μ,σ为均值和标准差
  • γ,β为可学习参数

7. 典型应用场景

  1. 文本生成

    • 自回归采样策略(贪心、beam search、top-k/p采样)
    • 温度控制
  2. 机器翻译

    • 编码器处理源语言
    • 解码器生成目标语言
  3. 文本分类

    • 使用[CLS]token或平均池化
    • 接分类头
  4. 问答系统

    • 编码问题与上下文
    • 预测答案起始/结束位置

项目实例

项目实例:https://github.com/toke648/SimpleLLM

一个基于Transformer架构实现的简易LLM大语言模型,完全自定义Transformer,通过大量的问答数据训练实现,支持自定义训练流程以及超参数等

从最基础的训练词表开始到构建Transformer、训练模型、推理,完全从代码上完整的大模型训练全流程

详细介绍了训练流程以及操作方法,并且持续更新

感兴趣的点个star,Ciallo~(∠・ω< )⌒★

欢迎贡献代码

相关文章

Attention Is All You Need:https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf

Attention Is All You Need 解读:https://blog.csdn.net/chengyq116/article/details/106065576/

Attention as Energy Minimization Visualizing Energy Landscapes(大佬博客):https://mcbal.github.io/post/attention-as-energy-minimization-visualizing-energy-landscapes/

.......

(正在施工)

Transformer架构介绍+从零搭建预训练模型项目的更多相关文章

  1. 从零搭建一个SpringCloud项目之Feign搭建

    从零搭建一个SpringCloud项目之Feign搭建 工程简述 目的:实现trade服务通过feign调用user服务的功能.因为trade服务会用到user里的一些类和接口,所以抽出了其他服务需要 ...

  2. 从零搭建一个IdentityServer——项目搭建

    本篇文章是基于ASP.NET CORE 5.0以及IdentityServer4的IdentityServer搭建,为什么要从零搭建呢?IdentityServer4本身就有很多模板可以直接生成一个可 ...

  3. 从零搭建react hooks项目(github有源代码)

    前言 首先这是一个react17的项目,包含项目中常用的路由.状态管理.less及全局变量配置.UI等等一系列的功能,开箱即用,是为了以后启动项目方便,特地做的基础框架,在这里分享出来. 这里写一下背 ...

  4. 从零搭建vue3.0项目架构(附带代码、步骤详解)

    前言: GitHub上我开源了vue-cli.vue-cli3两个库,文章末尾会附上GitHub仓库地址.这次把2.0的重新写了一遍,优化了一下.然后按照2.0的功能和代码,按照vue3.0的语法,完 ...

  5. Ionic01 简单介绍、环境搭建、创建项目、项目结构、创建组件、创建页面、子页面跳转

    1 Ionic 基本介绍 Ionic 是一款基于 Angular.Cordova 的强大的 HTML5 移动应用开发框架 , 可以快速创建一个跨平台的移动应用.可以快速开发移动 App.移动端 WEB ...

  6. 使用Vue脚手架(vue-cli)从零搭建一个vue项目(包含vue项目结构展示)

    注:在搭建项目之前,请先安装一些全局的工具(如:node,vue-cli等) node安装:去node官网(https://nodejs.org/en/)下载并安装node即可,安装node以后就可以 ...

  7. 从零搭建一个SpringCloud项目之Config(五)

    配置中心 一.配置中心服务端 新建项目study-config-server 引入依赖 <dependency> <groupId>org.springframework.cl ...

  8. 从零搭建一个SpringCloud项目之Zuul(四)

    整合Zuul 为什么要使用Zuul? 易于监控 易于认证 减少客户端与各个微服务之间的交互次数 引入依赖 <dependency> <groupId>org.springfra ...

  9. 从零搭建一个IdentityServer——目录(更新中...)

    从零搭建一个IdentityServer--项目搭建 从零搭建一个IdentityServer--集成Asp.net core Identity 从零搭建一个IdentityServer--初识Ope ...

  10. 从零搭建Pytorch模型教程(三)搭建Transformer网络

    ​ 前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍. ...

随机推荐

  1. OpenDeepWiki:AI驱动的代码知识库文档生成技术深度解析

    项目地址 Git仓库: https://github.com/AIDotNet/OpenDeepWiki 在线体验: https://opendeepwiki.com 本文档基于: 当前本地仓库分析 ...

  2. Java 解算法:合并区间

    题目:以数组 intervals 表示若干个区间的集合,其中单个区间为 intervals[i] = [starti, endi] .请你合并所有重叠的区间,并返回 一个不重叠的区间数组,该数组需恰好 ...

  3. 12Java基础之多态

    多态 多态是在继承/实现情况下的一种现象,表现为:对象多态.行为多态. 多态存在的条件 有继承关系 子类重写父类的方法 父类引用指向子类对象 多态的一个注意事项 多态是对象.行为的多态,Java中的属 ...

  4. [置顶] WHO AM I ?

    我是一名又菜又爱玩的做题家(什么也不会的大学生 会在这里发出我觉得有意思的题目以及一些好玩的trick 同时如果手机端的同学想要看算法竞赛比赛题解的话可以移步至我的知乎 link

  5. HTTP请求头中表示代理IP地址的属性及获取情况

    博客:https://www.emanjusaka.com 公众号:emanjusaka的编程栈 by emanjusaka from https://www.emanjusaka.com/archi ...

  6. JS slice(0);克隆数组

    // 克隆 const cloneArr1 = arr1.slice(0);

  7. Java变量与常量全解析(包含常量类、interface 与 final 的比较)

    ​ Java中的变量 变量是Java程序中最基本的存储单元,用于存储数据值.变量在程序运行期间其值可以改变.变量必须先声明后使用. 变量声明语法: 数据类型 变量名 [= 初始值]; 变量分类: 局部 ...

  8. TCP拥塞控制及常见算法(美团)

    TCP拥塞控制:确保网络中数据流量合理传输,避免网络拥塞崩溃的重要机制. TCP 拥塞控制的原理:TCP 通过监测网络中的拥塞迹象,如分组丢失.延迟增加等,来调整发送端的数据发送速率.当网络出现拥塞时 ...

  9. [题解]P2444 [POI2000] 病毒

    P2444 [POI2000] 病毒 题目核心是多模式匹配,所以考虑用对所有模式串建立AC自动机. 我们把自动机上,存在一个模式串作为前缀的节点,称作"危险节点". 如果无限长的安 ...

  10. MySQL 误操作时进行数据恢复

    binlog2sql binlog2sql 是一款用于解析 binlog 的工具, 可以从MySQL binlog解析出你要的SQL. 根据不同选项,你可以得到原始SQL.回滚SQL.去除主键的INS ...