前置知识:

PyTorch 基础函数操作整理

1. topk 操作

功能: torch.topk 用于返回输入张量中指定维度上的前 k 个最大元素及其对应的索引。

示例代码:

import torch

x = torch.tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]) values, indices = torch.topk(x, k=2, dim=1) print(values)
print(indices)

输出:

values: tensor([[4, 3],
[9, 5],
[6, 5]]) indices: tensor([[2, 0],
[2, 1],
[1, 2]])

2. scatter_ 操作

功能: torch.scatter_ 是一个原地操作函数,用于根据指定索引 index,将 src 中的元素分散到目标张量的指定位置。

示例代码:

import torch

indices = torch.tensor([[0, 2],
[1, 2],
[1, 2]]) result = torch.zeros([3, 3]).scatter_(1, indices, True)
print(result)

输出:

tensor([[1., 0., 1.],
[0., 1., 1.],
[0., 1., 1.]])

3. unsqueeze 操作

功能: 在指定维度上插入一个大小为 1 的新维度,从而改变张量的形状。

示例代码:

import torch

x = torch.tensor([[1, 2], [3, 4]])
print("原始张量形状:", x.shape) y = torch.unsqueeze(x, dim=0)
print("在第 0 维插入新维度后的张量形状:", y.shape) z = torch.unsqueeze(x, dim=1)
print("在第 1 维插入新维度后的张量形状:", z.shape) w = torch.unsqueeze(x, dim=2)
print("在第 2 维插入新维度后的张量形状:", w.shape)

输出:

原始张量形状: torch.Size([2, 2])
在第 0 维插入新维度后的张量形状: torch.Size([1, 2, 2])
在第 1 维插入新维度后的张量形状: torch.Size([2, 1, 2])
在第 2 维插入新维度后的张量形状: torch.Size([2, 2, 1])

4. gather 操作

功能: torch.gather 根据给定的索引 index 从输入张量中收集元素,构建一个新的张量。(gatherscatter_ 互为反操作)

示例代码:

import torch

input_tensor = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]) index_tensor = torch.tensor([[2, 0],
[1, 2],
[0, 1]]) output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

输出:

tensor([[30, 10],
[50, 60],
[70, 80]])

5. bincount 操作

功能: 统计非负整数张量中每个值出现的次数。

示例代码:

import torch

input_tensor = torch.tensor([1, 1, 2, 2, 10])
output = torch.bincount(input_tensor)
print(output)

输出:

tensor([0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 1])

6. where 操作

功能: 根据给定的条件对张量元素进行选择性操作,类似于 Python 中的三元运算符,返回满足条件的元素索引。

示例代码:

import torch

input_tensor = torch.tensor([[1, 2], [3, 4]])
indices = torch.where(input_tensor == 2)
print(indices)

输出:

(tensor([0]), tensor([1]))

原始MOE 代码实现

