阅读之前你可以前往 RWKV wiki 了解一些关于 RWKV 的基本知识,不过他们的 wiki 似乎没有对模型架构的详细介绍,于是便有了这篇文章。

RWKV-7 的核心:动态状态演化机制

RWKV-V7 的动态状态演化机制可以通俗理解为 “在线学习上下文关系的动态记忆更新” 。它的核心思想是:通过实时计算和更新一个内部状态(state)来动态捕捉上下文中 key 和 value 的关联关系,并利用这个状态处理当前输入的 query(在 RWKV 中是 r)以生成输出

1. 状态(state)的本质

在 RWKV 中,state 是一个三维张量(B, H, N, N),其中:

  • B:批量大小(batch size)
  • H:注意力头数量(head count)
  • N:每个头的维度(head size)

state 的作用是 维护一个动态的“知识库” ,记录历史输入的 key(k)和 value(v)之间的关系。它类似于传统 RNN 的隐藏状态,但更复杂,因为它显式建模了 key-value 的交互。

2. 状态更新公式与代码实现

官方公式:

St​=St−1​⋅diag(wt​)+(St−1​at⊤​bt​+vt⊤​kt​)

在代码中(RWKV7_OP 函数),这个公式被拆解为:

state = state * w + state @ a @ b + v @ k

该公式对应 RWKV-V7 的 动态状态演化机制 ,其核心思想是通过 在线学习 维护一个低秩状态矩阵 S_t,以捕捉上下文中的 key-value 关系。公式分为三部分:

  1. 权重衰减项(遗忘旧信息)

     
    state * w
    • 作用 :控制历史信息的遗忘速率。
    • 数学形式 S_t = S_{t-1} ⋅ diag(w_t),其中 w_t 是负值(通过 exp(-exp(...)) 生成)。
    • 思想 :类似 RNN 的遗忘门,w 越小(更负),遗忘越多,确保模型专注于当前上下文。
  2. 学习率调整项(动态修正知识)

     
    state @ aa @ bb
    • 作用 :根据当前输入调整已有知识的权重。
    • 数学形式 S_t = S_{t-1} ⋅ a_t^⊤ b_t,其中 ab 是动态学习率参数(a = -k, b = k ⋅ η)。
    • 思想 :模拟梯度下降中的学习率乘以梯度方向,使状态更新更适应当前输入的特征分布。
  3. 新信息注入项(更新知识库)

     
    vv @ kk
    • 作用 :将当前 token 的 key-value 对(k_t, v_t)外积加入状态。
    • 数学形式 S_t += v_t^⊤ ⋅ k_t
    • 思想 :直接注入新信息,确保模型能实时捕捉上下文中的新关联(如实体关系、语义依赖)

GPT 与 RNN 模式

在进一步了解之前需要知道的是,RWKV 分为两种模式,一种是 GPT 模式,一种是 RNN 模式,区别在于:

GPT 模式 rwkv_v7_demo.py

  • 计算方式 :采用 Transformer 的 GPT 模式 ,一次性处理整个输入序列(类似标准 Transformer 的前向传播)。
  • 状态管理 :无需显式维护隐藏状态(state),直接通过注意力机制处理上下文。
  • 适用场景 :适合 预填充(prefill) 阶段(即处理初始提示词),但自回归生成时效率较低。
  • 性能特点
    • 预填充阶段可以并行计算所有 token 的表示。
    • 自回归生成时需重新计算所有 token 的表示(无状态缓存),导致速度较慢。

RNN 模式 rwkv_v7_demo_rnn.py

  • 计算方式 :采用 RNN 模式 ,逐个 token 处理,显式维护隐藏状态(state)。
  • 状态管理 :通过 state 变量保存每个时间步的中间状态(如 att_kvffn_x_prev 等),避免重复计算。
  • 适用场景 :适合 自回归生成 (逐 token 生成),效率更高。
  • 性能特点
    • 预填充阶段效率低(需逐 token 处理)。
    • 生成阶段只需计算当前 token 的表示,速度显著优于 GPT 模式。

此外,官方还提供了结合两种模式的 fast 模式: rwkv_v7_demo_fast.py

GPT 模式

Time mix(RWKV_Tmix_x070)

Time mix 是整个 RWKV-7 的核心,也是最复杂的部分。

