Files
nano-vllm/docs/xattention_analysis.md
Zijie Tian 57f4e9c6e6 docs: reorganize documentation files
- Move notes.md to docs/development_notes.md
- Move Xattention_analysis.md to docs/xattention_analysis.md
- Delete DEBUG_SUMMARY.md (no longer needed)
- Update CLAUDE.md with documentation index entries

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 10:08:41 +08:00

18 KiB

COMPASS XAttention Implementation Analysis

Analysis Date: 2026-01-14 Researcher: Claude Code Agent Source: /home/zijie/Code/COMPASS/compass/src/


Executive Summary

COMPASS XAttention is a block sparse attention implementation that uses:

  1. Approximation phase (xattn_estimate) to compute attention importance and select blocks
  2. Computation phase (Xattention_prefill) to compute sparse attention using block_sparse_attn_func
  3. Triton kernels for efficient block-wise GEMM and softmax operations

Key Integration Constraint: Requires block_sparse_attn_func from flash-attention library, which is a C++ CUDA extension that must be compiled separately.


1. Function: xattn_estimate()

Purpose: Estimate attention importance and select which blocks to compute

Input Parameters

Parameter Type Default Description
query_states Tensor - Shape: (batch, num_heads, q_len, head_dim)
key_states Tensor - Shape: (batch, num_kv_heads, k_len, head_dim)
block_size int - Size of attention blocks (typically 128)
stride int - Downsampling stride for approximation
norm float 1 Normalization factor for attention scaling
softmax bool True Whether to apply softmax in estimation
threshold float 0.9 Block selection threshold (0-1)
chunk_size int 16384 Processing chunk size
select_mode str "inverse" Pattern selection mode
use_triton bool True Use Triton kernels (requires SM 80+)
causal bool True Apply causal masking
kdb int 1 Key downsampling factor
keep_sink bool False Always attend to first token
keep_recent bool False Always attend to recent tokens

Output

returns: (attn_sums, simple_masks)
  attn_sums: Tensor[float32]
    Shape: (batch, num_heads, num_q_blocks, num_k_blocks_per_chunk)
    Contains aggregated attention weights per block

  simple_masks: Tensor[bool]
    Shape: (batch, num_heads, num_q_blocks, num_k_blocks)
    Boolean mask indicating which blocks to compute

Algorithm

Step 1: Padding and Chunking

# Pad sequences to chunk_size boundaries
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len

# Compute number of blocks and chunks
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
k_block_num = (k_len + k_num_to_pad) // block_size
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
q_block_num = (q_len + q_num_to_pad) // block_size

Step 2: Pattern Selection (stride-based downsampling)

Purpose: Reduce computation by stride factor using patterned selection

Modes:

  1. "inverse" (default): Inverse stride pattern

    # Key: regular stride [0, stride, 2*stride, ...]
    # Query: reverse stride [(stride-1), (stride-1-stride), ...]
    reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
    reshaped_query = torch.cat([query_states[:, :, (stride-1-q)::stride*kdb, :] for q in range(stride)])
    
  2. "slash": Slash pattern (diagonal)

    # Both use regular stride
    reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
    reshaped_query = torch.cat([query_states[:, :, q::stride, :] for q in range(stride)])
    
  3. "random": Random permutation

  4. "double", "triple": Data augmentation modes

Step 3: Chunk-wise Attention Estimation

For each query chunk:

If use_triton=True (fast path):

# Triton kernel 1: Compute attention scores with fused reshape
attn_weights_slice = flat_group_gemm_fuse_reshape(
    query_chunk, key_states, stride,
    chunk_start, chunk_end, is_causal=causal
)

# Triton kernel 2: Softmax + block aggregation
attn_sum = softmax_fuse_block_sum(
    attn_weights_slice, reshaped_block_size, segment_size,
    chunk_start, chunk_end, real_q_len, scale, is_causal
)

If use_triton=False (PyTorch fallback):

# Standard matrix multiplication
attn_weights_slice = torch.matmul(chunked_query, reshaped_key.transpose(2, 3))

# Scale and apply causal mask
attn_weights_slice = attn_weights_slice / sqrt(head_dim) / stride / norm
attn_weights_slice = attn_weights_slice + causal_mask

# Softmax
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1)

# Aggregate to block level
attn_sum = attn_weights_slice.view(
    batch, heads, num_blocks_per_chunk, block_size//kdb, -1, block_size
).sum(dim=-1).sum(dim=-2)

