本地显存池

数据结构

因为kv cache有MHA,MLA,DoubleSparse 等多种自定义类型,需要进行一步抽象将框架和cache类型做隔离, 所以有了2级内存池的设计. 一级保存和cache类型无关的数据(token位置),跟具体业务隔离,二级给出抽象类接口, 不同的cache类型按需继承实现interface, 就能通过配置来进行管理.

二级显存池

req_to_token_pool
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations.""" def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
) self.size = size #size对应的是server_args.max_running_requests
self.max_context_len = max_context_len #对应的是从模型配置里读出来的支持最大的上下文长度
self.device = device
with memory_saver_adapter.region():
self.req_to_token = torch.zeros( #2维, 第一维偏移代表是第几个req, 第二位偏移记录在req中token在二级池的索引
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size)) #1维, 用于记录哪些req被释放掉了, 在后续的请求可以复用
token_to_kv_pool

功能: 将 token 的 KV Cache索引映射到其 KV Cache数据, 实际的实现中这个依然是2大类组合形成的, 包括PagedTokenToKVPoolAllocatorKVCache接口类和其对应的子类 (只看了page_size>1的实现)

PagedTokenToKVPoolAllocator主要负责kv分页后的页表管理, 存储的数据是free_pages, 假设page_size=4. 初始化状态如下:

+---------+---------+---------+---------+
| Page 1 | Page 2 | Page 3 | Page 4 |
| 4~7 | 8~11 | 12~15 | 16~19 |
+---------+---------+---------+---------+
free_pages: [1, 2, 3, 4] alloc:
分配8个token后:
free_pages: [3, 4]
分配到的索引: [4,5,6,7,8,9,10,11]
free:
传入要回收的 token 索引(如 [4,5,6,7,8,9,10,11]),会通过idx / page_size转换为页索引 [1,2],并加回free_pages,变为 [1,2,3,4]

KVCache子类, 以MLA为例, kv_buffer为layer_num个torch.Tensor, 存储了k_buffercache_kv_buffercache_v, 每个tensor的dim分别表示: (最大token数 + page_size, head_dim(MLA里是1),head维度)

在MLA里, head维度是 LoRA相关的KV维度(低秩适配部分) + QK经过RoPE后的维度

第一维要加page_size的原因是: 如果不加, 某些操作(如buffer[start_idx : start_idx + page_size])会越界, 加上page_size后可以避免这些越界情况的判断, 简化逻辑

        with memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]

HostKVCache

Hierarchical Caching(分层缓存)机制, 支持一部分kvcache通过offload方式放到内存里. 由于会影响推理速度暂没用到. 待有需求的时候再细看.

显存alloc/free

alloc

从二级显存池申请空间逻辑都在forwardBatch.prepare_for_extend/prepare_for_decode里面, 以extend为例, 分为几步:

  1. alloc_req_slots: 根据batch_size, 从req_to_token_pool中申请bs个free_req对应的token索引.
  2. 遍历reqs, 把刚才申请到的req_pool_indices[i]填到对应req的req_pool_idx成员里, 使其能够一一对应
    1. get_last_loc: 获取每个请求前缀最后一个token在req_pool_indices[i]中的索引
    2. token_to_kv_pool_allocator.alloc_extend: 计算逻辑在alloc_extend_kernel里,
      • 第一步: 因为之前算出了最后一个token在显存中的偏移, 根据这个偏移和page_size能拿到最后token所在页和还有多少剩余空间, 先把这页没满的空间填满.
      • 第二步: 从free_page里拿出新页继续填充
      • 第三步: 分配最后一页, 如果填不满, 就把剩下的token填到这一页的前几个里面
free

在Cache复用中决定这些cache何时被回收, 通过调用token_to_kv_pool_allocator.freereq_to_token_pool.free处理. 核心逻辑:

free_page_indices = torch.unique(free_index // self.page_size)  #会把所有 token 索引转换为页号(同一页的 token 都会变成同一个页号)。
self.free_pages = torch.cat((free_page_indices, self.free_pages)) #塞回free_pages

只要页内有任意 token 被回收,这一整页就会被回收

KVCache读写

在attn_backend的forward函数中

P阶段把kv写到kv_buffer里, 即根据cache_loc到KVCache子类中对应的偏移中将torch.Tensor的值复制过去

D阶段根据layer_id读出对应layer的cache.

        if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) # Call the wrapped function
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)

Cache复用

RadixCache

数据结构

基于基数树RadixTree数据结构实现的Cache, 其实就是压缩版的前缀树. 一看图就能弄清楚:

查询: 一直DFS到没有公共前缀为止.

插入:以root为起点遍历, 对当前节点做前缀匹配, 长度>0就进入子树否则进入兄弟节点. 一直DFS到没有公共前缀为止, 把不相同的str插入到新叶节点上.

插入与驱逐

前缀匹配代码解析:

    def _match_prefix_helper(self, node: TreeNode, key: List):           #传入的node就是root
node.last_access_time = time.time() child_key = self.get_child_key_fn(key) value = []
while len(key) > 0 and child_key in node.children.keys(): #非递归版dfs, 非page时当key中的第一个不在node的child中退出.(即完全不匹配)
child = node.children[child_key]
child.last_access_time = time.time()
prefix_len = self.key_match_fn(child.key, key) #树节点和当前token_id list进行前缀匹配
if prefix_len < len(child.key): #部分匹配
new_node = self._split_node(child.key, child, prefix_len) #分裂不匹配的那部分, 挂到当前节点下面作为child
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value) #完全匹配, 进入子节点继续遍历, 把已经匹配成功的节点加到结果里
node = child
key = key[prefix_len:] #去掉已经匹配过的前缀 if len(key):
child_key = self.get_child_key_fn(key) return value, node

驱逐使用了引用计数(lock_ref)用于记录当前cache有没有在使用, 当叶子的引用计数为0时可以驱逐释放. 参考函数dec_lock_ref, 注意这里lock_ref在减这个node时, 会把他的所有父节点路径全都减1. 驱逐代码解析:

    def evict(self, num_tokens: int):
if self.disable:
return
leaves = self._collect_leaves() #通过BFS方式获取到树上的所有节点
heapq.heapify(leaves) #把树list转成堆, 通过TreeNode中的__lt__进行比较排序, 其实就是比last_access_time
num_evicted = 0
while num_evicted < num_tokens and len(leaves): #循环pop heap
x = heapq.heappop(leaves)
if x == self.root_node:
break
if x.lock_ref > 0: #引用计数>0的叶子跳过
continue
self.token_to_kv_pool_allocator.free(x.value) #释放ref_count=0的kvcache
num_evicted += len(x.value)
self._delete_leaf(x) #在树上删掉这个叶节点
if len(x.parent.children) == 0: #如果这个叶节点的父节点, 被删除这个child后也变成了叶节点, 把他push进heap
heapq.heappush(leaves, x.parent)

cache_request(req)

  1. req_to_token_pool.req_to_token获取kv_indices
  2. 把当前这条请求更新到Radix Cache (insert())
  3. finished: 释放这条请求的KV Cache, unfinished: 更新这条请求在req_to_token_pool中的偏移
  4. finished: 把这条请求的last_node引用计数-1, 标识可以evict, unfinished: 如果开了page, 把req里的last_node 引用计数-1, 把页对齐的last_node 引用计数+1

ChunkCache

对于过长的token请求, 如果在一个batch内处理除了会及其占用显存资源导致显存超限外, 还有可能因为单请求无法并行处理严重影响其他请求的TTFT, 所以有了chunked_prefill这个功能, 主要作用就是将过长请求切分成多个chunk分别进行处理

sglang的实现在同一时期只能有一个请求在chunk, 而chunk请求在处理时和其他请求的不同点在于: 当前chunk在进行attention计算时, 需要依赖此前的chunk计算的kvcache. 如下图:

因此就有ChunkCache这么个东西, 专门用来处理图中绿色部分的kvcache.

cache_unfinished_req: 把没处理完的req当前chunk的显存池的偏移量取出来, 塞到prefix_indices里用于下一个chunkReq的构建.

cache_finished_req: 把当前chunk和之前的prefix chunk kvcache直接全部free掉

PD分离,KVCache通信

相关代码在python/sglang/srt/disaggregation中. 包括5种class:

  1. TransferBackend: 枚举类, 用于记录server_args指定的kvTransfer后端, 把不同的KVCache通信后端封装到相同的接口内方便框架兼容(mooncake/nixl)
  2. BaseKVManager: 抽象类接口, 每个后端自己实现. 管理KV通信线程, 以及P和D的连接关系. 绑定ZMQ用于D节点和P节点的TCP通信. P节点有两个ZMQ监听线程(Bootstrap和transfer), D节点只有一个decode线程.
  3. BaseKVBootstrapServer: 抽象类接口, 每个后端自己实现. 用于P节点接收 D节点alloc完成后发送的Notify请求. 在这个类中起一个新线程, 通过event_loop监听一个端口接收请求.
  4. KVSender: 抽象类接口, 用于P节点的请求发送(send接口)和状态查询(poll)
  5. KVReceiver: 抽象类接口, 用于D节点的请求接收(recv接口)和状态查询(poll)

注意在代码中会看到除了KV本身还有一类叫aux data, 是 auxiliary 的缩写,表示“辅助数据”, 比如位置编码、mask、attention map、LoRA 相关参数等.

通信步骤

建立连接

  1. P节点注册自身信息到MooncakeKVBootstrapServer

    • 每个P 节点启动时,会通过 HTTP PUT 请求,把自己的 rank_ip、rank_port等信息注册到 bootstrap server。
    • bootstrap server 会把这些信息按 DP/TP 分组存入prefill_port_table,以便后续 decode 查询
  2. D节点查询需要连接的P节点

    • D节点初始化时,会根据自己的 engine_rank、dp_group 等参数,通过 HTTP GET 请求向 bootstrap server 查询自己应该连接的 P 节点信息 _get_bootstrap_info_from_server
    • 查询参数为 engine_rank 和 target_dp_group,bootstrap server 返回对应 P节点的 IP 和端口
  3. D节点拿到P节点的 IP/端口后,通过 ZeroMQ 建立 socket 连接, 然后把自己的KVCache相关信息通过 ZeroMQ 发送给P节点(KVReceiver里init方法),完成后续的数据同步和传输

发送数据

def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = ( #从显存池里拿到token KV Value的起始地址
self.token_to_kv_pool.get_contiguous_buf_infos()
) kv_args.kv_data_ptrs = kv_data_ptrs #把这个显存地址传到BaseKVManager里用于初始化
kv_args.kv_data_lens = kv_data_lens #从而在send的时候只传kv_indices, TranferEngine也能知道从哪里取出kv value
kv_args.kv_item_lens = kv_item_lens
#...
  1. P节点从bootstrap队列拿出已经变成WaitForInput状态的请求, 即到达下图的Notify收到的状态. 初始化req.disagg_kv_sender
  2. 完成forward后, 通过send_kv_chunk->disagg_kv_sender.send()函数把kv_indices copy到内存添加到通信队列里.
  3. tranfer异步线程从队列里取KVCache索引, send_kvcache, group_concurrent_contiguous先把连续的内存块进行切分. 根据每个layer从线程池里取一个独立的线程进行并发通信, 最后等所有layer完成通信. (engine.transfer_sync 调用的是mooncake的内部方法, 需要后续细看mooncake代码, 记个TODO)

接收数据

之前在notify的时候把D节点要接收的显存地址也传过去了, 通过TransferEngine直接完成从显存到显存的copy.

P节点在通信完成后调用sync_status_to_decode_endpoint, 通过ZMQ告知D节点完成传输

D节点的start_decode_thread在收到传输完成通知后, 更新status. 这样就完成了KVCache的整体传输

参考: https://zhuanlan.zhihu.com/p/31160183506

Sglang kvcache code walkThrough: https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/sglang/kvcache-code-walk-through/readme-CN.md

MLA细节解析: https://zhuanlan.zhihu.com/p/19585986234

RadixAttenion解析: https://zhuanlan.zhihu.com/p/693556044

SgLang代码细读-3. Cache的更多相关文章

  1. Spark 代码走读之 Cache

    Spark是基于内存的计算模型,但是当compute chain非常长或者某个计算代价非常大时,能将某些计算的结果进行缓存就显得很方便了.Spark提供了两种缓存的方法 Cache 和 checkPo ...

  2. MyBatis源码分析(4)—— Cache构建以及应用

    @(MyBatis)[Cache] MyBatis源码分析--Cache构建以及应用 SqlSession使用缓存流程 如果开启了二级缓存,而Executor会使用CachingExecutor来装饰 ...

  3. MyBatis源码分析(3)—— Cache接口以及实现

    @(MyBatis)[Cache] MyBatis源码分析--Cache接口以及实现 Cache接口 MyBatis中的Cache以SPI实现,给需要集成其它Cache或者自定义Cache提供了接口. ...

  4. Cache的使用

    公共方法Add 将指定项添加到 Cache 对象,该对象具有依赖项.过期和优先级策略以及一个委托(可用于在从 Cache 移除插入项时通知应用程序). Equals(从 Object 继承) 已重载. ...

  5. jQuery的XX如何实现?——3.data与cache机制

    往期回顾: jQuery的XX如何实现?——1.框架 jQuery的XX如何实现?——2.show与链式调用 -------------------------- 源码链接:内附实例代码 jQuery ...

  6. Asp.net 服务器Application,Session,Cookie,ViewState和Cache区别

    2.8 Context 的使用Context 对象包含与当前页面相关的信息,提供对整个上下文的访问,包括请求.响应.以及上文中的Session 和Application 等信息.可以使用此对象在网页之 ...

  7. php header()函数设置页面Cache缓存

    header()函数在php的使用很大,下面我来介绍利用它实现页面缓存的一些方法,但使用header前必须注意,在它之前不能任何输出,包括空格. 手册上,我们对于cache都是写着如何设置,以便让代码 ...

  8. ASP.NET缓存 Cache之数据缓存

    添加 Cache[Key]=object  or Cache.Insert 移除 Cache.Remove(key) 1.将值直接写入Cache 代码如下 复制代码 HttpContext.Curre ...

  9. Linux内核启动代码分析二之开发板相关驱动程序加载分析

    Linux内核启动代码分析二之开发板相关驱动程序加载分析 1 从linux开始启动的函数start_kernel开始分析,该函数位于linux-2.6.22/init/main.c  start_ke ...

  10. Cache基础知识OR1200在ICache一个简短的引论

    以下摘录<步骤吓得核心--软-core处理器的室内设计与分析>一本书 12.1 Cache基本知识 12.1.1 Cache的作用 处理器的设计者通常会声称其设计的处理器一秒钟能做多少次乘 ...

随机推荐

  1. Flink学习(十六) ProcessFunctionAPI(底层API)

    我们之前学习的转换算子是无法访问时间的时间戳信息和水位线信息的.而这些在一些应用场景下,极为重要,例如MapFunction这样的map转换算子就无法访问时间戳或者当前事件的事件时间. 基于此,Dat ...

  2. LangChain大模型框架& Dify低代码 AI 开发平台

    目录 1. LangChain介绍 1.1 架构 1.2 概念 1.3 术语 1.4 LangChain实战 2. LLM 应用开发平台dify 2.1 dify安装 2.2 设置知识库 3. dif ...

  3. 【数值计算方法】数值积分&微分-python实现

    目录 数值积分 1. 引言 2. 几个常用积分公式及其复合公式 2.1 求积公式 2.2 代数精度 2.3 复合积分 2.4 常用积分公式的python实现 3. 变步长方法与外推加速技术 4. 牛顿 ...

  4. PPT 技巧&网站

    样机生成网站 https://mockuphone.com/device?type=computer CTROL+L 演示生成荧光笔 3.如何内嵌字体 文件->选项->保存->勾选潜 ...

  5. Bash Shell 30min 过家家

    带你捅破窗户纸 - 备注 : @博客园 : 1. 为什么不支持 pdf 上传了呀 2. 网站分类不好用 3. 排版OA工具升级下, 例如 markdown 写出来好丑. 尝试升级下呢 ? 后记: 学如 ...

  6. coco数据集详解

    什么是COCO数据集? MS COCO的全称是Microsoft Common Objects in Context,起源于微软于2014年出资标注的Microsoft COCO数据集,与ImageN ...

  7. go declared and not used

    Go语言在代码规范中定义未使用的变量会报"declared and not used"错误 package main import "fmt" func mai ...

  8. Oracle 强行断开用户连接的方法

    1.查找目标用户的当前进程 select sid,serial# from v$session where username='test'; 2.使用上述语句会返回一个进程列表,每行有两个数字,用数字 ...

  9. 学习 Docker 如何查看镜像信息?

    学习 Docker 如何查看镜像信息? 一.images 命令列出镜像 通过使用如下两个命令,列出本机已有的镜像: docker images 或: docker image ls 如下图所示: 对上 ...

  10. 阿里云平台OSS对象存储

    OSS即"OpenStorageService",概念上没啥新意,就是本地存储搬到阿里云平台上了,单个存储对象大小可以达到5G,看了下阿里的OSS教程java版本, 使用原生js和 ...