vllm-GDN-forward

vllm-GDN-forward

简单记录一下vllm对于GDN模型的优化,下边这里是来自Qwen3-Next模型的GDN部分的结构:

QKV-Proj

和nano-vllm中qkv的行为一直,可以看到图中的hidden_state分别经过了6个linear,这里也采用了一个增加吞吐量的手段,即这几个linear层被合并成了一个,去看模型safetensors文件发现他是这样保存的。

1
2
model.layers.0.linear_attn.in_proj_qkvz.weight
model.layers.0.linear_attn.in_proj_ba.weight

直接把hidden_state乘这个复合大矩阵,分别得到卷积前的$ q,k,\alpha,\beta,z $

执行gate(alpha)

1
2
3
4
x = a + dt_bias
softplus(x) = log(1 + exp(x))
g = -exp(A_log) * softplus(x)
beta = sigmoid(b)

得到了prepare后的q/k/gate/beta后,根据prefill阶段与decode阶段的不同:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# prefill阶段
core_attn_out, last_state = self.chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=ssm_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
chunk_indices=chunk_indices,
...
)
ssm_state[...] = last_state # 写回 KV cache

# docode阶段
core_attn_out, last_state = fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log, a=a, b=b, dt_bias=self.dt_bias,
q=q, k=k, v=v,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=...,
ssm_state_indices=...,
)

最后就是

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
core_attn_out  (SSM 输出, [L, HV, head_v_dim])


RMSNormGated(core_attn_out, z)
= RMSNorm(core_attn_out) * SiLU(z)

reshape

out_proj ────→ hidden_size

与标准 Mamba 输出公式的对应

标准 Mamba 的输出公式是:

output = RMSNorm(x) * SiLU(z) ← SSM 贡献(受 z 门控)

vllm-GDN-forward
http://example.com/2026/06/01/5_vllm项目GDN_forward/
作者
Soya
发布于
2026年6月1日
许可协议