Step 4: Block Selection

# Select blocks based on threshold
simple_mask = find_blocks_chunked(
    attn_sum,
    current_index,  # Starting block index
    threshold,      # 0.9 = select blocks covering 90% of attention mass
    None,           # or num_to_choose for top-k selection
    decoding=False,
    mode="prefill",
    causal=True
)

Selection Algorithm (find_blocks_chunked):

  1. Sort blocks by attention weight (descending)
  2. Compute cumulative sum
  3. Select blocks until cumulative_sum >= total_sum * threshold
  4. Enforce causal constraints (no future blocks)
  5. Always include sink token (first block) if keep_sink=True
  6. Always include diagonal blocks if keep_recent=True

2. Function: Xattention_prefill()

Purpose: Compute sparse attention using estimated block mask

Input Parameters

Parameter Type Default Description
query_states Tensor - (batch, num_heads, q_len, head_dim)
key_states Tensor - (batch, num_heads, k_len, head_dim)
value_states Tensor - (batch, num_heads, k_len, head_dim)
stride int - Downsampling stride for estimation
norm float 1 Normalization factor
threshold float 0.8 Block selection threshold
block_size int 128 MUST be 128 (hardcoded requirement)
use_triton bool True Use Triton kernels in estimation
causal bool True Apply causal masking
kdb int 1 Key downsampling factor
chunk_size int None Auto-computed if None
keep_sink bool False Always attend to first token
keep_recent bool False Always attend to recent tokens

Output

returns: attn_output
  attn_output: Tensor
    Shape: (batch, num_heads, q_len, head_dim)
    Sparse attention output

Algorithm Flow

Step 1: Auto-compute chunk_size

if chunk_size is None:
    chunk_size = int(max(
        min(
            max(2048, 1 << (k_len - 1).bit_length()),  # Round to power of 2
            128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()),  # Memory constraint
        ),
        2048,  # Minimum
    ))

Example:

  • k_len=8192chunk_size=8192
  • k_len=32768chunk_size=16384
  • k_len=65536chunk_size=16384

Step 2: Estimate attention and select blocks

attn_sums, approx_simple_mask = xattn_estimate(
    query_states, key_states,
    block_size=block_size, stride=stride, norm=norm,
    threshold=threshold, select_mode="inverse",
    use_triton=use_triton, causal=causal,
    chunk_size=chunk_size, kdb=kdb,
    keep_sink=keep_sink, keep_recent=keep_recent
)

Step 3: Prepare inputs for block_sparse_attn_func

# Hard constraints
assert block_size == 128
assert batch_size == 1

# Reshape to (seq_len, num_heads, head_dim)
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)

# Cumulative sequence lengths
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)

# Head mask type (all heads use mask)
head_mask_type = torch.tensor([1 for _ in range(num_heads)], dtype=torch.int32)

Step 4: Call block_sparse_attn_func

attn_output = block_sparse_attn_func(
    query_states,      # (q_len, num_heads, head_dim)
    key_states,        # (k_len, num_heads, head_dim)
    value_states,      # (k_len, num_heads, head_dim)
    q_cu_seq_lens,     # [0, q_len]
    k_cu_seq_lens,     # [0, k_len]
    head_mask_type,    # [1, 1, ..., 1]
    None,              # No custom layout
    approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(),  # Block mask
    q_len,
    k_len,
    p_dropout=0.0,
    deterministic=True,
    is_causal=causal
)

Step 5: Reshape output

attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
# Output shape: (batch, num_heads, q_len, head_dim)

3. Triton Kernel Dependencies

Kernel 1: flat_group_gemm_fuse_reshape_kernel

Purpose: Compute QK^T with stride-based reshaping

Key Features:

  • Loads stride keys and queries at once
  • Fused strided access pattern
  • Causal masking support
  • Block size auto-selection based on GPU memory

Block Size Selection:

# RTX 3090 (<30GB): BLOCK_M=64, BLOCK_N=64
# A100/H100 (>=30GB): BLOCK_M=128, BLOCK_N=128

Signature:

flat_group_gemm_fuse_reshape(
    query_states,  # (batch, heads, q_len, head_dim)
    key_states,    # (batch, heads, k_len, head_dim)
    stride,        # Downsampling factor
    chunk_start,   # Start position in keys
    chunk_end,     # End position in keys
    is_causal=True
)
# Returns: (batch, heads, q_len//stride, k_len//stride)

