Files
nano-vllm/nanovllm/kvcache/chunked_attention.py
2025-12-30 00:31:48 +08:00

625 lines
20 KiB
Python

"""
Chunked attention implementation for CPU KV cache offloading.
This module implements flash attention with LSE (log-sum-exp) output,
enabling proper online softmax merging for chunked prefill.
Key functions:
- flash_attn_with_lse: Flash attention that returns output and LSE
- merge_attention_outputs: Merge outputs from multiple KV chunks
- chunked_prefill_attention: High-level interface for chunked attention
"""
import math
import torch
import triton
import triton.language as tl
from typing import Tuple, List, Optional
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel_with_lse(
Q,
K,
V,
Out,
Lse,
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Flash attention forward kernel with LSE output.
Implements standard Flash Attention online softmax algorithm:
- m_i: running max of attention scores
- l_i: running sum of exp(scores - m_i)
- acc_o: running sum of softmax(scores) @ V (unnormalized)
Final output: acc_o / l_i
Final LSE: m_i + log(l_i)
"""
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Pointers
q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)
# Initialize running statistics
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
# Load Q (once per block)
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
)
# Loop over K, V blocks
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# Load K
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# Compute QK^T * scale
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= softmax_scale
# Apply masks
if not EVEN_N:
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# Online softmax: compute block max
m_ij = tl.max(qk, 1) # [BLOCK_M]
# New running max
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
# Rescale factor for previous accumulator
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
# Compute P = exp(qk - m_new)
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
# Sum of current block
l_ij = tl.sum(p, 1) # [BLOCK_M]
# Update running sum: l_new = l_i * alpha + l_ij
l_new = l_i * alpha + l_ij
# Rescale previous output and add new contribution
acc_o = acc_o * alpha[:, None]
# Load V
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# acc_o += P @ V
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# Update running statistics
m_i = m_new
l_i = l_new
# Final normalization: output = acc_o / l_i
acc_o = acc_o / l_i[:, None]
# Compute LSE = m_i + log(l_i)
lse_i = m_i + tl.log(l_i)
# Store LSE
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
if EVEN_M:
tl.store(lse_ptrs, lse_i)
else:
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
# Store output
out_ptrs = (
Out
+ off_b * stride_ob
+ off_h * stride_oh
+ (offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def flash_attn_with_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Flash attention forward pass that returns both output and LSE.
Uses flash_attn library which natively supports GQA without memory overhead.
Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
causal: Whether to apply causal masking
Returns:
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
"""
from flash_attn.flash_attn_interface import flash_attn_func
batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
# Use flash_attn_func which natively supports GQA (no memory overhead)
# It returns (output, softmax_lse) when return_attn_probs=True is not set
# We need to use the internal function to get LSE
out, lse, _ = flash_attn_func(
q, k, v,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
)
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
# Trim to actual seqlen_q
lse = lse[:, :, :seqlen_q]
return out, lse
@triton.jit
def _merge_lse_kernel(
lse1_ptr, lse2_ptr, lse_out_ptr,
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@triton.jit
def _merge_output_kernel(
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
pid_head = tl.program_id(2)
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
sum_exp = exp1 + exp2
# Process headdim in chunks
for d_offset in range(0, headdim, BLOCK_SIZE):
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
mask = d_idx < headdim
# Compute output index: [batch, seqlen_q, nheads, headdim]
base_idx = (pid_batch * seqlen_q * nheads * headdim +
pid_seq * nheads * headdim +
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
def merge_attention_outputs(
o1: torch.Tensor,
lse1: torch.Tensor,
o2: torch.Tensor,
lse2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge two attention outputs using online softmax (Triton fused kernel).
This implements the online softmax merging formula:
- m_new = max(lse1, lse2)
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
Args:
o1: First output [batch, seqlen_q, nheads, headdim]
lse1: First LSE [batch, nheads, seqlen_q]
o2: Second output [batch, seqlen_q, nheads, headdim]
lse2: Second LSE [batch, nheads, seqlen_q]
Returns:
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
lse_merged: Merged LSE [batch, nheads, seqlen_q]
"""
batch, seqlen_q, nheads, headdim = o1.shape
# Allocate output tensors
o_merged = torch.empty_like(o1)
lse_merged = torch.empty_like(lse1)
# Launch LSE merge kernel
num_lse_elements = batch * nheads * seqlen_q
BLOCK_SIZE_LSE = 256
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
_merge_lse_kernel[grid_lse](
lse1, lse2, lse_merged,
num_lse_elements,
BLOCK_SIZE=BLOCK_SIZE_LSE,
)
# Launch output merge kernel
BLOCK_SIZE = 128
grid_output = (batch, seqlen_q, nheads)
_merge_output_kernel[grid_output](
o1, o2, lse1, lse2, o_merged,
batch, seqlen_q, nheads, headdim,
BLOCK_SIZE=BLOCK_SIZE,
)
return o_merged, lse_merged
def chunked_attention_varlen(
q: torch.Tensor,
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k_list: List[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k_list: List[int],
softmax_scale: Optional[float] = None,
causal_mask_per_chunk: Optional[List[bool]] = None,
) -> torch.Tensor:
"""
Compute attention with KV split across multiple chunks.
This is the core function for chunked prefill. It computes attention
against each KV chunk and merges results using online softmax.
For causal attention with chunked KV:
- First chunk (current tokens): Apply causal mask
- Previous chunks: No causal mask (all previous tokens are valid context)
Args:
q: Query tensor [total_q_tokens, nheads, headdim]
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
max_seqlen_q: Maximum query sequence length
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
softmax_scale: Scaling factor
causal_mask_per_chunk: Whether to apply causal mask for each chunk
Returns:
out: Output tensor [total_q_tokens, nheads, headdim]
"""
if len(kv_chunks) == 0:
raise ValueError("Need at least one KV chunk")
nheads = q.shape[1]
headdim = q.shape[2]
batch = cu_seqlens_q.shape[0] - 1
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
if causal_mask_per_chunk is None:
# Default: causal for last chunk only
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
# Initialize accumulated output and LSE
accumulated_o = None
accumulated_lse = None
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
is_causal = causal_mask_per_chunk[chunk_idx]
# Reshape Q for batch processing
# For varlen, we need to handle each sequence separately
# For simplicity, assume single sequence (batch=1) for now
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
# Compute attention for this chunk
chunk_o, chunk_lse = flash_attn_with_lse(
q_batched,
k_chunk,
v_chunk,
softmax_scale=softmax_scale,
causal=is_causal,
)
# Merge with accumulated
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse,
)
# Remove batch dimension
return accumulated_o.squeeze(0)
class ChunkedPrefillState:
"""
State for tracking chunked prefill progress.
This class maintains the accumulated attention output and LSE
across multiple prefill chunks.
"""
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
self.num_layers = num_layers
self.dtype = dtype
self.device = device
# Per-layer accumulated outputs
# Each entry: (accumulated_output, accumulated_lse) or None
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
None for _ in range(num_layers)
]
# Track which chunks have been processed
self.processed_chunks: int = 0
def update_layer(
self,
layer_id: int,
chunk_output: torch.Tensor,
chunk_lse: torch.Tensor,
):
"""Update accumulated state for a layer with a new chunk's output."""
if self.layer_states[layer_id] is None:
self.layer_states[layer_id] = (chunk_output, chunk_lse)
else:
acc_o, acc_lse = self.layer_states[layer_id]
merged_o, merged_lse = merge_attention_outputs(
acc_o, acc_lse,
chunk_output, chunk_lse,
)
self.layer_states[layer_id] = (merged_o, merged_lse)
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
"""Get the final accumulated output for a layer."""
if self.layer_states[layer_id] is None:
return None
return self.layer_states[layer_id][0]
def clear(self):
"""Clear all accumulated state."""
self.layer_states = [None for _ in range(self.num_layers)]
self.processed_chunks = 0
# Test function
def _test_chunked_attention():
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
from flash_attn.flash_attn_interface import flash_attn_func
torch.manual_seed(42)
print("=" * 70)
print("Test: Chunked attention vs flash_attn_func (non-causal)")
print("=" * 70)
print("Splitting K,V into chunks, computing attention per chunk, then merging")
print()
for dtype in [torch.float16, torch.bfloat16]:
for num_chunks in [64, 128, 256]:
for batch, seqlen, nheads, headdim in [
(1, 1024, 32, 128),
(1, 2048, 32, 128),
(1, 4096, 32, 128),
(1, 8192, 32, 128),
]:
# Generate random Q, K, V
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
# Reference: full attention (non-causal)
out_ref = flash_attn_func(q, k, v, causal=False)
# Chunked attention: split K, V into chunks
chunk_size = seqlen // num_chunks
accumulated_o = None
accumulated_lse = None
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
k_chunk = k[:, start:end, :, :]
v_chunk = v[:, start:end, :, :]
# Q attends to this K,V chunk (non-causal)
chunk_o, chunk_lse = flash_attn_with_lse(
q, k_chunk, v_chunk, causal=False
)
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
# Merge with previous chunks
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse
)
# Compare
out_diff = (out_ref - accumulated_o).abs()
out_max_diff = out_diff.max().item()
out_mean_diff = out_diff.mean().item()
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
print(
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
)
print()
print("=" * 70)
print("Test completed!")
if __name__ == "__main__":
_test_chunked_attention()