import torch
from torch import nn # ExpertNetwork 类:定义每个专家的网络
class ExpertNetwork(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super().__init__()
self.hidden_size = hidden_size # 输入和输出的特征维度
self.intermediate_size = intermediate_size # 中间层的大小 # 定义两个线性层
self.linear1 = nn.Linear(hidden_size, intermediate_size) # (batch_size, hidden_size) -> (batch_size, intermediate_size)
self.linear2 = nn.Linear(intermediate_size, hidden_size) # (batch_size, intermediate_size) -> (batch_size, hidden_size) def forward(self, x):
x = self.linear1(x) # 经过第一个线性层
x = nn.functional.relu(x) # ReLU 激活函数
output = self.linear2(x) # 经过第二个线性层
return output # 返回输出,尺寸为 (batch_size, hidden_size) # Router 类:用于选择每个输入数据的专家
class Router(nn.Module):
def __init__(self, hidden_size, expert_num, top_k):
super().__init__()
self.router = nn.Linear(hidden_size, expert_num) # (batch_size, hidden_size) -> (batch_size, expert_num)
self.top_k = top_k # 每次选择 top_k 个专家
self.hidden_size = hidden_size # 输入的特征维度 def forward(self, x):
x = x.view(-1, self.hidden_size) # 展平输入,尺寸变为 (batch_size * seq_len, hidden_size)
x = self.router(x) # 通过 router 得到每个专家的选择权重,尺寸为 (batch_size * seq_len, expert_num)
x = nn.functional.softmax(x, dim=-1) # 使用 softmax 转换为概率分布,尺寸为 (batch_size * seq_len, expert_num)
topk_weight, topk_idx = torch.topk(x, k=self.top_k, dim=-1, sorted=False) # 选择 top_k 个专家,尺寸为 (batch_size * seq_len, top_k) # 权重归一化,使得它们的和为 1,尺寸为 (batch_size * seq_len, top_k)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) return topk_weight, topk_idx # 返回选择的 top_k 权重和专家索引 # MOELayer 类:实现混合专家层
class MOELayer(nn.Module):
def __init__(self, hidden_size, intermediate_size, expert_num, top_k):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.expert_num = expert_num
self.top_k = top_k # 定义多个专家网络
self.experts = nn.ModuleList(
[ExpertNetwork(self.hidden_size, self.intermediate_size) for _ in range(self.expert_num)]
) # 定义路由器
self.router = Router(self.hidden_size, self.expert_num, self.top_k) def forward(self, x):
batch_size, seq_len, _ = x.size() # 获取输入的尺寸,(batch_size, seq_len, hidden_size)
token_num = batch_size * seq_len # 计算总的 token 数量,(batch_size * seq_len)
x_flat = x.view(token_num, self.hidden_size) # 展平输入,尺寸为 (batch_size * seq_len, hidden_size) # 通过路由器获取 top_k 权重和索引
topk_weight, topk_idx = self.router(x) # 初始化输出为零张量,尺寸为 (batch_size * seq_len, hidden_size)
output = torch.zeros_like(x_flat) # 对于每个 token,选择 top_k 个专家进行计算
for token_idx in range(token_num): # 遍历所有 token
for expert_idx in range(self.top_k): # 遍历每个 token 的 top_k 个专家
# 选择相应的专家,并计算其输出
expert = self.experts[topk_idx[token_idx][expert_idx]]
output[token_idx] += topk_weight[token_idx][expert_idx] * expert(x_flat[token_idx]) # 加权输出 # 将输出恢复为原始形状 (batch_size, seq_len, hidden_size)
output = output.view(batch_size, seq_len, self.hidden_size)
return output # 设置超参数
HIDDEN_SIZE = 4096
INTERMEDIATE_SIZE = 2048
EXPERT_NUM = 8
TOP_K = 2 # 输入张量,尺寸为 (batch_size, seq_len, hidden_size)
inputs = torch.randn((2, 11, 4096)) # 实例化 MOELayer
moe_layer = MOELayer(HIDDEN_SIZE, INTERMEDIATE_SIZE, EXPERT_NUM, TOP_K) # 计算输出
outputs = moe_layer(inputs) # 输出结果的尺寸
print(outputs.size()) # 输出尺寸: (batch_size, seq_len, hidden_size)

DeepSeek MoE

源代码请参考:https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py

此处只保留有用的部分

1. Transformer 结构

class Transformer(nn.Module):
def __init__(self, args):
self.embed = ...
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim)
self.head = ColumnParallelLinear(...) def forward(self, tokens):
h = self.embed(tokens)
for layer in self.layers:
h = layer(h, ...)
h = self.norm(h)
logits = self.head(h)
return logits
  • 结构

    • self.embed:嵌入层,转换输入 token。
    • self.layers:使用 ModuleList 存储多个 Block 层。
    • self.norm:最终的 RMSNorm 归一化层。
    • self.head:输出层,使用 ColumnParallelLinear 进行并行计算。
  • 前向传播
    • 先通过 embed 进行 token 转换。
    • 依次经过多个 Block 层。
    • 经过 RMSNorm 归一化。
    • 通过 head 进行最终计算,返回 logits

2. Block 结构

class Block(nn.Module):
def __init__(self, layer_id, args):
self.attn = MLA(args)
self.ffn = MoE(args)
self.attn_norm = RMSNorm(args.dim)
self.ffn_norm = RMSNorm(args.dim) def forward(self, x):
x = x + self.attn(self.attn_norm(x))
x = x + self.ffn(self.ffn_norm(x))
return x
  • 主要包含:

    • MLA(多头注意力层)。
    • MoE(专家混合机制)。
    • 两个 RMSNorm 层。
  • 前向传播
    • 归一化后,输入 attn 并进行残差连接。
    • 归一化后,输入 MoE 并进行残差连接。

