""" Chunked attention implementation for CPU KV cache offloading. This module implements flash attention with LSE (log-sum-exp) output, enabling proper online softmax merging for chunked prefill. Key functions: - flash_attn_with_lse: Flash attention that returns output and LSE - merge_attention_outputs: Merge outputs from multiple KV chunks - chunked_prefill_attention: High-level interface for chunked attention """ import math import torch import triton import triton.language as tl from typing import Tuple, List, Optional @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _fwd_kernel_with_lse( Q, K, V, Out, Lse, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): """ Flash attention forward kernel with LSE output. Implements standard Flash Attention online softmax algorithm: - m_i: running max of attention scores - l_i: running sum of exp(scores - m_i) - acc_o: running sum of softmax(scores) @ V (unnormalized) Final output: acc_o / l_i Final LSE: m_i + log(l_i) """ start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) # Pointers q_ptrs = ( Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) k_ptrs = ( K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) ) v_ptrs = ( V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) # Initialize running statistics m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized) # Load Q (once per block) if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) else: q = tl.load( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) # Loop over K, V blocks end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # Load K if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn) else: k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: k = tl.load( k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) # Compute QK^T * scale qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) qk *= softmax_scale # Apply masks if not EVEN_N: qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) # Online softmax: compute block max m_ij = tl.max(qk, 1) # [BLOCK_M] # New running max m_new = tl.maximum(m_i, m_ij) # [BLOCK_M] # Rescale factor for previous accumulator alpha = tl.exp(m_i - m_new) # [BLOCK_M] # Compute P = exp(qk - m_new) p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N] # Sum of current block l_ij = tl.sum(p, 1) # [BLOCK_M] # Update running sum: l_new = l_i * alpha + l_ij l_new = l_i * alpha + l_ij # Rescale previous output and add new contribution acc_o = acc_o * alpha[:, None] # Load V if EVEN_N & EVEN_M: if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn) else: v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: v = tl.load( v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) # acc_o += P @ V p = p.to(v.dtype) acc_o += tl.dot(p, v) # Update running statistics m_i = m_new l_i = l_new # Final normalization: output = acc_o / l_i acc_o = acc_o / l_i[:, None] # Compute LSE = m_i + log(l_i) lse_i = m_i + tl.log(l_i) # Store LSE lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m if EVEN_M: tl.store(lse_ptrs, lse_i) else: tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q) # Store output out_ptrs = ( Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) ) if EVEN_M: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o) else: tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) else: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) else: tl.store( out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) ) def flash_attn_with_lse( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Flash attention forward pass that returns both output and LSE. Uses flash_attn library which natively supports GQA without memory overhead. Args: q: Query tensor [batch, seqlen_q, nheads_q, headdim] k: Key tensor [batch, seqlen_k, nheads_kv, headdim] v: Value tensor [batch, seqlen_k, nheads_kv, headdim] softmax_scale: Scaling factor (default: 1/sqrt(headdim)) causal: Whether to apply causal masking Returns: out: Output tensor [batch, seqlen_q, nheads_q, headdim] lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] """ from flash_attn.flash_attn_interface import flash_attn_func batch, seqlen_q, nheads_q, headdim = q.shape _, seqlen_k, nheads_kv, _ = k.shape if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(headdim) # Use flash_attn_func which natively supports GQA (no memory overhead) # It returns (output, softmax_lse) when return_attn_probs=True is not set # We need to use the internal function to get LSE out, lse, _ = flash_attn_func( q, k, v, softmax_scale=softmax_scale, causal=causal, return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask) ) # lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded] # Trim to actual seqlen_q lse = lse[:, :, :seqlen_q] return out, lse def merge_attention_outputs( o1: torch.Tensor, lse1: torch.Tensor, o2: torch.Tensor, lse2: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Merge two attention outputs using online softmax. This implements the online softmax merging formula: - m_new = max(lse1, lse2) - o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new)) - lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new)) Args: o1: First output [batch, seqlen_q, nheads, headdim] lse1: First LSE [batch, nheads, seqlen_q] o2: Second output [batch, seqlen_q, nheads, headdim] lse2: Second LSE [batch, nheads, seqlen_q] Returns: o_merged: Merged output [batch, seqlen_q, nheads, headdim] lse_merged: Merged LSE [batch, nheads, seqlen_q] """ # lse shape: [batch, nheads, seqlen_q] # o shape: [batch, seqlen_q, nheads, headdim] # Compute max for numerical stability max_lse = torch.maximum(lse1, lse2) # Compute scaling factors # exp1, exp2 shape: [batch, nheads, seqlen_q] exp1 = torch.exp(lse1 - max_lse) exp2 = torch.exp(lse2 - max_lse) # Reshape for broadcasting with output # [batch, nheads, seqlen_q] -> [batch, seqlen_q, nheads, 1] exp1_broad = exp1.transpose(1, 2).unsqueeze(-1) exp2_broad = exp2.transpose(1, 2).unsqueeze(-1) # Merge outputs sum_exp = exp1_broad + exp2_broad o_merged = (o1 * exp1_broad + o2 * exp2_broad) / sum_exp # Compute merged LSE lse_merged = max_lse + torch.log(exp1 + exp2) # Ensure output has same dtype as input o_merged = o_merged.to(o1.dtype) return o_merged, lse_merged def chunked_attention_varlen( q: torch.Tensor, kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]], cu_seqlens_q: torch.Tensor, cu_seqlens_k_list: List[torch.Tensor], max_seqlen_q: int, max_seqlen_k_list: List[int], softmax_scale: Optional[float] = None, causal_mask_per_chunk: Optional[List[bool]] = None, ) -> torch.Tensor: """ Compute attention with KV split across multiple chunks. This is the core function for chunked prefill. It computes attention against each KV chunk and merges results using online softmax. For causal attention with chunked KV: - First chunk (current tokens): Apply causal mask - Previous chunks: No causal mask (all previous tokens are valid context) Args: q: Query tensor [total_q_tokens, nheads, headdim] kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim] cu_seqlens_q: Cumulative sequence lengths for Q [batch+1] cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk max_seqlen_q: Maximum query sequence length max_seqlen_k_list: List of maximum key sequence lengths for each chunk softmax_scale: Scaling factor causal_mask_per_chunk: Whether to apply causal mask for each chunk Returns: out: Output tensor [total_q_tokens, nheads, headdim] """ if len(kv_chunks) == 0: raise ValueError("Need at least one KV chunk") nheads = q.shape[1] headdim = q.shape[2] batch = cu_seqlens_q.shape[0] - 1 if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(headdim) if causal_mask_per_chunk is None: # Default: causal for last chunk only causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True] # Initialize accumulated output and LSE accumulated_o = None accumulated_lse = None for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks): is_causal = causal_mask_per_chunk[chunk_idx] # Reshape Q for batch processing # For varlen, we need to handle each sequence separately # For simplicity, assume single sequence (batch=1) for now q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim] # Compute attention for this chunk chunk_o, chunk_lse = flash_attn_with_lse( q_batched, k_chunk, v_chunk, softmax_scale=softmax_scale, causal=is_causal, ) # Merge with accumulated if accumulated_o is None: accumulated_o = chunk_o accumulated_lse = chunk_lse else: accumulated_o, accumulated_lse = merge_attention_outputs( accumulated_o, accumulated_lse, chunk_o, chunk_lse, ) # Remove batch dimension return accumulated_o.squeeze(0) class ChunkedPrefillState: """ State for tracking chunked prefill progress. This class maintains the accumulated attention output and LSE across multiple prefill chunks. """ def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device): self.num_layers = num_layers self.dtype = dtype self.device = device # Per-layer accumulated outputs # Each entry: (accumulated_output, accumulated_lse) or None self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [ None for _ in range(num_layers) ] # Track which chunks have been processed self.processed_chunks: int = 0 def update_layer( self, layer_id: int, chunk_output: torch.Tensor, chunk_lse: torch.Tensor, ): """Update accumulated state for a layer with a new chunk's output.""" if self.layer_states[layer_id] is None: self.layer_states[layer_id] = (chunk_output, chunk_lse) else: acc_o, acc_lse = self.layer_states[layer_id] merged_o, merged_lse = merge_attention_outputs( acc_o, acc_lse, chunk_output, chunk_lse, ) self.layer_states[layer_id] = (merged_o, merged_lse) def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]: """Get the final accumulated output for a layer.""" if self.layer_states[layer_id] is None: return None return self.layer_states[layer_id][0] def clear(self): """Clear all accumulated state.""" self.layer_states = [None for _ in range(self.num_layers)] self.processed_chunks = 0 # Test function def _test_chunked_attention(): """Test chunked attention using flash_attn_with_lse and merge_attention_outputs.""" from flash_attn.flash_attn_interface import flash_attn_func torch.manual_seed(42) print("=" * 70) print("Test: Chunked attention vs flash_attn_func (non-causal)") print("=" * 70) print("Splitting K,V into chunks, computing attention per chunk, then merging") print() for dtype in [torch.float16, torch.bfloat16]: for num_chunks in [64, 128, 256]: for batch, seqlen, nheads, headdim in [ (1, 1024, 32, 128), (1, 2048, 32, 128), (1, 4096, 32, 128), (1, 8192, 32, 128), ]: # Generate random Q, K, V q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) # Reference: full attention (non-causal) out_ref = flash_attn_func(q, k, v, causal=False) # Chunked attention: split K, V into chunks chunk_size = seqlen // num_chunks accumulated_o = None accumulated_lse = None for i in range(num_chunks): start = i * chunk_size end = (i + 1) * chunk_size k_chunk = k[:, start:end, :, :] v_chunk = v[:, start:end, :, :] # Q attends to this K,V chunk (non-causal) chunk_o, chunk_lse = flash_attn_with_lse( q, k_chunk, v_chunk, causal=False ) if accumulated_o is None: accumulated_o = chunk_o accumulated_lse = chunk_lse else: # Merge with previous chunks accumulated_o, accumulated_lse = merge_attention_outputs( accumulated_o, accumulated_lse, chunk_o, chunk_lse ) # Compare out_diff = (out_ref - accumulated_o).abs() out_max_diff = out_diff.max().item() out_mean_diff = out_diff.mean().item() status = "PASS" if out_max_diff < 1e-2 else "FAIL" print( f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} " f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) " f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}" ) print() print("=" * 70) print("Test completed!") if __name__ == "__main__": _test_chunked_attention()