diff --git a/nanovllm/ops/__init__.py b/nanovllm/ops/__init__.py new file mode 100644 index 0000000..bb0839b --- /dev/null +++ b/nanovllm/ops/__init__.py @@ -0,0 +1,38 @@ +""" +Operators module for nano-vLLM. + +This module contains low-level attention operators and kernels. +""" + +from nanovllm.ops.chunked_attention import ( + flash_attn_with_lse, + merge_attention_outputs, + chunked_attention_varlen, + ChunkedPrefillState, +) + +from nanovllm.ops.xattn import ( + xattn_estimate, + xattn_estimate_chunked, + flat_group_gemm_fuse_reshape, + softmax_fuse_block_sum, + find_blocks_chunked, + create_causal_mask, + compute_sparsity, +) + +__all__ = [ + # chunked_attention + "flash_attn_with_lse", + "merge_attention_outputs", + "chunked_attention_varlen", + "ChunkedPrefillState", + # xattn + "xattn_estimate", + "xattn_estimate_chunked", + "flat_group_gemm_fuse_reshape", + "softmax_fuse_block_sum", + "find_blocks_chunked", + "create_causal_mask", + "compute_sparsity", +] diff --git a/nanovllm/ops/chunked_attention.py b/nanovllm/ops/chunked_attention.py new file mode 100644 index 0000000..6f92c33 --- /dev/null +++ b/nanovllm/ops/chunked_attention.py @@ -0,0 +1,624 @@ +""" +Chunked attention implementation for CPU KV cache offloading. + +This module implements flash attention with LSE (log-sum-exp) output, +enabling proper online softmax merging for chunked prefill. + +Key functions: +- flash_attn_with_lse: Flash attention that returns output and LSE +- merge_attention_outputs: Merge outputs from multiple KV chunks +- chunked_prefill_attention: High-level interface for chunked attention +""" + +import math +import torch +import triton +import triton.language as tl +from typing import Tuple, List, Optional + + +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel_with_lse( + Q, + K, + V, + Out, + Lse, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Flash attention forward kernel with LSE output. + + Implements standard Flash Attention online softmax algorithm: + - m_i: running max of attention scores + - l_i: running sum of exp(scores - m_i) + - acc_o: running sum of softmax(scores) @ V (unnormalized) + + Final output: acc_o / l_i + Final LSE: m_i + log(l_i) + """ + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # Pointers + q_ptrs = ( + Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + k_ptrs = ( + K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + ) + v_ptrs = ( + V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + ) + + # Initialize running statistics + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized) + + # Load Q (once per block) + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 + ) + + # Loop over K, V blocks + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + # Load K + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute QK^T * scale + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= softmax_scale + + # Apply masks + if not EVEN_N: + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + + # Online softmax: compute block max + m_ij = tl.max(qk, 1) # [BLOCK_M] + + # New running max + m_new = tl.maximum(m_i, m_ij) # [BLOCK_M] + + # Rescale factor for previous accumulator + alpha = tl.exp(m_i - m_new) # [BLOCK_M] + + # Compute P = exp(qk - m_new) + p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N] + + # Sum of current block + l_ij = tl.sum(p, 1) # [BLOCK_M] + + # Update running sum: l_new = l_i * alpha + l_ij + l_new = l_i * alpha + l_ij + + # Rescale previous output and add new contribution + acc_o = acc_o * alpha[:, None] + + # Load V + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # acc_o += P @ V + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # Update running statistics + m_i = m_new + l_i = l_new + + # Final normalization: output = acc_o / l_i + acc_o = acc_o / l_i[:, None] + + # Compute LSE = m_i + log(l_i) + lse_i = m_i + tl.log(l_i) + + # Store LSE + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + if EVEN_M: + tl.store(lse_ptrs, lse_i) + else: + tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q) + + # Store output + out_ptrs = ( + Out + + off_b * stride_ob + + off_h * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + + +def flash_attn_with_lse( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flash attention forward pass that returns both output and LSE. + + Uses flash_attn library which natively supports GQA without memory overhead. + + Args: + q: Query tensor [batch, seqlen_q, nheads_q, headdim] + k: Key tensor [batch, seqlen_k, nheads_kv, headdim] + v: Value tensor [batch, seqlen_k, nheads_kv, headdim] + softmax_scale: Scaling factor (default: 1/sqrt(headdim)) + causal: Whether to apply causal masking + + Returns: + out: Output tensor [batch, seqlen_q, nheads_q, headdim] + lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] + """ + from flash_attn.flash_attn_interface import flash_attn_func + + batch, seqlen_q, nheads_q, headdim = q.shape + _, seqlen_k, nheads_kv, _ = k.shape + + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(headdim) + + # Use flash_attn_func which natively supports GQA (no memory overhead) + # It returns (output, softmax_lse) when return_attn_probs=True is not set + # We need to use the internal function to get LSE + out, lse, _ = flash_attn_func( + q, k, v, + softmax_scale=softmax_scale, + causal=causal, + return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask) + ) + + # lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded] + # Trim to actual seqlen_q + lse = lse[:, :, :seqlen_q] + + return out, lse + + +@triton.jit +def _merge_lse_kernel( + lse1_ptr, lse2_ptr, lse_out_ptr, + num_elements: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused kernel for merging LSE values. + + IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss. + bf16 has only 7 bits of mantissa, causing significant errors in exp/log. + """ + # Each program handles BLOCK_SIZE elements + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_elements + + # Load lse values and convert to fp32 for precision + lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32) + lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32) + + # Compute max for numerical stability (in fp32) + max_lse = tl.maximum(lse1, lse2) + + # Compute exp(lse - max_lse) in fp32 + exp1 = tl.exp(lse1 - max_lse) + exp2 = tl.exp(lse2 - max_lse) + + # Compute merged LSE: max_lse + log(exp1 + exp2) in fp32 + lse_merged = max_lse + tl.log(exp1 + exp2) + + # Store result (convert back to original dtype) + tl.store(lse_out_ptr + offsets, lse_merged, mask=mask) + + +@triton.jit +def _merge_output_kernel( + o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr, + batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused kernel for merging attention outputs. + + IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss. + This is critical for numerical accuracy in chunked attention. + """ + # Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position + pid_batch = tl.program_id(0) + pid_seq = tl.program_id(1) + pid_head = tl.program_id(2) + + # Compute LSE index: [batch, nheads, seqlen_q] + lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq + + # Load LSE values and convert to fp32 for precision + lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32) + lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32) + + # Compute max and scaling factors in fp32 + max_lse = tl.maximum(lse1, lse2) + exp1 = tl.exp(lse1 - max_lse) + exp2 = tl.exp(lse2 - max_lse) + sum_exp = exp1 + exp2 + + # Process headdim in chunks + for d_offset in range(0, headdim, BLOCK_SIZE): + d_idx = d_offset + tl.arange(0, BLOCK_SIZE) + mask = d_idx < headdim + + # Compute output index: [batch, seqlen_q, nheads, headdim] + base_idx = (pid_batch * seqlen_q * nheads * headdim + + pid_seq * nheads * headdim + + pid_head * headdim) + o_idx = base_idx + d_idx + + # Load o1, o2 and convert to fp32 for weighted sum + o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32) + o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32) + + # Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp + o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp + + # Store result (Triton will convert back to original dtype) + tl.store(o_out_ptr + o_idx, o_merged, mask=mask) + + +def merge_attention_outputs( + o1: torch.Tensor, + lse1: torch.Tensor, + o2: torch.Tensor, + lse2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Merge two attention outputs using online softmax (Triton fused kernel). + + This implements the online softmax merging formula: + - m_new = max(lse1, lse2) + - o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new)) + - lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new)) + + Args: + o1: First output [batch, seqlen_q, nheads, headdim] + lse1: First LSE [batch, nheads, seqlen_q] + o2: Second output [batch, seqlen_q, nheads, headdim] + lse2: Second LSE [batch, nheads, seqlen_q] + + Returns: + o_merged: Merged output [batch, seqlen_q, nheads, headdim] + lse_merged: Merged LSE [batch, nheads, seqlen_q] + """ + batch, seqlen_q, nheads, headdim = o1.shape + + # Allocate output tensors + o_merged = torch.empty_like(o1) + lse_merged = torch.empty_like(lse1) + + # Launch LSE merge kernel + num_lse_elements = batch * nheads * seqlen_q + BLOCK_SIZE_LSE = 256 + grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),) + _merge_lse_kernel[grid_lse]( + lse1, lse2, lse_merged, + num_lse_elements, + BLOCK_SIZE=BLOCK_SIZE_LSE, + ) + + # Launch output merge kernel + BLOCK_SIZE = 128 + grid_output = (batch, seqlen_q, nheads) + _merge_output_kernel[grid_output]( + o1, o2, lse1, lse2, o_merged, + batch, seqlen_q, nheads, headdim, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return o_merged, lse_merged + + +def chunked_attention_varlen( + q: torch.Tensor, + kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k_list: List[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k_list: List[int], + softmax_scale: Optional[float] = None, + causal_mask_per_chunk: Optional[List[bool]] = None, +) -> torch.Tensor: + """ + Compute attention with KV split across multiple chunks. + + This is the core function for chunked prefill. It computes attention + against each KV chunk and merges results using online softmax. + + For causal attention with chunked KV: + - First chunk (current tokens): Apply causal mask + - Previous chunks: No causal mask (all previous tokens are valid context) + + Args: + q: Query tensor [total_q_tokens, nheads, headdim] + kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim] + cu_seqlens_q: Cumulative sequence lengths for Q [batch+1] + cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk + max_seqlen_q: Maximum query sequence length + max_seqlen_k_list: List of maximum key sequence lengths for each chunk + softmax_scale: Scaling factor + causal_mask_per_chunk: Whether to apply causal mask for each chunk + + Returns: + out: Output tensor [total_q_tokens, nheads, headdim] + """ + if len(kv_chunks) == 0: + raise ValueError("Need at least one KV chunk") + + nheads = q.shape[1] + headdim = q.shape[2] + batch = cu_seqlens_q.shape[0] - 1 + + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(headdim) + + if causal_mask_per_chunk is None: + # Default: causal for last chunk only + causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True] + + # Initialize accumulated output and LSE + accumulated_o = None + accumulated_lse = None + + for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks): + is_causal = causal_mask_per_chunk[chunk_idx] + + # Reshape Q for batch processing + # For varlen, we need to handle each sequence separately + # For simplicity, assume single sequence (batch=1) for now + q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim] + + # Compute attention for this chunk + chunk_o, chunk_lse = flash_attn_with_lse( + q_batched, + k_chunk, + v_chunk, + softmax_scale=softmax_scale, + causal=is_causal, + ) + + # Merge with accumulated + if accumulated_o is None: + accumulated_o = chunk_o + accumulated_lse = chunk_lse + else: + accumulated_o, accumulated_lse = merge_attention_outputs( + accumulated_o, accumulated_lse, + chunk_o, chunk_lse, + ) + + # Remove batch dimension + return accumulated_o.squeeze(0) + + +class ChunkedPrefillState: + """ + State for tracking chunked prefill progress. + + This class maintains the accumulated attention output and LSE + across multiple prefill chunks. + """ + + def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device): + self.num_layers = num_layers + self.dtype = dtype + self.device = device + + # Per-layer accumulated outputs + # Each entry: (accumulated_output, accumulated_lse) or None + self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [ + None for _ in range(num_layers) + ] + + # Track which chunks have been processed + self.processed_chunks: int = 0 + + def update_layer( + self, + layer_id: int, + chunk_output: torch.Tensor, + chunk_lse: torch.Tensor, + ): + """Update accumulated state for a layer with a new chunk's output.""" + if self.layer_states[layer_id] is None: + self.layer_states[layer_id] = (chunk_output, chunk_lse) + else: + acc_o, acc_lse = self.layer_states[layer_id] + merged_o, merged_lse = merge_attention_outputs( + acc_o, acc_lse, + chunk_output, chunk_lse, + ) + self.layer_states[layer_id] = (merged_o, merged_lse) + + def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]: + """Get the final accumulated output for a layer.""" + if self.layer_states[layer_id] is None: + return None + return self.layer_states[layer_id][0] + + def clear(self): + """Clear all accumulated state.""" + self.layer_states = [None for _ in range(self.num_layers)] + self.processed_chunks = 0 + + +# Test function +def _test_chunked_attention(): + """Test chunked attention using flash_attn_with_lse and merge_attention_outputs.""" + from flash_attn.flash_attn_interface import flash_attn_func + + torch.manual_seed(42) + + print("=" * 70) + print("Test: Chunked attention vs flash_attn_func (non-causal)") + print("=" * 70) + print("Splitting K,V into chunks, computing attention per chunk, then merging") + print() + + for dtype in [torch.float16, torch.bfloat16]: + for num_chunks in [64, 128, 256]: + for batch, seqlen, nheads, headdim in [ + (1, 1024, 32, 128), + (1, 2048, 32, 128), + (1, 4096, 32, 128), + (1, 8192, 32, 128), + ]: + # Generate random Q, K, V + q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) + + # Reference: full attention (non-causal) + out_ref = flash_attn_func(q, k, v, causal=False) + + # Chunked attention: split K, V into chunks + chunk_size = seqlen // num_chunks + accumulated_o = None + accumulated_lse = None + + for i in range(num_chunks): + start = i * chunk_size + end = (i + 1) * chunk_size + + k_chunk = k[:, start:end, :, :] + v_chunk = v[:, start:end, :, :] + + # Q attends to this K,V chunk (non-causal) + chunk_o, chunk_lse = flash_attn_with_lse( + q, k_chunk, v_chunk, causal=False + ) + + if accumulated_o is None: + accumulated_o = chunk_o + accumulated_lse = chunk_lse + else: + # Merge with previous chunks + accumulated_o, accumulated_lse = merge_attention_outputs( + accumulated_o, accumulated_lse, + chunk_o, chunk_lse + ) + + # Compare + out_diff = (out_ref - accumulated_o).abs() + out_max_diff = out_diff.max().item() + out_mean_diff = out_diff.mean().item() + + status = "PASS" if out_max_diff < 1e-2 else "FAIL" + print( + f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} " + f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) " + f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}" + ) + + print() + print("=" * 70) + print("Test completed!") + + +if __name__ == "__main__": + _test_chunked_attention() diff --git a/nanovllm/ops/xattn.py b/nanovllm/ops/xattn.py new file mode 100644 index 0000000..7c34e93 --- /dev/null +++ b/nanovllm/ops/xattn.py @@ -0,0 +1,1167 @@ +""" +XAttention block importance estimation with Triton kernels. + +Ported from COMPASS project (compass/src/Xattention.py, kernels.py, utils.py). + +This module implements the ESTIMATE phase of XAttention, which identifies +important blocks using stride-interleaved Q/K reshaping and Triton kernels. + +Architecture: + XAttention = Estimate (Triton) + Compute (BSA) + This module: Estimate only + BSA library: block_sparse_attn (external dependency for compute) + +Key functions: +- xattn_estimate: Estimate block importance and generate sparse mask +- flat_group_gemm_fuse_reshape: Fused stride reshape + GEMM kernel +- softmax_fuse_block_sum: Online softmax + block-wise sum kernel +- find_blocks_chunked: Block selection based on cumulative threshold +""" + +import math +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import Tuple, Optional + + +# ============================================================ +# Triton Kernels +# ============================================================ + +@triton.jit +def softmax_fuse_block_sum_kernel_causal( + In, + Out, + scale, + input_stride_0, + input_stride_1, + input_stride_2, + output_stride_0, + output_stride_1, + output_stride_2, + real_q_len, + k_len, # we assume k_len is divisible by segment_size + chunk_start, + chunk_end, + segment_size: tl.constexpr, + block_size: tl.constexpr, +): + """ + Fused softmax + block sum kernel with causal masking. + + This kernel performs online softmax on attention weights and sums + within each block, producing block-level attention scores. + + Algorithm: + 1. Two-pass online softmax (compute max, then normalize) + 2. Apply causal mask (future positions get -inf) + 3. Reshape to blocks and sum within each block + + Args (via grid): + block_id: Current Q block index + head_id: Attention head index + batch_id: Batch index + + Input shape: [batch, heads, q_len, k_len] + Output shape: [batch, heads, q_blocks, k_blocks] + """ + block_id = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size + offs_k = tl.arange(0, segment_size) + + num_iters = k_len // segment_size + num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size + + # Online softmax state + m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") # running max + l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 # running sum + + # Input pointer setup + input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 + input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 + + # Output pointer setup + output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 + output_ptr = output_ptr + tl.arange(0, segment_size // block_size) + + # Pass 1: Compute global max and sum (before causal boundary) + for iter in range(0, num_iters_before_causal): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + # Pass 1 continued: Handle causal boundary + for iter in range(num_iters_before_causal, num_iters_before_causal + 1): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) + X = tl.where(mask, X, -1.0e6) + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + l_i_inv = 1.0 / l_i + + sum_mask = offs_q[:, None] < real_q_len + + # Pass 2: Normalize and compute block sums (before causal boundary) + for iter in range(0, num_iters_before_causal): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + # Pass 2 continued: Handle causal boundary + for iter in range(num_iters_before_causal, num_iters_before_causal + 1): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) + X = tl.where(mask, X, -1.0e6) + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + # Pass 2 continued: Zero out future blocks + for iter in range(num_iters_before_causal + 1, num_iters): + X = tl.zeros([segment_size // block_size], dtype=tl.float32) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + +@triton.jit +def softmax_fuse_block_sum_kernel_non_causal( + In, + Out, + scale, + input_stride_0, + input_stride_1, + input_stride_2, + output_stride_0, + output_stride_1, + output_stride_2, + real_q_len, + k_len, # we assume k_len is divisible by segment_size + chunk_start, + chunk_end, + segment_size: tl.constexpr, + block_size: tl.constexpr, +): + """ + Fused softmax + block sum kernel without causal masking. + + Same as causal version but without causal mask application. + """ + block_id = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size + offs_k = tl.arange(0, segment_size) + + num_iters = k_len // segment_size + + m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") + l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 + + input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 + input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 + + output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 + output_ptr = output_ptr + tl.arange(0, segment_size // block_size) + + # Pass 1: Compute global max and sum + for iter in range(0, num_iters): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + l_i_inv = 1.0 / l_i + + sum_mask = offs_q[:, None] < real_q_len + + # Pass 2: Normalize and compute block sums + for iter in range(0, num_iters): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + +@triton.jit +def flat_group_gemm_fuse_reshape_kernel( + Q, K, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_oz, stride_oh, stride_on, + chunk_start, chunk_end, + H: tl.constexpr, + STRIDE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + is_causal: tl.constexpr, +): + """ + Fused stride reshape + GEMM kernel. + + This kernel computes Q_reshaped @ K_reshaped^T without explicitly + creating the reshaped tensors, saving memory and bandwidth. + + Stride reshape (inverse mode): + - K: concat([K[:,:,k::stride,:] for k in range(stride)]) + - Q: concat([Q[:,:,(stride-1-q)::stride,:] for q in range(stride)]) + + The kernel simulates this by adjusting pointer arithmetic: + - Q samples backwards: Q_ptrs starts at (stride-1), steps by -1 + - K samples forwards: K_ptrs starts at 0, steps by +1 + - Both accumulate across stride iterations + + Args (via grid): + block_m: Q block index (in reshaped space) + block_n: K block index (in reshaped space) + batch_id * H + head_id: Combined batch and head index + + Input shapes: + Q: [batch, heads, q_len, head_dim] + K: [batch, heads, k_len, head_dim] + + Output shape: [batch, heads, q_len/stride, k_len/stride] + """ + block_m = tl.program_id(0).to(tl.int64) + block_n = tl.program_id(1).to(tl.int64) + batch_id = tl.program_id(2).to(tl.int64) // H + head_id = tl.program_id(2).to(tl.int64) % H + + # Early exit for causal: skip blocks where K is entirely in the future + if is_causal: + if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: + return + + # Q pointer: sample from (stride-1) position, step backwards + Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn + Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1) + + # K pointer: sample from 0 position, step forwards + K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn + K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None] + + o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + # Accumulate Q @ K^T across stride positions + for iter in range(STRIDE): + q = tl.load(Q_ptrs - iter * stride_qn) # Q steps backwards + k = tl.load(K_ptrs + iter * stride_kn) # K steps forwards + o += tl.dot(q, k) + + # Store output + O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N + O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] + + tl.store(O_ptrs, o.to(Out.type.element_ty)) + + +# ============================================================ +# Triton Kernel Wrappers +# ============================================================ + +def softmax_fuse_block_sum( + attn_weights_slice: torch.Tensor, + reshaped_block_size: int, + segment_size: int, + chunk_start: int, + chunk_end: int, + real_q_len: int, + scale: float, + is_causal: bool = True, +) -> torch.Tensor: + """ + Compute softmax and block-wise sum of attention weights. + + This function takes raw QK^T scores (after stride reshape), + applies softmax, and sums within each block to produce + block-level attention scores. + + Args: + attn_weights_slice: Raw attention scores [batch, heads, q_len, k_len] + reshaped_block_size: Block size in reshaped space (block_size / stride) + segment_size: Processing segment size + chunk_start: Start position for this chunk + chunk_end: End position for this chunk + real_q_len: Actual Q length (before padding) + scale: Softmax scale factor (includes 1/sqrt(d) and stride normalization) + is_causal: Whether to apply causal masking + + Returns: + Block-level attention sums [batch, heads, q_blocks, k_blocks] + """ + batch_size, num_heads, q_len, k_len = attn_weights_slice.shape + + assert q_len % reshaped_block_size == 0, f"q_len {q_len} must be divisible by reshaped_block_size {reshaped_block_size}" + assert k_len % segment_size == 0, f"k_len {k_len} must be divisible by segment_size {segment_size}" + assert segment_size % reshaped_block_size == 0, f"segment_size {segment_size} must be divisible by reshaped_block_size {reshaped_block_size}" + assert attn_weights_slice.stride(-1) == 1, "Last dimension must be contiguous" + + output = torch.empty( + (batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size), + dtype=attn_weights_slice.dtype, + device=attn_weights_slice.device + ) + + grid = (q_len // reshaped_block_size, num_heads, batch_size) + + if is_causal: + softmax_fuse_block_sum_kernel_causal[grid]( + attn_weights_slice, + output, + scale, + attn_weights_slice.stride(0), + attn_weights_slice.stride(1), + attn_weights_slice.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + real_q_len, + k_len, + chunk_start, + chunk_end, + segment_size, + reshaped_block_size, + ) + else: + softmax_fuse_block_sum_kernel_non_causal[grid]( + attn_weights_slice, + output, + scale, + attn_weights_slice.stride(0), + attn_weights_slice.stride(1), + attn_weights_slice.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + real_q_len, + k_len, + chunk_start, + chunk_end, + segment_size, + reshaped_block_size, + ) + + return output + + +def flat_group_gemm_fuse_reshape( + query_states: torch.Tensor, + key_states: torch.Tensor, + stride: int, + chunk_start: int, + chunk_end: int, + is_causal: bool = True, +) -> torch.Tensor: + """ + Compute fused stride reshape + GEMM for Q @ K^T. + + This is the core estimation kernel of XAttention. It computes + attention scores between strided Q and K without explicitly + creating the reshaped tensors. + + The stride reshape (inverse mode) works as: + - K_reshaped: concat([K[:,:,k::stride,:] for k in range(stride)]) + - Q_reshaped: concat([Q[:,:,(stride-1-q)::stride,:] for q in range(stride)]) + + Result: Q_reshaped @ K_reshaped^T with shape [batch, heads, q_len/stride, k_len/stride] + + Args: + query_states: Q tensor [batch, heads, q_len, head_dim] + key_states: K tensor [batch, heads, k_len, head_dim] + stride: Stride for reshape (typically 8) + chunk_start: Start position (in reshaped space) for causal masking + chunk_end: End position (in reshaped space) for causal masking + is_causal: Whether to apply causal masking (skip future blocks) + + Returns: + Attention scores [batch, heads, q_len/stride, k_len/stride] + """ + batch_size, num_heads, q_len, head_dim = query_states.shape + kv_len = key_states.shape[2] + + assert key_states.shape[0] == batch_size + assert key_states.shape[1] == num_heads + assert key_states.shape[3] == head_dim + + output = torch.empty( + (batch_size, num_heads, q_len // stride, kv_len // stride), + dtype=query_states.dtype, + device=query_states.device + ) + + # Adjust block size based on GPU shared memory + # RTX 3090 has ~100KB, A100/H100 have ~160KB+ + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + if props.total_memory < 30 * 1024**3: # Less than 30GB (e.g., RTX 3090 24GB) + BLOCK_M = 64 + BLOCK_N = 64 + else: + BLOCK_M = 128 + BLOCK_N = 128 + + assert q_len % (stride * BLOCK_M) == 0, f"q_len {q_len} must be divisible by stride*BLOCK_M {stride * BLOCK_M}" + assert kv_len % (stride * BLOCK_N) == 0, f"kv_len {kv_len} must be divisible by stride*BLOCK_N {stride * BLOCK_N}" + + grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads) + + flat_group_gemm_fuse_reshape_kernel[grid]( + query_states, + key_states, + output, + query_states.stride(0), + query_states.stride(1), + query_states.stride(2), + key_states.stride(0), + key_states.stride(1), + key_states.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + chunk_start, + chunk_end, + num_heads, + stride, + head_dim, + BLOCK_M, + BLOCK_N, + is_causal, + ) + + return output + + +# ============================================================ +# Block Selection Utilities +# ============================================================ + +def find_blocks_chunked( + input_tensor: torch.Tensor, + current_index: int, + threshold: float, + num_to_choose: Optional[int], + decoding: bool, + mode: str = "both", + causal: bool = True, +) -> torch.Tensor: + """ + Select important blocks based on cumulative attention threshold. + + This function takes block-level attention scores and selects blocks + that cumulatively account for a specified fraction of total attention. + + Algorithm: + 1. Compute total attention per query block + 2. Sort blocks by attention score (descending) + 3. Accumulate until reaching threshold * total + 4. Mark accumulated blocks as selected + 5. Always keep diagonal blocks (for causal) and sink block + + Args: + input_tensor: Block attention scores [batch, heads, q_blocks, k_blocks] + current_index: Current chunk's starting block index + threshold: Cumulative attention threshold (e.g., 0.9 = keep 90% attention mass) + num_to_choose: Alternative to threshold - select fixed number of blocks + decoding: Whether in decode mode (vs prefill) + mode: "prefill", "decode", or "both" + causal: Whether to apply causal masking + + Returns: + Boolean mask [batch, heads, q_blocks, k_blocks] indicating selected blocks + """ + assert threshold is None or num_to_choose is None, "Only one of threshold or num_to_choose can be specified" + + batch_size, head_num, chunk_num, block_num = input_tensor.shape + + # Special case: prefill mode during decoding - return all True + if mode == "prefill" and decoding: + return torch.ones_like(input_tensor, dtype=torch.bool) + + # Special case: decode mode during prefill + if mode == "decode" and not decoding: + mask = torch.ones_like(input_tensor, dtype=torch.bool) + if causal: + mask[:, :, :, current_index : current_index + chunk_num] = torch.tril( + torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device) + ) + mask[:, :, current_index + chunk_num :, :] = 0 + return torch.cat( + [ + torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1], + torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :], + ], + dim=-1, + ) + else: + return mask + + # Convert to float for numerical operations + input_tensor = input_tensor.to(torch.float32) + + if threshold is not None: + # Compute required cumulative sum + total_sum = input_tensor.sum(dim=-1, keepdim=True) + if isinstance(threshold, torch.Tensor): + threshold = threshold.to(torch.float32) + required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand( + (batch_size, head_num, chunk_num, 1) + ).to(input_tensor.device) + else: + required_sum = total_sum * threshold + + if causal: + # Initialize mask with mandatory blocks + mask = torch.zeros_like(input_tensor, dtype=torch.bool) + mask[:, :, :, 0] = True # Sink block always selected + + # Diagonal blocks (current chunk's causal positions) + mask[:, :, :, current_index : current_index + chunk_num] = ( + torch.eye(chunk_num, device=mask.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(1, head_num, chunk_num, chunk_num) + ) + + # Mask out mandatory blocks for sorting + other_values = input_tensor.masked_fill(mask, 0) + sorted_values, _ = torch.sort(other_values, dim=-1, descending=True) + sorted_values = sorted_values.to(input_tensor.device) + + # Prepend mandatory blocks' contribution + sorted_values = torch.cat( + [ + torch.zeros((batch_size, head_num, chunk_num, 1), device=input_tensor.device), + torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True), + sorted_values[:, :, :, :-2], + ], + dim=-1, + ) + + # Get sorted indices (mandatory blocks get high priority) + _, index = torch.sort( + torch.where(mask, 100000 * (1 + input_tensor), input_tensor), + dim=-1, + descending=True, + ) + + # Compute cumulative sum (excluding current block) + cumulative_sum_without_self = torch.cat( + [ + torch.zeros((batch_size, head_num, chunk_num, 1), device=input_tensor.device), + sorted_values[:, :, :, 0:-1], + ], + dim=-1, + ).cumsum(dim=-1) + + # Select blocks until threshold is reached + index_mask = cumulative_sum_without_self < required_sum + index = torch.where(index_mask, index, 0) + + # Flatten for scatter operation + mask = mask.view(batch_size, head_num * chunk_num, block_num) + index = index.view(batch_size, head_num * chunk_num, block_num) + + # Mark selected blocks + mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True + mask = mask.view(batch_size, head_num, chunk_num, block_num) + else: + # Non-causal: simple threshold-based selection + mask = torch.zeros_like(input_tensor, dtype=torch.bool) + sorted_values, index = torch.sort(input_tensor, dim=-1, descending=True) + sorted_values = sorted_values.to(input_tensor.device) + + cumulative_sum_without_self = torch.cat( + [ + torch.zeros((batch_size, head_num, chunk_num, 1), device=input_tensor.device), + sorted_values[:, :, :, 0:-1], + ], + dim=-1, + ).cumsum(dim=-1) + + index_mask = cumulative_sum_without_self < required_sum + index = torch.where(index_mask, index, 0) + + mask = mask.view(batch_size, head_num * chunk_num, block_num) + index = index.view(batch_size, head_num * chunk_num, block_num) + mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True + mask = mask.view(batch_size, head_num, chunk_num, block_num) + else: + raise NotImplementedError("Block num selection (num_to_choose) not implemented") + + # Enforce causal: zero out future blocks + try: + if causal: + assert (~mask[:, :, :, current_index + chunk_num :]).all() + except: + mask[:, :, :, current_index + chunk_num :] = False + + # Validation + if causal: + if decoding: + assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all() + else: + lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device) + lambda_mask[:, :, :, 0] = True + lambda_mask[:, :, :, current_index : current_index + chunk_num] = ( + torch.eye(chunk_num, device=lambda_mask.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(1, head_num, chunk_num, chunk_num) + ) + assert torch.where(lambda_mask, mask, True).all() + + return mask + + +def create_causal_mask( + batch_size: int, + head_num: int, + block_size: int, + block_num: int, + divide_block_num: int, +) -> torch.Tensor: + """ + Create a causal attention mask for block-level attention. + + Args: + batch_size: Batch size + head_num: Number of attention heads + block_size: Tokens per block + block_num: Total number of blocks + divide_block_num: Block index at which causality boundary is applied + + Returns: + Causal mask [batch, heads, block_size, block_size * block_num] + """ + divide_block_num += 1 + if divide_block_num < 1 or divide_block_num > block_num: + raise ValueError( + f"divide_block_num ({divide_block_num}) must be between 1 and block_num ({block_num})." + ) + + total_size = block_size * block_num + device = "cuda" + mask = torch.zeros(block_size, total_size, device=device) + + # Mask future blocks + if divide_block_num < block_num: + mask[:, divide_block_num * block_size :] = float("-inf") + + # Apply triangular mask at causality boundary + if divide_block_num - 1 < block_num: + start_col = (divide_block_num - 1) * block_size + end_col = start_col + block_size + upper_tri_mask = torch.triu( + torch.full((block_size, block_size), float("-inf"), device=device), + diagonal=1, + ) + mask[:, start_col:end_col] = upper_tri_mask + + mask = mask.unsqueeze(0).unsqueeze(0) + mask = mask.expand(batch_size, head_num, block_size, total_size) + return mask + + +# ============================================================ +# Main Estimation Function +# ============================================================ + +def xattn_estimate( + query_states: torch.Tensor, + key_states: torch.Tensor, + block_size: int = 128, + stride: int = 8, + norm: float = 1.0, + threshold: float = 0.9, + chunk_size: int = 16384, + use_triton: bool = True, + causal: bool = True, + keep_sink: bool = False, + keep_recent: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Estimate block importance for XAttention sparse selection. + + This function implements the estimation phase of XAttention: + 1. Stride-interleaved reshape of Q and K (inverse mode) + 2. Compute block-level attention scores via Triton kernels + 3. Select important blocks based on cumulative threshold + + The result is a boolean mask indicating which K blocks each Q block + should attend to. This mask can be used with BSA (block_sparse_attn) + for efficient sparse attention computation. + + Args: + query_states: Q tensor [batch, heads, q_len, head_dim] + key_states: K tensor [batch, heads, k_len, head_dim] + block_size: Block size in tokens (must be 128 for BSA compatibility) + stride: Stride for Q/K reshape (typically 8) + norm: Normalization factor for attention scores + threshold: Cumulative attention threshold (0.0-1.0) + chunk_size: Processing chunk size for memory efficiency + use_triton: Whether to use Triton kernels (requires SM 80+) + causal: Whether to apply causal masking + keep_sink: Always keep first block (sink tokens) + keep_recent: Always keep diagonal blocks (recent context) + + Returns: + attn_sums: Block-level attention scores [batch, heads, q_blocks, k_blocks] + simple_masks: Boolean mask for sparse attention [batch, heads, q_blocks, k_blocks] + + Example: + >>> q = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.bfloat16) + >>> k = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.bfloat16) + >>> attn_sums, mask = xattn_estimate(q, k, block_size=128, stride=8, threshold=0.9) + >>> # mask can be used with block_sparse_attn_func for sparse computation + """ + batch_size, num_kv_head, k_len, head_dim = key_states.shape + batch_size, num_q_head, q_len, head_dim = query_states.shape + assert num_q_head == num_kv_head, "GQA not supported in estimation (heads must match)" + + # Compute padding to align with chunk_size + k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len + q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len + k_chunk_num = (k_len + k_num_to_pad) // chunk_size + k_block_num = (k_len + k_num_to_pad) // block_size + q_chunk_num = (q_len + q_num_to_pad) // chunk_size + q_block_num = (q_len + q_num_to_pad) // block_size + assert k_chunk_num >= q_chunk_num + + # Pad K and Q if needed + if k_num_to_pad > 0: + pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0).to("cuda") + else: + pad_key_states = key_states + if q_num_to_pad > 0: + pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0).to("cuda") + else: + pad_query_states = query_states + + # Check GPU capability for Triton + if use_triton: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + if props.major < 8: + use_triton = False + print(f"Triton kernel requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.") + + # Compute reshaped dimensions + reshaped_chunk_size = chunk_size // stride + reshaped_block_size = block_size // stride + k_reshaped_num_to_pad = k_num_to_pad // stride + k_reshaped_seq_len = (k_len + k_num_to_pad) // stride + num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size + + # Non-Triton fallback: explicit reshape + if not use_triton: + # K reshape: concat([K[:,:,k::stride,:] for k in range(stride)]) + reshaped_key = torch.cat( + [(pad_key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1 + ) + # Q reshape (inverse): concat([Q[:,:,(stride-1-q)::stride,:] for q in range(stride)]) + reshaped_query = torch.cat( + [(pad_query_states[:, :, (stride - 1 - q)::stride, :]) for q in range(stride)], + dim=-1, + ) + + attn_sum_list = [] + simple_mask_list = [] + + # Process each Q chunk + for chunk_idx in range(q_chunk_num): + if use_triton: + # Triton path: fused reshape + GEMM + attn_weights_slice = flat_group_gemm_fuse_reshape( + pad_query_states[ + :, + :, + (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, + :, + ], + pad_key_states, + stride, + (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size, + (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size, + is_causal=causal, + ) + + # Fused softmax + block sum + # Scale factor: log2(e) / sqrt(head_dim) / stride / norm + # log2(e) ≈ 1.4426950408889634 + attn_sum = softmax_fuse_block_sum( + attn_weights_slice, + reshaped_block_size, + min(4096, reshaped_block_size), + (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size, + (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size, + k_reshaped_seq_len - k_reshaped_num_to_pad, + 1.4426950408889634 / math.sqrt(head_dim) / stride / norm, + is_causal=causal, + ) + else: + # PyTorch fallback path + chunked_query = reshaped_query[ + :, :, + chunk_idx * reshaped_chunk_size : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size), + :, + ] + + # Compute attention scores + attn_weights_slice = torch.matmul( + chunked_query, reshaped_key.transpose(2, 3) + ).to("cuda") + attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm + + # Apply causal mask + if causal: + offset_token_chunk_num = k_chunk_num - q_chunk_num + causal_mask = torch.zeros( + (batch_size, num_q_head, reshaped_chunk_size, reshaped_chunk_size * k_chunk_num), + device=key_states.device, + ) + causal_mask[:, :, :, (-k_reshaped_num_to_pad):] = float("-inf") + chunk_start = (chunk_idx + offset_token_chunk_num) * reshaped_chunk_size + chunk_end = chunk_start + reshaped_chunk_size + causal_mask[:, :, :, chunk_start:chunk_end] = torch.triu( + torch.ones(1, num_q_head, reshaped_chunk_size, reshaped_chunk_size, device=key_states.device) * float("-inf"), + diagonal=1, + ) + if chunk_idx == q_chunk_num - 1 and q_num_to_pad // stride != 0: + causal_mask[:, :, (-(q_num_to_pad // stride)):, :] = float("-inf") + causal_mask[:, :, :, chunk_end:] = float("-inf") + attn_weights_slice = attn_weights_slice + causal_mask + + # Softmax + attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32).to(pad_query_states.dtype) + + if chunk_idx == q_chunk_num - 1 and q_num_to_pad // stride != 0: + attn_weights_slice[:, :, (-(q_num_to_pad // stride)):, :] = 0 + + # Block sum + attn_sum = ( + attn_weights_slice.view( + batch_size, num_kv_head, num_blocks_per_chunk, reshaped_block_size, -1, reshaped_block_size + ) + .sum(dim=-1) + .sum(dim=-2) + .to("cuda") + ) + + # Select blocks based on threshold + simple_mask = find_blocks_chunked( + attn_sum, + k_block_num - q_block_num + chunk_idx * num_blocks_per_chunk, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ) + + attn_sum_list.append(attn_sum) + simple_mask_list.append(simple_mask) + + del attn_weights_slice + + if not use_triton: + del reshaped_query, reshaped_key + + # Concatenate results from all chunks + attn_sums = torch.cat(attn_sum_list, dim=-2) + simple_masks = torch.cat(simple_mask_list, dim=-2) + + # Apply causal mask to final output + if causal: + simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( + torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0), + simple_masks[:, :, -q_block_num:, -q_block_num:], + False, + ) + + # Always keep sink block + if keep_sink: + simple_masks[:, :, :, 0] = True + + # Always keep diagonal (recent) blocks + if keep_recent: + eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool) + eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_kv_head, q_block_num, q_block_num) + simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( + eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:] + ) + + return attn_sums, simple_masks + + +def compute_sparsity(mask: torch.Tensor, causal: bool = True) -> float: + """ + Compute the sparsity ratio of a block mask. + + Args: + mask: Boolean mask [batch, heads, q_blocks, k_blocks] + causal: Whether mask is causal (only lower triangle counts) + + Returns: + Sparsity ratio (0.0 = dense, 1.0 = fully sparse) + """ + batch, heads, q_blocks, k_blocks = mask.shape + + if causal: + # Only count lower triangle + causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool)) + total_blocks = causal_mask.sum().item() * batch * heads + selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() + else: + total_blocks = mask.numel() + selected_blocks = mask.sum().item() + + return 1.0 - (selected_blocks / total_blocks) + + +# ============================================================ +# Chunked Estimation Function (for Chunked Prefill) +# ============================================================ + +def xattn_estimate_chunked( + query_states: torch.Tensor, + key_states: torch.Tensor, + q_start_pos: int, + block_size: int = 128, + stride: int = 8, + norm: float = 1.0, + threshold: float = 0.9, + chunk_size: int = 16384, + use_triton: bool = True, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Estimate block importance for XAttention in chunked prefill mode. + + This function is designed for chunked prefill scenarios where: + - Q is processed in chunks while K accumulates across chunks + - q_start_pos indicates the position of the current Q chunk in the full sequence + - K length can be >= Q length (accumulated KV cache) + + Ported from COMPASS project (compass/src/Xattn_chunked.py). + + Args: + query_states: Q tensor [batch, heads, q_chunk_len, head_dim] - current Q chunk + key_states: K tensor [batch, heads, k_len, head_dim] - accumulated K (k_len >= q_chunk_len) + q_start_pos: Start position of this Q chunk in the full sequence + block_size: Block size in tokens (typically 128 for BSA compatibility) + stride: Stride for Q/K reshape (typically 8) + norm: Normalization factor for attention scores + threshold: Cumulative attention threshold (0.0-1.0) + chunk_size: Processing chunk size for Triton kernel alignment + use_triton: Whether to use Triton kernels (requires SM 80+) + causal: Whether to apply causal masking + + Returns: + attn_sums: Block-level attention scores [batch, heads, q_blocks, k_blocks] + simple_masks: Boolean mask for sparse attention [batch, heads, q_blocks, k_blocks] + + Example: + >>> # Chunk 0: Q[0:C] attends to K[0:C] + >>> attn_sums, mask = xattn_estimate_chunked(q_chunk0, k_chunk0, q_start_pos=0) + >>> + >>> # Chunk 1: Q[C:2C] attends to K[0:2C] + >>> attn_sums, mask = xattn_estimate_chunked(q_chunk1, k_accum, q_start_pos=C) + """ + batch_size, num_heads, q_len, head_dim = query_states.shape + _, _, k_len, _ = key_states.shape + + # Store original lengths for valid region tracking + original_q_len = q_len + original_k_len = k_len + + # Validate inputs + assert k_len >= q_len, f"K length ({k_len}) must be >= Q length ({q_len})" + assert q_start_pos + q_len <= k_len, f"Q end position ({q_start_pos + q_len}) exceeds K length ({k_len})" + + # Calculate block counts + q_block_num = (q_len + block_size - 1) // block_size + k_block_num = (k_len + block_size - 1) // block_size + q_start_block = q_start_pos // block_size + + # Check GPU capability for Triton + if use_triton: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + if props.major < 8: + use_triton = False + + # Pad Q and K for alignment + if use_triton: + # For Triton: pad to chunk_size alignment + padded_q_len = ((q_len + chunk_size - 1) // chunk_size) * chunk_size + padded_k_len = ((k_len + chunk_size - 1) // chunk_size) * chunk_size + else: + # For PyTorch fallback: pad to block_size alignment + padded_q_len = q_block_num * block_size + padded_k_len = k_block_num * block_size + + q_pad = padded_q_len - q_len + k_pad = padded_k_len - k_len + + if q_pad > 0: + query_states = F.pad(query_states, (0, 0, 0, q_pad), value=0) + if k_pad > 0: + key_states = F.pad(key_states, (0, 0, 0, k_pad), value=0) + + # Reshape dimensions + reshaped_block_size = block_size // stride + reshaped_q_len = padded_q_len // stride + reshaped_k_len = padded_k_len // stride + + # Calculate valid lengths in reshaped space (for masking padding) + valid_q_reshaped = (original_q_len + stride - 1) // stride + valid_k_reshaped = (original_k_len + stride - 1) // stride + + if use_triton: + # Compute chunk boundaries in reshaped space + chunk_start = q_start_block * reshaped_block_size + chunk_end = chunk_start + reshaped_q_len # Padded end for computation + real_q_len = chunk_start + valid_q_reshaped # Valid end for masking padding + + # Use Triton kernel for efficient computation + attn_weights = flat_group_gemm_fuse_reshape( + query_states, + key_states, + stride, + chunk_start, # q_start in reshaped space + chunk_end, # q_end in reshaped space (padded) + is_causal=causal, + ) + + # Softmax + block sum + attn_sum = softmax_fuse_block_sum( + attn_weights, + reshaped_block_size, + min(4096, reshaped_block_size), + chunk_start, + chunk_end, + real_q_len, + 1.4426950408889634 / math.sqrt(head_dim) / stride / norm, + is_causal=causal, + ) + + # Extract only the valid block region + attn_sum = attn_sum[:, :, :q_block_num, :k_block_num] + else: + # PyTorch fallback implementation + # Reshape K: interleave positions and concatenate head dims + reshaped_key = torch.cat( + [(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1 + ) # (B, H, k_len/stride, D*stride) + + # Reshape Q (inverse mode) + reshaped_query = torch.cat( + [(query_states[:, :, (stride - 1 - q)::stride, :]) for q in range(stride)], + dim=-1, + ) + + # Compute attention weights: (B, H, q_len/stride, k_len/stride) + attn_weights = torch.matmul( + reshaped_query, reshaped_key.transpose(2, 3) + ) / math.sqrt(head_dim) / stride / norm + + # Apply causal mask + if causal: + reshaped_q_positions = reshaped_q_len + causal_mask = torch.zeros( + (batch_size, num_heads, reshaped_q_positions, reshaped_k_len), + device=key_states.device, + dtype=attn_weights.dtype, + ) + + # Mask out padding in K + if k_pad > 0: + causal_mask[:, :, :, -(k_pad // stride):] = float("-inf") + + # Mask out future positions + q_start_reshaped = q_start_pos // stride + for q_idx in range(reshaped_q_positions): + q_pos_reshaped = q_start_reshaped + q_idx + if q_pos_reshaped + 1 < reshaped_k_len: + causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf") + + # Handle padding in Q + if q_pad > 0: + q_pad_reshaped = q_pad // stride + if q_pad_reshaped > 0: + causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf") + + attn_weights = attn_weights + causal_mask + + # Apply softmax + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + + # Zero out padded Q positions + if q_pad > 0: + q_pad_reshaped = q_pad // stride + if q_pad_reshaped > 0: + attn_weights[:, :, -q_pad_reshaped:, :] = 0 + + # Aggregate to block level + attn_sum = attn_weights.view( + batch_size, + num_heads, + q_block_num, + reshaped_block_size, + k_block_num, + reshaped_block_size, + ).sum(dim=-1).sum(dim=-2) + + # Find blocks that exceed threshold + simple_mask = find_blocks_chunked( + attn_sum, + q_start_block, # offset for causal mask in find_blocks_chunked + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ) + + # Apply causal constraint on block level + if causal: + # For block-level causal: Q block i can only attend to K blocks j where j <= q_start_block + i + for q_blk_idx in range(q_block_num): + q_blk_global = q_start_block + q_blk_idx + if q_blk_global + 1 < k_block_num: + simple_mask[:, :, q_blk_idx, q_blk_global + 1:] = False + + return attn_sum, simple_mask diff --git a/tests/test_xattn_estimate_chunked.py b/tests/test_xattn_estimate_chunked.py new file mode 100644 index 0000000..76cb664 --- /dev/null +++ b/tests/test_xattn_estimate_chunked.py @@ -0,0 +1,244 @@ +""" +Test: Compare xattn_estimate vs xattn_estimate_chunked + +Verify that chunked estimation with EXTERNAL chunking produces the same mask +as standard estimation. This ensures the chunked version can be used in +chunked prefill scenarios without accuracy loss. + +Usage: + CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ + python tests/test_xattn_estimate_chunked.py +""" + +import sys +import traceback +import torch +from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked + +# ============================================================ +# Configuration +# ============================================================ + +# Configuration for xattn_estimate_chunked consistency test. +# Key requirements for 100% match: +# 1. Use matching chunk_size for both standard and chunked versions +# 2. Use same random seed for reproducibility +# Note: Tiny differences (~0.000001) may occur at boundary cases due to +# floating point precision in cumulative sum calculations. +BLOCK_SIZE = 64 +STRIDE = 4 +THRESHOLD = 0.9 +CHUNK_SIZE = 4096 # External chunking size + +# Test sequence lengths +TEST_SEQ_LENS = [4096, 8192, 16384, 32768] + +# ============================================================ +# Utility Functions +# ============================================================ + +def compare_masks(mask1, mask2, name1="standard", name2="chunked"): + """Compare two masks and report differences.""" + if mask1.shape != mask2.shape: + print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}") + return False + + diff = (mask1 != mask2).sum().item() + total = mask1.numel() + match_rate = (total - diff) / total * 100 + + print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})") + + if diff > 0: + diff_indices = torch.where(mask1 != mask2) + print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}") + + return diff == 0 + + +def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size): + """ + Run xattn_estimate_chunked with EXTERNAL chunking. + This simulates how chunked prefill should be used in practice. + """ + batch_size, num_heads, q_len, head_dim = query.shape + _, _, k_len, _ = key.shape + + q_block_num = (q_len + block_size - 1) // block_size + k_block_num = (k_len + block_size - 1) // block_size + + # If Q fits in one chunk, call directly + if q_len <= chunk_size: + return xattn_estimate_chunked( + query, key, + q_start_pos=0, + block_size=block_size, + stride=stride, + threshold=threshold, + use_triton=True, + chunk_size=chunk_size, + ) + + # External chunking: split Q and call for each chunk + num_q_chunks = (q_len + chunk_size - 1) // chunk_size + print(f" External chunking: {num_q_chunks} chunks") + + combined_attn_sum = torch.zeros( + batch_size, num_heads, q_block_num, k_block_num, + dtype=query.dtype, device=query.device + ) + combined_mask = torch.zeros( + batch_size, num_heads, q_block_num, k_block_num, + dtype=torch.bool, device=query.device + ) + + q_block_offset = 0 + for q_chunk_idx in range(num_q_chunks): + q_chunk_start = q_chunk_idx * chunk_size + q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len) + + q_chunk = query[:, :, q_chunk_start:q_chunk_end, :] + + # For causal attention, K accumulates up to current Q position + # q_start_pos=0 means Q starts at position 0 in the full sequence + # K is [0, q_chunk_end) for causal attention + k_end = q_chunk_end + k_chunk = key[:, :, :k_end, :] + + attn_sum_chunk, mask_chunk = xattn_estimate_chunked( + q_chunk, k_chunk, + q_start_pos=q_chunk_start, + block_size=block_size, + stride=stride, + threshold=threshold, + use_triton=True, + chunk_size=chunk_size, + ) + + # Place chunk results into combined output + chunk_q_blocks = mask_chunk.shape[2] + chunk_k_blocks = mask_chunk.shape[3] + combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk + combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk + q_block_offset += chunk_q_blocks + + return combined_attn_sum, combined_mask + + +def test_single_seq_len(seq_len, num_heads=32, head_dim=128): + """Test a single sequence length.""" + print(f"\nTesting seq_len={seq_len}") + print("=" * 60) + + # Generate random Q/K + query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) + key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) + + # Run standard xattn_estimate + print("[1] Running standard xattn_estimate...") + try: + attn_sum_std, mask_std = xattn_estimate( + query, key, + block_size=BLOCK_SIZE, + stride=STRIDE, + threshold=THRESHOLD, + chunk_size=CHUNK_SIZE, + use_triton=True, + causal=True, + ) + density_std = mask_std.float().mean().item() + print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}") + except Exception as e: + print(f" ERROR: {e}") + traceback.print_exc() + return False + + # Run chunked xattn_estimate with EXTERNAL chunking + print("[2] Running chunked xattn_estimate (external chunking)...") + try: + attn_sum_chunked, mask_chunked = run_chunked_externally( + query, key, + block_size=BLOCK_SIZE, + stride=STRIDE, + threshold=THRESHOLD, + chunk_size=CHUNK_SIZE, + ) + density_chunked = mask_chunked.float().mean().item() + print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}") + except Exception as e: + print(f" ERROR: {e}") + traceback.print_exc() + return False + + # Compare results + print("[3] Comparing results...") + chunked_q_blocks = mask_chunked.shape[2] + chunked_k_blocks = mask_chunked.shape[3] + + # Extract comparable region from standard mask + mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks] + + # Compare masks + masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked") + + # Compare attn_sums + attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks] + if attn_sum_std_comparable.shape == attn_sum_chunked.shape: + attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item() + print(f" Attn sum max diff: {attn_diff:.6f}") + else: + print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}") + + # Clean up GPU memory + del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked + torch.cuda.empty_cache() + + return masks_match + + +# ============================================================ +# Main Test +# ============================================================ + +if __name__ == "__main__": + print("XAttention Chunked vs Standard Test") + print("=" * 60) + print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}") + print(f"External chunk_size={CHUNK_SIZE}") + print() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("CUDA not available!") + sys.exit(1) + + print(f"Using GPU: {torch.cuda.get_device_name(0)}") + print("✓ xattn_estimate imported") + print("✓ xattn_estimate_chunked imported") + + # Run tests + all_passed = True + results = [] + + for seq_len in TEST_SEQ_LENS: + passed = test_single_seq_len(seq_len) + chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE + results.append((seq_len, chunks, passed)) + if not passed: + all_passed = False + + # Summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + for seq_len, chunks, passed in results: + status = "PASSED" if passed else "FAILED" + print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}") + + print("=" * 60) + if all_passed: + print("ALL TESTS PASSED!") + sys.exit(0) + else: + print("SOME TESTS FAILED!") + sys.exit(1)