1. 背景

根据本qiang~最新的趋势观察,基于MoE架构的开源大模型越来越多,比如马斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等,因此想探究一下MoE里面的部分细节。

此文是本qiang~针对大语言模型的MoE的整理,包括原理、流程及部分源码。

2. MoE原理

MoE的流行源于”欧洲的OpenAI” Mistral AI发布的论文及模型《Mixtral of Experts》,评测集上的效果吊打众多开源模型,如Llama 2 70B和GPT3.5。

《Mixtral of Experts》基础模型使用的是Mistral AI自研的Mistral 7B,该模型的特点包括:滑窗注意力(Sliding Window Aattention), 滚动缓冲区缓存(Rolling Buffer Cache)以及预填充-分块(Pre-fill and Chunking),具体细节可以查阅文末的论文地址。

本文以《Mixtral of Experts》为引子,探究MoE的相关细节,MoE的原理如下图所示:

图2.1 MoE的原理

(1) Transformers架构中的每一层中的FFN网络均替换为了8个FFN(专家),且由一个网关路由(gate router)进行控制

(2) 针对每一个token,每一层的网关路由仅选择其中的2个FFN(专家)来处理当前状态并进行加权输出

(3) 结果就是,每一个token访问了47B参数,但是在推理阶段仅仅使用了13B的激活参数(即,只使用2个专家,冻结其他6个专家)。

(4) 与Dropout机制对比,Dropout让部分神经元失活,而MoE是让部分专家失活。

3. 源码

本qiang~研读并尝试执行了Mistral官网的github推理代码,该代码框架非常适合新手,无他,只因其几乎只是在torch上层做的封装,很少引擎其他第三方库,不像transformers,功能强大,但不适合新手研读代码…

为了普适性,下面的代码截取了transformers框架中的代码。

首先看下通用Transformers中FFN中的代码模块,代码位置在transformers.models.mistral.modeling_mistral, 主要流程是:

(1) 先经过gate_proj和up_proj的2个[hidden_size, intermediate_size]的线性转换

(2) 使用激活函数对gate_proj进行激活

(3) 二者的内积再经过down_proj线性转换。

 1 class MistralMLP(nn.Module):
2 def __init__(self, config):
3 super().__init__()
4 self.config = config
5 self.hidden_size = config.hidden_size
6 self.intermediate_size = config.intermediate_size
7 self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
8 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
9 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
10 self.act_fn = ACT2FN[config.hidden_act]
11
12 def forward(self, x):
13 return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

再来看下MoE中的专家模块,代码位置在transformers.models.mixtral.modeling_mixtral,主要流程是:

(1) 首先经过网关路由self.gate

(2) 然后选择其中2个专家,并归一化

(3) 之后遍历每个专家网络,并按照expert_mask进行筛选

(4) 如果expert_mask有值,则选择指定部分的隐藏层进行FFN操作,且输出结果进行加权

(5) 最后原地增加先前初始化的最终结果变量final_hidden_states

class MixtralSparseMoeBlock(nn.Module):

    def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok # gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
) # One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx]) if top_x.shape[0] == 0:
continue # in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist() # Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

其中MixtralBlockSparseTop2MLP代码如下,可以看到和传统MistralMLP内容完全一致。

class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states

4. MoE微调

由于MoE只是将每一层的FFN改变为了每一层的gate网关路由+8个FFN专家,且gate网关路由和8个专家内部均为线性运算,所以可以无缝地结合LoRA、QLoRA进行指令微调。

可以参考开源项目:https://github.com/yangjianxin1/Firefly

5. 答疑解惑

(1) 问:MoE 8*7B的模型是56B参数?

答:MoE 8*7B的参数量是47B,而不是56B,原因是每一层除了8个专家网络外,其他层均是复用的。

(2) 问:MoE的基础模型是Mistral 7B?

答:不是,MoE的模型架构与Mistral 7B相同,但其中的FFN替换为了8个FFN,且MoE是基于多语言数据集预训练而来的。

(3) MoE的稀疏性(sparse)体现在哪里?

答:在训练和推理时,同时只有两个专家网络会被激活,进行前向计算,其它专家网络处于失活状态。

6. 总结

一句话足矣~

本文主要针对大语言模型的MoE,包括原理及部分源码。

此外,建议大家可以针对源码进行运行,关于源码,欢迎大家一块交流。

7. 参考

(1) Mistral 7B:https://arxiv.org/pdf/2310.06825v1.pdf

(2) MoE: https://arxiv.org/pdf/2401.04088v1.pdf

(3) MoE开源指令微调框架Firefly: https://github.com/yangjianxin1/Firefly

