Context Cache的使用几乎已经是行业共识,目标是优化大模型首Token的推理延时,在多轮对话,超长System Prompt,超长结构化JSON和Few-shot等应用场景,是不可或缺的。这一章我们主要从原理、一些论文提出的优化项和VLLM开源项目入手,分析下context Cache的实现和适合场景。

重温KV Cache

Context Cache的本质其实是KV Cache在多次请求之间的复用,所以我们先重温下KV Cache的原理。

以下是Transformer的基础模型结构由多层layer串联构成,而KV缓存的是每一层用于Self-Attention计算的历史序列的Key和Value取值,所以缓存向量维度是batch_size * seq_len * num_head * head_dim,(以下可视化来自LLM Visualizationtransformer-explainer )

只所以对self-attention的KV进行缓存,因为在Transformer的众多计算单元中只有self-attention是上下文依赖的,也就是在计算第k个token的输出时,需要使用第k个token(Query)和前面K-1个token的(K&V)进行内机计算,使得self-attention的计算复杂度随序列长度平方增长。而如果对历史序列中的KV进行缓存后,每次生成新token只需要计算当前token,这样时间复杂度就可以降为线性(O(n)),显著降低计算量。

而其他Linear, FFN层的计算都是针对dim层,每个token的计算独立,和历史token无关和序列长度无关,因此也没有缓存的必要。

当然KV Cache也不全是优点,虽然能显著降低推理延时,但是会带来较大的显存占用,占用显存和序列长度、模型层数成正比。

那现在众多大模型API厂商都支持的Context Cache能力和KV cache有哪些区别呢?

首先从cache共享上,因为KV cache只用于单一序列的推理过程中,因此没有任何共享问题,本次推理完成即释放,而context cache会在多次推理请求之间共享,因此对于如何命中Cache,管理Cache, 提升Cache的存储和使用效率,就需要更多的考虑。

其次从使用时机上,KV cache是用于首Token之后的增量预测(auto-regressive Phrase),而Context Cache是用于首Token之前的存量计算(Prompt Phrase)。在context cache出现之前这两个阶段其实有明确的划分,计算prompt的阶段需要对全部序列进行attention计算属于数据计算密集的任务,而解码阶段因为kV Cache的存在更多是存储密集型任务。因此Context Cache是面向首Token延时的优化方案。

KV Cache只是Context Cache的基础使用形式,下面我们会分别就Contxt Cache的几个核心问题包括命中率低,等讨论一些优化方案

并行效率更高: Chunk Attention

微软的Chunk Attention通过对KV Cache分段存储,实现self-attention的并发计算,提升长序列self-attention计算效率,经评估自注意力内核的速度可以提高3.2-4.8倍。下面我们结合源码简单说下,源码chunk attention是C语言写的,这里简单转换成python pseudo

  1. Prefix Aware KV Cache

传统的KV缓存以密集tensor形式存储,大小是batch * head * seq length * dim,当多个序列共享相同的前缀(prompt)时这些kv缓存就是相同的,因此可以进行共享存储,共享存储的大小则是head * chunk length * dim,这里论文选择使用前缀树进行存储。用前缀不用其他存储方式,主要是为了保证位置编码的可用性。这部分就不细说了,就是Trie树,child节点存储的是对应chunk 序列的kv tensor。而这里的分块存储,为后面推理阶段并行self-attention的计算提供了可能性。

\[\begin{align*} \text{Phase 1:} & \forall C_i \in \text{共享块} \\ W_i &= Q_{C_i}K_{C_i}^\top / \sqrt{d} \\ M_i &= \text{rowmax}(W_i) \\ E_i &= \exp(W_i - M_i) \\ N_i &= \text{rowsum}(E_i) \\ O_i &= E_i V_{C_i} \\ \text{Phase 2:} & \forall \text{序列} s \\ M^{\text{total}} &= \text{max}(M_1, M_2, ..., M_k) \\ N^{\text{total}} &= \sum_{i=1}^k \exp(M_i - M^{\text{total}})N_i \\ O^{\text{total}} &= \sum_{i=1}^k \frac{\exp(M_i - M^{\text{total}})O_i}{N^{\text{total}}} \end{align*}
\]
  1. Two-Phase Partition(TPP)