Kernel 2: softmax_fuse_block_sum_kernel_causal / _non_causal

Purpose: Online softmax with block aggregation

Algorithm:

  1. Forward pass (compute m_i, l_i):
    m_i = max(m_i, m_local)
    alpha = exp(m_i - m_new)
    l_i = l_i * alpha + sum(exp(X - m_new))
    
  2. Backward pass (compute softmax with scaling):
    softmax = exp(X - m_i) / l_i
    aggregate to blocks: sum(softmax) over block_size
    

Key Features:

  • Single-pass softmax (no materializing full attention matrix)
  • Causal masking integrated
  • Outputs block-level sums directly

Signature:

softmax_fuse_block_sum(
    attn_weights_slice,  # (batch, heads, q_len, k_len)
    reshaped_block_size, # Block size (128//stride)
    segment_size,        # Processing segment (min(4096, block_size))
    chunk_start,         # Start position
    chunk_end,           # End position
    real_q_len,          # Actual query length (before padding)
    scale,               # 1.4426950408889634 / sqrt(head_dim) / stride / norm
    is_causal=True
)
# Returns: (batch, heads, q_len//block_size, k_len//block_size)

4. Key Parameters and Their Meanings

Critical Parameters

Parameter Meaning Typical Value Impact
block_size Block granularity 128 Fixed at 128, affects mask granularity
stride Downsampling factor 4-16 Higher = faster but less accurate
threshold Sparsity level 0.8-0.9 Higher = denser mask, more computation
chunk_size Processing chunk 16384 Affects memory and efficiency
kdb Key downsampling boost 1 Experimental, use 1
norm Scaling factor 1.0 Attention temperature control

Trade-offs

Stride (stride):

  • stride=1: No approximation, same as dense attention
  • stride=4: 4x faster estimation, good accuracy
  • stride=8: 8x faster, moderate accuracy loss
  • stride=16: 16x faster, significant accuracy loss

Threshold (threshold):

  • threshold=0.8: Select blocks covering 80% of attention mass (~20% sparsity)
  • threshold=0.9: Select blocks covering 90% of attention mass (~10% sparsity)
  • threshold=0.95: Very dense, only prunes ~5% of blocks

5. Dependencies

Required Libraries

  1. block_sparse_attn (CRITICAL)

    • Source: /home/zijie/Code/COMPASS/3rdparty/flash-attention/
    • Function: block_sparse_attn_func
    • Type: C++ CUDA extension
    • Build: Requires compilation with torch.utils.cpp_extension
  2. Triton (optional but recommended)

    • Required for: use_triton=True
    • GPU requirement: SM 80+ (A100, RTX 3090, H100, etc.)
    • Check: torch.cuda.get_device_properties().major >= 8
  3. PyTorch

    • Version: Compatible with flash-attention
    • Features: F.pad, matmul, softmax, view, transpose

Dependency Tree

Xattention_prefill
├── xattn_estimate
│   ├── flat_group_gemm_fuse_reshape (Triton)
│   ├── softmax_fuse_block_sum (Triton)
│   └── find_blocks_chunked (PyTorch)
└── block_sparse_attn_func (C++ CUDA)

6. Integration Issues for nano-vllm

Critical Issue 1: block_sparse_attn_func Dependency

Problem: block_sparse_attn_func is a C++ CUDA extension that must be compiled from flash-attention source.

Options:

  1. Compile flash-attention with block sparse support

    cd /home/zijie/Code/COMPASS/3rdparty/flash-attention
    python setup.py install
    
    • Risk: May conflict with existing flash-attention installation
    • Complexity: High (C++ compilation)
  2. Replace with FlashInfer block sparse

    • FlashInfer is already a dependency
    • Has similar block sparse attention
    • Need to adapt interface
  3. Custom CUDA kernel

    • Implement simplified block sparse attention
    • High development cost
    • Maintenance burden

Critical Issue 2: Hard-coded Constraints

assert block_size == 128  # Line 358
assert batch_size == 1    # Line 359

Impact:

  • Cannot process multiple sequences in one batch
  • Fixed block size limits flexibility
  • Must work around these constraints

Critical Issue 3: Triton GPU Requirement

props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.major < 8:
    use_triton = False

Impact:

  • Triton kernels only work on SM 80+ (A100, RTX 3090, H100)
  • Older GPUs (V100, T4, RTX 2080) fall back to slow PyTorch implementation
  • RTX 3090 works but uses smaller block sizes (64 vs 128)

