vllm-GDN-forward
简单记录一下vllm对于GDN模型的优化,下边这里是来自Qwen3-Next模型的GDN部分的结构:

QKV-Proj
和nano-vllm中qkv的行为一直,可以看到图中的hidden_state分别经过了6个linear,这里也采用了一个增加吞吐量的手段,即这几个linear层被合并成了一个,去看模型safetensors文件发现他是这样保存的。
1 2
| model.layers.0.linear_attn.in_proj_qkvz.weight model.layers.0.linear_attn.in_proj_ba.weight
|
直接把hidden_state乘这个复合大矩阵,分别得到卷积前的$ q,k,\alpha,\beta,z $
执行gate(alpha)
1 2 3 4
| x = a + dt_bias softplus(x) = log(1 + exp(x)) g = -exp(A_log) * softplus(x) beta = sigmoid(b)
|
得到了prepare后的q/k/gate/beta后,根据prefill阶段与decode阶段的不同:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| core_attn_out, last_state = self.chunk_gated_delta_rule( q, k, v, g, beta, initial_state=ssm_state, output_final_state=True, cu_seqlens=non_spec_query_start_loc, chunk_indices=chunk_indices, ... ) ssm_state[...] = last_state
core_attn_out, last_state = fused_sigmoid_gating_delta_rule_update( A_log=self.A_log, a=a, b=b, dt_bias=self.dt_bias, q=q, k=k, v=v, initial_state=ssm_state, inplace_final_state=True, cu_seqlens=..., ssm_state_indices=..., )
|
最后就是
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| core_attn_out (SSM 输出, [L, HV, head_v_dim]) │ ▼ RMSNormGated(core_attn_out, z) = RMSNorm(core_attn_out) * SiLU(z) │ ▼ reshape │ out_proj ────→ hidden_size
与标准 Mamba 输出公式的对应
标准 Mamba 的输出公式是:
output = RMSNorm(x) * SiLU(z) ← SSM 贡献(受 z 门控)
|