一、Transformer

Transformer最开始用于机器翻译任务,其架构是seq2seq的编码器解码器架构。其核心是自注意力机制: 每个输入都可以看到全局信息,从而缓解RNN的长期依赖问题。
输入: (待学习的)输入词嵌入 + 位置编码(相对位置)
编码器结构: 6层编码器: 一层编码器 = 多头注意力+残差(LN) + FFN+残差(LN)
输出:每一个位置上输出预测概率分布(K类类别分布)

1.1 自注意力

分解式

缩放内积注意力
1. 自注意力的优势
         a. 计算开销,计算可并行 (嵌入维度d,序列长度n,计算复杂度O(n^2d))
         b. 建模长期依赖 (稳定训练过程)
2. 自注意力缩放(内积过大,softmax饱和)
We suspect that for large values d_k, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.To counteract this effect, we scale the dot products by sqrt(d_k)
如上为原文。作者怀疑,如果Q和K的维度特别大,会使得内积后的值也大。从而使softmax进入梯度极小的区域(类似于sigmoid的饱和区域)。 这样容易导致梯度消失。
所以,他们将内积值除以sqrt(d_k),进行一个缩放,而又不破坏相对比例。
 
多头注意力机制(multi-head attention)
Transformer 提出多头注意力机制(不同头结果拼起来,再做线性变换),增强了 attention 层的能力(参数量不变)。解释:
  1. 它扩展了模型关注不同位置的能力。不同注意力头,关注不同的位置。长距离依赖
  2. 多头注意力机制赋予 attention 层多个“子表示空间(训练之后,每组注意力可以看作是把输入的向量映射到一个”子表示空间“)
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)
参数说明如下:
  • embed_dim:最终输出的 K、Q、V 矩阵的维度,这个维度需要和词向量的维度一样
  • num_heads:设置多头注意力的数量。如果设置为 1,那么只使用一组注意力。如果设置为其他数值,那么 num_heads 的值需要能够被 embed_dim 整除
  • dropout:这个 dropout 加在 attention score 后面
定义 MultiheadAttention 的对象后,调用时传入的参数如下。
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)
  • query:对应于 Query 矩阵,形状是 (L,N,E) 。其中 L 是输出序列长度,N 是 batch size,E 是词向量的维度
  • key:对应于 Key 矩阵,形状是 (S,N,E) 。其中 S 是输入序列长度,N 是 batch size,E 是词向量的维度
  • value:对应于 Value 矩阵,形状是 (S,N,E) 。其中 S 是输入序列长度,N 是 batch size,E 是词向量的维度
  • key_padding_mask:如果提供了这个参数,那么计算 attention score 时,忽略 Key 矩阵中某些 padding 元素,不参与计算 attention(序列长度不同)。形状是 (N,S)。其中 N 是 batch size,S 是输入序列长度。
    • 如果 key_padding_mask 是 ByteTensor,那么非 0 元素对应的位置会被忽略
    • 如果 key_padding_mask 是 BoolTensor,那么 True 对应的位置会被忽略
  • attn_mask:计算输出时,忽略某些位置。形状可以是 2D (L,S),或者 3D (N∗numheads,L,S)。其中 L 是输出序列长度,S 是输入序列长度,N 是 batch size。
    • 如果 attn_mask 是 ByteTensor,那么非 0 元素对应的位置会被忽略
    • 如果 attn_mask 是 BoolTensor,那么 True 对应的位置会被忽略
在实际中,K、V 矩阵的长度一样,而 Q 矩阵的序列长度可不一样。这种情况发生在:在解码器部分的encoder-decoder attention层中,Q 矩阵是来自解码器下层,而 K、V 矩阵则是来自编码器的输出。
 

2. Encoder 和 Decoder

 
编码器就是编码器层(多头注意力+(残差+LN),FFN+(残差+LN))的堆叠。
 
解码器

Self-attention layers in the decoder allow each position in the decoder to attend to all positions in the decoder up to and including that position.We need to prevent leftward information flow in the decoder to preserve the auto-regressive property.We implement this inside of scaled dot-product attention by masking out (setting to −∞) all values in the input of the softmax which correspond to illegal connections.

      为了保持自回归的性质,要保持从左往右的顺序 (这种情况下,不能利用要预测的未来来推断过去)。  这里将当前token以后的进行mask (即将注意力得分加上-inf,将其变成无穷小,使其注意力系数极小接近于无) [exp(-inf) = 0]
      GAT也是这样做的,只不过mask的是非邻居结点 (避免信息泄露,从而让模型学不好)。
 

避免信息泄露,在解码器中使用mask:

# mask 不为空,那么就把 mask 为 0 的位置的 attention 分数设置为 -1e10(系数无穷小)
attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
if mask is not None:
  attention = attention.masked_fill(mask == 0, -1e10)
  attention = self.do(torch.softmax(attention, dim=-1))
  x = torch.matmul(attention, V)
 
交叉注意力层(Decoder-Encoder Attention, Decoder到Encoder输出的潜表示)
使用前一层的输出来构造 Query 矩阵,而 Key 矩阵和 Value 矩阵来自于编码器最终的输出(seq2seq都是这样的,预测当前输出时,不仅看之前的输出,同时也对输入隐状态进行关注)

 
预测
每一个位置有一个分类损失;总的损失就是每个位置损失之和。
 
训练
让我们假设输出词汇表只包含 6 个单词(“a”, “am”, “i”, “thanks”, “student”, and “”(“”表示句子末尾))。
 
这种架构本就可以用来做语言模型,只不过这里做了seq2seq的翻译。
如果训练数据中本身就有很多句子对,就可以直接通过语言模型实现翻译,例如GPT架构。
 
学习笔记,配图参考知乎-张贤同学、李宏毅机器学习。

