前置知识:

PyTorch 基础函数操作整理

1. topk 操作

功能: torch.topk 用于返回输入张量中指定维度上的前 k 个最大元素及其对应的索引。

示例代码:

import torch

x = torch.tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]) values, indices = torch.topk(x, k=2, dim=1) print(values)
print(indices)

输出:

values: tensor([[4, 3],
[9, 5],
[6, 5]]) indices: tensor([[2, 0],
[2, 1],
[1, 2]])

2. scatter_ 操作

功能: torch.scatter_ 是一个原地操作函数,用于根据指定索引 index,将 src 中的元素分散到目标张量的指定位置。

示例代码:

import torch

indices = torch.tensor([[0, 2],
[1, 2],
[1, 2]]) result = torch.zeros([3, 3]).scatter_(1, indices, True)
print(result)

输出:

tensor([[1., 0., 1.],
[0., 1., 1.],
[0., 1., 1.]])

3. unsqueeze 操作

功能: 在指定维度上插入一个大小为 1 的新维度,从而改变张量的形状。

示例代码:

import torch

x = torch.tensor([[1, 2], [3, 4]])
print("原始张量形状:", x.shape) y = torch.unsqueeze(x, dim=0)
print("在第 0 维插入新维度后的张量形状:", y.shape) z = torch.unsqueeze(x, dim=1)
print("在第 1 维插入新维度后的张量形状:", z.shape) w = torch.unsqueeze(x, dim=2)
print("在第 2 维插入新维度后的张量形状:", w.shape)

输出:

原始张量形状: torch.Size([2, 2])
在第 0 维插入新维度后的张量形状: torch.Size([1, 2, 2])
在第 1 维插入新维度后的张量形状: torch.Size([2, 1, 2])
在第 2 维插入新维度后的张量形状: torch.Size([2, 2, 1])

4. gather 操作

功能: torch.gather 根据给定的索引 index 从输入张量中收集元素,构建一个新的张量。(gatherscatter_ 互为反操作)

示例代码:

import torch

input_tensor = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]) index_tensor = torch.tensor([[2, 0],
[1, 2],
[0, 1]]) output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

输出:

tensor([[30, 10],
[50, 60],
[70, 80]])

5. bincount 操作

功能: 统计非负整数张量中每个值出现的次数。

示例代码:

import torch

input_tensor = torch.tensor([1, 1, 2, 2, 10])
output = torch.bincount(input_tensor)
print(output)

输出:

tensor([0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 1])

6. where 操作

功能: 根据给定的条件对张量元素进行选择性操作,类似于 Python 中的三元运算符,返回满足条件的元素索引。

示例代码:

import torch

input_tensor = torch.tensor([[1, 2], [3, 4]])
indices = torch.where(input_tensor == 2)
print(indices)

输出:

(tensor([0]), tensor([1]))

原始MOE 代码实现