RWKV_Tmix_x070 是 RWKV 模型中实现 Time Mixing 的核心模块,用于处理序列数据中的时序依赖关系。它是 RWKV-V7 架构中负责模拟注意力机制和动态状态更新的组件,类似于 Transformer 中的自注意力模块,但采用了更高效的线性复杂度设计。

@MyFunction
def forward(self, x, v_first):
# 获取输入张量的维度:批量大小 B、序列长度 T、特征维度 C
B, T, C = x.size() # 获取注意力头数量 H 和每个头的维度 N(固定为 64)
H = self.n_head
N = self.head_size # 使用 time_shift 操作获取前一个时间步的信息,并与当前输入做差值
xx = self.time_shift(x) - x # 得到时间差分项,用于生成不同方向的输入 # 根据时间差分项和可学习参数,生成不同方向的输入
xr = x + xx * self.x_r # receptance 输入方向
xw = x + xx * self.x_w # decay weight 输入方向
xk = x + xx * self.x_k # key 输入方向
xv = x + xx * self.x_v # value 输入方向
xa = x + xx * self.x_a # learning rate 输入方向
xg = x + xx * self.x_g # gate 输入方向 # 通过线性层生成 r(receptance),类似于传统 attention 中的 query
r = self.receptance(xr) # 构造 w(权重衰减项):控制 state 的遗忘速率
# 使用 tanh 和 softplus 确保 w 值在 (-inf, -0.5) 范围内
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # 通过线性层生成 key 向量
k = self.key(xk) # 通过线性层生成 value 向量
v = self.value(xv) # 如果是第一层,则保存当前 v 作为 v_first(初始 value)
if self.layer_id == 0:
v_first = v
else:
# 否则使用门控机制对 v 进行残差连接,保留一部分初始信息
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # 生成动态学习率 a,控制 state 更新的学习率(类似在线梯度下降中的学习率)
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # 生成门控参数 g,控制最终输出的激活强度
g = torch.sigmoid(xg @ self.g1) @ self.g2 # 对 key 进行缩放和 L2 归一化,增强数值稳定性
kk = k * self.k_k
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # 调整 key 的尺度,结合学习率 a 和可学习参数 self.k_a
k = k * (1 + (a - 1) * self.k_a) # 调用 RWKV7_OP 函数进行状态更新和 attention 输出计算
x = RWKV7_OP(r, w, k, v, -kk, kk * a) # 使用 GroupNorm 对输出进行归一化
x = self.ln_x(x.view(B * T, C)).view(B, T, C) # 添加一个额外的残差连接,融合 r, k, v 的乘积项
x = x + ((r.view(B, T, H, -1) *
k.view(B, T, H, -1) *
self.r_k).sum(dim=-1, keepdim=True) *
v.view(B, T, H, -1)).view(B, T, C) # 最终输出经过门控 g 控制,并通过线性层输出
x = self.output(x * g) # 返回输出和 v_first,供下一层使用
return x, v_first

模块基于以下两个核心机制:

一. 时间差分与方向调整

首先,Time Shift 通过 time_shift(x) - x 得到当前 token 与前一 token 的差异,然后结合可学习参数生成多个“方向”的输入(如 xr, xw, xk, xv, xa, xg),分别用于构建不同的注意力相关参数。

具体来说,Time Shift 的作用是将输入张量在时间维度上向前移动一个单位,这样在处理当前时间步 xt​ 时,模型可以获取到到上一个时间步 xt−1​ 的信息。

举个例子:

# 假设原始输入是:
x = [
[x_0_0, x_0_1, ..., x_0_C], # 时间步 t=0
[x_1_0, x_1_1, ..., x_1_C], # 时间步 t=1
...
[x_T_0, x_T_1, ..., x_T_C] # 时间步 t=T
] # Time Shift 后得到的是:
x_shifted = [
[0, 0, ..., 0], # 第一个位置用 0 填充(表示没有前一步)
[x_0_0, x_0_1, ..., x_0_C], # 第二个位置是 t=0 的值(相当于延迟了一个时间步)
...
[x_{T-1}_0, ..., x_{T-1}_C] # 最后一个位置是 t=T-1 的值
]

接着利用 Time Shift 的结果 xx 我们就可以计算不同的注意力相关参数。

xr = x + xx * self.x_r
xw = x + xx * self.x_w
xk = x + xx * self.x_k
xv = x + xx * self.x_v
xa = x + xx * self.x_a
xg = x + xx * self.x_g

