1. 背景

根据本qiang~最新的趋势观察,基于MoE架构的开源大模型越来越多,比如马斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等,因此想探究一下MoE里面的部分细节。

此文是本qiang~针对大语言模型的MoE的整理,包括原理、流程及部分源码。

2. MoE原理

MoE的流行源于”欧洲的OpenAI” Mistral AI发布的论文及模型《Mixtral of Experts》,评测集上的效果吊打众多开源模型,如Llama 2 70B和GPT3.5。

《Mixtral of Experts》基础模型使用的是Mistral AI自研的Mistral 7B,该模型的特点包括:滑窗注意力(Sliding Window Aattention), 滚动缓冲区缓存(Rolling Buffer Cache)以及预填充-分块(Pre-fill and Chunking),具体细节可以查阅文末的论文地址。

本文以《Mixtral of Experts》为引子,探究MoE的相关细节,MoE的原理如下图所示:

图2.1 MoE的原理

(1) Transformers架构中的每一层中的FFN网络均替换为了8个FFN(专家),且由一个网关路由(gate router)进行控制

(2) 针对每一个token,每一层的网关路由仅选择其中的2个FFN(专家)来处理当前状态并进行加权输出

(3) 结果就是,每一个token访问了47B参数,但是在推理阶段仅仅使用了13B的激活参数(即,只使用2个专家,冻结其他6个专家)。

(4) 与Dropout机制对比,Dropout让部分神经元失活,而MoE是让部分专家失活。

3. 源码

本qiang~研读并尝试执行了Mistral官网的github推理代码,该代码框架非常适合新手,无他,只因其几乎只是在torch上层做的封装,很少引擎其他第三方库,不像transformers,功能强大,但不适合新手研读代码…

为了普适性,下面的代码截取了transformers框架中的代码。

首先看下通用Transformers中FFN中的代码模块,代码位置在transformers.models.mistral.modeling_mistral, 主要流程是:

(1) 先经过gate_proj和up_proj的2个[hidden_size, intermediate_size]的线性转换

(2) 使用激活函数对gate_proj进行激活

(3) 二者的内积再经过down_proj线性转换。

 1 class MistralMLP(nn.Module):
2 def __init__(self, config):
3 super().__init__()
4 self.config = config
5 self.hidden_size = config.hidden_size
6 self.intermediate_size = config.intermediate_size
7 self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
8 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
9 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
10 self.act_fn = ACT2FN[config.hidden_act]
11
12 def forward(self, x):
13 return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

再来看下MoE中的专家模块,代码位置在transformers.models.mixtral.modeling_mixtral,主要流程是:

(1) 首先经过网关路由self.gate

(2) 然后选择其中2个专家,并归一化

(3) 之后遍历每个专家网络,并按照expert_mask进行筛选

(4) 如果expert_mask有值,则选择指定部分的隐藏层进行FFN操作,且输出结果进行加权

(5) 最后原地增加先前初始化的最终结果变量final_hidden_states

class MixtralSparseMoeBlock(nn.Module):

    def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok # gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
) # One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx]) if top_x.shape[0] == 0:
continue # in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist() # Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

其中MixtralBlockSparseTop2MLP代码如下,可以看到和传统MistralMLP内容完全一致。

class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states

4. MoE微调

由于MoE只是将每一层的FFN改变为了每一层的gate网关路由+8个FFN专家,且gate网关路由和8个专家内部均为线性运算,所以可以无缝地结合LoRA、QLoRA进行指令微调。

可以参考开源项目:https://github.com/yangjianxin1/Firefly

5. 答疑解惑

(1) 问:MoE 8*7B的模型是56B参数?

答:MoE 8*7B的参数量是47B,而不是56B,原因是每一层除了8个专家网络外,其他层均是复用的。

(2) 问:MoE的基础模型是Mistral 7B?

答:不是,MoE的模型架构与Mistral 7B相同,但其中的FFN替换为了8个FFN,且MoE是基于多语言数据集预训练而来的。

(3) MoE的稀疏性(sparse)体现在哪里?

答:在训练和推理时,同时只有两个专家网络会被激活,进行前向计算,其它专家网络处于失活状态。

6. 总结

一句话足矣~

本文主要针对大语言模型的MoE,包括原理及部分源码。

此外,建议大家可以针对源码进行运行,关于源码,欢迎大家一块交流。

7. 参考

(1) Mistral 7B:https://arxiv.org/pdf/2310.06825v1.pdf

(2) MoE: https://arxiv.org/pdf/2401.04088v1.pdf

(3) MoE开源指令微调框架Firefly: https://github.com/yangjianxin1/Firefly