两阶段的推理算法则是如何更高效的使用以上共享前缀的KV Cache。第一步是共享前缀的部分,读取前缀树种存储的chunk KV Cache,进行chunk粒度的并行self-attention计算并存储多个partial attention。

然后剩余未命中前缀的部分,按序列计算剩余attention,之后通过attention合并把把一个序列的多段attention进行合并得到最终的attention.

以下是C转换成python的伪代码

# 伪代码体现两阶段推理核心逻辑
class PartialAttention:
def __init__(self, chunk_size=64):
self.chunk_size = chunk_size # 分块大小
self.scale = 1 / sqrt(d_head) # 缩放因子 def chunk_first_phase(self, Q_shared, K_shared, V_shared):
"""
分块优先阶段计算(对应论文Algorithm 1)
Q_shared: [n_shared_seqs, n_heads, d_head]
K_shared: [n_shared_seqs, chunk_size, d_head]
V_shared: [n_shared_seqs, chunk_size, d_head]
"""
# 步骤1: 计算分块注意力得分
attn_scores = torch.einsum('bhd,bsd->bhs', Q_shared, K_shared) * self.scale # 步骤2: 在线计算局部softmax
max_values = attn_scores.max(dim=-1, keepdim=True).values # [b, h, 1]
exp_scores = torch.exp(attn_scores - max_values) # 数值稳定 # 步骤3: 保存中间结果(对应论文公式1)
partial_attn = {
'exp_scores': exp_scores, # [b, h, s]
'max_values': max_values, # [b, h, 1]
'sum_exp': exp_scores.sum(-1), # [b, h]
'partial_v': torch.einsum('bhs,bsd->bhd', exp_scores, V_shared)
}
return partial_attn def sequence_first_phase(self, Q_private, partial_results):
"""
序列优先阶段(对应论文Algorithm 2)
Q_private: [1, n_heads, d_head] (单个序列的查询)
"""
# 步骤1: 初始化累计变量
global_max = -float('inf')
global_sum = 0
merged_output = 0 # 步骤2: 合并所有分块结果(对应论文公式2)
for partial in partial_results:
local_max = partial['max_values']
local_exp = partial['exp_scores']
local_sum = partial['sum_exp']
local_v = partial['partial_v'] # 调整指数差值
adjust_factor = torch.exp(local_max - global_max) # 更新全局统计量
new_global_max = torch.max(global_max, local_max)
new_global_sum = global_sum * adjust_factor + local_sum * torch.exp(local_max - new_global_max) # 更新输出
merged_output = merged_output * adjust_factor + local_v * torch.exp(local_max - new_global_max) # 保存新全局值
global_max = new_global_max
global_sum = new_global_sum # 步骤3: 最终归一化
return merged_output / global_sum.unsqueeze(-1)

而之所以attention可以通过局部计算再合并后效果和原始attention一致,不过是因为指数计算的变换,我们抛开论文里面不太好理解的pseudo code,看一个具体的case

  1. 步骤1:分块计算中间结果
  • 块1(s₁, s₂):

    • 计算局部最大值:m₁ = max(s₁, s₂)
    • 调整后指数:e^{s₁ - m₁}, e^
    • 局部分母:sum₁ = e^{s₁ - m₁} + e^
  • 块2(s₃, s₄):
    • 计算局部最大值:m₂ = max(s₃, s₄)
    • 调整后指数:e^{s₃ - m₂}, e^
    • 局部分母:sum₂ = e^{s₃ - m₂} + e^
  1. 步骤2:合并中间结果
  • 全局最大值:M = max(m₁, m₂)

    • 调整因子:
    • 块1调整因子:α = e^
    • 块2调整因子:β = e^
  • 合并的分母
total_sum = α * sum₁ + β * sum₂
= e^{m₁ - M}*(e^{s₁ - m₁} + e^{s₂ - m₁}) + e^{m₂ - M}*(e^{s₃ - m₂} + e^{s₄ - m₂})
= e^{s₁ - M} + e^{s₂ - M} + e^{s₃ - M} + e^{s₄ - M}
  • 合并后的attention计算结果
total_sum = α * sum₁ + β * sum₂
= e^{m₁ - M}*(e^{s₁ - m₁} + e^{s₂ - m₁}) + e^{m₂ - M}*(e^{s₃ - m₂} + e^{s₄ - m₂})
= e^{s₁ - M} + e^{s₂ - M} + e^{s₃ - M} + e^{s₄ - M}