Issue 4: Memory Layout

XAttention expects:

query_states: (batch, num_heads, q_len, head_dim)

nano-vllm uses:

query_states: (num_heads, total_tokens, head_dim)  # Flattened batch

Required: Transpose and reshape before/after calling XAttention

Issue 5: Chunking Incompatibility

XAttention: Processes in fixed-size chunks (e.g., 16384 tokens)

  • Requires padding to chunk boundaries
  • Adds overhead for short sequences

nano-vllm: Processes variable-length requests

  • No padding requirement
  • Dynamic batch sizing

7. Integration Strategy

  1. Keep xattn_estimate (pure PyTorch + Triton)

    • No external dependencies
    • Computes block mask
  2. Replace block_sparse_attn_func with FlashInfer

    • FlashInfer: flashinfer.single_prefill_with_kv_cache
    • Similar API, already compiled
    • Supports block sparse
  3. Adapt mask format

    • XAttention: (batch, heads, q_blocks, k_blocks) boolean mask
    • FlashInfer: (num_qo, num_kv) boolean mask or custom format
  4. Handle constraints

    • Enforce batch_size=1 by processing one request at a time
    • Keep block_size=128 as requirement

Alternative: Pure PyTorch Implementation

  1. Extract estimation algorithm
  2. Implement sparse attention using PyTorch operations
  3. Use FlashInfer for final computation
  4. No Triton dependency

8. Code Example: Adaptation

def xattention_prefill_adapted(
    query_states,  # (num_heads, q_len, head_dim)
    key_states,    # (num_heads, k_len, head_dim)
    value_states,  # (num_heads, k_len, head_dim)
    stride=4,
    threshold=0.9,
    block_size=128,
    causal=True,
):
    # Step 1: Add batch dimension
    q = query_states.unsqueeze(0)  # (1, heads, q_len, dim)
    k = key_states.unsqueeze(0)
    v = value_states.unsqueeze(0)

    # Step 2: Estimate mask (no external dependency)
    _, block_mask = xattn_estimate(
        q, k,
        block_size=block_size,
        stride=stride,
        threshold=threshold,
        use_triton=True,
        causal=causal,
    )
    # block_mask: (1, heads, q_blocks, k_blocks)

    # Step 3: Convert block mask to token mask
    q_blocks, k_blocks = block_mask.shape[-2:]
    token_mask = block_mask.repeat_interleave(block_size, dim=-2)
    token_mask = token_mask.repeat_interleave(block_size, dim=-1)
    token_mask = token_mask[:, :, :q.size(2), :k.size(2)]  # Trim padding

    # Step 4: Use FlashInfer with mask
    from flashinfer import single_prefill_with_kv_cache
    output = single_prefill_with_kv_cache(
        q.squeeze(0),
        k.squeeze(0),
        v.squeeze(0),
        custom_mask=token_mask.squeeze(0),
    )

    return output  # (num_heads, q_len, head_dim)

9. Summary of Findings

Advantages

  1. Accurate approximation: Pattern-based stride selection preserves attention patterns
  2. Flexible sparsity: Threshold-based control over computation
  3. GPU optimization: Triton kernels for estimation phase
  4. Proven in practice: Used in COMPASS system

Challenges

  1. Hard dependency: block_sparse_attn_func requires C++ compilation
  2. Rigid constraints: block_size=128, batch_size=1
  3. GPU-specific: Triton only on SM 80+
  4. Memory layout mismatch: Requires reshape/transpose
  5. Chunking overhead: Padding to chunk boundaries

Integration Complexity

Component Complexity Risk
xattn_estimate Medium Low (PyTorch + Triton)
block_sparse_attn_func High Critical (C++ dependency)
Interface adaptation Low Low (reshape)
Constraint handling Medium Medium (workarounds)

Overall Integration Risk: HIGH (due to C++ dependency)


10. Next Steps

  1. Evaluate FlashInfer compatibility

    • Can FlashInfer replace block_sparse_attn_func?
    • What mask format does it expect?
  2. Prototype estimation phase

    • Extract xattn_estimate function
    • Test with nano-vllm inputs
    • Validate mask quality
  3. Benchmark Triton kernels

    • Compare Triton vs PyTorch estimation
    • Measure speedup on RTX 3090
    • Profile memory usage
  4. Design interface

    • Define nano-vllm sparse attention API
    • Specify mask format
    • Plan integration points