[fix] Fixed chunked_attention.py implement.

This commit is contained in:
Zijie Tian
2025-12-11 22:39:50 +08:00
parent b9ed77cbbb
commit 0bd7ba7536

View File

@@ -31,7 +31,6 @@ def _fwd_kernel_with_lse(
V, V,
Out, Out,
Lse, Lse,
TMP,
softmax_scale, softmax_scale,
stride_qb, stride_qb,
stride_qh, stride_qh,
@@ -60,7 +59,17 @@ def _fwd_kernel_with_lse(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
"""Flash attention forward kernel with LSE output for online softmax.""" """
Flash attention forward kernel with LSE output.
Implements standard Flash Attention online softmax algorithm:
- m_i: running max of attention scores
- l_i: running sum of exp(scores - m_i)
- acc_o: running sum of softmax(scores) @ V (unnormalized)
Final output: acc_o / l_i
Final LSE: m_i + log(l_i)
"""
start_m = tl.program_id(0) start_m = tl.program_id(0)
off_hb = tl.program_id(1) off_hb = tl.program_id(1)
off_b = off_hb // nheads off_b = off_hb // nheads
@@ -69,6 +78,7 @@ def _fwd_kernel_with_lse(
offs_n = tl.arange(0, BLOCK_N) offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM) offs_d = tl.arange(0, BLOCK_HEADDIM)
# Pointers
q_ptrs = ( q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
) )
@@ -79,12 +89,12 @@ def _fwd_kernel_with_lse(
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
) )
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m # Initialize running statistics
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
# Load Q # Load Q (once per block)
if EVEN_M & EVEN_N: if EVEN_M & EVEN_N:
if EVEN_HEADDIM: if EVEN_HEADDIM:
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
@@ -98,7 +108,7 @@ def _fwd_kernel_with_lse(
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
) )
# Loop over k, v and update accumulator # Loop over K, V blocks
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N): for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
@@ -123,28 +133,39 @@ def _fwd_kernel_with_lse(
other=0.0, other=0.0,
) )
# Compute QK # Compute QK^T * scale
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k)) qk += tl.dot(q, tl.trans(k))
qk *= softmax_scale
# Masking # Apply masks
if not EVEN_N: if not EVEN_N:
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL: if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# Online softmax # Online softmax: compute block max
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) m_ij = tl.max(qk, 1) # [BLOCK_M]
p = tl.exp(qk * softmax_scale - m_ij[:, None])
l_ij = tl.sum(p, 1)
# Scale acc_o # New running max
acc_o_scale = tl.exp(m_i - m_ij) m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
tl.store(t_ptrs, acc_o_scale)
acc_o_scale = tl.load(t_ptrs)
acc_o = acc_o * acc_o_scale[:, None]
# Load V and update output # Rescale factor for previous accumulator
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
# Compute P = exp(qk - m_new)
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
# Sum of current block
l_ij = tl.sum(p, 1) # [BLOCK_M]
# Update running sum: l_new = l_i * alpha + l_ij
l_new = l_i * alpha + l_ij
# Rescale previous output and add new contribution
acc_o = acc_o * alpha[:, None]
# Load V
if EVEN_N & EVEN_M: if EVEN_N & EVEN_M:
if EVEN_HEADDIM: if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn) v = tl.load(v_ptrs + start_n * stride_vn)
@@ -164,23 +185,26 @@ def _fwd_kernel_with_lse(
other=0.0, other=0.0,
) )
# acc_o += P @ V
p = p.to(v.dtype) p = p.to(v.dtype)
acc_o += tl.dot(p, v) acc_o += tl.dot(p, v)
# Update statistics # Update running statistics
m_i = m_ij m_i = m_new
l_i_new = tl.exp(lse_i - m_ij) + l_ij l_i = l_new
lse_i = m_ij + tl.log(l_i_new)
# Final scaling # Final normalization: output = acc_o / l_i
o_scale = tl.exp(m_i - lse_i) acc_o = acc_o / l_i[:, None]
tl.store(t_ptrs, o_scale)
o_scale = tl.load(t_ptrs) # Compute LSE = m_i + log(l_i)
acc_o = acc_o * o_scale[:, None] lse_i = m_i + tl.log(l_i)
# Store LSE # Store LSE
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
if EVEN_M:
tl.store(lse_ptrs, lse_i) tl.store(lse_ptrs, lse_i)
else:
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
# Store output # Store output
out_ptrs = ( out_ptrs = (
@@ -258,7 +282,6 @@ def flash_attn_with_lse(
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
out = torch.empty_like(q) out = torch.empty_like(q)
BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16) BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16)
@@ -272,7 +295,6 @@ def flash_attn_with_lse(
v, v,
out, out,
lse, lse,
tmp,
softmax_scale, softmax_scale,
q.stride(0), q.stride(0),
q.stride(2), q.stride(2),
@@ -498,25 +520,35 @@ class ChunkedPrefillState:
# Test function # Test function
def _test_chunked_attention(): def _test_chunked_attention():
"""Test chunked attention correctness against full attention.""" """Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.flash_attn_interface import flash_attn_func
torch.manual_seed(42) torch.manual_seed(42)
batch, seqlen, nheads, headdim = 1, 1024, 32, 128 print("=" * 70)
print("Test: Chunked attention vs flash_attn_func (non-causal)")
print("=" * 70)
print("Splitting K,V into chunks, computing attention per chunk, then merging")
print()
for dtype in [torch.float16, torch.bfloat16]:
for num_chunks in [64, 128, 256]:
for batch, seqlen, nheads, headdim in [
(1, 1024, 32, 128),
(1, 2048, 32, 128),
(1, 4096, 32, 128),
(1, 8192, 32, 128),
]:
# Generate random Q, K, V # Generate random Q, K, V
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
# Full attention (reference) # Reference: full attention (non-causal)
out_ref = flash_attn_func(q, k, v, causal=True) out_ref = flash_attn_func(q, k, v, causal=False)
# Chunked attention
chunk_size = 256
num_chunks = seqlen // chunk_size
# Chunked attention: split K, V into chunks
chunk_size = seqlen // num_chunks
accumulated_o = None accumulated_o = None
accumulated_lse = None accumulated_lse = None
@@ -524,31 +556,39 @@ def _test_chunked_attention():
start = i * chunk_size start = i * chunk_size
end = (i + 1) * chunk_size end = (i + 1) * chunk_size
# Q for this chunk k_chunk = k[:, start:end, :, :]
q_chunk = q[:, start:end, :, :] v_chunk = v[:, start:end, :, :]
# K, V up to current position (for causal) # Q attends to this K,V chunk (non-causal)
k_context = k[:, :end, :, :]
v_context = v[:, :end, :, :]
# Compute attention
chunk_o, chunk_lse = flash_attn_with_lse( chunk_o, chunk_lse = flash_attn_with_lse(
q_chunk, k_context, v_context, causal=True q, k_chunk, v_chunk, causal=False
) )
if accumulated_o is None: if accumulated_o is None:
accumulated_o = chunk_o accumulated_o = chunk_o
accumulated_lse = chunk_lse accumulated_lse = chunk_lse
else: else:
# For chunked prefill, we need to concatenate outputs, not merge # Merge with previous chunks
# Because each chunk's Q attends to different K positions accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o = torch.cat([accumulated_o, chunk_o], dim=1) accumulated_o, accumulated_lse,
chunk_o, chunk_lse
)
# Compare # Compare
max_diff = (out_ref - accumulated_o).abs().max().item() out_diff = (out_ref - accumulated_o).abs()
print(f"Max difference: {max_diff}") out_max_diff = out_diff.max().item()
assert max_diff < 1e-2, f"Chunked attention differs from reference: {max_diff}" out_mean_diff = out_diff.mean().item()
print("Test passed!")
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
print(
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
)
print()
print("=" * 70)
print("Test completed!")
if __name__ == "__main__": if __name__ == "__main__":