3. MoE(专家混合)

class MoE(nn.Module):
def __init__(self, args):
self.n_routed_experts = ... # 路由专家个数
self.n_activated_experts = ... # 激活专家个数
self.gate = Gate(args) # 路由选择 Router
self.experts = nn.ModuleList([Expert(...) for i in range(self.n_routed_experts)])
self.shared_experts = MLP(...) # 共享专家网络列表 def forward(self, x):
weights, indices = self.gate(x) # 选择专家及权重
y = torch.zeros_like(x) # 初始化
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).to(x.device)
for i in range(self.n_routed_experts): # 轮询专家
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, None] # 加权累加
z = self.shared_experts(x)
return (y + z)
  • Gate 选择 top-k 个专家,并给出权重。
  • torch.bincount 统计每个专家的使用次数。
  • 依次对 n_routed_experts 轮询:
    • 找到被选中的 token。
    • 经过 Expert 计算并加权累加。
    • shared_experts 计算并加上结果。

4. Expert(专家网络)

class Expert(nn.Module):
def __init__(self, dim, inter_dim):
self.w1 = Linear(dim, inter_dim)
self.w2 = Linear(inter_dim, dim)
self.w3 = Linear(dim, inter_dim) def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
  • 结构:

    • w1, w3:分别将 dim → inter_dim
    • w2:将 inter_dim → dim
  • 前向计算:
    • SILU 激活后,与 w3(x) 相乘,再通过 w2

5. Gate(路由选择)