LLM面面观之MoE的更多相关文章

  1. 2013成都网络赛 C We Love MOE Girls(水题)

    We Love MOE Girls Time Limit: 1000/500 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) ...

  2. 与你相遇好幸运,The Moe Node.js Code Style Guide

    The Moe Node.js Code Style Guide  By 一个最萌的开发者 @2016.9.21 >>代码是人来阅读的,格式规范的代码是对编程人员最好的礼物 :) > ...

  3. 比Redis更快:Berkeley DB面面观

    比Redis更快:Berkeley DB面面观 Redis很火,最近大家用的多.从两年前开始,Memcached转向Redis逐渐成为潮流:而Berkeley DB可能很多朋友还很陌生,首先,我们简单 ...

  4. 《火球——UML大战需求分析》(第2章 耗尽脑汁的需求分析工作)——2.1 需求分析面面观

    说明: <火球——UML大战需求分析>是我撰写的一本关于需求分析及UML方面的书,我将会在CSDN上为大家分享前面几章的内容,总字数在几万以上,图片有数十张.欢迎你按文章的序号顺序阅读,谢 ...

  5. HDU 4730 We Love MOE Girls (2013成都网络赛,签到水题)

    We Love MOE Girls Time Limit: 1000/500 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) ...

  6. 使用Python进行多线程检查.moe三位剩余有效域名

    翻看博客看到一段不错的代码 虽然近期没有购买域名的需求 不过日后有购买域名的需求的话 稍作修改直接使用还是很方便的 import threading import requests import js ...

  7. Hugging Face 每周速递: Chatbot Hackathon;FLAN-T5 XL 微调;构建更安全的 LLM

    每一周,我们的同事都会向社区的成员们发布一些关于 Hugging Face 相关的更新,包括我们的产品和平台更新.社区活动.学习资源和内容更新.开源库和模型更新等,我们将其称之为「Hugging Ne ...

  8. 微软开源了一个 助力开发LLM 加持的应用的 工具包 semantic-kernel

    在首席执行官萨蒂亚·纳德拉(Satya Nadella)的支持下,微软似乎正在迅速转变为一家以人工智能为中心的公司.最近微软的众多产品线都采用GPT-4加持,从Microsoft 365等商业产品到& ...

  9. Semantic Kernel 入门系列:🛸LLM降临的时代

    不论你是否关心,不可否认,AGI的时代即将到来了. 在这个突如其来的时代中,OpenAI的ChatGPT无疑处于浪潮之巅.而在ChatGPT背后,我们不能忽视的是LLM(Large Language ...

  10. Schillace法则:使用LLM创建软件的最佳实践

    LLM(大语言模型)的发展正在改变软件开发的方式. 以前,开发人员需要编写大量的代码来实现其意图,但现在,随着语言模型的发展,开发人员可以使用自然语言来表达他们的意图,而无需编写大量的代码.这使得软件 ...

随机推荐

  1. 《系列二》-- 9、bean属性填充

    目录 一.概述: populateBean 在什么时候执行? 二.populateBean 的重要操作 三.重点操作一 propertyValue 的注入 3.1 根据 Bean名称注入 3.2 浅看 ...

  2. Redis原理再学习03:数据结构-链表 list

    链表list介绍 1. 链表list简介 链表(linked list)是一种基础数据结构,是一种线性表,但是不会按照线性表的顺序存储数据,而是在每一个节点里存到下一个节点的指针. 链表插入节点时是 ...

  3. macOS使用CodeRunner快速配置fortran环境

    个人网站:xzajyjs.cn 由于一些项目的缘故,需要有fortran的需求,但由于是M1 mac的缘故,不能像windows那样直接使用vs+ivf这种经典配置.搜了一下网上主流的跨平台方案,主要 ...

  4. 【系统选型】OA需求分析,OA系统选型及各供应商对比。

    去年公司内部做OA信息化升级,需要更新换代一下OA系统,当时OA选型整理下来的资料分享一下. 需求调研整理后如下: 一共四个模块需要更新&升级 :  OA模块(包括行政) + 合同模块 + 费 ...

  5. ABP开发需要用到的命令

    0.命令行在哪里执行? 在Visual Studio的"解决方案资源管理器"的解决方案或者项目上点鼠标右键,选择"在终端中打开". 1.安装abp的命令行 官网 ...

  6. 【Azure Function App】在ADF(Azure Data Factory)中调用 Azure Function 时候遇见 Failed to get MI access token

    问题描述 在ADF(Azure Data Factory)中,调用Azure Function App中的Function,遇见了 Failed to get MI access token Ther ...

  7. opencv库图像基础3直方图-python

    opencv库图像基础3直方图-python 直方图是什么 OpenCV 中的直方图是图像中像素值分布情况的统计表示.它是图像空间域内像素值分布的图形表示,以便更好地理解颜色分布. 灰度直方图是图像中 ...

  8. OpenCV开发笔记(七十七):相机标定(二):通过棋盘标定计算相机内参矩阵矫正畸变摄像头图像

    前言   通过相机图片可以识别出棋盘角点了,这时候我们需要通过角点去计算相机内参矩阵,通过上篇得知畸变的原理,所以我们尽可能要全方位都能获取标定图片,全方位意思是提供的多张图综合起来基本覆盖了相机所有 ...

  9. 常见字符的ASCII码值

    ASCII值就是字符对应的十进制数值,字符就是可以表示的字符.

  10. Django:Nginx 启动,无法加载样式,无法加载静态文件

    一般是由于 Nginx 配置文件的问题 # 编辑 Nginx 配置文件 vim /etc/nginx/nginx.conf # 如果出现下面这个 use nginx 就需要改成 use root 保存 ...