LLaMA (以LLaMA2为例,文末附加对比1 2 3 三个版本的变化)
补充背景:
关于Transformer和Llama架构的演进
一、背景
LLaMA 2 和 LLaMA2-Chat
参数规模:70亿、130亿和700亿
数据和训练规模:
上下文长度
训练资源
性能表现:
二、预训练 pretraining
1. 预训练数据
· 训练语料来自公开课用的数据源,不包括Meta的产品或服务数据
· 在2万亿个数据tokens上进行了训练
· 对真实的数据源进行上采样以提高只是并减少错误
2. 训练细节
2.1 标准的Transformer架构
2.2 RMSNorm归一化
2.3 SwiGLU激活函数
2.4 RoPE 旋转位置编码
import torch
import torch.nn as nn
class LlamaRotaryEmbedding(nn.Module):
"""
计算 RoPE(旋转位置编码)所需的 cos(θ) 和 sin(θ) 值
"""
def __init__(self, dim, base=10000):
"""
初始化 LlamaRotaryEmbedding,计算逆频率 inv_freq
:param dim: 需要旋转的位置编码维度(通常是 head_dim)
:param base: 位置编码的基数(通常是 10000)
"""
super().__init__()
# 计算逆频率 inv_freq,维度为 [dim/2]
# 公式:θ_i = 10000^{-2(i-1)/d}
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq) # 不作为模型参数,但存入模型权重
def forward(self, x, position_ids):
"""
计算 cos(θ) 和 sin(θ) 值
:param x: 输入 tensor(仅用于获取 batch 形状)
:param position_ids: 位置索引 (batch_size, seq_len)
:return: cos(θ), sin(θ)
"""
# 扩展 inv_freq 以匹配 position_ids 的 batch 维度
position_ids = position_ids.unsqueeze(-1) # 形状变为 (batch_size, seq_len, 1)
freqs = torch.einsum("bi,j->bij", position_ids.float(), self.inv_freq) # 计算 mθ
emb = torch.cat((freqs, freqs), dim=-1) # 复制 freq,使其维度与 embedding 维度匹配
# 计算 cos(θ) 和 sin(θ),形状 (batch_size, seq_len, head_dim)
cos = emb.cos()
sin = emb.sin()
return cos, sin
def rotate_half(x):
"""
实现向量的旋转变换
例如:输入 [x1, x2, x3, x4] -> 输出 [-x2, x1, -x4, x3]
:param x: 输入 tensor 形状 (batch_size, seq_len, num_heads, head_dim)
:return: 旋转后的 tensor
"""
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] # 拆分 tensor
return torch.cat((-x2, x1), dim=-1) # 交换并加负号
def apply_rotary_pos_emb(q, k, cos, sin):
"""
计算旋转位置编码后的 Q 和 K
:param q: Query (batch_size, num_heads, seq_len, head_dim)
:param k: Key (batch_size, num_heads, seq_len, head_dim)
:param cos: cos(θ) 形状 (batch_size, seq_len, head_dim)
:param sin: sin(θ) 形状 (batch_size, seq_len, head_dim)
:return: 旋转编码后的 Q, K
"""
# 进行旋转变换
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaSdpaAttention(nn.Module):
"""
LLaMA 的注意力机制,包含 RoPE 旋转位置编码
"""
def __init__(self, embed_dim, num_heads):
"""
初始化注意力层
:param embed_dim: 总 embedding 维度
:param num_heads: 头数
"""
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # 每个头的维度
# 线性变换层
self.q_proj = nn.Linear(embed_dim, embed_dim) # Query 投影
self.k_proj = nn.Linear(embed_dim, embed_dim) # Key 投影
self.v_proj = nn.Linear(embed_dim, embed_dim) # Value 投影
# RoPE 旋转位置编码
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
def forward(self, hidden_states, position_ids):
"""
前向传播,计算注意力并应用 RoPE
:param hidden_states: 输入 tensor (batch_size, seq_len, embed_dim)
:param position_ids: 位置索引 (batch_size, seq_len)
:return: 旋转编码后的 query_states, key_states, value_states
"""
# 计算 Q, K, V
query_states = self.q_proj(hidden_states) # (batch_size, seq_len, embed_dim)
key_states = self.k_proj(hidden_states) # (batch_size, seq_len, embed_dim)
value_states = self.v_proj(hidden_states) # (batch_size, seq_len, embed_dim)
# 重新调整形状以匹配多头注意力
batch_size, seq_len, _ = hidden_states.shape
query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 计算 RoPE 旋转角度
cos, sin = self.rotary_emb(hidden_states, position_ids) # 计算 cos(θ), sin(θ)
# 应用 RoPE 旋转编码
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
return query_states, key_states, value_states
# 测试代码
if __name__ == "__main__":
batch_size = 2
seq_len = 4
embed_dim = 16
num_heads = 4
# 生成输入数据
hidden_states = torch.rand(batch_size, seq_len, embed_dim)
position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) # (batch_size, seq_len)
# 实例化注意力层
attention_layer = LlamaSdpaAttention(embed_dim, num_heads)
# 计算 RoPE 旋转后的 Q, K, V
query_states, key_states, value_states = attention_layer(hidden_states, position_ids)
# 输出结果
print("Query States:", query_states.shape)
print("Key States:", key_states.shape)
print("Value States:", value_states.shape)
2.5 GQA 分组查询注意力
2.6 Tokenizer分词器
三、微调 fine-tuning
1. 有监督微调 SFT
2. 基于人工反馈的强化学习 RLHF
3. 多轮对话中保持一致性的系统消息
四、LLaMA的前世今生(LLaMA1,2,3)
Llama1
动机:Meta认为推理成本更重要,所以提高数据量而不是模型大小,因为训练只需要一次,而推理是无数次的
具体行动:针对Transformer-decoder架构,做了以下修改:
和GPT-3一样将Normalization从每个子层的输出位置移动到了输入位置
将Layer Norm 改为 RMS Norm
动机:进行Norm时,对特征进行平移并不能改变特征的分布,所以可以去掉平移相关的部分
Note: 平移相关的部分指的是:
a. 输入特征-均值 \(x - E[x]\)
b. 对标准化后进行线性变化偏差的参数 \(\beta\)
LayerNorm(层归一化)
\(
\text{LayerNorm}(x) = \frac{x - E[x]}{\sqrt{\text{Var}[x] + \epsilon}} * \gamma + \beta
\)
RMSNorm(均方根归一化)
\(\
\text{RMSNorm}(x) = \frac{x}{\sqrt{\text{Mean}(x^2) + \epsilon}} * \gamma
\)
Note:
\(\text{Var}(x) = E[x^2] - (E[x])^2\)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) # 计算均方值
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # 归一化
return self.weight * hidden_states.to(input_dtype) # 乘以可训练参数
采用旋转位置编码
采用silu激活函数
\(
\text{silu}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}
\)
其中,$\sigma(x) = \frac{1}{1 + e^{-x}} $是 Sigmoid 函数。
- 这个函数是输入值 $x $乘以其 Sigmoid 值的结果。
- 猛一看和Relu比较像, 不同的是它不像 ReLU 那样直接截断负值,而是在负数区域仍有非零梯度。
Llama2
70B模型训练了172万GPU小时相当于2048个GPU训练35天
2.引入了GQA(Group Query Attention)
减小模型参数量和kv cache的大小
是左右的折中
Llama2只有70B做了GQA
Llama3
字典从三万2000个Token扩充4倍,提高推理效率,原来一个中文被编码为多个token,现在只需要1一个token,推理次数就减少了。
从仅聊天-->指令跟随
LLaMA (以LLaMA2为例,文末附加对比1 2 3 三个版本的变化)的更多相关文章
- Visual Studio Code-批量在文末添加文本字段
小技巧一例,在vs code或notepad++文末批量添加文本字段信息,便于数据信息的完整,具体操作如下: Visual Studio Code批量添加"@azureyun.com&quo ...
- Angular 2的12个经典面试问题汇总(文末附带Angular測试)
Angular作为眼下最为流行的前端框架,受到了前端开发者的普遍欢迎.不论是初学Angular的新手.还是有一定Angular开发经验的开发者,了解本文中的12个经典面试问题,都将会是一个深入了解和学 ...
- 寻找一把进入 Alibaba Sentinel 的钥匙(文末附流程图)
经过前面几篇文章的铺垫,我们正式来探讨 Sentinel 的 entry 方法的实现流程.即探究进入 Alibaba Sentinel 核心的一把钥匙. @ 目录 1.SphU.entry 流程分析 ...
- Angular 2的12个经典面试问题汇总(文末附带Angular测试)
Angular作为目前最为流行的前端框架,受到了前端开发者的普遍欢迎.不论是初学Angular的新手,还是有一定Angular开发经验的开发者,了解本文中的12个经典面试问题,都将会是一个深入了解和学 ...
- 30分钟玩转Net MVC 基于WebUploader的大文件分片上传、断网续传、秒传(文末附带demo下载)
现在的项目开发基本上都用到了上传文件功能,或图片,或文档,或视频.我们常用的常规上传已经能够满足当前要求了, 然而有时会出现如下问题: 文件过大(比如1G以上),超出服务端的请求大小限制: 请求时间过 ...
- C# 30分钟完成百度人脸识别——进阶篇(文末附源码)
距离上次入门篇时隔两个月才出这进阶篇,小编惭愧,对不住关注我的卡哇伊的小伙伴们,为此小编用这篇博来谢罪. 前面的准备工作我就不说了,注册百度账号api,创建web网站项目,引入动态链接库引入. 不了解 ...
- 文末福利丨i春秋互联网安全校园行第1站精彩回顾
活动背景 为响应国家完善网络安全人才培养体系.推动网络安全教育的号召,i春秋特此发起“互联网安全校园行”系列活动.旨在通过活动和知识普及提升大学生信息安全意识,并通过线下交流.技能分享.安全小活动以及 ...
- i春秋官网4.0上线啦 文末有福利
爱瑞宝地(Everybody)期待了很久的 i春秋官网4.0上线啦 除了产品的功能更加完善 性能和体验也将大幅度提高 清新.舒适的视觉感受 搭配更加便捷的操作流程 只需一秒,扫码立即登录 即刻进入网络 ...
- Angular的12个经典问题,看看你能答对几个?(文末附带Angular测试)
Angular作为目前最为流行的前端框架,受到了前端开发者的普遍欢迎.不论是初学Angular的新手,还是有一定Angular开发经验的开发者,了解本文中的12个经典面试问题,都将会是一个深入了解和学 ...
- 文末有福利 | IT从业者应关注哪些技术热点?
7月14-15日,MPD工作坊北京站即将开幕,目前大会日程已经出炉,来自各大企业的技术专家,按照软件研发中心的岗位职能划分,从产品运营.团队管理.架构技术.自动化运维等领域进行干货分享,点击此[链接] ...
随机推荐
- SQLSTATE[HY000] [2002] Connection refused报错 PHP连接docker容器中的mysql
Laradock 是基于 Docker 提供的完整 PHP 本地开发环境 在框架中连接 MySQL 时 报错 SQLSTATE[HY000] [2002] Connection refused 主要还 ...
- OpenGL与GLSL各版本对应说明
OpenGL 4.6 (API Core Profile) (API Compatibility Profile) OpenGL Shading Language 4.60 Specification ...
- K8S组件详解
K8S的控制平面.和工作节点是集群正常运行的核心,通过这两部分的协同工作,K8S才能够实现高效的容器编排.管理.和自动化运维. K8S Kubernetes(简称K8s),是一个开源的容器编排平台,用 ...
- AIX操作系统基本命令
1,内核 bootinfo -k 2,硬件 bootinfo -r lscfg |grep proc lspv lscfg 3,操作系统 oslevel -r oslevel -s uname ...
- 项目管理知识体系指南(PMBOK 指南)
项目管理知识体系指南(PMBOK 指南) 第6版--笔记项目管理十大知识领域,五大管理过程组,49个过程.如下表格:项目:项目的定义 : (Project Management Institute)项 ...
- 网页P图
此篇文章记录一段比较好玩的网页P图代码 1.在你要修改的网页上Fn + F12或者F12打开控制台,然后在console里输入这样一段代码,回车 document.designMode = 'on' ...
- 【调研】Vision Language Model Safety
Adversarial Attacks White-box Attacks Task-specific Attacks 的目标是针对某个具体的任务(如图像描述生成.指代表达理解等),通过精心设计的对抗 ...
- 简单实现Android的本地文件读写,暨将List数据保存到Json文件中并读出
一.让我们从引入依赖开始 //将这两行代码添加到以上位置,其他的一般不用管 implementation 'com.google.code.gson:gson:2.8.5' implementatio ...
- P3392 涂国旗 题解
题目大意 题目真的是不说人话...... 有一个国家的国旗是由一个 N * M 的方格组成的.如果想要这面国旗合法,就必须满足要求: 国旗从上到下必须是白色.蓝色和红色,顺序不能改变. 每一种颜色都至 ...
- 多模态模型 Grounding DINO 初识
简介 Grounding DINO 是一种先进的零样本目标检测模型,由 IDEA Research 开发.它通过将基于 Transformer 的检测器 DINO 与Grounded Pre-Trai ...