Files
nano-vllm/nanovllm/kvcache/chunked_attention.py
2025-12-15 21:39:14 +08:00

539 lines
17 KiB
Python

"""
Chunked attention implementation for CPU KV cache offloading.
This module implements flash attention with LSE (log-sum-exp) output,
enabling proper online softmax merging for chunked prefill.
Key functions:
- flash_attn_with_lse: Flash attention that returns output and LSE
- merge_attention_outputs: Merge outputs from multiple KV chunks
- chunked_prefill_attention: High-level interface for chunked attention
"""
import math
import torch
import triton
import triton.language as tl
from typing import Tuple, List, Optional
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel_with_lse(
Q,
K,
V,
Out,
Lse,
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Flash attention forward kernel with LSE output.
Implements standard Flash Attention online softmax algorithm:
- m_i: running max of attention scores
- l_i: running sum of exp(scores - m_i)
- acc_o: running sum of softmax(scores) @ V (unnormalized)
Final output: acc_o / l_i
Final LSE: m_i + log(l_i)
"""
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Pointers
q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)
# Initialize running statistics
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
# Load Q (once per block)
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
)
# Loop over K, V blocks
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# Load K
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# Compute QK^T * scale
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= softmax_scale
# Apply masks
if not EVEN_N:
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# Online softmax: compute block max
m_ij = tl.max(qk, 1) # [BLOCK_M]
# New running max
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
# Rescale factor for previous accumulator
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
# Compute P = exp(qk - m_new)
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
# Sum of current block
l_ij = tl.sum(p, 1) # [BLOCK_M]
# Update running sum: l_new = l_i * alpha + l_ij
l_new = l_i * alpha + l_ij
# Rescale previous output and add new contribution
acc_o = acc_o * alpha[:, None]
# Load V
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# acc_o += P @ V
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# Update running statistics
m_i = m_new
l_i = l_new
# Final normalization: output = acc_o / l_i
acc_o = acc_o / l_i[:, None]
# Compute LSE = m_i + log(l_i)
lse_i = m_i + tl.log(l_i)
# Store LSE
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
if EVEN_M:
tl.store(lse_ptrs, lse_i)
else:
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
# Store output
out_ptrs = (
Out
+ off_b * stride_ob
+ off_h * stride_oh
+ (offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def flash_attn_with_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Flash attention forward pass that returns both output and LSE.
Uses flash_attn library which natively supports GQA without memory overhead.
Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
causal: Whether to apply causal masking
Returns:
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
"""
from flash_attn.flash_attn_interface import flash_attn_func
batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
# Use flash_attn_func which natively supports GQA (no memory overhead)
# It returns (output, softmax_lse) when return_attn_probs=True is not set
# We need to use the internal function to get LSE
out, lse, _ = flash_attn_func(
q, k, v,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
)
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
# Trim to actual seqlen_q
lse = lse[:, :, :seqlen_q]
return out, lse
def merge_attention_outputs(
o1: torch.Tensor,
lse1: torch.Tensor,
o2: torch.Tensor,
lse2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge two attention outputs using online softmax.
This implements the online softmax merging formula:
- m_new = max(lse1, lse2)
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
Args:
o1: First output [batch, seqlen_q, nheads, headdim]
lse1: First LSE [batch, nheads, seqlen_q]
o2: Second output [batch, seqlen_q, nheads, headdim]
lse2: Second LSE [batch, nheads, seqlen_q]
Returns:
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
lse_merged: Merged LSE [batch, nheads, seqlen_q]
"""
# lse shape: [batch, nheads, seqlen_q]
# o shape: [batch, seqlen_q, nheads, headdim]
# Compute max for numerical stability
max_lse = torch.maximum(lse1, lse2)
# Compute scaling factors
# exp1, exp2 shape: [batch, nheads, seqlen_q]
exp1 = torch.exp(lse1 - max_lse)
exp2 = torch.exp(lse2 - max_lse)
# Reshape for broadcasting with output
# [batch, nheads, seqlen_q] -> [batch, seqlen_q, nheads, 1]
exp1_broad = exp1.transpose(1, 2).unsqueeze(-1)
exp2_broad = exp2.transpose(1, 2).unsqueeze(-1)
# Merge outputs
sum_exp = exp1_broad + exp2_broad
o_merged = (o1 * exp1_broad + o2 * exp2_broad) / sum_exp
# Compute merged LSE
lse_merged = max_lse + torch.log(exp1 + exp2)
# Ensure output has same dtype as input
o_merged = o_merged.to(o1.dtype)
return o_merged, lse_merged
def chunked_attention_varlen(
q: torch.Tensor,
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k_list: List[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k_list: List[int],
softmax_scale: Optional[float] = None,
causal_mask_per_chunk: Optional[List[bool]] = None,
) -> torch.Tensor:
"""
Compute attention with KV split across multiple chunks.
This is the core function for chunked prefill. It computes attention
against each KV chunk and merges results using online softmax.
For causal attention with chunked KV:
- First chunk (current tokens): Apply causal mask
- Previous chunks: No causal mask (all previous tokens are valid context)
Args:
q: Query tensor [total_q_tokens, nheads, headdim]
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
max_seqlen_q: Maximum query sequence length
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
softmax_scale: Scaling factor
causal_mask_per_chunk: Whether to apply causal mask for each chunk
Returns:
out: Output tensor [total_q_tokens, nheads, headdim]
"""
if len(kv_chunks) == 0:
raise ValueError("Need at least one KV chunk")
nheads = q.shape[1]
headdim = q.shape[2]
batch = cu_seqlens_q.shape[0] - 1
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
if causal_mask_per_chunk is None:
# Default: causal for last chunk only
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
# Initialize accumulated output and LSE
accumulated_o = None
accumulated_lse = None
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
is_causal = causal_mask_per_chunk[chunk_idx]
# Reshape Q for batch processing
# For varlen, we need to handle each sequence separately
# For simplicity, assume single sequence (batch=1) for now
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
# Compute attention for this chunk
chunk_o, chunk_lse = flash_attn_with_lse(
q_batched,
k_chunk,
v_chunk,
softmax_scale=softmax_scale,
causal=is_causal,
)
# Merge with accumulated
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse,
)
# Remove batch dimension
return accumulated_o.squeeze(0)
class ChunkedPrefillState:
"""
State for tracking chunked prefill progress.
This class maintains the accumulated attention output and LSE
across multiple prefill chunks.
"""
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
self.num_layers = num_layers
self.dtype = dtype
self.device = device
# Per-layer accumulated outputs
# Each entry: (accumulated_output, accumulated_lse) or None
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
None for _ in range(num_layers)
]
# Track which chunks have been processed
self.processed_chunks: int = 0
def update_layer(
self,
layer_id: int,
chunk_output: torch.Tensor,
chunk_lse: torch.Tensor,
):
"""Update accumulated state for a layer with a new chunk's output."""
if self.layer_states[layer_id] is None:
self.layer_states[layer_id] = (chunk_output, chunk_lse)
else:
acc_o, acc_lse = self.layer_states[layer_id]
merged_o, merged_lse = merge_attention_outputs(
acc_o, acc_lse,
chunk_output, chunk_lse,
)
self.layer_states[layer_id] = (merged_o, merged_lse)
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
"""Get the final accumulated output for a layer."""
if self.layer_states[layer_id] is None:
return None
return self.layer_states[layer_id][0]
def clear(self):
"""Clear all accumulated state."""
self.layer_states = [None for _ in range(self.num_layers)]
self.processed_chunks = 0
# Test function
def _test_chunked_attention():
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
from flash_attn.flash_attn_interface import flash_attn_func
torch.manual_seed(42)
print("=" * 70)
print("Test: Chunked attention vs flash_attn_func (non-causal)")
print("=" * 70)
print("Splitting K,V into chunks, computing attention per chunk, then merging")
print()
for dtype in [torch.float16, torch.bfloat16]:
for num_chunks in [64, 128, 256]:
for batch, seqlen, nheads, headdim in [
(1, 1024, 32, 128),
(1, 2048, 32, 128),
(1, 4096, 32, 128),
(1, 8192, 32, 128),
]:
# Generate random Q, K, V
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
# Reference: full attention (non-causal)
out_ref = flash_attn_func(q, k, v, causal=False)
# Chunked attention: split K, V into chunks
chunk_size = seqlen // num_chunks
accumulated_o = None
accumulated_lse = None
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
k_chunk = k[:, start:end, :, :]
v_chunk = v[:, start:end, :, :]
# Q attends to this K,V chunk (non-causal)
chunk_o, chunk_lse = flash_attn_with_lse(
q, k_chunk, v_chunk, causal=False
)
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
# Merge with previous chunks
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse
)
# Compare
out_diff = (out_ref - accumulated_o).abs()
out_max_diff = out_diff.max().item()
out_mean_diff = out_diff.mean().item()
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
print(
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
)
print()
print("=" * 70)
print("Test completed!")
if __name__ == "__main__":
_test_chunked_attention()