import torch
from torch import nn # ExpertNetwork 类:定义每个专家的网络
class ExpertNetwork(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super().__init__()
self.hidden_size = hidden_size # 输入和输出的特征维度
self.intermediate_size = intermediate_size # 中间层的大小 # 定义两个线性层
self.linear1 = nn.Linear(hidden_size, intermediate_size) # (batch_size, hidden_size) -> (batch_size, intermediate_size)
self.linear2 = nn.Linear(intermediate_size, hidden_size) # (batch_size, intermediate_size) -> (batch_size, hidden_size) def forward(self, x):
x = self.linear1(x) # 经过第一个线性层
x = nn.functional.relu(x) # ReLU 激活函数
output = self.linear2(x) # 经过第二个线性层
return output # 返回输出,尺寸为 (batch_size, hidden_size) # Router 类:用于选择每个输入数据的专家
class Router(nn.Module):
def __init__(self, hidden_size, expert_num, top_k):
super().__init__()
self.router = nn.Linear(hidden_size, expert_num) # (batch_size, hidden_size) -> (batch_size, expert_num)
self.top_k = top_k # 每次选择 top_k 个专家
self.hidden_size = hidden_size # 输入的特征维度 def forward(self, x):
x = x.view(-1, self.hidden_size) # 展平输入,尺寸变为 (batch_size * seq_len, hidden_size)
x = self.router(x) # 通过 router 得到每个专家的选择权重,尺寸为 (batch_size * seq_len, expert_num)
x = nn.functional.softmax(x, dim=-1) # 使用 softmax 转换为概率分布,尺寸为 (batch_size * seq_len, expert_num)
topk_weight, topk_idx = torch.topk(x, k=self.top_k, dim=-1, sorted=False) # 选择 top_k 个专家,尺寸为 (batch_size * seq_len, top_k) # 权重归一化,使得它们的和为 1,尺寸为 (batch_size * seq_len, top_k)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) return topk_weight, topk_idx # 返回选择的 top_k 权重和专家索引 # MOELayer 类:实现混合专家层
class MOELayer(nn.Module):
def __init__(self, hidden_size, intermediate_size, expert_num, top_k):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.expert_num = expert_num
self.top_k = top_k # 定义多个专家网络
self.experts = nn.ModuleList(
[ExpertNetwork(self.hidden_size, self.intermediate_size) for _ in range(self.expert_num)]
) # 定义路由器
self.router = Router(self.hidden_size, self.expert_num, self.top_k) def forward(self, x):
batch_size, seq_len, _ = x.size() # 获取输入的尺寸,(batch_size, seq_len, hidden_size)
token_num = batch_size * seq_len # 计算总的 token 数量,(batch_size * seq_len)
x_flat = x.view(token_num, self.hidden_size) # 展平输入,尺寸为 (batch_size * seq_len, hidden_size) # 通过路由器获取 top_k 权重和索引
topk_weight, topk_idx = self.router(x) # 初始化输出为零张量,尺寸为 (batch_size * seq_len, hidden_size)
output = torch.zeros_like(x_flat) # 对于每个 token,选择 top_k 个专家进行计算
for token_idx in range(token_num): # 遍历所有 token
for expert_idx in range(self.top_k): # 遍历每个 token 的 top_k 个专家
# 选择相应的专家,并计算其输出
expert = self.experts[topk_idx[token_idx][expert_idx]]
output[token_idx] += topk_weight[token_idx][expert_idx] * expert(x_flat[token_idx]) # 加权输出 # 将输出恢复为原始形状 (batch_size, seq_len, hidden_size)
output = output.view(batch_size, seq_len, self.hidden_size)
return output # 设置超参数
HIDDEN_SIZE = 4096
INTERMEDIATE_SIZE = 2048
EXPERT_NUM = 8
TOP_K = 2 # 输入张量,尺寸为 (batch_size, seq_len, hidden_size)
inputs = torch.randn((2, 11, 4096)) # 实例化 MOELayer
moe_layer = MOELayer(HIDDEN_SIZE, INTERMEDIATE_SIZE, EXPERT_NUM, TOP_K) # 计算输出
outputs = moe_layer(inputs) # 输出结果的尺寸
print(outputs.size()) # 输出尺寸: (batch_size, seq_len, hidden_size)

DeepSeek MoE

源代码请参考:https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py

此处只保留有用的部分

1. Transformer 结构

class Transformer(nn.Module):
def __init__(self, args):
self.embed = ...
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim)
self.head = ColumnParallelLinear(...) def forward(self, tokens):
h = self.embed(tokens)
for layer in self.layers:
h = layer(h, ...)
h = self.norm(h)
logits = self.head(h)
return logits
  • 结构

    • self.embed:嵌入层,转换输入 token。
    • self.layers:使用 ModuleList 存储多个 Block 层。
    • self.norm:最终的 RMSNorm 归一化层。
    • self.head:输出层,使用 ColumnParallelLinear 进行并行计算。
  • 前向传播
    • 先通过 embed 进行 token 转换。
    • 依次经过多个 Block 层。
    • 经过 RMSNorm 归一化。
    • 通过 head 进行最终计算,返回 logits

2. Block 结构

class Block(nn.Module):
def __init__(self, layer_id, args):
self.attn = MLA(args)
self.ffn = MoE(args)
self.attn_norm = RMSNorm(args.dim)
self.ffn_norm = RMSNorm(args.dim) def forward(self, x):
x = x + self.attn(self.attn_norm(x))
x = x + self.ffn(self.ffn_norm(x))
return x
  • 主要包含:

    • MLA(多头注意力层)。
    • MoE(专家混合机制)。
    • 两个 RMSNorm 层。
  • 前向传播
    • 归一化后,输入 attn 并进行残差连接。
    • 归一化后,输入 MoE 并进行残差连接。