【大语言模型基础】-详解Transformer原理的更多相关文章

  1. 学习《深度学习与计算机视觉算法原理框架应用》《大数据架构详解从数据获取到深度学习》PDF代码

    <深度学习与计算机视觉 算法原理.框架应用>全书共13章,分为2篇,第1篇基础知识,第2篇实例精讲.用通俗易懂的文字表达公式背后的原理,实例部分提供了一些工具,很实用. <大数据架构 ...

  2. Java基础学习总结(33)——Java8 十大新特性详解

    Java8 十大新特性详解 本教程将Java8的新特新逐一列出,并将使用简单的代码示例来指导你如何使用默认接口方法,lambda表达式,方法引用以及多重Annotation,之后你将会学到最新的API ...

  3. 深入浅出DOM基础——《DOM探索之基础详解篇》学习笔记

    来源于:https://github.com/jawil/blog/issues/9 之前通过深入学习DOM的相关知识,看了慕课网DOM探索之基础详解篇这个视频(在最近看第三遍的时候,准备记录一点东西 ...

  4. Android中Canvas绘图基础详解(附源码下载) (转)

    Android中Canvas绘图基础详解(附源码下载) 原文链接  http://blog.csdn.net/iispring/article/details/49770651   AndroidCa ...

  5. Python学习二:词典基础详解

    作者:NiceCui 本文谢绝转载,如需转载需征得作者本人同意,谢谢. 本文链接:http://www.cnblogs.com/NiceCui/p/7862377.html 邮箱:moyi@moyib ...

  6. 三剑客基础详解(grep、sed、awk)

    目录 三剑客基础详解 三剑客之grep详解 1.通配符 2.基础正则 3.grep 讲解 4.拓展正则 5.POSIX字符类 三剑客之sed讲解 1.sed的执行流程 2.语法格式 三剑客之Awk 1 ...

  7. Dom探索之基础详解

    认识DOM DOM级别 注::DOM 0级标准实际并不存在,只是历史坐标系的一个参照点而已,具体的说,它指IE4.0和Netscape Navigator4.0最初支持的DHTML. 节点类型 注:1 ...

  8. javaScript基础详解(1)

    javaScript基础详解 首先讲javaScript的摆放位置:<script> 与 </script> 可以放在head和body之间,也可以body中或者head中 J ...

  9. Python学习一:序列基础详解

    作者:NiceCui 本文谢绝转载,如需转载需征得作者本人同意,谢谢. 本文链接:http://www.cnblogs.com/NiceCui/p/7858473.html 邮箱:moyi@moyib ...

  10. java继承基础详解

    java继承基础详解 继承是一种由已存在的类型创建一个或多个子类的机制,即在现有类的基础上构建子类. 在java中使用关键字extends表示继承关系. 基本语法结构: 访问控制符 class 子类名 ...

随机推荐

  1. Hadoop3 No FileSystem for scheme "hdfs"

    Hadoop3 No FileSystem for scheme "hdfs" 异常信息: org.apache.hadoop.fs.UnsupportedFileSystemEx ...

  2. 神经网络优化篇:详解Batch Norm 为什么奏效?(Why does Batch Norm work?)

    Batch Norm 为什么奏效? 为什么Batch归一化会起作用呢? 一个原因是,已经看到如何归一化输入特征值\(x\),使其均值为0,方差1,它又是怎样加速学习的,有一些从0到1而不是从1到100 ...

  3. 【链表】双向链表的介绍和基本操作(C语言实现)【保姆级别详细教学】

    双向链表 文章目录 前言 双向链表的基本介绍 一些链表的分类 带头双向循环链表的基本结构 双向链表的实现 结点的定义.头指针的创建 开辟结点接口 初始化头结点接口 打印接口 尾插接口 尾删接口 头插接 ...

  4. Leetcode刷题第七天-回溯-哈希

    332:重新岸炮行程 链接:332. 重新安排行程 - 力扣(LeetCode) 机场字典:{起飞机场:[到达机场的列表]} 去重:到达机场列表,i>0时,当前机场和上一个机场相等,contin ...

  5. 小团队如何妙用 JuiceFS

    早些年还在 ENJOY 的时候, 就已经在用 JuiceFS, 并且一路伴随着我工作过的四家小公司, 这玩意对我来说, 已经成了理所应当不可或缺的基础设施, 对于我服务过的小团队而言, 更是实实在在的 ...

  6. 《ASP.NET Core 与 RESTful API 开发实战》-- (第7章)-- 读书笔记(中)

    第 7 章 高级主题 7.2 并发 当两个用户获取同一个资源后,再同时修改该资源,就会导致并发问题 常见实现并发的方法有以下两种: 保守式并发控制,每次修改资源,都锁定资源 开放式并发控制,每次修改资 ...

  7. 【SpringBootStarter】自定义全局加解密组件

    [SpringBootStarter] 目的 了解SpringBoot Starter相关概念以及开发流程 实现自定义SpringBoot Starter(全局加解密) 了解测试流程 优化 最终引用的 ...

  8. ubuntu之jupyter notebook配置

    ubuntu之jupyter notebook配置 安装jupyter: 前提:安装pip pip install jupyter jupyter notebook 配置: 生成配置文件: jupyt ...

  9. NC208250 牛牛的最美味和最不美味的零食

    题目链接 题目 题目描述 牛牛为了减(吃)肥(好),希望对他的零食序列有更深刻的了解,所以他把他的零食排成一列,然后对每一个零食的美味程度都打了分,现在他有可能执行两种操作: eat k:吃掉当前的第 ...

  10. python 中异常类型总结

    异常类型: 异常名称 描述BaseException             所有异常的基类SystemExit                   解释器请求退出KeyboardInterrupt  ...