LLM面面观之MoE的更多相关文章

  1. 2013成都网络赛 C We Love MOE Girls(水题)

    We Love MOE Girls Time Limit: 1000/500 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) ...

  2. 与你相遇好幸运,The Moe Node.js Code Style Guide

    The Moe Node.js Code Style Guide  By 一个最萌的开发者 @2016.9.21 >>代码是人来阅读的,格式规范的代码是对编程人员最好的礼物 :) > ...

  3. 比Redis更快:Berkeley DB面面观

    比Redis更快:Berkeley DB面面观 Redis很火,最近大家用的多.从两年前开始,Memcached转向Redis逐渐成为潮流:而Berkeley DB可能很多朋友还很陌生,首先,我们简单 ...

  4. 《火球——UML大战需求分析》(第2章 耗尽脑汁的需求分析工作)——2.1 需求分析面面观

    说明: <火球——UML大战需求分析>是我撰写的一本关于需求分析及UML方面的书,我将会在CSDN上为大家分享前面几章的内容,总字数在几万以上,图片有数十张.欢迎你按文章的序号顺序阅读,谢 ...

  5. HDU 4730 We Love MOE Girls (2013成都网络赛,签到水题)

    We Love MOE Girls Time Limit: 1000/500 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) ...

  6. 使用Python进行多线程检查.moe三位剩余有效域名

    翻看博客看到一段不错的代码 虽然近期没有购买域名的需求 不过日后有购买域名的需求的话 稍作修改直接使用还是很方便的 import threading import requests import js ...

  7. Hugging Face 每周速递: Chatbot Hackathon;FLAN-T5 XL 微调;构建更安全的 LLM

    每一周,我们的同事都会向社区的成员们发布一些关于 Hugging Face 相关的更新,包括我们的产品和平台更新.社区活动.学习资源和内容更新.开源库和模型更新等,我们将其称之为「Hugging Ne ...

  8. 微软开源了一个 助力开发LLM 加持的应用的 工具包 semantic-kernel

    在首席执行官萨蒂亚·纳德拉(Satya Nadella)的支持下,微软似乎正在迅速转变为一家以人工智能为中心的公司.最近微软的众多产品线都采用GPT-4加持,从Microsoft 365等商业产品到& ...

  9. Semantic Kernel 入门系列:🛸LLM降临的时代

    不论你是否关心,不可否认,AGI的时代即将到来了. 在这个突如其来的时代中,OpenAI的ChatGPT无疑处于浪潮之巅.而在ChatGPT背后,我们不能忽视的是LLM(Large Language ...

  10. Schillace法则:使用LLM创建软件的最佳实践

    LLM(大语言模型)的发展正在改变软件开发的方式. 以前,开发人员需要编写大量的代码来实现其意图,但现在,随着语言模型的发展,开发人员可以使用自然语言来表达他们的意图,而无需编写大量的代码.这使得软件 ...

随机推荐

  1. 【Android 逆向】【攻防世界】Ph0en1x-100

    1. apk 安装到手机,老套路需要输入flag 2. jadx 打开apk,没有加壳 ...... public void onGoClick(View v) { String sInput = t ...

  2. 文心一言 VS 讯飞星火 VS chatgpt (202)-- 算法导论15.3 1题

    一.对于矩阵链乘法问题,下面两种确定最优代价的方法哪种更高效?第一种方法是穷举所有可能的括号化方案,对每种方案计算乘法运算次数,第二种方法是运行RECURSIVE-MATRIX-CHAIN.证明你的结 ...

  3. Arrays.asList的坑

    Arrays.asList 方法的坑 此方法接受可变个数的参数 构建一个ArrayList 可此ArrayList 非彼ArrayList ,他返回的是 Arrays 的一个内部类,实现了Abstra ...

  4. IDEA关联Tomcat(详细教程+安装包)

    IDEA关联Tomcat 下载Tomcat安装包并解压到全英文目录 第一步:打开IDEA--Settings 第二步:搜索application--进入Application Services--点击 ...

  5. zookeeper源码(10)node增删改查及监听

    本文将从leader处理器入手,详细分析node的增删改查流程及监听器原理. 回顾数据读写流程 leader ZookeeperServer.processPacket封装Request并提交给业务处 ...

  6. 重新定义 vscode 命令行工具 code命令 code $profile

    vscode 默认命令行有问题 他那个每次都打开cli.js 目录名里面有空格 要 &开头后面跟双引号 所以从新定义后 变量是 $变量名 前面再加上& 就能调用那个exe了 后面再跟上 ...

  7. vscode 格式化 vue 和 js代码 vetur prettier beautify

    这个文档 不涉及eslint 只专注自动格式化 格式化个性化需求: js中 自动去分号 js中 双引号变单引号 最大空换行数 是2 vue template中 属性自动折行 vue 的自动格式化 需要 ...

  8. 【LLM】大模型落地-从理论到实践

    简述 按个人偏好和目标总结了学习目标和路径(可按需学习),后续将陆续整理出相应学习资料和资源. 学习目标 熟悉主流LLM(Llama, ChatGLM, Qwen)的技术架构和技术细节:有实际应用RA ...

  9. python中记录打印的log模块logging的用法实例

    日志基础教程   日志是对软件执行时所发生事件的一种追踪方式.软件开发人员对他们的代码添加日志调用,借此来指示某事件的发生.一个事件通过一些包含变量数据的描述信息来描述(比如:每个事件发生时的数据都是 ...

  10. Spring Boot命令指定环境启动jar包

    原文地址:Spring Boot命令指定环境启动jar包 - Stars-One的杂货小窝 记下通过命令行的方式去改变spring boot项目中的环境配置信息 命令 项目中有以下配置 applica ...