空间利用更高:Radix Attention

同样是使用树形存储,SGLang使用了Radix Tree结合LRU的存储策略(论文还有更多提高cache命中率之类的策略这里不予赘述)。所以其实核心就在Radix Tree和Prefix Tree的对比了。

像前缀树每个节点只能是单个Token,而Radix支持可变长的Token列表,因此可以节省大量节点指针是空间效率更高的存储模式,这里我们直接举个例子

存储以下键:"test", "team", "slow", "slowly"。

  • Trie Tree
root
├─ t
│ ├─ e
│ │ ├─ s → t
│ │ └─ a → m
└─ s
└─ l → o → w
└─ l → y
  • Radix Tree
root
├─ te
│ ├─ "st" → [leaf: "test"]
│ └─ "am" → [leaf: "team"]
└─ "slow"
└─ "ly" → [leaf: "slowly"]

整体上对比下ChunkAttention是固定长度的分块,树结构是静态的,对于类似超长systemt prompt、multiple few-shot、超长contxt的场景更加合适,因为能通过多chunk实现self-attention的并发计算。而Radix Attention基于Radix Tree,支持动态可变长的前缀,自然更适合类似多轮对话,multi-step思维链推理等动态前缀场景。

显式管理Cache:PML

  • PROMPT CACHE: MODULAR ATTENTION REUSE FOR LOW-LATENCY INFERENCE

Prompt Cache则是提出了标记语言PML,支持用户显示标记输入内容中,哪些内容是需要被cache的。这里包含3个主要机制

  1. 通过XML语言标记的cache模块

如上图所示,每个module本身就是一个存储块,被XML包裹的内容会分别进行KV cache存储。下次在推理时会根据XML的tag进行对应Key,Value Tensor Cache的获取。

  1. 不连续的位置编码ID

但这就和前面的chunk Attention,Radix Attention有了显著的差异,就是PML支持非连续位置的缓存。这也引出了一个问题,只要不是从头开始的Prompt Cache,都会存在缓存cache的位置编码和使用该cache所在位置不同的问题。

论文给出的解释是他们经过实际测试,发现模型对非连续位置编码有包容性,只要cache的的模块内容部相对位置编码正确即可,即便下一次Cache出现的位置和它被缓存的绝对位置不同,也不会影响推理效果。但个人认为这依赖用户PML定义的极端合理化,也就是每个XML包裹的内容都是一个完整的语义内容,且彼此之间相对独立。例如多个few-shot使用这个模式就是可以的,一个system prompt里面<requirement>和<output foramt>用这种方式应该也可以。但要是放在多轮对话场景,或者multi-step思考推理场景,我感觉会出现问题。

  1. 缓存模块的拼接与新内容的计算

当一个用户提示由多个模块(有些是缓存的,有些是新的)组成时,Prompt Cache 会:

  • 从缓存中检索已缓存模块的KV状态
  • 对于提示中未被缓存的新文本段(比如插入到两个缓存模块之间的文本,或者在缓存模块之后的新文本),系统会根据它们在最终完整提示中的实际位置,为这些新文本段计算新的KV状态和相应的位置ID。例如,如果一段新文本插入在起始位置为0的模块A和起始位置为110的模块C之间,且模块A长度为50,那么这段新文本的位置ID将从50开始。
  • 最后,系统将这些来自缓存的、带有原始(可能不从0开始)位置ID的KV状态,与新计算的、带有新分配位置ID的KV状态,按照它们在完整提示中的正确顺序拼接起来 。

VLLM源码分析

最后我们来直接看一个生产级别的源码实现。以下是VLLM中KV Cache的整个调用链路。整个调用链路如下

  1. 初始化阶段
  • LLMEngine 初始化时调用 _initialize_kv_caches()
  • Worker 通过 determine_num_available_blocks() 计算可用的 GPU 和 CPU 块数
  • ModelRunner 初始化 KV 缓存配置
  • BlockPool 创建并管理 KVCacheBlock 对象
  1. 推理阶段
  • 客户端发送推理请求
  • LLMEngine 接收请求并转发给 Worker
  • Worker 调用 ModelRunner 执行模型推理
  • ModelRunner 通过 KVCacheManager 获取或分配 KV 缓存块
  • KVCacheManager 从 BlockPool 获取新的缓存块
  • Attention 层使用 KV 缓存进行注意力计算