3. MoE(专家混合)

class MoE(nn.Module):
def __init__(self, args):
self.n_routed_experts = ... # 路由专家个数
self.n_activated_experts = ... # 激活专家个数
self.gate = Gate(args) # 路由选择 Router
self.experts = nn.ModuleList([Expert(...) for i in range(self.n_routed_experts)])
self.shared_experts = MLP(...) # 共享专家网络列表 def forward(self, x):
weights, indices = self.gate(x) # 选择专家及权重
y = torch.zeros_like(x) # 初始化
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).to(x.device)
for i in range(self.n_routed_experts): # 轮询专家
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, None] # 加权累加
z = self.shared_experts(x)
return (y + z)
  • Gate 选择 top-k 个专家,并给出权重。
  • torch.bincount 统计每个专家的使用次数。
  • 依次对 n_routed_experts 轮询:
    • 找到被选中的 token。
    • 经过 Expert 计算并加权累加。
    • shared_experts 计算并加上结果。

4. Expert(专家网络)

class Expert(nn.Module):
def __init__(self, dim, inter_dim):
self.w1 = Linear(dim, inter_dim)
self.w2 = Linear(inter_dim, dim)
self.w3 = Linear(dim, inter_dim) def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
  • 结构:

    • w1, w3:分别将 dim → inter_dim
    • w2:将 inter_dim → dim
  • 前向计算:
    • SILU 激活后,与 w3(x) 相乘,再通过 w2

5. Gate(路由选择)

