diff --git a/CLAUDE.md b/CLAUDE.md index ef08768..b0a96b2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -61,6 +61,8 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py | [`docs/gpu_only_performance_issue.md`](docs/gpu_only_performance_issue.md) | GPU-only mode slower than offload due to PagedAttention scatter overhead, optimization proposals | | [`docs/offload_accuracy_issue.md`](docs/offload_accuracy_issue.md) | **BUG**: CPU offload mode 66% accuracy vs 100% non-offload on RULER NIAH benchmark | | [`docs/64k_memory_analysis.md`](docs/64k_memory_analysis.md) | 64k inference memory analysis: GPU-only vs offload, OOM root cause (fragmentation), RTX 3090 limitations | +| [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design | +| [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work | ## Configuration diff --git a/DEBUG_SUMMARY.md b/DEBUG_SUMMARY.md deleted file mode 100644 index ac7dfc4..0000000 --- a/DEBUG_SUMMARY.md +++ /dev/null @@ -1,103 +0,0 @@ -# Chunked Prefill Bug Debug Summary - -## Problem -`test_needle.py --enable-offload --input-len 8192` fails with garbage output. - -The model generates completely wrong tokens instead of the expected "7492". - -## Investigation Progress - -### 1. Stream Synchronization Fix (Completed) -- Replaced Triton `store_kvcache` kernel with pure PyTorch operations -- Moved `store_kvcache` to `compute_stream` in chunked prefill mode -- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload -- Added sync: `default_stream.wait_stream(compute_stream)` before return - -### 2. KV Cache Alignment Verification (Completed) -Created alignment tests to compare K/V tensors between torch reference and nanovllm: - -**RoPE Alignment:** -- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0) -- Confirmed RoPE is NOT the cause of the bug - -**K/V Cache Alignment (Chunk 0):** -- Cosine similarity: ~1.0 for all layers -- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision) -- Mean diff: < 0.001 -- **Conclusion: K/V cache offload is working correctly** - -### 3. Layer Output Divergence Analysis (Completed) -Created per-chunk layer output comparison: - -**Chunk 0 (tokens 0-4096):** -- All layers pass with excellent cosine similarity (0.999+) -- Max diff grows in later layers but within acceptable range - -**Chunk 1 (tokens 4096-8192):** -- Layers 0-19: OK (cosine ~1.0) -- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114) -- Divergence correlates with later transformer layers - -### 4. Critical Discovery: Single-Chunk Offload Also Fails -**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled. - -``` -# Without offload: PASSES -python tests/test_needle.py --input-len 2048 -# Output: "7492" (correct) - -# With offload: FAILS -python tests/test_needle.py --enable-offload --input-len 2048 -# Output: "The Ble White Th G Lopsiswin..." (garbage) -``` - -**This proves the bug is NOT in:** -- Chunked attention logic (merge_attention_outputs) -- Multi-chunk KV loading -- Ring buffer pipeline - -**The bug IS in:** -- The decode path when CPU offload is enabled -- How prefilled KV is loaded/used during decode - -### 5. Decode Path Analysis (In Progress) -The decode path in CPU offload mode: -1. Prefill writes KV to GPU, offloads to CPU -2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline` -3. Attend to prefilled KV + accumulated decode tokens -4. Merge results - -**Observations:** -- `prefilled_blocks` set is empty after decode (should contain block IDs) -- CPU cache has valid data (reasonable mean/std values) -- Decode buffer has zeros (decode tokens not being stored correctly?) - -## Current Status - -### Working -- Stream synchronization fixes -- K/V cache offload to CPU (verified alignment) -- RoPE implementation -- Chunked prefill attention for first chunk - -### Not Working -- Decode with CPU offload (even for single-chunk inputs) -- Multi-chunk attention (divergence in later layers for chunk 1) - -## Next Steps -1. Debug why `prefilled_blocks` is empty after decode -2. Check if decode path correctly loads KV from CPU -3. Verify decode buffer is being written correctly -4. Compare decode attention outputs between offload and non-offload modes - -## Key Files -- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths -- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine -- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks` -- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration - -## Hypothesis -The decode path fails because: -1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty -2. OR the decode attention is not correctly loading/using the prefilled KV from CPU -3. OR there's a stream synchronization issue specific to decode path diff --git a/notes.md b/docs/development_notes.md similarity index 100% rename from notes.md rename to docs/development_notes.md diff --git a/docs/xattention_analysis.md b/docs/xattention_analysis.md new file mode 100644 index 0000000..165b731 --- /dev/null +++ b/docs/xattention_analysis.md @@ -0,0 +1,597 @@ +# COMPASS XAttention Implementation Analysis + +**Analysis Date**: 2026-01-14 +**Researcher**: Claude Code Agent +**Source**: `/home/zijie/Code/COMPASS/compass/src/` + +--- + +## Executive Summary + +COMPASS XAttention is a **block sparse attention** implementation that uses: +1. **Approximation phase** (`xattn_estimate`) to compute attention importance and select blocks +2. **Computation phase** (`Xattention_prefill`) to compute sparse attention using `block_sparse_attn_func` +3. **Triton kernels** for efficient block-wise GEMM and softmax operations + +**Key Integration Constraint**: Requires `block_sparse_attn_func` from flash-attention library, which is a **C++ CUDA extension** that must be compiled separately. + +--- + +## 1. Function: `xattn_estimate()` + +**Purpose**: Estimate attention importance and select which blocks to compute + +### Input Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `query_states` | Tensor | - | Shape: `(batch, num_heads, q_len, head_dim)` | +| `key_states` | Tensor | - | Shape: `(batch, num_kv_heads, k_len, head_dim)` | +| `block_size` | int | - | Size of attention blocks (typically 128) | +| `stride` | int | - | Downsampling stride for approximation | +| `norm` | float | 1 | Normalization factor for attention scaling | +| `softmax` | bool | True | Whether to apply softmax in estimation | +| `threshold` | float | 0.9 | Block selection threshold (0-1) | +| `chunk_size` | int | 16384 | Processing chunk size | +| `select_mode` | str | "inverse" | Pattern selection mode | +| `use_triton` | bool | True | Use Triton kernels (requires SM 80+) | +| `causal` | bool | True | Apply causal masking | +| `kdb` | int | 1 | Key downsampling factor | +| `keep_sink` | bool | False | Always attend to first token | +| `keep_recent` | bool | False | Always attend to recent tokens | + +### Output + +```python +returns: (attn_sums, simple_masks) + attn_sums: Tensor[float32] + Shape: (batch, num_heads, num_q_blocks, num_k_blocks_per_chunk) + Contains aggregated attention weights per block + + simple_masks: Tensor[bool] + Shape: (batch, num_heads, num_q_blocks, num_k_blocks) + Boolean mask indicating which blocks to compute +``` + +### Algorithm + +#### Step 1: Padding and Chunking +```python +# Pad sequences to chunk_size boundaries +k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len +q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len + +# Compute number of blocks and chunks +k_chunk_num = (k_len + k_num_to_pad) // chunk_size +k_block_num = (k_len + k_num_to_pad) // block_size +q_chunk_num = (q_len + q_num_to_pad) // chunk_size +q_block_num = (q_len + q_num_to_pad) // block_size +``` + +#### Step 2: Pattern Selection (stride-based downsampling) + +**Purpose**: Reduce computation by `stride` factor using patterned selection + +**Modes**: +1. **`"inverse"`** (default): Inverse stride pattern + ```python + # Key: regular stride [0, stride, 2*stride, ...] + # Query: reverse stride [(stride-1), (stride-1-stride), ...] + reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)]) + reshaped_query = torch.cat([query_states[:, :, (stride-1-q)::stride*kdb, :] for q in range(stride)]) + ``` + +2. **`"slash"`**: Slash pattern (diagonal) + ```python + # Both use regular stride + reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)]) + reshaped_query = torch.cat([query_states[:, :, q::stride, :] for q in range(stride)]) + ``` + +3. **`"random"`**: Random permutation +4. **`"double"`, `"triple"`**: Data augmentation modes + +#### Step 3: Chunk-wise Attention Estimation + +For each query chunk: + +**If `use_triton=True`** (fast path): +```python +# Triton kernel 1: Compute attention scores with fused reshape +attn_weights_slice = flat_group_gemm_fuse_reshape( + query_chunk, key_states, stride, + chunk_start, chunk_end, is_causal=causal +) + +# Triton kernel 2: Softmax + block aggregation +attn_sum = softmax_fuse_block_sum( + attn_weights_slice, reshaped_block_size, segment_size, + chunk_start, chunk_end, real_q_len, scale, is_causal +) +``` + +**If `use_triton=False`** (PyTorch fallback): +```python +# Standard matrix multiplication +attn_weights_slice = torch.matmul(chunked_query, reshaped_key.transpose(2, 3)) + +# Scale and apply causal mask +attn_weights_slice = attn_weights_slice / sqrt(head_dim) / stride / norm +attn_weights_slice = attn_weights_slice + causal_mask + +# Softmax +attn_weights_slice = F.softmax(attn_weights_slice, dim=-1) + +# Aggregate to block level +attn_sum = attn_weights_slice.view( + batch, heads, num_blocks_per_chunk, block_size//kdb, -1, block_size +).sum(dim=-1).sum(dim=-2) +``` + +#### Step 4: Block Selection + +```python +# Select blocks based on threshold +simple_mask = find_blocks_chunked( + attn_sum, + current_index, # Starting block index + threshold, # 0.9 = select blocks covering 90% of attention mass + None, # or num_to_choose for top-k selection + decoding=False, + mode="prefill", + causal=True +) +``` + +**Selection Algorithm** (`find_blocks_chunked`): +1. Sort blocks by attention weight (descending) +2. Compute cumulative sum +3. Select blocks until `cumulative_sum >= total_sum * threshold` +4. Enforce causal constraints (no future blocks) +5. Always include sink token (first block) if `keep_sink=True` +6. Always include diagonal blocks if `keep_recent=True` + +--- + +## 2. Function: `Xattention_prefill()` + +**Purpose**: Compute sparse attention using estimated block mask + +### Input Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `query_states` | Tensor | - | `(batch, num_heads, q_len, head_dim)` | +| `key_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` | +| `value_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` | +| `stride` | int | - | Downsampling stride for estimation | +| `norm` | float | 1 | Normalization factor | +| `threshold` | float | 0.8 | Block selection threshold | +| `block_size` | int | 128 | **MUST be 128** (hardcoded requirement) | +| `use_triton` | bool | True | Use Triton kernels in estimation | +| `causal` | bool | True | Apply causal masking | +| `kdb` | int | 1 | Key downsampling factor | +| `chunk_size` | int | None | Auto-computed if None | +| `keep_sink` | bool | False | Always attend to first token | +| `keep_recent` | bool | False | Always attend to recent tokens | + +### Output + +```python +returns: attn_output + attn_output: Tensor + Shape: (batch, num_heads, q_len, head_dim) + Sparse attention output +``` + +### Algorithm Flow + +#### Step 1: Auto-compute chunk_size +```python +if chunk_size is None: + chunk_size = int(max( + min( + max(2048, 1 << (k_len - 1).bit_length()), # Round to power of 2 + 128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()), # Memory constraint + ), + 2048, # Minimum + )) +``` + +**Example**: +- `k_len=8192` → `chunk_size=8192` +- `k_len=32768` → `chunk_size=16384` +- `k_len=65536` → `chunk_size=16384` + +#### Step 2: Estimate attention and select blocks +```python +attn_sums, approx_simple_mask = xattn_estimate( + query_states, key_states, + block_size=block_size, stride=stride, norm=norm, + threshold=threshold, select_mode="inverse", + use_triton=use_triton, causal=causal, + chunk_size=chunk_size, kdb=kdb, + keep_sink=keep_sink, keep_recent=keep_recent +) +``` + +#### Step 3: Prepare inputs for block_sparse_attn_func +```python +# Hard constraints +assert block_size == 128 +assert batch_size == 1 + +# Reshape to (seq_len, num_heads, head_dim) +query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim) +key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim) +value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim) + +# Cumulative sequence lengths +q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device) +k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device) + +# Head mask type (all heads use mask) +head_mask_type = torch.tensor([1 for _ in range(num_heads)], dtype=torch.int32) +``` + +#### Step 4: Call block_sparse_attn_func +```python +attn_output = block_sparse_attn_func( + query_states, # (q_len, num_heads, head_dim) + key_states, # (k_len, num_heads, head_dim) + value_states, # (k_len, num_heads, head_dim) + q_cu_seq_lens, # [0, q_len] + k_cu_seq_lens, # [0, k_len] + head_mask_type, # [1, 1, ..., 1] + None, # No custom layout + approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(), # Block mask + q_len, + k_len, + p_dropout=0.0, + deterministic=True, + is_causal=causal +) +``` + +#### Step 5: Reshape output +```python +attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2) +# Output shape: (batch, num_heads, q_len, head_dim) +``` + +--- + +## 3. Triton Kernel Dependencies + +### Kernel 1: `flat_group_gemm_fuse_reshape_kernel` + +**Purpose**: Compute QK^T with stride-based reshaping + +**Key Features**: +- Loads `stride` keys and queries at once +- Fused strided access pattern +- Causal masking support +- Block size auto-selection based on GPU memory + +**Block Size Selection**: +```python +# RTX 3090 (<30GB): BLOCK_M=64, BLOCK_N=64 +# A100/H100 (>=30GB): BLOCK_M=128, BLOCK_N=128 +``` + +**Signature**: +```python +flat_group_gemm_fuse_reshape( + query_states, # (batch, heads, q_len, head_dim) + key_states, # (batch, heads, k_len, head_dim) + stride, # Downsampling factor + chunk_start, # Start position in keys + chunk_end, # End position in keys + is_causal=True +) +# Returns: (batch, heads, q_len//stride, k_len//stride) +``` + +### Kernel 2: `softmax_fuse_block_sum_kernel_causal` / `_non_causal` + +**Purpose**: Online softmax with block aggregation + +**Algorithm**: +1. **Forward pass** (compute m_i, l_i): + ``` + m_i = max(m_i, m_local) + alpha = exp(m_i - m_new) + l_i = l_i * alpha + sum(exp(X - m_new)) + ``` +2. **Backward pass** (compute softmax with scaling): + ``` + softmax = exp(X - m_i) / l_i + aggregate to blocks: sum(softmax) over block_size + ``` + +**Key Features**: +- Single-pass softmax (no materializing full attention matrix) +- Causal masking integrated +- Outputs block-level sums directly + +**Signature**: +```python +softmax_fuse_block_sum( + attn_weights_slice, # (batch, heads, q_len, k_len) + reshaped_block_size, # Block size (128//stride) + segment_size, # Processing segment (min(4096, block_size)) + chunk_start, # Start position + chunk_end, # End position + real_q_len, # Actual query length (before padding) + scale, # 1.4426950408889634 / sqrt(head_dim) / stride / norm + is_causal=True +) +# Returns: (batch, heads, q_len//block_size, k_len//block_size) +``` + +--- + +## 4. Key Parameters and Their Meanings + +### Critical Parameters + +| Parameter | Meaning | Typical Value | Impact | +|-----------|---------|---------------|--------| +| `block_size` | Block granularity | 128 | **Fixed at 128**, affects mask granularity | +| `stride` | Downsampling factor | 4-16 | Higher = faster but less accurate | +| `threshold` | Sparsity level | 0.8-0.9 | Higher = denser mask, more computation | +| `chunk_size` | Processing chunk | 16384 | Affects memory and efficiency | +| `kdb` | Key downsampling boost | 1 | Experimental, use 1 | +| `norm` | Scaling factor | 1.0 | Attention temperature control | + +### Trade-offs + +**Stride (`stride`)**: +- `stride=1`: No approximation, same as dense attention +- `stride=4`: 4x faster estimation, good accuracy +- `stride=8`: 8x faster, moderate accuracy loss +- `stride=16`: 16x faster, significant accuracy loss + +**Threshold (`threshold`)**: +- `threshold=0.8`: Select blocks covering 80% of attention mass (~20% sparsity) +- `threshold=0.9`: Select blocks covering 90% of attention mass (~10% sparsity) +- `threshold=0.95`: Very dense, only prunes ~5% of blocks + +--- + +## 5. Dependencies + +### Required Libraries + +1. **`block_sparse_attn`** (CRITICAL) + - Source: `/home/zijie/Code/COMPASS/3rdparty/flash-attention/` + - Function: `block_sparse_attn_func` + - Type: **C++ CUDA extension** + - Build: Requires compilation with `torch.utils.cpp_extension` + +2. **Triton** (optional but recommended) + - Required for: `use_triton=True` + - GPU requirement: SM 80+ (A100, RTX 3090, H100, etc.) + - Check: `torch.cuda.get_device_properties().major >= 8` + +3. **PyTorch** + - Version: Compatible with flash-attention + - Features: F.pad, matmul, softmax, view, transpose + +### Dependency Tree + +``` +Xattention_prefill +├── xattn_estimate +│ ├── flat_group_gemm_fuse_reshape (Triton) +│ ├── softmax_fuse_block_sum (Triton) +│ └── find_blocks_chunked (PyTorch) +└── block_sparse_attn_func (C++ CUDA) +``` + +--- + +## 6. Integration Issues for nano-vllm + +### Critical Issue 1: `block_sparse_attn_func` Dependency + +**Problem**: `block_sparse_attn_func` is a **C++ CUDA extension** that must be compiled from flash-attention source. + +**Options**: +1. **Compile flash-attention with block sparse support** + ```bash + cd /home/zijie/Code/COMPASS/3rdparty/flash-attention + python setup.py install + ``` + - Risk: May conflict with existing flash-attention installation + - Complexity: High (C++ compilation) + +2. **Replace with FlashInfer block sparse** + - FlashInfer is already a dependency + - Has similar block sparse attention + - Need to adapt interface + +3. **Custom CUDA kernel** + - Implement simplified block sparse attention + - High development cost + - Maintenance burden + +### Critical Issue 2: Hard-coded Constraints + +```python +assert block_size == 128 # Line 358 +assert batch_size == 1 # Line 359 +``` + +**Impact**: +- Cannot process multiple sequences in one batch +- Fixed block size limits flexibility +- Must work around these constraints + +### Critical Issue 3: Triton GPU Requirement + +```python +props = torch.cuda.get_device_properties(torch.cuda.current_device()) +if props.major < 8: + use_triton = False +``` + +**Impact**: +- Triton kernels only work on SM 80+ (A100, RTX 3090, H100) +- Older GPUs (V100, T4, RTX 2080) fall back to slow PyTorch implementation +- RTX 3090 works but uses smaller block sizes (64 vs 128) + +### Issue 4: Memory Layout + +**XAttention expects**: +```python +query_states: (batch, num_heads, q_len, head_dim) +``` + +**nano-vllm uses**: +```python +query_states: (num_heads, total_tokens, head_dim) # Flattened batch +``` + +**Required**: Transpose and reshape before/after calling XAttention + +### Issue 5: Chunking Incompatibility + +**XAttention**: Processes in fixed-size chunks (e.g., 16384 tokens) +- Requires padding to chunk boundaries +- Adds overhead for short sequences + +**nano-vllm**: Processes variable-length requests +- No padding requirement +- Dynamic batch sizing + +--- + +## 7. Integration Strategy + +### Recommended Approach: **Wrapper with FlashInfer** + +1. **Keep `xattn_estimate`** (pure PyTorch + Triton) + - No external dependencies + - Computes block mask + +2. **Replace `block_sparse_attn_func` with FlashInfer** + - FlashInfer: `flashinfer.single_prefill_with_kv_cache` + - Similar API, already compiled + - Supports block sparse + +3. **Adapt mask format** + - XAttention: `(batch, heads, q_blocks, k_blocks)` boolean mask + - FlashInfer: `(num_qo, num_kv)` boolean mask or custom format + +4. **Handle constraints** + - Enforce `batch_size=1` by processing one request at a time + - Keep `block_size=128` as requirement + +### Alternative: **Pure PyTorch Implementation** + +1. Extract estimation algorithm +2. Implement sparse attention using PyTorch operations +3. Use FlashInfer for final computation +4. No Triton dependency + +--- + +## 8. Code Example: Adaptation + +```python +def xattention_prefill_adapted( + query_states, # (num_heads, q_len, head_dim) + key_states, # (num_heads, k_len, head_dim) + value_states, # (num_heads, k_len, head_dim) + stride=4, + threshold=0.9, + block_size=128, + causal=True, +): + # Step 1: Add batch dimension + q = query_states.unsqueeze(0) # (1, heads, q_len, dim) + k = key_states.unsqueeze(0) + v = value_states.unsqueeze(0) + + # Step 2: Estimate mask (no external dependency) + _, block_mask = xattn_estimate( + q, k, + block_size=block_size, + stride=stride, + threshold=threshold, + use_triton=True, + causal=causal, + ) + # block_mask: (1, heads, q_blocks, k_blocks) + + # Step 3: Convert block mask to token mask + q_blocks, k_blocks = block_mask.shape[-2:] + token_mask = block_mask.repeat_interleave(block_size, dim=-2) + token_mask = token_mask.repeat_interleave(block_size, dim=-1) + token_mask = token_mask[:, :, :q.size(2), :k.size(2)] # Trim padding + + # Step 4: Use FlashInfer with mask + from flashinfer import single_prefill_with_kv_cache + output = single_prefill_with_kv_cache( + q.squeeze(0), + k.squeeze(0), + v.squeeze(0), + custom_mask=token_mask.squeeze(0), + ) + + return output # (num_heads, q_len, head_dim) +``` + +--- + +## 9. Summary of Findings + +### Advantages + +1. **Accurate approximation**: Pattern-based stride selection preserves attention patterns +2. **Flexible sparsity**: Threshold-based control over computation +3. **GPU optimization**: Triton kernels for estimation phase +4. **Proven in practice**: Used in COMPASS system + +### Challenges + +1. **Hard dependency**: `block_sparse_attn_func` requires C++ compilation +2. **Rigid constraints**: `block_size=128`, `batch_size=1` +3. **GPU-specific**: Triton only on SM 80+ +4. **Memory layout mismatch**: Requires reshape/transpose +5. **Chunking overhead**: Padding to chunk boundaries + +### Integration Complexity + +| Component | Complexity | Risk | +|-----------|------------|------| +| `xattn_estimate` | Medium | Low (PyTorch + Triton) | +| `block_sparse_attn_func` | High | **Critical** (C++ dependency) | +| Interface adaptation | Low | Low (reshape) | +| Constraint handling | Medium | Medium (workarounds) | + +**Overall Integration Risk**: **HIGH** (due to C++ dependency) + +--- + +## 10. Next Steps + +1. **Evaluate FlashInfer compatibility** + - Can FlashInfer replace `block_sparse_attn_func`? + - What mask format does it expect? + +2. **Prototype estimation phase** + - Extract `xattn_estimate` function + - Test with nano-vllm inputs + - Validate mask quality + +3. **Benchmark Triton kernels** + - Compare Triton vs PyTorch estimation + - Measure speedup on RTX 3090 + - Profile memory usage + +4. **Design interface** + - Define nano-vllm sparse attention API + - Specify mask format + - Plan integration points