核心的KV Cache存储在KVCacheBlock类中,使用了和前面Chunk Attention相同的块存储机制,这里通过hash所有前缀token+当前块token得到block_hash

以下为cache存储、生成cache的哈希ID部分的核心代码。

@dataclass
class KVCacheBlock:
block_id: int
ref_cnt: int = 0
_block_hash: Optional[BlockHashType] = None
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None def hash_block_tokens(
hash_function: Callable,
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType:
"""计算块的哈希值,用于前缀缓存"""
if not parent_block_hash:
parent_block_hash = NONE_HASH curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType(
hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys)

VLLM还使用了Radix Attention的LRU最少使用驱逐策略,通过ref_cnt的引用计数追踪不使用的cache。同时当计数归0,cache不会直接被释放,而是会被添加到evictor,当内存压力大时再进行释放,保证更大程度的cache复用,避免频繁地内存分配和释放,相关驱逐代码如下

def _decr_refcount_cached_block(self, block: Block) -> None:
block_id = block.block_id
assert block_id is not None refcount = self._refcounter.decr(block_id)
if refcount > 0:
block.block_id = None
return
else:
assert refcount == 0 # 将块添加到 evictor 而不是直接释放
self.evictor.add(block_id, block.content_hash,
block.num_tokens_total,
self._block_tracker[block_id].last_accessed) class LRUEvictor(Evictor):
def evict(self) -> Tuple[int, int]:
while self.priority_queue:
last_accessed, _, block_id, content_hash = heapq.heappop(
self.priority_queue)
if (block_id in self.free_table and
self.free_table[block_id].last_accessed == last_accessed):
self.free_table.pop(block_id)
return block_id, content_hash

当用户的Query进来,会按照固定chunk大小进行切分,然后从左到右进行去寻找cache,实现最长前缀命中。不过这里命中都是整个Block命中,也就是不会像Radix Tree一样支持Token级别的命中。命中前缀部分的代码如下

def find_longest_cache_hit(self, block_hashes: list[BlockHashType], max_length: int) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
max_num_blocks = max_length // self.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
return computed_blocks

闭源 context Cache文档

  1. https://ai.google.dev/gemini-api/docs/caching?hl=zh-cn&lang=python
  2. https://platform.openai.com/docs/guides/prompt-caching
  3. https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching

解密prompt系列54.Context Cache代码示例和原理分析的更多相关文章

  1. 解密Prompt系列6. lora指令微调扣细节-请冷静,1个小时真不够~

    上一章介绍了如何基于APE+SELF自动化构建指令微调样本.这一章咱就把微调跑起来,主要介绍以Lora为首的低参数微调原理,环境配置,微调代码,以及大模型训练中显存和耗时优化的相关技术细节 标题这样写 ...

  2. RocketMQ延迟消息的代码实战及原理分析

    RocketMQ简介 RocketMQ是一款开源的分布式消息系统,基于高可用分布式集群技术,提供低延时的.高可靠.万亿级容量.灵活可伸缩的消息发布与订阅服务. 它前身是MetaQ,是阿里基于Kafka ...

  3. 解密prompt系列5. APE+SELF=自动化指令集构建代码实现

    上一章我们介绍了不同的指令微调方案, 这一章我们介绍如何降低指令数据集的人工标注成本!这样每个人都可以构建自己的专属指令集, 哈哈当然我也在造数据集进行时~ 介绍两种方案SELF Instruct和A ...

  4. 解密Prompt系列2. 冻结Prompt微调LM: T5 & PET & LM-BFF

    这一章我们介绍固定prompt微调LM的相关模型,他们的特点都是针对不同的下游任务设计不同的prompt模板,在微调过程中固定模板对预训练模型进行微调.以下按时间顺序介绍,支持任意NLP任务的T5,针 ...

  5. 解密Prompt系列4. 升级Instruction Tuning:Flan/T0/InstructGPT/TKInstruct

    这一章我们聊聊指令微调,指令微调和前3章介绍的prompt有什么关系呢?哈哈只要你细品,你就会发现大家对prompt和instruction的定义存在些出入,部分认为instruction是promp ...

  6. 解密Prompt系列3. 冻结LM微调Prompt: Prefix-Tuning & Prompt-Tuning & P-Tuning

    这一章我们介绍在下游任务微调中固定LM参数,只微调Prompt的相关模型.这类模型的优势很直观就是微调的参数量小,能大幅降低LLM的微调参数量,是轻量级的微调替代品.和前两章微调LM和全部冻结的pro ...

  7. Java并发包源码学习系列:阻塞队列BlockingQueue及实现原理分析

    目录 本篇要点 什么是阻塞队列 阻塞队列提供的方法 阻塞队列的七种实现 TransferQueue和BlockingQueue的区别 1.ArrayBlockingQueue 2.LinkedBloc ...

  8. 解密SVM系列(二):SVM的理论基础(转载)

    解密SVM系列(二):SVM的理论基础     原文博主讲解地太好了  收藏下 解密SVM系列(三):SMO算法原理与实战求解 支持向量机通俗导论(理解SVM的三层境界) 上节我们探讨了关于拉格朗日乘 ...

  9. 实战SpringCloud响应式微服务系列教程(第十章)响应式RESTful服务完整代码示例

    本文为实战SpringCloud响应式微服务系列教程第十章,本章给出响应式RESTful服务完整代码示例.建议没有之前基础的童鞋,先看之前的章节,章节目录放在文末. 1.搭建响应式RESTful服务. ...

  10. 微软代码示例:ASP.NET 2.0 三层架构应用程序教程系列

    本文转自:http://www.codeusing.com/hi/uephee.wen/resource/view/170.aspx 资源分类:微软代码示例               更新日期:20 ...

随机推荐

  1. 2024CSP-S邮寄

    前言 去年被沉重打击到了,不过从此以后心态就好很多了,不会因为什么考试动不动就崩溃了. 考前 一直在认真复习,也停了课,甚至差点错过运动会.从国庆开始听了几天课,消化课件,然后考试.考试的稳定性不高, ...

  2. 实验一:Tableau数据可视化入门

    实验目的: 1.熟悉TableauDesktop使用方法. 2.通过Tableau软件来实现Excel中数据的基本可视化. 实验原理: Tableau是新一代商业智能工具软件,它将数据连接.运算.分析 ...

  3. 使用df命令

    1.使用df命令,查看整体的磁盘使用情况 df命令是用来查看硬盘的挂载点,以及对应的硬盘容量信息.包括硬盘的总大小,已经使用的大小,剩余大小.以及使用的空间占有的百分比等. 最常用的命令格式就是: 1 ...

  4. pagehelper的失效问题

    pagehelper是常用的分页插件,代码中常用到,使用简便且对代码侵入性较小,很多人都喜欢使用.不过有时会遇到分页失败问题,输出结果没有分页,日志输出sql语句没有分页关键字及分页参数,目测是pag ...

  5. 生命游戏Delphi实现

    生命游戏,康威生命游戏(Game of Life),剑桥大学约翰·何顿·康威设计的计算机程序. 生命游戏没有游戏玩家各方之间的竞争,也谈不上输赢,可以把它归类为仿真游戏.事实上,也是因为它模拟和显示的 ...

  6. 3. RabbitMQ 的(Hello World) 和 RabbitMQ 的(Work Queues)工作队列

    3. RabbitMQ 的(Hello World) 和 RabbitMQ 的(Work Queues)工作队列 @ 目录 3. RabbitMQ 的(Hello World) 和 RabbitMQ ...

  7. KGDB调试Linux内核与模块

    前言 内核 5.10 版本 openEuler 使用 yum install 下载了源码,并且通过两个 VMware 虚拟机进行调试 ubuntu 直接使用 git 拉取了https://kernel ...

  8. Jmeter压测数据库总结

    1.添加MySQL驱动 在测试计划里面添加,操作如下: MySQL驱动自行下载 2.添加线程组 选中测试计划,右键添加->线程(用户)->线程组,页面如下: 3.添加JDBC连接配置 选中 ...

  9. 神级辅助工具,解决GPT-SoVITS配音发音纠正和逐句优化

    即使地表最强AI配音也无法自动识别360应配音成三百六十还是三六零,在长文配音中很难一次满意,总会因为个别几句配音不理想而毁掉整个配音成果. 在GPT-SoVITS配音中,自动把长文章拆分成段落或长句 ...

  10. mybatis的模糊查询的实现方式

    一.比较灵活 1:xml的配置 <select id="selectUserByUsername1" parameterType="string" res ...