这些新变量 xr, xw, xk, xv, xa, xg 分别用于生成注意力机制中的不同角色:

变量
对应角色
作用
xr
receptance (r)
类似 query,在状态更新中起调节作用
xw
decayweight (w)
控制遗忘速率
xk
key (k)
用于构建上下文关联
xv
value (v)
上下文信息载体
xa
learning rate (a)
动态调整学习率
xg
gate (g)
非线性激活门控

二.线性变换生成核心参数

r = self.receptance(xr)
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
k = self.key(xk)
v = self.value(xv)
  • receptance(xr) : 生成 r(类似传统注意力中的 query,但用于控制当前状态如何与新的输入进行交互,决定哪些信息会被“接收”并用于输出计算)。
  • w : 权重衰减参数(控制遗忘速率),类似于 RNN 中的遗忘门。
    • 公式 w = -softplus(-(...)) - 0.5,确保 w 为负值。
    • 作用 :模拟指数衰减,遗忘旧信息。
  • key(xk) : 生成 k(key 向量,决定新信息如何与当前 state 进行交互)。
  • value(xv) : 生成 v(value 向量,表示输入的“值”,包含实际的内容信息)。

三. 残差连接机制

if self.layer_id == 0:
v_first = v # store the v of the first layer
else:
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual

一、什么是残差连接?

简单定义:

残差连接(Residual Connection) 是指将某一层的输入直接加到该层的输出上。
数学表达为:

output=input+F(input)

其中 F(input) 是网络中某个函数或模块对输入的变换。


二、为什么需要残差连接?

残差连接最早出现在 ResNet 中,用于解决深度神经网络中的两个关键问题:

1. 梯度消失/爆炸
  • 在深层网络中,反向传播时梯度可能会指数级衰减或爆炸。
  • 残差连接可以缓解这个问题,使得梯度更容易流动。
2. 信息丢失 / 语义漂移
  • 随着层数加深,原始输入的信息可能被逐渐稀释。
  • 残差连接可以让模型在处理过程中始终保留原始输入的信息。
3. 训练稳定性提升

实验表明,加入残差连接后,深层网络更容易优化和收敛。

在长序列建模中,随着序列变长,中间 token 可能会逐渐丢失最初的上下文信息,通过 v_first 引入初始 token 的 value,可以让后续层仍然保有最初的信息。(v_first - v) 是一个残差项,使得当前值 v 可以根据初始值进行调整,这种设计有助于缓解梯度消失问题,提升训练稳定性。

主要公式可以拆解为三个主要部分:

v = v +
(v_first - v) *
torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)

1. v: 当前时间步的 value 向量

  • 是通过 self.value(xv) 得到的当前 token 的 value 表示。
  • 它表示模型对当前 token 的理解或编码。

2. v_first: 第一个时间步的 value

  • 来自第一个层的第一个 token 的 value(在 layer_id == 0 时被赋值)。
  • 目的是在整个序列处理过程中保留初始信息(类似 RNN 中的“初始状态”或 Transformer 中的 [CLS] token)。

3. (v_first - v): 初始值与当前值之间的差异

  • 这是一个残差项,表示当前 token 与初始 token 的语义差距。
  • 如果当前 token 与初始 token 差异较大,则这个值会更大,可能意味着需要更多的初始信息来修正当前值。

4. torch.sigmoid(...):门控函数

  • 将线性变换的结果映射到 [0, 1] 区间。
  • 控制 (v_first - v) 对当前 v 的影响程度。
  • 类似于 LSTM 中的遗忘门或 GRU 中的更新门。

5. self.v0 + (xv @ self.v1) @ self.v2:可学习的门控参数

  • xv: 经过 time_shift 处理后的输入,用于生成 value 的原始输入。
  • self.v1, self.v2: 两个低秩矩阵,构成 LoRA 结构(Low-Rank Adaptation),用于减少参数量。
  • 整个表达式是一个门控信号,根据当前输入动态决定是否引入初始信息。

四. 动态学习率参数 a 和门控参数 g

a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2

a 是动态学习率参数 ,用于控制状态(state)更新的强度,使模型能够根据上下文自适应地调整记忆更新的幅度;

