docs: reorganize documentation files

- Move notes.md to docs/development_notes.md
- Move Xattention_analysis.md to docs/xattention_analysis.md
- Delete DEBUG_SUMMARY.md (no longer needed)
- Update CLAUDE.md with documentation index entries

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-14 10:08:41 +08:00
parent ac1ccbceaa
commit 57f4e9c6e6
4 changed files with 599 additions and 103 deletions

324
docs/development_notes.md Normal file
View File

@@ -0,0 +1,324 @@
# Notes: Sparsity Integration into Layerwise Offload
## Current Architecture Analysis
### GPU-Only Path vs Offload Path
| Aspect | GPU-Only | Layerwise Offload |
|--------|----------|-------------------|
| KV Storage | GPU blocks (paged) | CPU pinned + GPU ring buffer |
| Prefill | All layers → then attention | Per-layer: attention → offload |
| Decode | FlashAttn with block table | Ring buffer H2D → FlashAttn |
| Sparse Support | MInference via `attention.py` | Not integrated |
### MInference Flow (GPU-Only)
```
attention.py:101-105:
if context.sparse_prefill_policy is not None:
o = context.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
minference.py:sparse_prefill_attention():
1. estimate_pattern(q, k, layer_id) -> vertical_indices, slash_indices
2. _triton_mixed_sparse_attention(q, k, v, indices)
3. return output
```
### Quest Flow (GPU Block Mode)
```
hybrid_manager.py (if using CPU offload with Quest):
select_blocks(available_blocks, ctx) -> selected block IDs
-> load selected blocks to GPU
-> standard FlashAttn with loaded blocks
```
### Layerwise Offload Prefill Flow
```
model_runner.py:run_layerwise_offload_prefill():
for layer_id in range(num_layers):
# QKV projection
q, k, v = qkv_proj(hidden_ln)
# RoPE
q, k = rotary_emb(positions, q, k)
# FULL attention (no sparsity!)
attn_output = flash_attn_varlen_func(q, k, v, ...)
# MLP
hidden_states = mlp(attn_out + residual)
# Sync offload ALL k, v to CPU
for block_id in cpu_block_ids:
k_cache_cpu[layer_id, block_id].copy_(k[start:end])
v_cache_cpu[layer_id, block_id].copy_(v[start:end])
```
### Layerwise Offload Decode Flow
```
model_runner.py:run_layerwise_offload_decode():
# Preload first N layers to ring buffer
for i in range(num_buffers):
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
for layer_id in range(num_layers):
current_buffer = layer_id % num_buffers
# Wait for buffer load
offload_engine.wait_buffer_load(current_buffer)
# Get prefilled KV from ring buffer (ALL blocks loaded)
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
# QKV for new token
q, k_new, v_new = qkv_proj(hidden_ln)
# Concat and full attention
k_full = torch.cat([k_prefill, k_decode_prev, k_new])
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...)
# Start loading next layer
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
```
## Integration Points
### 1. Prefill Sparse Integration Point
**Location:** `model_runner.py:535-543`
**Current:**
```python
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
```
**After Integration:**
```python
if self.sparse_policy and self.sparse_policy.supports_offload_prefill:
attn_output, k_sparse, v_sparse = self.sparse_policy.offload_prefill_attention(
q, k, v, layer_id
)
k_to_offload = k_sparse if k_sparse is not None else k
v_to_offload = v_sparse if v_sparse is not None else v
else:
attn_output = flash_attn_varlen_func(q, k, v, ...)
k_to_offload, v_to_offload = k, v
```
### 2. Decode Sparse Integration Point
**Location:** `model_runner.py:636-637` and `model_runner.py:704-706`
**Current (preload):**
```python
for i in range(num_preload):
offload_engine.load_layer_kv_to_buffer(
i, i, cpu_block_table, valid_tokens_per_block
)
```
**After Integration:**
```python
for i in range(num_preload):
layer_to_load = i
if self.sparse_policy and self.sparse_policy.supports_offload_decode:
# Prepare q for this layer (need to compute ahead)
# OR: use previous layer's pattern as estimate
selected_blocks = self.sparse_policy.select_offload_blocks(
None, # q not available yet at preload
layer_to_load,
cpu_block_table,
valid_tokens_per_block
)
else:
selected_blocks = cpu_block_table
offload_engine.load_sparse_layer_kv_to_buffer(
i, layer_to_load, selected_blocks, valid_tokens_per_block
)
```
**Challenge:** Q is not available during preload phase!
**Solutions:**
1. Skip sparse preload, only sparse for non-preloaded layers
2. Use previous decode step's pattern as estimate
3. Add preload hook to sparse policy
### 3. Offload Engine Extension
**New Method in OffloadEngine:**
```python
def load_sparse_layer_kv_to_buffer(
self,
buffer_idx: int,
layer_id: int,
selected_cpu_block_ids: List[int],
original_valid_tokens: List[int],
) -> int:
"""
Load only selected blocks from CPU to buffer.
Returns:
Total tokens loaded (may be less than full sequence)
"""
stream = self.layer_load_streams[buffer_idx]
with torch.cuda.stream(stream):
stream.wait_event(self.buffer_compute_done_events[buffer_idx])
# Build mapping: original block -> selected position
offset = 0
for i, cpu_block_id in enumerate(selected_cpu_block_ids):
# Find original index to get valid tokens
valid_tokens = original_valid_tokens[i] # Need mapping
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
non_blocking=True
)
# ... v_cache same
offset += valid_tokens
self.buffer_load_events[buffer_idx].record(stream)
return offset # Caller needs to know actual loaded tokens
```
## Metadata Flow for Quest
### During Prefill Offload
**Current:** No metadata collection in offload path
**Required:** Call `on_prefill_offload()` for each block
```python
# In run_layerwise_offload_prefill()
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
# BEFORE offload: update Quest metadata
if self.sparse_policy and hasattr(self.sparse_policy, 'on_prefill_offload'):
self.sparse_policy.on_prefill_offload(
cpu_block_id, layer_id, k[start:end], actual_size
)
# Offload
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
```
### Quest Metadata Shape
```python
# BlockMetadataManager
key_min: [num_blocks, num_layers, num_kv_heads, head_dim] # Min key per block per layer
key_max: [num_blocks, num_layers, num_kv_heads, head_dim] # Max key per block per layer
```
**Memory:** 2 * num_blocks * num_layers * kv_heads * head_dim * 2 bytes
- Example: 1000 blocks * 28 layers * 4 heads * 128 dim * 2 * 2 = ~57 MB
## Performance Considerations
### MInference Prefill Overhead
| Operation | Time (64K seq) |
|-----------|----------------|
| Pattern estimation (last-64) | ~5ms |
| Triton sparse attention | ~80ms |
| Full FlashAttention | ~100ms |
| **Net Speedup** | ~15-20% |
### Quest Decode Overhead
| Operation | Time |
|-----------|------|
| Block scoring (GPU metadata) | ~0.1ms |
| Top-K selection | ~0.05ms |
| Sparse H2D load (8 blocks) | ~2ms |
| Full H2D load (100 blocks) | ~20ms |
| **Net Speedup** | ~10x H2D |
### Memory Trade-offs
| Mode | GPU Memory | CPU Memory | H2D Bandwidth |
|------|------------|------------|---------------|
| Full offload | Ring buffer | Full KV | High |
| Sparse offload | Ring buffer | Full KV | Low (subset) |
| Aggressive sparse | Ring buffer | Sparse KV | Very low |
## Edge Cases
### 1. Short Sequences (< sparse threshold)
```python
if total_tokens < sparse_threshold:
# Fall back to full attention
use_sparse = False
```
### 2. First Decode Step (no previous Q)
Quest can't score blocks without Q. Options:
- Use average embedding as proxy
- Load all blocks for first step
- Use prefill pattern as estimate
### 3. Variable Sequence Lengths in Batch
Layerwise offload currently only supports batch_size=1:
```python
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
```
Sparse integration should maintain this constraint.
### 4. Ring Buffer vs Sparse Load Mismatch
Ring buffer assumes fixed `total_prefill_tokens`:
```python
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, total_prefill_tokens)
```
Sparse load has variable token count. Need:
```python
# Track actual loaded tokens per buffer
loaded_tokens[buffer_idx] = sparse_load_count
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, loaded_tokens[buffer_idx])
```
## Testing Strategy
### Unit Tests
1. `test_sparse_policy_interface.py` - Verify new interface methods
2. `test_minference_offload.py` - MInference in offload mode
3. `test_quest_offload.py` - Quest block selection in offload mode
### Integration Tests
1. `test_offload_sparse_e2e.py` - Full prefill+decode with sparsity
2. `test_accuracy_comparison.py` - Compare outputs: full vs sparse
### Benchmarks
1. `bench_offload_sparse.py` - Compare:
- Full offload (baseline)
- MInference prefill + Quest decode
- Aggressive sparse offload

