Transformer 结构分析
self-attetion
1. 输入
X.shape == (batch\_size, seq\_len, embedding\_dim)
\]
2. 计算Q,K,V
K = Linear(X) = XW_{K} \\
V = Linear(X) = XW_{V} \\
\\
W == (embedding\_dim, embedding\_dim) \\
Q, K, V == (batch\_size, seq\_len, embedding\_dim)
\]
3. 处理多头
将最后一维(embedding_dim) 拆成h份,需要保证embedding_dim能够被h整除。每个tensor的最后两个维度表示一个头,QKV各自都有h个头,接下来需要把这些头分别进行计算
4. 计算
按顺序取出上图中的一组QKV,计算:
\]
\]
\]
\]
(1)计算得到各个字之间的关系(相似度).这里的d的维度是
(batch_size, h, seq_len, embedding_dim) * (batch_size, h, embedding_dim, seq_len)==>(batch_size, h, seq_len, seq_len)
。QKV分别有batch_size * h
个矩阵,可以认为是在一个(batch_size, h)
的棋盘中,每个位置放置了一个大小为(seq_len, embedding_dim)
的矩阵。这里的前两个维度不变只是把棋盘中对应位置的矩阵拿出来做矩阵乘法,并把结果再放回到棋盘中。(2)用mask矩阵遮盖掉超出句子长度的部分。将句子中用来pading的字符全部替换成 inf, 这样 计算softmax的时候它们的值会为0,就不会参与到接下来与V的计算当中
(3) \(d_k\) 是为了改变已经偏离的方差。我的理解是,由于矩阵转置后相乘会有很多内积运算,而内积运算将\(d_k\)个数相加时会改变数据的分布。而这个分布的趋势是 \(mean=0, variance=d_k\)。为了使方差回归到1,把所有结果都除上一个\(\sqrt{d_k}\),这样求平方时会抵消已有的方差\(d_k\)
# 均值为0,方差为1
a = np.random.randn(2,3000)
b = np.random.randn(3000,2)
c = a.dot(b) print(np.var(a))
print(np.mean(c))
print(np.var(c)) # 1.0262973662546435
# 25.625943965792157
# 1347.432397285718
To illustrate why the dot products get large, assume that the components of q and k are independent random variables with > mean 0 and variance 1. Then their dot product, \(q \cdot k=\sum_{i=1}^{d_{k}} q_{i} k_{i}\), has mean 0 and variance dk.
(4)计算各个词义所占的比例 \(d \cdot v\),按照权重融合了各个字的语义。最后将多个头的结果拼接成一个完成的embedding作为self-attendion的输出。
(batch_size, h, seq_len, seq_len)
*batch_size, h, seq_len, embedding/h
部分代码如下:
# (batch, seq_len, h, embed/head) -> (batch, h, seq_len, embed/head)
q = self.qry(y).view(y.size(0), y.size(1), self.head, -1).transpose(1, 2)
k = self.key(x).view(x.size(0), x.size(1), self.head, -1).transpose(1, 2)
v = self.val(x).view(x.size(0), x.size(1), self.head, -1).transpose(1, 2)
d = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1)) # 相似度 (batch , h, seq, seq)
d = d.masked_fill(m, -float('inf')) # 把所有为true的地方替换成inf,这里是遮盖掉句子内部的pad
a = F.softmax(d, dim=-1) # (batch , h, seq, seq)
# (batch , h, seq_len, seq_len) * (batch, h, seq_len, embedding/h)
# => (batch, h, seq_len, embedding/h)
# => (batch, seq_len, h, embedding/h)
c = torch.matmul(a, v).transpose(1, 2)
# (batch, seq_len, embedding)
c = c.contiguous().view(c.size(0), c.size(1), -1)
结构图
Encoder的完整过程:
1). 字向量与位置编码:
\]
\]
2). 自注意力机制:
\]
\]
\]
\]
3). 残差连接与\(Layer \ Normalization\)
\]
\]
4). 两层线性映射并用激活函数激活, 比如说\(ReLU\):
\]
5). 重复3).:
\]
\]
\]
Decoder 的完整过程
1). 输入数据
- 输入y的embedding:
X \in ({batch\_size * seq\_len * embed\_dim} )
\]
- encoder层的输出
h \in (batch\_size * seq\_len * embed\_dim)
\]
- mx: x的mask;遮盖住pad的部分,替换为inf,这样计算softmax就会变成0,不会影响后面的计算
def get_pad(self, x):
"""
根据句子的实际长度获取句子的句子的mask。用于计算attention的mask,它不是对角矩阵
维度是 (batch, head, seq_len, seq_len)
:param x:
:return: mask (batch, head, seq_len, seq_len)
"""
seq_len = x.size(1)
pad = (x == 0)
for _ in range(2):
pad = torch.unsqueeze(pad, dim=1)
return pad.repeat(1, self.head, seq_len, 1)
- my: y的mask;用于mask-self-attention,先经过和x的一样的mask过程,再用对角矩阵进行mask,这样在进行训练的时候,只能看到当前字和当前字之前的字。这里的mask是一个对角矩阵,它的形状类似下面这样:
torch.triu(torch.ones(seq_len, seq_len).byte(), diagonal=1) # [0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
# [0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
# [0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
# [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
# [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
# [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
# [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
def get_att(head: int, seq_len: int):
"""
计算mask self attention的mask,对角矩阵
:param head: int
:param seq_len: int
:return:
"""
# 上三角矩阵, 不保留对角线
att = torch.triu(torch.ones(seq_len, seq_len).byte(), diagonal=1)
for _ in range(2):
# torch.squeeze() 删掉维度为1的维度:(1,3)==> (3)
# torch.unsqueeze() 扩充维度,在指定位置加上维数为1的维度:(3)==> (1,3)
att = torch.unsqueeze(att, dim=0)
# 像瓦片一样平铺
return att.repeat(1, head, 1, 1)
2). 多层 decoder Layer结构
- mask-self-attention + 残差 + LayerNorm; y经过mask之后含义已经改变,每一行表示当前词和之前的语义,表示的是某一时刻的可以获得的语义。比如0时刻只能获得第一个单词的语义,而第二个时刻可以获得前两个单词的语义。
mask_self_attention
得到的结果,每一行就是一个时刻包含的语义关系,表示我当前已经翻译出的单词的语义。
y = LayerNorm(y + r)
\]
- self-attention + 残差 + LayerNorm,这里每一层decoder layer的数据都来自encoder的输出x,x经过变换生成K,V,用当前的y计算得到Q。然后计算Q和K的相似度再应用到V上就是结果; 这里的 \(Q_y, K_x, V_x\)就类似于seq2seq中的attention,把每个时刻的y和所有的x进行内积运算,找到每个x的权重再从所有的x中抽取需要的信息。一个\(Q_y\)已经包含了decoder中的所有时刻。最后得到的结果表示的是,每个时刻应该从encoder中抽取哪些信息。\(y_0\)的shape是
(batch_size, h, seq_len, embedding/h)
.
y = LayerNorm(y + r)
\]
- 激活层:
y = LayerNorm(y_0 + y)
\]
class DecodeLayer(nn.Module):
def __init__(self, embed_len, head):
super(DecodeLayer, self).__init__()
self.head = head
self.qrys = nn.ModuleList([nn.Linear(embed_len, embed_len / head) for _ in range(2)])
self.keys = nn.ModuleList([nn.Linear(embed_len, embed_len / head) for _ in range(2)])
self.vals = nn.ModuleList([nn.Linear(embed_len, embed_len / head) for _ in range(2)])
self.lal = nn.Sequential(nn.Linear(embed_len, embed_len),
nn.ReLU(),
nn.Linear(embed_len, embed_len))
self.lns = nn.ModuleList([nn.LayerNorm(embed_len) for _ in range(3)])
def mul_att(self, x, y, m, i):
# q (batch, seq_len, head, embed/head) -> (batch, head, seq_len, embed/head)
q = self.qrys[i](y).view(y.size(0), y.size(1), self.head, -1).transpose(1, 2)
k = self.keys[i](x).view(x.size(0), x.size(1), self.head, -1).transpose(1, 2)
v = self.vals[i](x).view(x.size(0), x.size(1), self.head, -1).transpose(1, 2)
# (batch, head, seq_len, embed/head)
d = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))
d = d.masked_fill(m, -float('inf'))
a = F.softmax(d, dim=-1)
# (batch , h, seq_len, seq_len) * (batch, h, seq_len, embedding/h)
# => (batch, h, seq_len, embedding/h)
# => (batch, seq_len, h, embedding/h)
c = torch.matmul(a, v).transpose(1, 2)
c = c.contiguous().view(c.size(0), c.size(1), -1)
return c
def forward(self, y, x, my, mx):
"""
:param y: 带上positional encoder的embedding。 (batch, seq_len, embedding)
:param x: encoder的输出 (batch, seq_len, embedding)
:param my: y 的mask (batch, head, seq_len, seq_len)
:param mx: x 的mask (batch, head, seq_len, seq_len)
:return:
"""
r = y # 暂时保存用于计算残差网络
y = self.mul_att(y, y, my, 0)
y = self.lns[0](y + r)
r = y
y = self.mul_att(x, y, mx, 1)
y = self.lns[1](y + r)
r = y
y = self.lal(y)
return self.lns[2](y + r)
3)输出:
logits = softmax(y)
\]
待补充
Transformer 结构分析的更多相关文章
- Spatial Transformer Networks(空间变换神经网络)
Reference:Spatial Transformer Networks [Google.DeepMind]Reference:[Theano源码,基于Lasagne] 闲扯:大数据不如小数据 这 ...
- ABBYY PDF Transformer+怎么标志注释
ABBYY PDF Transformer+是一款可创建.编辑.添加注释及将PDF文件转换为其他可编辑格式的通用工具,可用来在PDF页面的任何位置添加注释(关于如何通过ABBYY PDF Transf ...
- OAF_文件系列6_实现OAF导出XML文件javax.xml.parsers/transformer(案例)
20150803 Created By BaoXinjian
- 泛函编程(27)-泛函编程模式-Monad Transformer
经过了一段时间的学习,我们了解了一系列泛函数据类型.我们知道,在所有编程语言中,数据类型是支持软件编程的基础.同样,泛函数据类型Foldable,Monoid,Functor,Applicative, ...
- Facebook的体系结构分析---外文转载
Facebook的体系结构分析---外文转载 From various readings and conversations I had, my understanding of Facebook's ...
- Android项目目录结构分析
Android项目目录结构分析 1.HelloWorld项目的目录结构1.1.src文件夹1.2.gen文件夹1.3.Android 2.1文件夹1.4.assets 1.5.res文件夹1.6.An ...
- 【转载】nedmalloc结构分析
原文:nedmalloc结构分析 nedmalloc是一个跨平台的高性能多线程内存分配库,很多库都使用它,例如:OGRE.现在我们来看看nedmalloc的实现 (以WIN32部分为例) 位操作 ...
- 如何用Transformer+从PDF文档编辑数据
ABBYY PDF Transformer+是一款可创建.编辑.添加注释及将PDF文件转换为其他可编辑格式的通用工具,可使用该软件从PDF文档编辑机密信息,然后再发布它们,文本和图像均可编辑,本文将为 ...
- ABBYY PDF Transformer+ Pro支持全世界189种语言
ABBYY PDF Transformer+ Pro版支持189种语言,包括我们人类的自然语言.人造语言以及正式语言.受支持的语言可能会因产品的版本不同而各异.本文具体列举了所有ABBYY PDF T ...
随机推荐
- 在 C# 中使用 Span<T> 和 Memory<T> 编写高性能代码
目录 在 C# 中使用 Span 和 Memory 编写高性能代码 .NET 中支持的内存类型 .NET Core 2.1 中新增的类型 访问连续内存: Span 和 Memory Span 介绍 C ...
- 第六十九篇:vue项目的运行过程
好家伙, 1.vue的目录结构分析 来看看项目的目录 (粗略的大概的解释) 2.vue项目的运行流程 在工程化项目中,vue要做的事情很单纯:通过main.js把App.vue渲染到index.htm ...
- Linux常用基础命令一
一.目录操作 进入路径 cd [目录地址] 切换回主目录 cd 返回上一个路径 cd - 打印当前路径 pwd 列出目录下文件 ls ---查看只包含非隐藏文件 ls -a -----查看目录下所有文 ...
- 员工离职困扰?来看AI如何解决,基于人力资源分析的 ML 模型构建全方案 ⛵
作者:韩信子@ShowMeAI 数据分析实战系列:https://www.showmeai.tech/tutorials/40 机器学习实战系列:https://www.showmeai.tech/t ...
- KDB_Database_Link 使用介绍
kdb_database_link 是 KingbaseES 为了兼容oracle 语法而开发的跨数据库访问扩展,可用于访问KingbaseES, Postgresql , Oracle .以下分别介 ...
- git reset总结
git reset git 的重置操作 有三种模式:hard.mixed(默认).soft 1. hard 用法 hard会重置stage区和工作区,和移动代码库上HEAD 和branch的指针所指向 ...
- Docker_删除所有容器
删除所有容器 docker rm `docker ps -aq`
- FreeSql 导入数据的各种场景总结 [C#.NET ORM]
前言 导入数据这种脏活.累活,相信大家多多少少都有经历,常见的场景有: 同服务器从A表导数据到B表 批量导入新数据 批量新增或更新数据 跨服务器从A表导数据到B表 每种场景有自己的特点,我们一般会根据 ...
- logstash接受checkpoint防火墙日志并用ruby分词
直接上logstahs配置文件 input{ syslog{ type => "syslog" port => 514 } } filter { grok { matc ...
- java的URI和URL的关系
java的URI和URL到底是什么 在我们做开发时,经常有URI和URL弄混的问题,如果当时直接看URI和URL的源码就不可能弄混.首先我总结一下URI和URL的关系:他们的关系是:URL是一种特殊的 ...