- Move notes.md to docs/development_notes.md - Move Xattention_analysis.md to docs/xattention_analysis.md - Delete DEBUG_SUMMARY.md (no longer needed) - Update CLAUDE.md with documentation index entries Co-Authored-By: Claude <noreply@anthropic.com>
18 KiB
COMPASS XAttention Implementation Analysis
Analysis Date: 2026-01-14
Researcher: Claude Code Agent
Source: /home/zijie/Code/COMPASS/compass/src/
Executive Summary
COMPASS XAttention is a block sparse attention implementation that uses:
- Approximation phase (
xattn_estimate) to compute attention importance and select blocks - Computation phase (
Xattention_prefill) to compute sparse attention usingblock_sparse_attn_func - Triton kernels for efficient block-wise GEMM and softmax operations
Key Integration Constraint: Requires block_sparse_attn_func from flash-attention library, which is a C++ CUDA extension that must be compiled separately.
1. Function: xattn_estimate()
Purpose: Estimate attention importance and select which blocks to compute
Input Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
query_states |
Tensor | - | Shape: (batch, num_heads, q_len, head_dim) |
key_states |
Tensor | - | Shape: (batch, num_kv_heads, k_len, head_dim) |
block_size |
int | - | Size of attention blocks (typically 128) |
stride |
int | - | Downsampling stride for approximation |
norm |
float | 1 | Normalization factor for attention scaling |
softmax |
bool | True | Whether to apply softmax in estimation |
threshold |
float | 0.9 | Block selection threshold (0-1) |
chunk_size |
int | 16384 | Processing chunk size |
select_mode |
str | "inverse" | Pattern selection mode |
use_triton |
bool | True | Use Triton kernels (requires SM 80+) |
causal |
bool | True | Apply causal masking |
kdb |
int | 1 | Key downsampling factor |
keep_sink |
bool | False | Always attend to first token |
keep_recent |
bool | False | Always attend to recent tokens |
Output
returns: (attn_sums, simple_masks)
attn_sums: Tensor[float32]
Shape: (batch, num_heads, num_q_blocks, num_k_blocks_per_chunk)
Contains aggregated attention weights per block
simple_masks: Tensor[bool]
Shape: (batch, num_heads, num_q_blocks, num_k_blocks)
Boolean mask indicating which blocks to compute
Algorithm
Step 1: Padding and Chunking
# Pad sequences to chunk_size boundaries
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
# Compute number of blocks and chunks
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
k_block_num = (k_len + k_num_to_pad) // block_size
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
q_block_num = (q_len + q_num_to_pad) // block_size
Step 2: Pattern Selection (stride-based downsampling)
Purpose: Reduce computation by stride factor using patterned selection
Modes:
-
"inverse"(default): Inverse stride pattern# Key: regular stride [0, stride, 2*stride, ...] # Query: reverse stride [(stride-1), (stride-1-stride), ...] reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)]) reshaped_query = torch.cat([query_states[:, :, (stride-1-q)::stride*kdb, :] for q in range(stride)]) -
"slash": Slash pattern (diagonal)# Both use regular stride reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)]) reshaped_query = torch.cat([query_states[:, :, q::stride, :] for q in range(stride)]) -
"random": Random permutation -
"double","triple": Data augmentation modes
Step 3: Chunk-wise Attention Estimation
For each query chunk:
If use_triton=True (fast path):
# Triton kernel 1: Compute attention scores with fused reshape
attn_weights_slice = flat_group_gemm_fuse_reshape(
query_chunk, key_states, stride,
chunk_start, chunk_end, is_causal=causal
)
# Triton kernel 2: Softmax + block aggregation
attn_sum = softmax_fuse_block_sum(
attn_weights_slice, reshaped_block_size, segment_size,
chunk_start, chunk_end, real_q_len, scale, is_causal
)
If use_triton=False (PyTorch fallback):
# Standard matrix multiplication
attn_weights_slice = torch.matmul(chunked_query, reshaped_key.transpose(2, 3))
# Scale and apply causal mask
attn_weights_slice = attn_weights_slice / sqrt(head_dim) / stride / norm
attn_weights_slice = attn_weights_slice + causal_mask
# Softmax
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1)
# Aggregate to block level
attn_sum = attn_weights_slice.view(
batch, heads, num_blocks_per_chunk, block_size//kdb, -1, block_size
).sum(dim=-1).sum(dim=-2)
Step 4: Block Selection
# Select blocks based on threshold
simple_mask = find_blocks_chunked(
attn_sum,
current_index, # Starting block index
threshold, # 0.9 = select blocks covering 90% of attention mass
None, # or num_to_choose for top-k selection
decoding=False,
mode="prefill",
causal=True
)
Selection Algorithm (find_blocks_chunked):
- Sort blocks by attention weight (descending)
- Compute cumulative sum
- Select blocks until
cumulative_sum >= total_sum * threshold - Enforce causal constraints (no future blocks)
- Always include sink token (first block) if
keep_sink=True - Always include diagonal blocks if
keep_recent=True
2. Function: Xattention_prefill()
Purpose: Compute sparse attention using estimated block mask
Input Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
query_states |
Tensor | - | (batch, num_heads, q_len, head_dim) |
key_states |
Tensor | - | (batch, num_heads, k_len, head_dim) |
value_states |
Tensor | - | (batch, num_heads, k_len, head_dim) |
stride |
int | - | Downsampling stride for estimation |
norm |
float | 1 | Normalization factor |
threshold |
float | 0.8 | Block selection threshold |
block_size |
int | 128 | MUST be 128 (hardcoded requirement) |
use_triton |
bool | True | Use Triton kernels in estimation |
causal |
bool | True | Apply causal masking |
kdb |
int | 1 | Key downsampling factor |
chunk_size |
int | None | Auto-computed if None |
keep_sink |
bool | False | Always attend to first token |
keep_recent |
bool | False | Always attend to recent tokens |
Output
returns: attn_output
attn_output: Tensor
Shape: (batch, num_heads, q_len, head_dim)
Sparse attention output
Algorithm Flow
Step 1: Auto-compute chunk_size
if chunk_size is None:
chunk_size = int(max(
min(
max(2048, 1 << (k_len - 1).bit_length()), # Round to power of 2
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()), # Memory constraint
),
2048, # Minimum
))
Example:
k_len=8192→chunk_size=8192k_len=32768→chunk_size=16384k_len=65536→chunk_size=16384
Step 2: Estimate attention and select blocks
attn_sums, approx_simple_mask = xattn_estimate(
query_states, key_states,
block_size=block_size, stride=stride, norm=norm,
threshold=threshold, select_mode="inverse",
use_triton=use_triton, causal=causal,
chunk_size=chunk_size, kdb=kdb,
keep_sink=keep_sink, keep_recent=keep_recent
)
Step 3: Prepare inputs for block_sparse_attn_func
# Hard constraints
assert block_size == 128
assert batch_size == 1
# Reshape to (seq_len, num_heads, head_dim)
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
# Cumulative sequence lengths
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
# Head mask type (all heads use mask)
head_mask_type = torch.tensor([1 for _ in range(num_heads)], dtype=torch.int32)
Step 4: Call block_sparse_attn_func
attn_output = block_sparse_attn_func(
query_states, # (q_len, num_heads, head_dim)
key_states, # (k_len, num_heads, head_dim)
value_states, # (k_len, num_heads, head_dim)
q_cu_seq_lens, # [0, q_len]
k_cu_seq_lens, # [0, k_len]
head_mask_type, # [1, 1, ..., 1]
None, # No custom layout
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(), # Block mask
q_len,
k_len,
p_dropout=0.0,
deterministic=True,
is_causal=causal
)
Step 5: Reshape output
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
# Output shape: (batch, num_heads, q_len, head_dim)
3. Triton Kernel Dependencies
Kernel 1: flat_group_gemm_fuse_reshape_kernel
Purpose: Compute QK^T with stride-based reshaping
Key Features:
- Loads
stridekeys and queries at once - Fused strided access pattern
- Causal masking support
- Block size auto-selection based on GPU memory
Block Size Selection:
# RTX 3090 (<30GB): BLOCK_M=64, BLOCK_N=64
# A100/H100 (>=30GB): BLOCK_M=128, BLOCK_N=128
Signature:
flat_group_gemm_fuse_reshape(
query_states, # (batch, heads, q_len, head_dim)
key_states, # (batch, heads, k_len, head_dim)
stride, # Downsampling factor
chunk_start, # Start position in keys
chunk_end, # End position in keys
is_causal=True
)
# Returns: (batch, heads, q_len//stride, k_len//stride)
Kernel 2: softmax_fuse_block_sum_kernel_causal / _non_causal
Purpose: Online softmax with block aggregation
Algorithm:
- Forward pass (compute m_i, l_i):
m_i = max(m_i, m_local) alpha = exp(m_i - m_new) l_i = l_i * alpha + sum(exp(X - m_new)) - Backward pass (compute softmax with scaling):
softmax = exp(X - m_i) / l_i aggregate to blocks: sum(softmax) over block_size
Key Features:
- Single-pass softmax (no materializing full attention matrix)
- Causal masking integrated
- Outputs block-level sums directly
Signature:
softmax_fuse_block_sum(
attn_weights_slice, # (batch, heads, q_len, k_len)
reshaped_block_size, # Block size (128//stride)
segment_size, # Processing segment (min(4096, block_size))
chunk_start, # Start position
chunk_end, # End position
real_q_len, # Actual query length (before padding)
scale, # 1.4426950408889634 / sqrt(head_dim) / stride / norm
is_causal=True
)
# Returns: (batch, heads, q_len//block_size, k_len//block_size)
4. Key Parameters and Their Meanings
Critical Parameters
| Parameter | Meaning | Typical Value | Impact |
|---|---|---|---|
block_size |
Block granularity | 128 | Fixed at 128, affects mask granularity |
stride |
Downsampling factor | 4-16 | Higher = faster but less accurate |
threshold |
Sparsity level | 0.8-0.9 | Higher = denser mask, more computation |
chunk_size |
Processing chunk | 16384 | Affects memory and efficiency |
kdb |
Key downsampling boost | 1 | Experimental, use 1 |
norm |
Scaling factor | 1.0 | Attention temperature control |
Trade-offs
Stride (stride):
stride=1: No approximation, same as dense attentionstride=4: 4x faster estimation, good accuracystride=8: 8x faster, moderate accuracy lossstride=16: 16x faster, significant accuracy loss
Threshold (threshold):
threshold=0.8: Select blocks covering 80% of attention mass (~20% sparsity)threshold=0.9: Select blocks covering 90% of attention mass (~10% sparsity)threshold=0.95: Very dense, only prunes ~5% of blocks
5. Dependencies
Required Libraries
-
block_sparse_attn(CRITICAL)- Source:
/home/zijie/Code/COMPASS/3rdparty/flash-attention/ - Function:
block_sparse_attn_func - Type: C++ CUDA extension
- Build: Requires compilation with
torch.utils.cpp_extension
- Source:
-
Triton (optional but recommended)
- Required for:
use_triton=True - GPU requirement: SM 80+ (A100, RTX 3090, H100, etc.)
- Check:
torch.cuda.get_device_properties().major >= 8
- Required for:
-
PyTorch
- Version: Compatible with flash-attention
- Features: F.pad, matmul, softmax, view, transpose
Dependency Tree
Xattention_prefill
├── xattn_estimate
│ ├── flat_group_gemm_fuse_reshape (Triton)
│ ├── softmax_fuse_block_sum (Triton)
│ └── find_blocks_chunked (PyTorch)
└── block_sparse_attn_func (C++ CUDA)
6. Integration Issues for nano-vllm
Critical Issue 1: block_sparse_attn_func Dependency
Problem: block_sparse_attn_func is a C++ CUDA extension that must be compiled from flash-attention source.
Options:
-
Compile flash-attention with block sparse support
cd /home/zijie/Code/COMPASS/3rdparty/flash-attention python setup.py install- Risk: May conflict with existing flash-attention installation
- Complexity: High (C++ compilation)
-
Replace with FlashInfer block sparse
- FlashInfer is already a dependency
- Has similar block sparse attention
- Need to adapt interface
-
Custom CUDA kernel
- Implement simplified block sparse attention
- High development cost
- Maintenance burden
Critical Issue 2: Hard-coded Constraints
assert block_size == 128 # Line 358
assert batch_size == 1 # Line 359
Impact:
- Cannot process multiple sequences in one batch
- Fixed block size limits flexibility
- Must work around these constraints
Critical Issue 3: Triton GPU Requirement
props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.major < 8:
use_triton = False
Impact:
- Triton kernels only work on SM 80+ (A100, RTX 3090, H100)
- Older GPUs (V100, T4, RTX 2080) fall back to slow PyTorch implementation
- RTX 3090 works but uses smaller block sizes (64 vs 128)
Issue 4: Memory Layout
XAttention expects:
query_states: (batch, num_heads, q_len, head_dim)
nano-vllm uses:
query_states: (num_heads, total_tokens, head_dim) # Flattened batch
Required: Transpose and reshape before/after calling XAttention
Issue 5: Chunking Incompatibility
XAttention: Processes in fixed-size chunks (e.g., 16384 tokens)
- Requires padding to chunk boundaries
- Adds overhead for short sequences
nano-vllm: Processes variable-length requests
- No padding requirement
- Dynamic batch sizing
7. Integration Strategy
Recommended Approach: Wrapper with FlashInfer
-
Keep
xattn_estimate(pure PyTorch + Triton)- No external dependencies
- Computes block mask
-
Replace
block_sparse_attn_funcwith FlashInfer- FlashInfer:
flashinfer.single_prefill_with_kv_cache - Similar API, already compiled
- Supports block sparse
- FlashInfer:
-
Adapt mask format
- XAttention:
(batch, heads, q_blocks, k_blocks)boolean mask - FlashInfer:
(num_qo, num_kv)boolean mask or custom format
- XAttention:
-
Handle constraints
- Enforce
batch_size=1by processing one request at a time - Keep
block_size=128as requirement
- Enforce
Alternative: Pure PyTorch Implementation
- Extract estimation algorithm
- Implement sparse attention using PyTorch operations
- Use FlashInfer for final computation
- No Triton dependency
8. Code Example: Adaptation
def xattention_prefill_adapted(
query_states, # (num_heads, q_len, head_dim)
key_states, # (num_heads, k_len, head_dim)
value_states, # (num_heads, k_len, head_dim)
stride=4,
threshold=0.9,
block_size=128,
causal=True,
):
# Step 1: Add batch dimension
q = query_states.unsqueeze(0) # (1, heads, q_len, dim)
k = key_states.unsqueeze(0)
v = value_states.unsqueeze(0)
# Step 2: Estimate mask (no external dependency)
_, block_mask = xattn_estimate(
q, k,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
causal=causal,
)
# block_mask: (1, heads, q_blocks, k_blocks)
# Step 3: Convert block mask to token mask
q_blocks, k_blocks = block_mask.shape[-2:]
token_mask = block_mask.repeat_interleave(block_size, dim=-2)
token_mask = token_mask.repeat_interleave(block_size, dim=-1)
token_mask = token_mask[:, :, :q.size(2), :k.size(2)] # Trim padding
# Step 4: Use FlashInfer with mask
from flashinfer import single_prefill_with_kv_cache
output = single_prefill_with_kv_cache(
q.squeeze(0),
k.squeeze(0),
v.squeeze(0),
custom_mask=token_mask.squeeze(0),
)
return output # (num_heads, q_len, head_dim)
9. Summary of Findings
Advantages
- Accurate approximation: Pattern-based stride selection preserves attention patterns
- Flexible sparsity: Threshold-based control over computation
- GPU optimization: Triton kernels for estimation phase
- Proven in practice: Used in COMPASS system
Challenges
- Hard dependency:
block_sparse_attn_funcrequires C++ compilation - Rigid constraints:
block_size=128,batch_size=1 - GPU-specific: Triton only on SM 80+
- Memory layout mismatch: Requires reshape/transpose
- Chunking overhead: Padding to chunk boundaries
Integration Complexity
| Component | Complexity | Risk |
|---|---|---|
xattn_estimate |
Medium | Low (PyTorch + Triton) |
block_sparse_attn_func |
High | Critical (C++ dependency) |
| Interface adaptation | Low | Low (reshape) |
| Constraint handling | Medium | Medium (workarounds) |
Overall Integration Risk: HIGH (due to C++ dependency)
10. Next Steps
-
Evaluate FlashInfer compatibility
- Can FlashInfer replace
block_sparse_attn_func? - What mask format does it expect?
- Can FlashInfer replace
-
Prototype estimation phase
- Extract
xattn_estimatefunction - Test with nano-vllm inputs
- Validate mask quality
- Extract
-
Benchmark Triton kernels
- Compare Triton vs PyTorch estimation
- Measure speedup on RTX 3090
- Profile memory usage
-
Design interface
- Define nano-vllm sparse attention API
- Specify mask format
- Plan integration points