From 0bd7ba75366d61618d292f837dcaf4d6afc35ceb Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Thu, 11 Dec 2025 22:39:50 +0800 Subject: [PATCH] [fix] Fixed chunked_attention.py implement. --- nanovllm/kvcache/chunked_attention.py | 178 ++++++++++++++++---------- 1 file changed, 109 insertions(+), 69 deletions(-) diff --git a/nanovllm/kvcache/chunked_attention.py b/nanovllm/kvcache/chunked_attention.py index cc679d1..ddcb62c 100644 --- a/nanovllm/kvcache/chunked_attention.py +++ b/nanovllm/kvcache/chunked_attention.py @@ -31,7 +31,6 @@ def _fwd_kernel_with_lse( V, Out, Lse, - TMP, softmax_scale, stride_qb, stride_qh, @@ -60,7 +59,17 @@ def _fwd_kernel_with_lse( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - """Flash attention forward kernel with LSE output for online softmax.""" + """ + Flash attention forward kernel with LSE output. + + Implements standard Flash Attention online softmax algorithm: + - m_i: running max of attention scores + - l_i: running sum of exp(scores - m_i) + - acc_o: running sum of softmax(scores) @ V (unnormalized) + + Final output: acc_o / l_i + Final LSE: m_i + log(l_i) + """ start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads @@ -69,6 +78,7 @@ def _fwd_kernel_with_lse( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) + # Pointers q_ptrs = ( Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) @@ -79,12 +89,12 @@ def _fwd_kernel_with_lse( V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) - t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m - lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # Initialize running statistics + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized) - # Load Q + # Load Q (once per block) if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load(q_ptrs) @@ -98,7 +108,7 @@ def _fwd_kernel_with_lse( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) - # Loop over k, v and update accumulator + # Loop over K, V blocks end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -123,28 +133,39 @@ def _fwd_kernel_with_lse( other=0.0, ) - # Compute QK + # Compute QK^T * scale qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) + qk *= softmax_scale - # Masking + # Apply masks if not EVEN_N: qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - # Online softmax - m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) - p = tl.exp(qk * softmax_scale - m_ij[:, None]) - l_ij = tl.sum(p, 1) + # Online softmax: compute block max + m_ij = tl.max(qk, 1) # [BLOCK_M] - # Scale acc_o - acc_o_scale = tl.exp(m_i - m_ij) - tl.store(t_ptrs, acc_o_scale) - acc_o_scale = tl.load(t_ptrs) - acc_o = acc_o * acc_o_scale[:, None] + # New running max + m_new = tl.maximum(m_i, m_ij) # [BLOCK_M] - # Load V and update output + # Rescale factor for previous accumulator + alpha = tl.exp(m_i - m_new) # [BLOCK_M] + + # Compute P = exp(qk - m_new) + p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N] + + # Sum of current block + l_ij = tl.sum(p, 1) # [BLOCK_M] + + # Update running sum: l_new = l_i * alpha + l_ij + l_new = l_i * alpha + l_ij + + # Rescale previous output and add new contribution + acc_o = acc_o * alpha[:, None] + + # Load V if EVEN_N & EVEN_M: if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn) @@ -164,23 +185,26 @@ def _fwd_kernel_with_lse( other=0.0, ) + # acc_o += P @ V p = p.to(v.dtype) acc_o += tl.dot(p, v) - # Update statistics - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) + # Update running statistics + m_i = m_new + l_i = l_new - # Final scaling - o_scale = tl.exp(m_i - lse_i) - tl.store(t_ptrs, o_scale) - o_scale = tl.load(t_ptrs) - acc_o = acc_o * o_scale[:, None] + # Final normalization: output = acc_o / l_i + acc_o = acc_o / l_i[:, None] + + # Compute LSE = m_i + log(l_i) + lse_i = m_i + tl.log(l_i) # Store LSE lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m - tl.store(lse_ptrs, lse_i) + if EVEN_M: + tl.store(lse_ptrs, lse_i) + else: + tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q) # Store output out_ptrs = ( @@ -258,7 +282,6 @@ def flash_attn_with_lse( seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) out = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16) @@ -272,7 +295,6 @@ def flash_attn_with_lse( v, out, lse, - tmp, softmax_scale, q.stride(0), q.stride(2), @@ -498,57 +520,75 @@ class ChunkedPrefillState: # Test function def _test_chunked_attention(): - """Test chunked attention correctness against full attention.""" + """Test chunked attention using flash_attn_with_lse and merge_attention_outputs.""" from flash_attn.flash_attn_interface import flash_attn_func torch.manual_seed(42) - batch, seqlen, nheads, headdim = 1, 1024, 32, 128 + print("=" * 70) + print("Test: Chunked attention vs flash_attn_func (non-causal)") + print("=" * 70) + print("Splitting K,V into chunks, computing attention per chunk, then merging") + print() - # Generate random Q, K, V - q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) - k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) - v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) + for dtype in [torch.float16, torch.bfloat16]: + for num_chunks in [64, 128, 256]: + for batch, seqlen, nheads, headdim in [ + (1, 1024, 32, 128), + (1, 2048, 32, 128), + (1, 4096, 32, 128), + (1, 8192, 32, 128), + ]: + # Generate random Q, K, V + q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) - # Full attention (reference) - out_ref = flash_attn_func(q, k, v, causal=True) + # Reference: full attention (non-causal) + out_ref = flash_attn_func(q, k, v, causal=False) - # Chunked attention - chunk_size = 256 - num_chunks = seqlen // chunk_size + # Chunked attention: split K, V into chunks + chunk_size = seqlen // num_chunks + accumulated_o = None + accumulated_lse = None - accumulated_o = None - accumulated_lse = None + for i in range(num_chunks): + start = i * chunk_size + end = (i + 1) * chunk_size - for i in range(num_chunks): - start = i * chunk_size - end = (i + 1) * chunk_size + k_chunk = k[:, start:end, :, :] + v_chunk = v[:, start:end, :, :] - # Q for this chunk - q_chunk = q[:, start:end, :, :] + # Q attends to this K,V chunk (non-causal) + chunk_o, chunk_lse = flash_attn_with_lse( + q, k_chunk, v_chunk, causal=False + ) - # K, V up to current position (for causal) - k_context = k[:, :end, :, :] - v_context = v[:, :end, :, :] + if accumulated_o is None: + accumulated_o = chunk_o + accumulated_lse = chunk_lse + else: + # Merge with previous chunks + accumulated_o, accumulated_lse = merge_attention_outputs( + accumulated_o, accumulated_lse, + chunk_o, chunk_lse + ) - # Compute attention - chunk_o, chunk_lse = flash_attn_with_lse( - q_chunk, k_context, v_context, causal=True - ) + # Compare + out_diff = (out_ref - accumulated_o).abs() + out_max_diff = out_diff.max().item() + out_mean_diff = out_diff.mean().item() - if accumulated_o is None: - accumulated_o = chunk_o - accumulated_lse = chunk_lse - else: - # For chunked prefill, we need to concatenate outputs, not merge - # Because each chunk's Q attends to different K positions - accumulated_o = torch.cat([accumulated_o, chunk_o], dim=1) + status = "PASS" if out_max_diff < 1e-2 else "FAIL" + print( + f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} " + f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) " + f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}" + ) - # Compare - max_diff = (out_ref - accumulated_o).abs().max().item() - print(f"Max difference: {max_diff}") - assert max_diff < 1e-2, f"Chunked attention differs from reference: {max_diff}" - print("Test passed!") + print() + print("=" * 70) + print("Test completed!") if __name__ == "__main__":