597
docs/xattention_analysis.md Normal file
View File

@@ -0,0 +1,597 @@
# COMPASS XAttention Implementation Analysis
**Analysis Date**: 2026-01-14
**Researcher**: Claude Code Agent
**Source**: `/home/zijie/Code/COMPASS/compass/src/`
---
## Executive Summary
COMPASS XAttention is a **block sparse attention** implementation that uses:
1. **Approximation phase** (`xattn_estimate`) to compute attention importance and select blocks
2. **Computation phase** (`Xattention_prefill`) to compute sparse attention using `block_sparse_attn_func`
3. **Triton kernels** for efficient block-wise GEMM and softmax operations
**Key Integration Constraint**: Requires `block_sparse_attn_func` from flash-attention library, which is a **C++ CUDA extension** that must be compiled separately.
---
## 1. Function: `xattn_estimate()`
**Purpose**: Estimate attention importance and select which blocks to compute
### Input Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `query_states` | Tensor | - | Shape: `(batch, num_heads, q_len, head_dim)` |
| `key_states` | Tensor | - | Shape: `(batch, num_kv_heads, k_len, head_dim)` |
| `block_size` | int | - | Size of attention blocks (typically 128) |
| `stride` | int | - | Downsampling stride for approximation |
| `norm` | float | 1 | Normalization factor for attention scaling |
| `softmax` | bool | True | Whether to apply softmax in estimation |
| `threshold` | float | 0.9 | Block selection threshold (0-1) |
| `chunk_size` | int | 16384 | Processing chunk size |
| `select_mode` | str | "inverse" | Pattern selection mode |
| `use_triton` | bool | True | Use Triton kernels (requires SM 80+) |
| `causal` | bool | True | Apply causal masking |
| `kdb` | int | 1 | Key downsampling factor |
| `keep_sink` | bool | False | Always attend to first token |
| `keep_recent` | bool | False | Always attend to recent tokens |
### Output
```python
returns: (attn_sums, simple_masks)
attn_sums: Tensor[float32]
Shape: (batch, num_heads, num_q_blocks, num_k_blocks_per_chunk)
Contains aggregated attention weights per block
simple_masks: Tensor[bool]
Shape: (batch, num_heads, num_q_blocks, num_k_blocks)
Boolean mask indicating which blocks to compute
```
### Algorithm
#### Step 1: Padding and Chunking
```python
# Pad sequences to chunk_size boundaries
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
# Compute number of blocks and chunks
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
k_block_num = (k_len + k_num_to_pad) // block_size
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
q_block_num = (q_len + q_num_to_pad) // block_size
```
#### Step 2: Pattern Selection (stride-based downsampling)
**Purpose**: Reduce computation by `stride` factor using patterned selection
**Modes**:
1. **`"inverse"`** (default): Inverse stride pattern
```python
# Key: regular stride [0, stride, 2*stride, ...]
# Query: reverse stride [(stride-1), (stride-1-stride), ...]
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
reshaped_query = torch.cat([query_states[:, :, (stride-1-q)::stride*kdb, :] for q in range(stride)])
```
2. **`"slash"`**: Slash pattern (diagonal)
```python
# Both use regular stride
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
reshaped_query = torch.cat([query_states[:, :, q::stride, :] for q in range(stride)])
```
3. **`"random"`**: Random permutation
4. **`"double"`, `"triple"`**: Data augmentation modes
#### Step 3: Chunk-wise Attention Estimation
For each query chunk:
**If `use_triton=True`** (fast path):
```python
# Triton kernel 1: Compute attention scores with fused reshape
attn_weights_slice = flat_group_gemm_fuse_reshape(
query_chunk, key_states, stride,
chunk_start, chunk_end, is_causal=causal
)
# Triton kernel 2: Softmax + block aggregation
attn_sum = softmax_fuse_block_sum(
attn_weights_slice, reshaped_block_size, segment_size,
chunk_start, chunk_end, real_q_len, scale, is_causal
)
```
**If `use_triton=False`** (PyTorch fallback):
```python
# Standard matrix multiplication
attn_weights_slice = torch.matmul(chunked_query, reshaped_key.transpose(2, 3))
# Scale and apply causal mask
attn_weights_slice = attn_weights_slice / sqrt(head_dim) / stride / norm
attn_weights_slice = attn_weights_slice + causal_mask
# Softmax
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1)
# Aggregate to block level
attn_sum = attn_weights_slice.view(
batch, heads, num_blocks_per_chunk, block_size//kdb, -1, block_size
).sum(dim=-1).sum(dim=-2)
```
#### Step 4: Block Selection
```python
# Select blocks based on threshold
simple_mask = find_blocks_chunked(
attn_sum,
current_index, # Starting block index
threshold, # 0.9 = select blocks covering 90% of attention mass
None, # or num_to_choose for top-k selection
decoding=False,
mode="prefill",
causal=True
)
```
**Selection Algorithm** (`find_blocks_chunked`):
1. Sort blocks by attention weight (descending)
2. Compute cumulative sum
3. Select blocks until `cumulative_sum >= total_sum * threshold`
4. Enforce causal constraints (no future blocks)
5. Always include sink token (first block) if `keep_sink=True`
6. Always include diagonal blocks if `keep_recent=True`
---
## 2. Function: `Xattention_prefill()`
**Purpose**: Compute sparse attention using estimated block mask
### Input Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `query_states` | Tensor | - | `(batch, num_heads, q_len, head_dim)` |
| `key_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
| `value_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
| `stride` | int | - | Downsampling stride for estimation |
| `norm` | float | 1 | Normalization factor |
| `threshold` | float | 0.8 | Block selection threshold |
| `block_size` | int | 128 | **MUST be 128** (hardcoded requirement) |
| `use_triton` | bool | True | Use Triton kernels in estimation |
| `causal` | bool | True | Apply causal masking |
| `kdb` | int | 1 | Key downsampling factor |
| `chunk_size` | int | None | Auto-computed if None |
| `keep_sink` | bool | False | Always attend to first token |
| `keep_recent` | bool | False | Always attend to recent tokens |
### Output
```python
returns: attn_output
attn_output: Tensor
Shape: (batch, num_heads, q_len, head_dim)
Sparse attention output
```
### Algorithm Flow
#### Step 1: Auto-compute chunk_size
```python
if chunk_size is None:
chunk_size = int(max(
min(
max(2048, 1 << (k_len - 1).bit_length()), # Round to power of 2
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()), # Memory constraint
),
2048, # Minimum
))
```
**Example**:
- `k_len=8192` → `chunk_size=8192`
- `k_len=32768` → `chunk_size=16384`
- `k_len=65536` → `chunk_size=16384`
#### Step 2: Estimate attention and select blocks
```python
attn_sums, approx_simple_mask = xattn_estimate(
query_states, key_states,
block_size=block_size, stride=stride, norm=norm,
threshold=threshold, select_mode="inverse",
use_triton=use_triton, causal=causal,
chunk_size=chunk_size, kdb=kdb,
keep_sink=keep_sink, keep_recent=keep_recent
)
```
#### Step 3: Prepare inputs for block_sparse_attn_func
```python
# Hard constraints
assert block_size == 128
assert batch_size == 1
# Reshape to (seq_len, num_heads, head_dim)
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
# Cumulative sequence lengths
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
# Head mask type (all heads use mask)
head_mask_type = torch.tensor([1 for _ in range(num_heads)], dtype=torch.int32)
```
#### Step 4: Call block_sparse_attn_func
```python
attn_output = block_sparse_attn_func(
query_states, # (q_len, num_heads, head_dim)
key_states, # (k_len, num_heads, head_dim)
value_states, # (k_len, num_heads, head_dim)
q_cu_seq_lens, # [0, q_len]
k_cu_seq_lens, # [0, k_len]
head_mask_type, # [1, 1, ..., 1]
None, # No custom layout
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(), # Block mask
q_len,
k_len,
p_dropout=0.0,
deterministic=True,
is_causal=causal
)
```
#### Step 5: Reshape output
```python
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
# Output shape: (batch, num_heads, q_len, head_dim)
```
---
## 3. Triton Kernel Dependencies
### Kernel 1: `flat_group_gemm_fuse_reshape_kernel`
**Purpose**: Compute QK^T with stride-based reshaping
**Key Features**:
- Loads `stride` keys and queries at once
- Fused strided access pattern
- Causal masking support
- Block size auto-selection based on GPU memory
**Block Size Selection**:
```python
# RTX 3090 (<30GB): BLOCK_M=64, BLOCK_N=64
# A100/H100 (>=30GB): BLOCK_M=128, BLOCK_N=128
```
**Signature**:
```python
flat_group_gemm_fuse_reshape(
query_states, # (batch, heads, q_len, head_dim)
key_states, # (batch, heads, k_len, head_dim)
stride, # Downsampling factor
chunk_start, # Start position in keys
chunk_end, # End position in keys
is_causal=True
)
# Returns: (batch, heads, q_len//stride, k_len//stride)
```
### Kernel 2: `softmax_fuse_block_sum_kernel_causal` / `_non_causal`
**Purpose**: Online softmax with block aggregation
**Algorithm**:
1. **Forward pass** (compute m_i, l_i):
```
m_i = max(m_i, m_local)
alpha = exp(m_i - m_new)
l_i = l_i * alpha + sum(exp(X - m_new))
```
2. **Backward pass** (compute softmax with scaling):
```
softmax = exp(X - m_i) / l_i
aggregate to blocks: sum(softmax) over block_size
```
**Key Features**:
- Single-pass softmax (no materializing full attention matrix)
- Causal masking integrated
- Outputs block-level sums directly
**Signature**:
```python
softmax_fuse_block_sum(
attn_weights_slice, # (batch, heads, q_len, k_len)
reshaped_block_size, # Block size (128//stride)
segment_size, # Processing segment (min(4096, block_size))
chunk_start, # Start position
chunk_end, # End position
real_q_len, # Actual query length (before padding)
scale, # 1.4426950408889634 / sqrt(head_dim) / stride / norm
is_causal=True
)
# Returns: (batch, heads, q_len//block_size, k_len//block_size)
```
---
## 4. Key Parameters and Their Meanings
### Critical Parameters
| Parameter | Meaning | Typical Value | Impact |
|-----------|---------|---------------|--------|
| `block_size` | Block granularity | 128 | **Fixed at 128**, affects mask granularity |
| `stride` | Downsampling factor | 4-16 | Higher = faster but less accurate |
| `threshold` | Sparsity level | 0.8-0.9 | Higher = denser mask, more computation |
| `chunk_size` | Processing chunk | 16384 | Affects memory and efficiency |
| `kdb` | Key downsampling boost | 1 | Experimental, use 1 |
| `norm` | Scaling factor | 1.0 | Attention temperature control |
### Trade-offs
**Stride (`stride`)**:
- `stride=1`: No approximation, same as dense attention
- `stride=4`: 4x faster estimation, good accuracy
- `stride=8`: 8x faster, moderate accuracy loss
- `stride=16`: 16x faster, significant accuracy loss
**Threshold (`threshold`)**:
- `threshold=0.8`: Select blocks covering 80% of attention mass (~20% sparsity)
- `threshold=0.9`: Select blocks covering 90% of attention mass (~10% sparsity)
- `threshold=0.95`: Very dense, only prunes ~5% of blocks
---
## 5. Dependencies
### Required Libraries
1. **`block_sparse_attn`** (CRITICAL)
- Source: `/home/zijie/Code/COMPASS/3rdparty/flash-attention/`
- Function: `block_sparse_attn_func`
- Type: **C++ CUDA extension**
- Build: Requires compilation with `torch.utils.cpp_extension`
2. **Triton** (optional but recommended)
- Required for: `use_triton=True`
- GPU requirement: SM 80+ (A100, RTX 3090, H100, etc.)
- Check: `torch.cuda.get_device_properties().major >= 8`
3. **PyTorch**
- Version: Compatible with flash-attention
- Features: F.pad, matmul, softmax, view, transpose
### Dependency Tree
```
Xattention_prefill
├── xattn_estimate
│ ├── flat_group_gemm_fuse_reshape (Triton)
│ ├── softmax_fuse_block_sum (Triton)
│ └── find_blocks_chunked (PyTorch)
└── block_sparse_attn_func (C++ CUDA)
```
---
## 6. Integration Issues for nano-vllm
### Critical Issue 1: `block_sparse_attn_func` Dependency
**Problem**: `block_sparse_attn_func` is a **C++ CUDA extension** that must be compiled from flash-attention source.
**Options**:
1. **Compile flash-attention with block sparse support**
```bash
cd /home/zijie/Code/COMPASS/3rdparty/flash-attention
python setup.py install
```
- Risk: May conflict with existing flash-attention installation
- Complexity: High (C++ compilation)
2. **Replace with FlashInfer block sparse**
- FlashInfer is already a dependency
- Has similar block sparse attention
- Need to adapt interface
3. **Custom CUDA kernel**
- Implement simplified block sparse attention
- High development cost
- Maintenance burden
### Critical Issue 2: Hard-coded Constraints
```python
assert block_size == 128 # Line 358
assert batch_size == 1 # Line 359
```
**Impact**:
- Cannot process multiple sequences in one batch
- Fixed block size limits flexibility
- Must work around these constraints
### Critical Issue 3: Triton GPU Requirement
```python
props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.major < 8:
use_triton = False
```
**Impact**:
- Triton kernels only work on SM 80+ (A100, RTX 3090, H100)
- Older GPUs (V100, T4, RTX 2080) fall back to slow PyTorch implementation
- RTX 3090 works but uses smaller block sizes (64 vs 128)
### Issue 4: Memory Layout
**XAttention expects**:
```python
query_states: (batch, num_heads, q_len, head_dim)
```
**nano-vllm uses**:
```python
query_states: (num_heads, total_tokens, head_dim) # Flattened batch
```
**Required**: Transpose and reshape before/after calling XAttention
### Issue 5: Chunking Incompatibility
**XAttention**: Processes in fixed-size chunks (e.g., 16384 tokens)
- Requires padding to chunk boundaries
- Adds overhead for short sequences
**nano-vllm**: Processes variable-length requests
- No padding requirement
- Dynamic batch sizing
---
## 7. Integration Strategy
### Recommended Approach: **Wrapper with FlashInfer**
1. **Keep `xattn_estimate`** (pure PyTorch + Triton)
- No external dependencies
- Computes block mask
2. **Replace `block_sparse_attn_func` with FlashInfer**
- FlashInfer: `flashinfer.single_prefill_with_kv_cache`
- Similar API, already compiled
- Supports block sparse
3. **Adapt mask format**
- XAttention: `(batch, heads, q_blocks, k_blocks)` boolean mask
- FlashInfer: `(num_qo, num_kv)` boolean mask or custom format
4. **Handle constraints**
- Enforce `batch_size=1` by processing one request at a time
- Keep `block_size=128` as requirement
### Alternative: **Pure PyTorch Implementation**
1. Extract estimation algorithm
2. Implement sparse attention using PyTorch operations
3. Use FlashInfer for final computation
4. No Triton dependency
---
## 8. Code Example: Adaptation
```python
def xattention_prefill_adapted(
query_states, # (num_heads, q_len, head_dim)
key_states, # (num_heads, k_len, head_dim)
value_states, # (num_heads, k_len, head_dim)
stride=4,
threshold=0.9,
block_size=128,
causal=True,
):
# Step 1: Add batch dimension
q = query_states.unsqueeze(0) # (1, heads, q_len, dim)
k = key_states.unsqueeze(0)
v = value_states.unsqueeze(0)
# Step 2: Estimate mask (no external dependency)
_, block_mask = xattn_estimate(
q, k,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
causal=causal,
)
# block_mask: (1, heads, q_blocks, k_blocks)
# Step 3: Convert block mask to token mask
q_blocks, k_blocks = block_mask.shape[-2:]
token_mask = block_mask.repeat_interleave(block_size, dim=-2)
token_mask = token_mask.repeat_interleave(block_size, dim=-1)
token_mask = token_mask[:, :, :q.size(2), :k.size(2)] # Trim padding
# Step 4: Use FlashInfer with mask
from flashinfer import single_prefill_with_kv_cache
output = single_prefill_with_kv_cache(
q.squeeze(0),
k.squeeze(0),
v.squeeze(0),
custom_mask=token_mask.squeeze(0),
)
return output # (num_heads, q_len, head_dim)
```
---
## 9. Summary of Findings
### Advantages
1. **Accurate approximation**: Pattern-based stride selection preserves attention patterns
2. **Flexible sparsity**: Threshold-based control over computation
3. **GPU optimization**: Triton kernels for estimation phase
4. **Proven in practice**: Used in COMPASS system
### Challenges
1. **Hard dependency**: `block_sparse_attn_func` requires C++ compilation
2. **Rigid constraints**: `block_size=128`, `batch_size=1`
3. **GPU-specific**: Triton only on SM 80+
4. **Memory layout mismatch**: Requires reshape/transpose
5. **Chunking overhead**: Padding to chunk boundaries
### Integration Complexity
| Component | Complexity | Risk |
|-----------|------------|------|
| `xattn_estimate` | Medium | Low (PyTorch + Triton) |
| `block_sparse_attn_func` | High | **Critical** (C++ dependency) |
| Interface adaptation | Low | Low (reshape) |
| Constraint handling | Medium | Medium (workarounds) |
**Overall Integration Risk**: **HIGH** (due to C++ dependency)
---
## 10. Next Steps
1. **Evaluate FlashInfer compatibility**
- Can FlashInfer replace `block_sparse_attn_func`?
- What mask format does it expect?
2. **Prototype estimation phase**
- Extract `xattn_estimate` function
- Test with nano-vllm inputs
- Validate mask quality
3. **Benchmark Triton kernels**
- Compare Triton vs PyTorch estimation
- Measure speedup on RTX 3090
- Profile memory usage
4. **Design interface**
- Define nano-vllm sparse attention API
- Specify mask format
- Plan integration points