class Gate(nn.Module):
def __init__(self, args):
self.topk = args.n_activated_experts
self.n_groups = args.n_expert_groups
self.topk_groups = args.n_limited_groups
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) def forward(self, x):
scores = linear(x, self.weight) # 计算专家得分
scores = scores.sigmoid()
scores = scores.view(x.size(0), self.n_groups, -1)
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) # 选择最高的 2 个分数
indices = group_scores.topk(self.topk_groups, dim=-1)[1] # 选择 top 4 组
mask = torch.zeros_like(scores[..., 0]).scatter_(-1, indices, True)
scores = (scores * mask.unsqueeze(-1)).flatten(1)
indices = torch.topk(scores, self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
return weights, indices
  • 步骤

    • linear(x, weight) 计算 x 对所有专家的分数。
    • 经过 sigmoid 归一化分数。
    • 计算 top-k,选出最优的 4 组。
    • scatter_ 生成 mask 进行筛选。
    • topk 选取最终 k 个专家,并获取权重。

6. MLP(共享专家网络)

class MLP(nn.Module):
def __init__(self, dim, inter_dim):
self.w1 = ColumnParallelLinear(dim, inter_dim)
self.w2 = RowParallelLinear(inter_dim, dim)
self.w3 = ColumnParallelLinear(dim, inter_dim) def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
  • ColumnParallelLinear & RowParallelLinear 进行并行计算。
  • SILU 激活后,与 w3(x) 相乘,再通过 w2 计算。

DeepSeek MOE 代码实现的更多相关文章

  1. Javascript 语言精粹 代码片段合集

    Javascript 语言精粹 代码片段合集 标签:Douglas-Crockford Javascript 最佳实践 原文链接 更好的阅读体验 使用一个method 方法定义新方法 Function ...

  2. 与你相遇好幸运,The Moe Node.js Code Style Guide

    The Moe Node.js Code Style Guide  By 一个最萌的开发者 @2016.9.21 >>代码是人来阅读的,格式规范的代码是对编程人员最好的礼物 :) > ...

  3. Google HTML/CSS代码风格指南(中文版)

    原文链接:http://wncbl.cn/posts/c8e10815/ 看一下没什么印象,那就写一遍吧. 背景 本文档定义了HTML/CSS的编写格式和风格规则.它旨在提高合作和代码质量,并使其支持 ...

  4. [改善Java代码]asList方法产生的List对象不可更改

    上一个建议之处了asList方法在转换基本类型数组时候存在的问题,在看下asList方法返回的列表有何特殊的地方.看代码: import java.util.Arrays; import java.u ...

  5. [改善Java代码]避开基本类型数组转换列表陷阱

    开发中经常用到Arrays和Collections这两个工具类. 在数组和列表之间进行切换.非常方便.但是也会遇到一些问题. 看代码: import java.util.Arrays; import ...

  6. Google HTML/CSS/JS代码风格指南

    JS版本参见:http://www.zhangxinxu.com/wordpress/2012/07/google-html-css-javascript-style-guides/ HTML/CSS ...

  7. Eslint 能自动格式化代码,为什么还要用 Prettier?

    ESLint 与 Prettier 区别: ESLint:代码检测工具:可以检测出你代码中潜在的问题,比如使用了某个变量却忘记了定义: Prettier:代码格式化工具:作为代码格式化工具,能够统一你 ...

  8. 建议53:用状态模式美化代码,关于python-state工具包的理解

        在<编写高质量代码:改善python程序的91个建议>的建议53:用状态模式美化代码小节中,介绍了状态模式例如以下:就是当一个对象的内在状态改变时,同意改变其行为,但这个对象看起来 ...

  9. 使用Python进行多线程检查.moe三位剩余有效域名

    翻看博客看到一段不错的代码 虽然近期没有购买域名的需求 不过日后有购买域名的需求的话 稍作修改直接使用还是很方便的 import threading import requests import js ...

  10. 日期格式代码出现两次的错误 ORA-01810

    错误的原因是使用了两次MM . 一.Oracle中使用to_date()时格式化日期需要注意格式码 如:select to_date('2005-01-01 11:11:21','yyyy-MM-dd ...

随机推荐

  1. 漫画赏析:Linux 内核到底长啥样

    今天,我来为大家解读一幅来自 TurnOff.us 的漫画 "InSide The Linux Kernel" . TurnOff.us 是一个极客漫画网站,作者Daniel St ...

  2. WIN2012域用户添加和批量添加工具

    WIN2012域用户添加和批量添加,不需要进行复杂的进电脑管理去添加 直接在软件上就可单个用户添加,可批量添加,并把指定的用户加入组 可以自定义组织单位,使用起来比较简单方便. 链接:https:// ...

  3. Cursor 老改坏代码?六哥这几招超管用!

    大家好,我是六哥!最近不少小伙伴和我吐槽,在使用Cursor时,AI老是把代码改坏,让人头疼不已.我自己也用了大几十个小时Cursor,今天就来给大家分享一些实用小窍门,教大家如何巧妙规避这类问题. ...

  4. 当Kafka化身抽水马桶:论组件并发提升与系统可用性的量子纠缠关系

    <当Kafka化身抽水马桶:论组件并发提升与系统可用性的量子纠缠关系> 引言:一场OOM引发的血案 某个月黑风高的夜晚,监控系统突然发出刺耳的警报--我们的数据发现流水线集体扑街.事后复盘 ...

  5. Zotero设置

    1. 说明 Zotero 中文社区 | 百度网盘 使用 zotero 仅同步题录信息,使用其他云同步程序同步文献的附件,此处以坚果云为例进行演示,前期的坚果云同步设置参考文章:Zotero坚果云同步. ...

  6. Global.asax 转

    备忘: 项目中的Global.asax文件里通常包含这5个方法: Application_Start – web 应用程序最初启动时执行 Application_End – 应用程序关闭时运行 App ...

  7. C#之值类型与引用类型--out参数--ref参数-"=="、Equals和ReferenceEquals之间的区别

    一.值类型和引用类型 1.值类型 (1)值类型的大小是固定的 (2)值类型都派生自ValueType (3)值类型不能继承,只能实现接口 2.值类型:int.char.double.float.lon ...

  8. 什么是 Java 的 PLAB(Promotion Local Allocation Buffer)?

    什么是 Java 的 PLAB(Promotion Local Allocation Buffer)? PLAB 全称是 Promotion Local Allocation Buffer,是 Jav ...

  9. JS 对象(Object)和字符串(String)互转方法、JS遍历对象

    原文:https://www.cnblogs.com/fps2tao/p/8723164.html 1.对象(Object)和字符串(String)互转 利用原生JSON对象,将对象转为字符串 var ...

  10. 探秘Transformer系列之(30)--- 投机解码

    探秘Transformer系列之(30)--- 投机解码 目录 探秘Transformer系列之(30)--- 投机解码 0x00 概述 0x01 背景 1.1 问题 1.2 自回归解码 0x02 定 ...