g 是门控参数 ,用于调节最终输出的激活程度,通过可学习的权重决定哪些信息应该被强调或抑制,从而增强模型的表达能力和非线性建模能力。两者共同提升了模型对长序列的建模效率和准确性。

五. 调用 RWKV7_OP 状态更新函数

x = RWKV7_OP(r, w, k, v, -kk, kk*a)

六. 组合输出

x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
x = self.output(x * g)
  • self.ln_x : 对 x 进行分组归一化(GroupNorm),提升训练稳定性。
  • 残差连接 :将 r, k, v 的乘积加入输出。
    • 公式 x += (r * k * r_k) · v,增强非线性表达。
  • 门控输出 x = self.output(x * g),通过门控参数 g 调制输出。

通道混合 ChannelMix (RWKV_CMix_x070

class RWKV_CMix_x070(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) # 时间维度的位移操作
with torch.no_grad():
self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd)) # 可学习的权重参数
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) # 第一阶段线性变换
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) # 第二阶段线性变换 @MyFunction
def forward(self, x):
xx = self.time_shift(x) - x # 计算时间差分
k = x + xx * self.x_k # 调整输入特征
k = torch.relu(self.key(k)) ** 2 # 非线性激活
return self.value(k) # 输出变换后的特征

关键组件解析

(1) 时间位移(Time Shift)

  • 使用 nn.ZeroPad2d((0, 0, 1, -1)) 对输入张量进行时间维度的位移(前移一位),生成 time_shift(x)
  • 计算差值 xx = time_shift(x) - x,捕捉序列中相邻位置的局部变化特征。

(2) 输入调整(Input Gating)

  • 引入可学习参数 self.x_k,通过 k = x + xx * self.x_k 调整输入特征,类似门控机制,控制历史信息的保留比例。

(3) 非线性变换

  • 第一阶段 :通过 key 线性层将输入从 n_embd 维度映射到更高维度 dim_ffn(通常为 4 * n_embd)。
  • 激活函数 :使用 ReLU 激活后平方(torch.relu(...) ** 2),增强非线性表达能力。
  • 第二阶段 :通过 value 线性层将特征映射回 n_embd 维度,形成最终输出。

与传统 FFN 的对比

