FA2与Flash Decoding

FA2与Flash Decoding

在学习LLM Infer的时候,总是会提到FLash_Decoding这个概念,这里对比一下FA2和Flash_Decoding(FD)的一点区别。

FlashAttention2

在学习了Flash Attention V1的过程中,FA的矩阵遍历顺序是外层遍历KV, 内部依次遍历Q tile,那么这样会带来一个问题,在计算Attention的时候,最后其实是要得到一个与Q维度相同的output,那么FA的问题:

在内层遍历Q的时候(此时这个Q会被拆分为[Q_tile1, Q_tile2...]),每个Block需要依次的更新Out[tile1], Out[tile2] ...,这样在写入的时候,每一个block就丧失了一定并行度,因为写入需要按顺序写。

FA2就解决的FA1的这个问题,FA2将循环的步骤变为了外层循环Q,内部循环KV,这样每一个block就不存在写冲突,每一个block只需要更新自身的。

用“循环”来表达似乎不太准确,这里用triton kernel的grid来表达更准确一些

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

# FA的grid形式,一个program负责某个batch的某一个head
grid_fa1 = (batch_size, head_num)

# 在kernel 内部
for start_m in range(0, q_len, BLOCK_M):
...

# FA2的grid形式像是这样的kernel,一个program负责一整块Q矩阵的一小部分
grid_fa2 = (batch_size, cdiv(query_seq_len, BLOCK_SIZE_M), head_num)

@triton.jit
def flash_attn_2_fwd_kernel(
Q, K, V, ...
):
# 1. 获取当前 Block 在 Grid 中的空间位置
pid_m = tl.program_id(0) # Q 块的索引 (Decode 时为 0)
pid_bh = tl.program_id(1) # Batch 索引
pid_h = tl.program_id(2) # Head 索引

# 2. 定位到当前 Head 对应的 Q 向量指针
# 每个 Block 独占一个特定的 Batch 和 Head
q_ptr = Q + pid_bh * stride_qb + pid_h * stride_qh

# 3. 在这个 Block(SM)内部,用指针循环去【串行】遍历所有的 KV 块
for start_n in range(0, KV_LEN, BLOCK_N):
k_ptr = K + start_n * stride_kn + ...
v_ptr = V + start_n * stride_vn + ...

# 加载 KV 块到片上 SRAM
k = tl.load(k_ptr)
v = tl.load(v_ptr)

# 在计算单元内部迭代更新 Online Softmax 局部结果
...

那么看到FA2的这个grid形式,在Prefill阶段并没有什么问题,但是在Decode阶段,这里的query_seq_len === 1,这样子就退化到了普通了FA,也就失去了FA2的并行度。也就是为什么会提出FD了。

FlashDecoding

与FA2对于FA的优化方法类似,在Decode的过程中,哪一个数据可以进一步Split呢?也就只有KV Cache了,在进行split之前,需要注意一个问题,在Online Softmax的时候(看FA2的kernel),我们需要保存每一个BLOCK_N长度块内部的局部最大值以及一个指数结果(可以看这个视频这个解释了online softmax)

为了解决这个不同的program就需要从HBM中读取局部值,所以我们在切分KV的时候,需要将局部最值/指数计算结果写入一次HBM。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Flash Decoding的grid形式
grid = (batch_size, head_num, n_splits)

@triton.jit
def flash_decoding_fwd_kernel(
Q, K, V, Output_Splits, Logsumexp_Splits, ...
):
# 1. 获取当前 Block 的位置
pid_b = tl.program_id(0) # Batch 索引
pid_h = tl.program_id(1) # Head 索引
pid_split = tl.program_id(2) # 当前这个 SM 负责的 KV 切片索引

# 2. 计算当前 split 负责的 KV 范围
kv_per_split = KV_LEN // NUM_SPLITS
start_kv_idx = pid_split * kv_per_split
end_kv_idx = start_kv_idx + kv_per_split

# 3. 现在的循环变短了,只遍历自己分到的那一小段 KV
for start_n in range(start_kv_idx, end_kv_idx, BLOCK_N):
k = tl.load(K + start_n * ...)
v = tl.load(V + start_n * ...)
# 计算局部 Attention
...

# 4. 因为是并行计算,算完后不能直接当作最终结果
# 必须把每个 split 的局部输出和 Softmax 分母(Logsumexp)写回全局显存(HBM)
tl.store(Output_Splits + pid_b*... + pid_h*... + pid_split*..., local_out)
tl.store(Logsumexp_Splits + pid_b*... + pid_h*... + pid_split*..., local_lse)

Flash Attention源代码解读

前面讲了那么多,其实对于V1版本与V2版本的交换QKV遍历顺序并没有特别的明白,于是乎决定看一看FA的源代码对比一下区别,看懂了才能更好理解~

Flash Attention V1

FA源代码仓库 https://github.com/Dao-AILab/flash-attention,
由于对于Cuda并不是那么的了解,因此这里直接看Triton的代码,更加通俗易懂一点,V1版本的Triton代码链接:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton_og.py

推理的过程中主要涉及到是fwd的代码,我们这里只速度的浏览一下fwd kernel,结合手写的online softmax,我们把FA源代码和公式一一对应起来,能够更好理解FA,首先这里假设了QKV三个矩阵的shape是相同的,都为[batch_size, head_num, seq_len, d_model]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
TMP,
L,
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
# seq_len维度
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# d_model维度
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m

m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")

l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
# casual mask
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
# l_ij [BLOCK_M]
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)

这里kernel_grid输入为grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]),说明了tl.program_id(0)对应了 seq这个维度,tl.program_id(1)则将bs与head_num打包为了一整个一维索引。BLOCK_N在内部循环中使用,这个遍历的是d_model维度。

计算offset对应了triton_kernel,将QKV大矩阵切分为了什么样子的小矩阵。

Q -> [BLOCK_M, d_model]
K -> [BLOCK_N, d_model]
V -> [BLOCK_N, d_model]
输出O -> [BLOCK_M, BLOCK_N]

直接进入遍历BLOCK_N这个循环for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N),首先根据偏移量读取kk = tl.load(k_ptrs + start_n * stride_kn),计算得到qk的矩阵,sm_scale对应了稳定的Attention计算中分母的根号d。

m_ij = tl.max(qk, 1)这里计算得到当前这个BLOCK_M,BLOCK_N大小矩阵中的每一行的最大值,得到的形状是[BLOCK_M]的矩阵。

1
2
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)

这一行对应于手写图片中当前block计算softmax的分母,也就是手写图片中的$d_i$,接下来注释中update则对应手写图片的递推公式计算。不过这里代码中的m_ij对应了当前这个block的最大值,而m_i则记录的是所有已经遍历block的最大值。

1
2
m_i   = max(block_0, block_1, ..., block_{start_n - 1})
m_ij = max(block_{start_n})

alpha对应于$e^{m^{i-1} - m^i}$这一项,beta是一个中间结果,l_i_new = alpha * l_i + beta * l_ijbeta * l_ij对应手写公式的$exp(s_i - m_i)$。

经过上面的一一对照关系acc_scale = l_i / l_i_new * alpha这里的公式就很简单了$\frac{d_{i-1}}{d_{i}} * e^{m^{i-1} - m^i}$,至此我们就构造出了我们需要的部分,后续就是简单的element-wise乘积计算再加上load-store到输出tensor。

通过这里的Kernel我们就看出来了FA V1中SM将Q的seq_len维度进行拆分,而内部则是串行的遍历KV得到attention。


FA2与Flash Decoding
http://example.com/2026/06/01/FA2与Flash_Decoding/
作者
Soya
发布于
2026年6月1日
许可协议