FA2与Flash Decoding
FA2与Flash Decoding
在学习LLM Infer的时候,总是会提到FLash_Decoding这个概念,这里对比一下FA2和Flash_Decoding(FD)的一点区别。
FlashAttention2
在学习了Flash Attention V1的过程中,FA的矩阵遍历顺序是外层遍历KV, 内部依次遍历Q tile,那么这样会带来一个问题,在计算Attention的时候,最后其实是要得到一个与Q维度相同的output,那么FA的问题:
在内层遍历Q的时候(此时这个Q会被拆分为
[Q_tile1, Q_tile2...]),每个Block需要依次的更新Out[tile1], Out[tile2] ...,这样在写入的时候,每一个block就丧失了一定并行度,因为写入需要按顺序写。
FA2就解决的FA1的这个问题,FA2将循环的步骤变为了外层循环Q,内部循环KV,这样每一个block就不存在写冲突,每一个block只需要更新自身的。
用“循环”来表达似乎不太准确,这里用triton kernel的grid来表达更准确一些
1 | |
那么看到FA2的这个grid形式,在Prefill阶段并没有什么问题,但是在Decode阶段,这里的query_seq_len === 1,这样子就退化到了普通了FA,也就失去了FA2的并行度。也就是为什么会提出FD了。
FlashDecoding
与FA2对于FA的优化方法类似,在Decode的过程中,哪一个数据可以进一步Split呢?也就只有KV Cache了,在进行split之前,需要注意一个问题,在Online Softmax的时候(看FA2的kernel),我们需要保存每一个BLOCK_N长度块内部的局部最大值以及一个指数结果(可以看这个视频这个解释了online softmax)
为了解决这个不同的program就需要从HBM中读取局部值,所以我们在切分KV的时候,需要将局部最值/指数计算结果写入一次HBM。
1 | |
Flash Attention源代码解读
前面讲了那么多,其实对于V1版本与V2版本的交换QKV遍历顺序并没有特别的明白,于是乎决定看一看FA的源代码对比一下区别,看懂了才能更好理解~
Flash Attention V1
FA源代码仓库 https://github.com/Dao-AILab/flash-attention,
由于对于Cuda并不是那么的了解,因此这里直接看Triton的代码,更加通俗易懂一点,V1版本的Triton代码链接:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton_og.py
推理的过程中主要涉及到是fwd的代码,我们这里只速度的浏览一下fwd kernel,结合手写的online softmax,我们把FA源代码和公式一一对应起来,能够更好理解FA,首先这里假设了QKV三个矩阵的shape是相同的,都为[batch_size, head_num, seq_len, d_model]。
1 | |
这里kernel_grid输入为grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]),说明了tl.program_id(0)对应了 seq这个维度,tl.program_id(1)则将bs与head_num打包为了一整个一维索引。BLOCK_N在内部循环中使用,这个遍历的是d_model维度。
计算offset对应了triton_kernel,将QKV大矩阵切分为了什么样子的小矩阵。
Q -> [BLOCK_M, d_model]
K -> [BLOCK_N, d_model]
V -> [BLOCK_N, d_model]
输出O -> [BLOCK_M, BLOCK_N]
直接进入遍历BLOCK_N这个循环for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N),首先根据偏移量读取kk = tl.load(k_ptrs + start_n * stride_kn),计算得到qk的矩阵,sm_scale对应了稳定的Attention计算中分母的根号d。
m_ij = tl.max(qk, 1)这里计算得到当前这个BLOCK_M,BLOCK_N大小矩阵中的每一行的最大值,得到的形状是[BLOCK_M]的矩阵。
1 | |
这一行对应于手写图片中当前block计算softmax的分母,也就是手写图片中的$d_i$,接下来注释中update则对应手写图片的递推公式计算。不过这里代码中的m_ij对应了当前这个block的最大值,而m_i则记录的是所有已经遍历block的最大值。
1 | |
alpha对应于$e^{m^{i-1} - m^i}$这一项,beta是一个中间结果,l_i_new = alpha * l_i + beta * l_ij中beta * l_ij对应手写公式的$exp(s_i - m_i)$。
经过上面的一一对照关系acc_scale = l_i / l_i_new * alpha这里的公式就很简单了$\frac{d_{i-1}}{d_{i}} * e^{m^{i-1} - m^i}$,至此我们就构造出了我们需要的部分,后续就是简单的element-wise乘积计算再加上load-store到输出tensor。
通过这里的Kernel我们就看出来了FA V1中SM将Q的seq_len维度进行拆分,而内部则是串行的遍历KV得到attention。