特性
RWKV_CMix_x070
传统 Transformer FFN
激活函数
ReLU(x)^2
GELUSwish
线性层
两层(key+value
两层(linear + activation
时间依赖性
显式引入时间差分(time_shift
参数量
较低(依赖dim_ffn
较高(通常为

RWKV-7 架构理解的更多相关文章

  1. TAF /tars必修课(一):整体架构理解

    来自零点智能社区 一.前言 TAF,一个后台逻辑层的高性能RPC框架,目前支持C++,Java, node 三种语言, 往后可能会考虑提供更多主流语言的支持如 go等,自定义协议JCE,同时也支持HT ...

  2. 沉淀再出发:Spring的架构理解

    沉淀再出发:Spring的架构理解 一.前言 在Spring之前使用的EJB框架太庞大和重量级了,开发成本很高,由此spring应运而生.关于Spring,学过java的人基本上都会慢慢接触到,并且在 ...

  3. ARM CORTEX-M3 内核架构理解归纳

    ARM CORTEX-M3 内核架构理解归纳 来源:网络 个人觉得对CM3架构归纳的非常不错,因此转载 基于<ARM-CORTEX M3 权威指南>做学习总结: 在我看来,Cotex-M3 ...

  4. RESTful 架构理解

    REST中的关键词: 1.资源 2.资源的表述 3.状态转移 资源: "资源",可以是一段文本.一张图片.一首歌曲.一种操作.你可以用一个URI(统一资源定位符)指向它,每种资源对 ...

  5. SOA面向服务的架构理解

    Ø  单一应用架构 ·当网站流量很小时,只需一个应用,将所有功能都部署在一起,以减少部署节点和成本. Ø  垂直应用架构 当访问量逐渐增大,单一应用增加机器带来的加速度越来越小,将应用拆成互不相干的几 ...

  6. 【K8s】Kubernetes架构理解

    抽空学习了一下Kubernetes,感觉和大数据领域内集群的资源管理.任务调度等有异曲同工之处,简单总结一下备忘. [概念] Kubernetes是一个工业级的容器编排平台,单词有点长,常用K8s代称 ...

  7. mybatis源码解析之架构理解

    mybatis是一个非常优秀的开源orm框架,在大型的互联网公司,基本上都会用到,而像程序员的圣地-阿里虽然用的是自己开发的一套框架,但其核心思想也无外乎这些,因此,去一些大型互联网公司面试的时候,总 ...

  8. MYSQL架构理解

    目录 一.MYSQL架构 1. 架构图 2.分层实现 3.查询组件 二.并发控制 三. 事务 四.引擎 摘自 通过对MYSQL重要的几个属性的理解,建立一个基本的MYSQL的知识框架 一.MYSQL架 ...

  9. 从数据分析系统总架构理解BI工具的价值所在

    ​现如今,应用商业智能BI工具的企业是越来越多了,由此也可见企业对数据分析的重视.因此,掌握一定的数据分析知识对"打工人"来说是非常重要的.现在小编就来跟大家一起来了解一下商业智能 ...

  10. Comware 架构理解

    网络操作系统 首先什么是网络操作系统: 一种说法是:运行在路由器,网络交换机,防火墙上的特别的操作系统 另一种说法是:部署在局域网或者私有网络,允许网络中的多个计算机共享文件和打印机,因为现在的单机系 ...

随机推荐

  1. 【由技及道】镜像圣殿建造指南:Harbor私有仓库的量子封装艺术【人工智障AI2077的开发日志009】

    摘要:当容器镜像需要同时存在于8个平行宇宙时,就像在量子计算机里管理72个维度的镜像分身.本文记录一个未来AI如何通过Harbor搭建量子镜像圣殿,让容器分发成为跨越时空的瞬间传送. 动机:镜像管理的 ...

  2. TSP问题的不可近似性

    \(\S\) 结论 TSP问题:n阶带权无向完全图中,找权值最小的哈密顿回路(无向图中遍历所有顶点的回路) 优化问题,记最优解为OPT 对于一般的n顶点TSP问题(非Metric),任意 多项式时间内 ...

  3. 线上测试木舟物联网平台之如何通过HTTP网络组件接入设备

    一.概述 木舟 (Kayak) 是什么? 木舟(Kayak)是基于.NET6.0软件环境下的surging微服务引擎进行开发的, 平台包含了微服务和物联网平台.支持异步和响应式编程开发,功能包含了物模 ...

  4. MUX-VLAN

    MUX VLAN(Multiplex VLAN)是一种高级的VLAN技术,它通过在交换机上实现二层流量隔离和灵活的网络资源控制,提供了一种更为细致的网络管理方式. 一.基本概念 MUX VLAN分为主 ...

  5. CAD通过XCLIP命令插入DWG参照裁剪图形,引用局部图像效果(CAD裁剪任意区域)

    CAD通过XCLIP命令插入DWG参照裁剪图形,实现引用局部图像效果,裁剪任意区域! 1.首先在你要引用局部图的文件内,插入参照! 2. 然后再空白区域指定插入点,输入比例因子,默认输入1,然后缩小视 ...

  6. Unity开发Hololens2—环境配置

    博客地址:https://www.cnblogs.com/zylyehuo/ 配置如下: win11 专业版 Unity2018.4.26f1 / 2019.4.11f1 Hololens2 VS20 ...

  7. Tomcat之Jconsole监控

    JConsole的图形用户界面是一个符合Java管理扩展(JMX)规范的监测工具,JConsole使用Java虚拟机(Java VM),提供在Java平台上运行的应用程序的性能和资源消耗的信息.在Ja ...

  8. RANSAC---从直线拟合到特征匹配去噪

    Ransac全称为Random Sample Consensus,随机一致性采样.该方法是一种十分高效的数据拟合方法.我们通过最简单的拟合直线任务来了解这种方法思路,继而扩展到特征点匹配中的误点剔除问 ...

  9. MD5加密BASE64加解密

    MD5需要引入system.Hash,BASE64需要引入System.NetEncoding,这两个单元应该只有高版本的DELPHI IDE才有 (貌似XE5以上版本才有).如果是D7的话,找第三方 ...

  10. MYSQL架构介绍

    专栏持续更新中- 本专栏针对的是掌握MySQL基本操作后想要对其有深入了解并且有高性能追求的读者. 第一篇文章主要是对MySQL架构的主要概括,让读者脑海中有个对MySQL大体轮廓,很多地方没有展开细 ...