[fix] Fixed chunked_attention.py implement.
This commit is contained in:
@@ -31,7 +31,6 @@ def _fwd_kernel_with_lse(
|
||||
V,
|
||||
Out,
|
||||
Lse,
|
||||
TMP,
|
||||
softmax_scale,
|
||||
stride_qb,
|
||||
stride_qh,
|
||||
@@ -60,7 +59,17 @@ def _fwd_kernel_with_lse(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
"""Flash attention forward kernel with LSE output for online softmax."""
|
||||
"""
|
||||
Flash attention forward kernel with LSE output.
|
||||
|
||||
Implements standard Flash Attention online softmax algorithm:
|
||||
- m_i: running max of attention scores
|
||||
- l_i: running sum of exp(scores - m_i)
|
||||
- acc_o: running sum of softmax(scores) @ V (unnormalized)
|
||||
|
||||
Final output: acc_o / l_i
|
||||
Final LSE: m_i + log(l_i)
|
||||
"""
|
||||
start_m = tl.program_id(0)
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
@@ -69,6 +78,7 @@ def _fwd_kernel_with_lse(
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
|
||||
# Pointers
|
||||
q_ptrs = (
|
||||
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
||||
)
|
||||
@@ -79,12 +89,12 @@ def _fwd_kernel_with_lse(
|
||||
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||
)
|
||||
|
||||
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
||||
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# Initialize running statistics
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
|
||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
|
||||
|
||||
# Load Q
|
||||
# Load Q (once per block)
|
||||
if EVEN_M & EVEN_N:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs)
|
||||
@@ -98,7 +108,7 @@ def _fwd_kernel_with_lse(
|
||||
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
|
||||
)
|
||||
|
||||
# Loop over k, v and update accumulator
|
||||
# Loop over K, V blocks
|
||||
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
||||
for start_n in range(0, end_n, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
@@ -123,28 +133,39 @@ def _fwd_kernel_with_lse(
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute QK
|
||||
# Compute QK^T * scale
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= softmax_scale
|
||||
|
||||
# Masking
|
||||
# Apply masks
|
||||
if not EVEN_N:
|
||||
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
||||
|
||||
# Online softmax
|
||||
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
||||
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# Online softmax: compute block max
|
||||
m_ij = tl.max(qk, 1) # [BLOCK_M]
|
||||
|
||||
# Scale acc_o
|
||||
acc_o_scale = tl.exp(m_i - m_ij)
|
||||
tl.store(t_ptrs, acc_o_scale)
|
||||
acc_o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * acc_o_scale[:, None]
|
||||
# New running max
|
||||
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
|
||||
|
||||
# Load V and update output
|
||||
# Rescale factor for previous accumulator
|
||||
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
|
||||
|
||||
# Compute P = exp(qk - m_new)
|
||||
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
|
||||
|
||||
# Sum of current block
|
||||
l_ij = tl.sum(p, 1) # [BLOCK_M]
|
||||
|
||||
# Update running sum: l_new = l_i * alpha + l_ij
|
||||
l_new = l_i * alpha + l_ij
|
||||
|
||||
# Rescale previous output and add new contribution
|
||||
acc_o = acc_o * alpha[:, None]
|
||||
|
||||
# Load V
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||
@@ -164,23 +185,26 @@ def _fwd_kernel_with_lse(
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# acc_o += P @ V
|
||||
p = p.to(v.dtype)
|
||||
acc_o += tl.dot(p, v)
|
||||
|
||||
# Update statistics
|
||||
m_i = m_ij
|
||||
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
||||
lse_i = m_ij + tl.log(l_i_new)
|
||||
# Update running statistics
|
||||
m_i = m_new
|
||||
l_i = l_new
|
||||
|
||||
# Final scaling
|
||||
o_scale = tl.exp(m_i - lse_i)
|
||||
tl.store(t_ptrs, o_scale)
|
||||
o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * o_scale[:, None]
|
||||
# Final normalization: output = acc_o / l_i
|
||||
acc_o = acc_o / l_i[:, None]
|
||||
|
||||
# Compute LSE = m_i + log(l_i)
|
||||
lse_i = m_i + tl.log(l_i)
|
||||
|
||||
# Store LSE
|
||||
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
||||
tl.store(lse_ptrs, lse_i)
|
||||
if EVEN_M:
|
||||
tl.store(lse_ptrs, lse_i)
|
||||
else:
|
||||
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
|
||||
|
||||
# Store output
|
||||
out_ptrs = (
|
||||
@@ -258,7 +282,6 @@ def flash_attn_with_lse(
|
||||
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
out = torch.empty_like(q)
|
||||
|
||||
BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16)
|
||||
@@ -272,7 +295,6 @@ def flash_attn_with_lse(
|
||||
v,
|
||||
out,
|
||||
lse,
|
||||
tmp,
|
||||
softmax_scale,
|
||||
q.stride(0),
|
||||
q.stride(2),
|
||||
@@ -498,57 +520,75 @@ class ChunkedPrefillState:
|
||||
|
||||
# Test function
|
||||
def _test_chunked_attention():
|
||||
"""Test chunked attention correctness against full attention."""
|
||||
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
batch, seqlen, nheads, headdim = 1, 1024, 32, 128
|
||||
print("=" * 70)
|
||||
print("Test: Chunked attention vs flash_attn_func (non-causal)")
|
||||
print("=" * 70)
|
||||
print("Splitting K,V into chunks, computing attention per chunk, then merging")
|
||||
print()
|
||||
|
||||
# Generate random Q, K, V
|
||||
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||||
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||||
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
for num_chunks in [64, 128, 256]:
|
||||
for batch, seqlen, nheads, headdim in [
|
||||
(1, 1024, 32, 128),
|
||||
(1, 2048, 32, 128),
|
||||
(1, 4096, 32, 128),
|
||||
(1, 8192, 32, 128),
|
||||
]:
|
||||
# Generate random Q, K, V
|
||||
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||
|
||||
# Full attention (reference)
|
||||
out_ref = flash_attn_func(q, k, v, causal=True)
|
||||
# Reference: full attention (non-causal)
|
||||
out_ref = flash_attn_func(q, k, v, causal=False)
|
||||
|
||||
# Chunked attention
|
||||
chunk_size = 256
|
||||
num_chunks = seqlen // chunk_size
|
||||
# Chunked attention: split K, V into chunks
|
||||
chunk_size = seqlen // num_chunks
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
for i in range(num_chunks):
|
||||
start = i * chunk_size
|
||||
end = (i + 1) * chunk_size
|
||||
|
||||
for i in range(num_chunks):
|
||||
start = i * chunk_size
|
||||
end = (i + 1) * chunk_size
|
||||
k_chunk = k[:, start:end, :, :]
|
||||
v_chunk = v[:, start:end, :, :]
|
||||
|
||||
# Q for this chunk
|
||||
q_chunk = q[:, start:end, :, :]
|
||||
# Q attends to this K,V chunk (non-causal)
|
||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||
q, k_chunk, v_chunk, causal=False
|
||||
)
|
||||
|
||||
# K, V up to current position (for causal)
|
||||
k_context = k[:, :end, :, :]
|
||||
v_context = v[:, :end, :, :]
|
||||
if accumulated_o is None:
|
||||
accumulated_o = chunk_o
|
||||
accumulated_lse = chunk_lse
|
||||
else:
|
||||
# Merge with previous chunks
|
||||
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||
accumulated_o, accumulated_lse,
|
||||
chunk_o, chunk_lse
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||
q_chunk, k_context, v_context, causal=True
|
||||
)
|
||||
# Compare
|
||||
out_diff = (out_ref - accumulated_o).abs()
|
||||
out_max_diff = out_diff.max().item()
|
||||
out_mean_diff = out_diff.mean().item()
|
||||
|
||||
if accumulated_o is None:
|
||||
accumulated_o = chunk_o
|
||||
accumulated_lse = chunk_lse
|
||||
else:
|
||||
# For chunked prefill, we need to concatenate outputs, not merge
|
||||
# Because each chunk's Q attends to different K positions
|
||||
accumulated_o = torch.cat([accumulated_o, chunk_o], dim=1)
|
||||
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
|
||||
print(
|
||||
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
|
||||
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
|
||||
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
|
||||
)
|
||||
|
||||
# Compare
|
||||
max_diff = (out_ref - accumulated_o).abs().max().item()
|
||||
print(f"Max difference: {max_diff}")
|
||||
assert max_diff < 1e-2, f"Chunked attention differs from reference: {max_diff}"
|
||||
print("Test passed!")
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("Test completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user