class Gate(nn.Module):
def __init__(self, args):
self.topk = args.n_activated_experts
self.n_groups = args.n_expert_groups
self.topk_groups = args.n_limited_groups
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) def forward(self, x):
scores = linear(x, self.weight) # 计算专家得分
scores = scores.sigmoid()
scores = scores.view(x.size(0), self.n_groups, -1)
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) # 选择最高的 2 个分数
indices = group_scores.topk(self.topk_groups, dim=-1)[1] # 选择 top 4 组
mask = torch.zeros_like(scores[..., 0]).scatter_(-1, indices, True)
scores = (scores * mask.unsqueeze(-1)).flatten(1)
indices = torch.topk(scores, self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
return weights, indices
  • 步骤

    • linear(x, weight) 计算 x 对所有专家的分数。
    • 经过 sigmoid 归一化分数。
    • 计算 top-k,选出最优的 4 组。
    • scatter_ 生成 mask 进行筛选。
    • topk 选取最终 k 个专家,并获取权重。

6. MLP(共享专家网络)

class MLP(nn.Module):
def __init__(self, dim, inter_dim):
self.w1 = ColumnParallelLinear(dim, inter_dim)
self.w2 = RowParallelLinear(inter_dim, dim)
self.w3 = ColumnParallelLinear(dim, inter_dim) def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
  • ColumnParallelLinear & RowParallelLinear 进行并行计算。
  • SILU 激活后,与 w3(x) 相乘,再通过 w2 计算。

DeepSeek MOE 代码实现的更多相关文章

  1. Javascript 语言精粹 代码片段合集

    Javascript 语言精粹 代码片段合集 标签:Douglas-Crockford Javascript 最佳实践 原文链接 更好的阅读体验 使用一个method 方法定义新方法 Function ...

  2. 与你相遇好幸运,The Moe Node.js Code Style Guide

    The Moe Node.js Code Style Guide  By 一个最萌的开发者 @2016.9.21 >>代码是人来阅读的,格式规范的代码是对编程人员最好的礼物 :) > ...

  3. Google HTML/CSS代码风格指南(中文版)

    原文链接:http://wncbl.cn/posts/c8e10815/ 看一下没什么印象,那就写一遍吧. 背景 本文档定义了HTML/CSS的编写格式和风格规则.它旨在提高合作和代码质量,并使其支持 ...

  4. [改善Java代码]asList方法产生的List对象不可更改

    上一个建议之处了asList方法在转换基本类型数组时候存在的问题,在看下asList方法返回的列表有何特殊的地方.看代码: import java.util.Arrays; import java.u ...

  5. [改善Java代码]避开基本类型数组转换列表陷阱

    开发中经常用到Arrays和Collections这两个工具类. 在数组和列表之间进行切换.非常方便.但是也会遇到一些问题. 看代码: import java.util.Arrays; import ...

  6. Google HTML/CSS/JS代码风格指南

    JS版本参见:http://www.zhangxinxu.com/wordpress/2012/07/google-html-css-javascript-style-guides/ HTML/CSS ...

  7. Eslint 能自动格式化代码,为什么还要用 Prettier?

    ESLint 与 Prettier 区别: ESLint:代码检测工具:可以检测出你代码中潜在的问题,比如使用了某个变量却忘记了定义: Prettier:代码格式化工具:作为代码格式化工具,能够统一你 ...

  8. 建议53:用状态模式美化代码,关于python-state工具包的理解

        在<编写高质量代码:改善python程序的91个建议>的建议53:用状态模式美化代码小节中,介绍了状态模式例如以下:就是当一个对象的内在状态改变时,同意改变其行为,但这个对象看起来 ...

  9. 使用Python进行多线程检查.moe三位剩余有效域名

    翻看博客看到一段不错的代码 虽然近期没有购买域名的需求 不过日后有购买域名的需求的话 稍作修改直接使用还是很方便的 import threading import requests import js ...

  10. 日期格式代码出现两次的错误 ORA-01810

    错误的原因是使用了两次MM . 一.Oracle中使用to_date()时格式化日期需要注意格式码 如:select to_date('2005-01-01 11:11:21','yyyy-MM-dd ...

随机推荐

  1. linux测试url的访问速度

    在Linux中,你可以使用curl命令来测试URL的访问速度.curl是一个强大的命令行工具,可以用于文件传输和测试网络连接. 以下是使用curl测试URL访问速度的步骤: 打开终端或命令行界面. 输 ...

  2. nginx 配置go服务反向代理

    nginx 配置 详细请看Nginx 极简教程 server { listen 80; server_name localhost; #charset koi8-r; # nginx访问活动日志 ac ...

  3. selenium自动化测试入门

    Selenium是一个基于浏览器的自动化测试工具,它提供了一种跨平台.跨浏览器的端到端的web自动化解决方案. Selenium是用于自动化控制浏览器做各种操作,打开网页,点击按钮,输入表单等等,可以 ...

  4. Cython二进制逆向系列(三)运算符

    Cython二进制逆向系列(三)运算符 在开始前,先给出本文用到的py源代码 def test1(x, y): # 数学运算符 a = x + y b = x - y c = x * y d = x ...

  5. RocketMq安装踩坑:docker0网桥冲突

    前言 最近项目用到了RocketMq,需要在Cento7系统上搭建一套集群环境用于测试.整个的环境搭建过程中,我遇到了一个比较初级的问题:启动RocketMq的broker失败.   问题经过 首先我 ...

  6. 多态的转型和案例--java进阶day02

    1.多态的转型 1.向上转型 我们之前学的多态创建对象,使用的都是向上转型,父类引用指向子类(赋值方式则是从子到父),f拿到子类的地址,就能访问子类的堆内存 2.向下转型 和向上转型相反,子类引用指向 ...

  7. 什么是MIME类型-基础知识补全

    MIME类型(Multipurpose Internet Mail Extensions,多用途互联网邮件扩展)是一种标准,用于标识互联网上传输的文件类型.它最初是为电子邮件设计的,后来被广泛应用于W ...

  8. 如何确定dbgrid选择的是记录而不是分组

    with cxgrdbtblvwGrid1DBTableView1.Controller do if FocusedRecord is TcxGridDataRow then begin i := c ...

  9. AI可解释性 II | Saliency Maps-based 归因方法(Attribution)论文导读(持续更新)

    AI可解释性 II | Saliency Maps-based 归因方法(Attribution)论文导读(持续更新) 导言 本文作为AI可解释性系列的第二部分,旨在以汉语整理并阅读归因方法(Attr ...

  10. vscode安装离线插件autopep8

    商店 从上面的链接进去,在visual studio code一栏开始搜索,我要的是autopep8,所以搜索得到的是这样的: 点进去后,是这个界面,然后我是离线下载,要的是拓展包,所以是下面操作 下 ...