- Create docs/xattention_algorithm_guide.md with detailed algorithm explanation - Stride reshape (inverse mode) for Q/K interleaved sampling - Triton kernels: flat_group_gemm_fuse_reshape, softmax_fuse_block_sum - Block selection via find_blocks_chunked with cumulative threshold - BSA (block_sparse_attn) dependency for sparse computation - Update docs/sparse_attention_guide.md XAttention section with accurate description - Add documentation index entry in CLAUDE.md Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
15 KiB
Sparse Attention Computation Flow Guide
This document explains the computation flow for block sparse attention methods used in long-context LLM inference. The implementations are from the x-attention project.
Overview
All sparse attention methods follow a common pattern:
1. Estimate important blocks (low-cost estimation)
2. Create block mask (True = compute, False = skip)
3. Execute block sparse attention kernel (only compute selected blocks)
The key difference between methods is how they estimate which blocks are important.
Common Concepts
Block Tiling
All methods divide Q, K, V into blocks:
- Block Size: Typically 64 or 128 tokens per block
- Num Q Blocks:
ceil(seq_len / block_size) - Num K Blocks:
ceil(seq_len / block_size) - Block Mask: Boolean tensor
[batch, heads, q_blocks, k_blocks]
Causal Masking
For autoregressive models, block (i, j) is only valid if j <= i (causal constraint).
Block Sparse Attention Kernel
All methods ultimately call block_sparse_attn_func from MIT-HAN-LAB:
from block_sparse_attn import block_sparse_attn_func
output = block_sparse_attn_func(
query_states, # [batch, seq, heads, head_dim]
key_states, # [batch, seq, heads, head_dim]
value_states, # [batch, seq, heads, head_dim]
block_mask, # [batch, heads, q_blocks, k_blocks] bool
block_size=64, # tokens per block
causal=True,
)
Method 1: XAttention (xattn_estimate)
Source: compass/src/Xattention.py
详细文档: docs/xattention_algorithm_guide.md
Core Idea
Use stride interleaved reshape (inverse mode) to efficiently estimate block-level attention importance, then use BSA (Block Sparse Attention) library for sparse computation.
Algorithm
def xattn_estimate(query, key, block_size=128, stride=8):
"""
Estimate block importance using stride-interleaved attention.
1. K reshape (正向交错): concat([K[:,:,k::stride,:] for k in range(stride)])
Q reshape (反向交错): concat([Q[:,:,(stride-1-q)::stride,:] for q])
结果: 序列长度 seq_len -> seq_len/stride, head_dim -> head_dim*stride
2. Triton kernel (flat_group_gemm_fuse_reshape):
融合 reshape + GEMM,计算 Q_reshaped @ K_reshaped^T
3. Triton kernel (softmax_fuse_block_sum):
在线 softmax + 按 block_size/stride 分组求和
输出: attn_sum [batch, heads, q_blocks, k_blocks]
4. find_blocks_chunked:
按 attn_sum 降序排序,累积到 threshold 的块标记为 True
对角块和 sink 块始终保留
"""
Key Parameters
| Parameter | Default | Description |
|---|---|---|
block_size |
128 | Tokens per block (BSA 要求固定 128) |
stride |
8 | Q/K 交错采样步长,越大估计越快但越粗糙 |
threshold |
0.9 | 累积注意力阈值,选择累积权重达到此比例的块 |
chunk_size |
16384 | 估计时的分块大小 |
Computation Flow
query [B, H, S, D]
|
v
Stride interleaved reshape (Triton fused)
|
v
flat_group_gemm_fuse_reshape: Q_r @ K_r^T
|
v
softmax_fuse_block_sum: 在线 softmax + 块求和
|
v
attn_sum [B, H, q_blocks, k_blocks]
|
v
find_blocks_chunked: 累积阈值选择
|
v
simple_mask [B, H, q_blocks, k_blocks] (bool)
|
v
block_sparse_attn_func(q, k, v, simple_mask) ← BSA 库
|
v
output [B, H, S, D]
Dependencies
from block_sparse_attn import block_sparse_attn_func # MIT-HAN-LAB BSA 库
import triton # Triton kernels for estimation
Usage
from compass.src.Xattention import Xattention_prefill
output = Xattention_prefill(
query_states, key_states, value_states,
threshold=0.9,
stride=8,
block_size=128,
use_triton=True,
)
---
## Method 2: FlexPrefill
**Source**: `xattn/src/Flexprefill.py`
### Core Idea
Use **last-q attention pattern** to detect vertical (column-wise important) and slash (diagonal-like) patterns. Adaptively adjust budget based on **JS divergence** between estimated and uniform distribution.
### Algorithm
```python
def Flexprefill_prefill(query, key, value, gamma=0.9, tau=0.1):
"""
1. Compute attention using only last 64 queries (last_q_attn)
This reveals which K positions are globally important
2. Detect vertical pattern: columns with high attention across all last-q
These are "sink tokens" that all queries attend to
3. Detect slash pattern: diagonal bands that capture local attention
4. Compute JS divergence between estimated pattern and uniform
- High divergence = sparse pattern detected, use fewer blocks
- Low divergence = dense pattern needed, use more blocks
5. Adjust budget per head based on divergence
6. Select top-scoring blocks up to budget
"""
Key Parameters
| Parameter | Default | Description |
|---|---|---|
gamma |
0.9 | Base coverage ratio (fraction of blocks to keep) |
tau |
0.1 | JS divergence threshold for adaptive budget |
min_budget |
0.5 | Minimum coverage even for sparse patterns |
block_size |
64 | Tokens per block |
Patterns Detected
-
Vertical Pattern: Columns where many queries attend heavily
- Detected by summing attention across query dimension
- Captures "attention sinks" (e.g., BOS token, punctuation)
-
Slash Pattern: Diagonal bands
- Captures local context attention
- Width determined by
slash_sizeparameter
Computation Flow
query [B, S, H, D]
|
v
Take last 64 queries -> last_q [B, 64, H, D]
|
v
Compute last_q attention to all K -> attn [B, H, 64, S]
|
v
Analyze pattern:
- Vertical: sum over query dim, find high columns
- Slash: diagonal bands
|
v
Compute JS divergence per head
|
v
Adaptive budget = gamma * (1 - divergence/tau)
|
v
Select blocks up to budget -> block_mask
|
v
block_sparse_attn_func(q, k, v, block_mask)
|
v
output [B, S, H, D]
Triton Kernels
FlexPrefill uses custom Triton kernels for efficiency:
flex_prefill_attention_kernel: Block-wise attention with pattern maskingflex_vertical_slash_kernel: Combined vertical + slash pattern attention
Method 3: MInference
Source: xattn/src/Minference.py
Core Idea
Simple and direct: use vertical_slash_sparse_attention kernel with pre-computed vertical and slash indices.
Algorithm
def Minference_prefill(query, key, value, vertical_topk=100):
"""
1. Compute attention from last 64 queries to all K
2. For each head, identify:
- vertical_indices: top-k columns with highest total attention
- slash_indices: diagonal band positions
3. Call vertical_slash_sparse_attention kernel
- Computes attention only at selected positions
- Returns output with zeros elsewhere
"""
Key Parameters
| Parameter | Default | Description |
|---|---|---|
vertical_topk |
100 | Number of important columns to select |
slash_n |
64 | Size of diagonal band |
Computation Flow
query [B, S, H, D]
|
v
Take last 64 queries
|
v
Compute attention scores to all K
|
v
Sum across queries -> column importance [H, S]
|
v
Select top-k columns per head -> vertical_indices [H, topk]
|
v
Generate slash indices (diagonal positions) -> slash_indices [H, slash_n]
|
v
vertical_slash_sparse_attention(q, k, v, vertical_indices, slash_indices)
|
v
output [B, S, H, D]
CUDA Kernel
MInference uses a specialized CUDA kernel:
from minference.ops.block_sparse_attn import vertical_slash_sparse_attention
output = vertical_slash_sparse_attention(
query, key, value,
vertical_indices, # [heads, topk]
slash_indices, # [heads, slash_n]
)
Method 4: AvgPool
Source: xattn/src/AvgPool.py
Core Idea
Pool Q and K within each block using average pooling, compute block-level softmax attention, then select blocks using top-k or top-p (nucleus sampling).
Algorithm
def AvgPool_prefill(query, key, value, block_size=128, top_k=64, top_p=None):
"""
1. Divide Q into blocks, apply average pooling -> block_q [B, q_blocks, H, D]
2. Divide K into blocks, apply average pooling -> block_k [B, k_blocks, H, D]
3. Compute block attention: softmax(block_q @ block_k.T / sqrt(d))
Result: [B, H, q_blocks, k_blocks]
4. Select blocks:
- top_k: Select k highest scoring blocks per row
- top_p: Sort descending, accumulate until sum > p, select all accumulated
5. Create block_mask from selection
6. Execute block_sparse_attn_func
"""
Key Parameters
| Parameter | Default | Description |
|---|---|---|
block_size |
128 | Tokens per block |
chunk_size |
16384 | Processing chunk size |
top_k |
64 | Fixed number of blocks per row |
top_p |
None | Cumulative probability threshold (0.0-1.0) |
pool_method |
"avg" | Pooling method (avg, max) |
Top-K vs Top-P Selection
Top-K: Select exactly k blocks with highest scores per row.
scores = [0.4, 0.3, 0.2, 0.1], k=2
selected = [0.4, 0.3] # indices 0, 1
Top-P (Nucleus): Sort descending, accumulate until exceeding threshold.
scores = [0.4, 0.3, 0.2, 0.1], p=0.8
cumsum = [0.4, 0.7, 0.9, 1.0]
selected = [0.4, 0.3, 0.2] # first 3 (cumsum exceeds 0.8 at index 2)
Computation Flow
query [B, S, H, D]
|
v
Reshape to blocks [B, num_blocks, block_size, H, D]
|
v
Average pool -> block_q [B, num_blocks, H, D]
|
v
Same for K -> block_k [B, num_blocks, H, D]
|
v
Compute block scores [B, H, q_blocks, k_blocks]
|
v
Apply causal mask
|
v
top_k or top_p selection -> block_mask
|
v
block_sparse_attn_func(q, k, v, block_mask)
|
v
output [B, S, H, D]
Comparison
| Method | Estimation Cost | Adaptivity | Typical Sparsity |
|---|---|---|---|
| XAttention | Medium (strided attn) | Threshold-based | 60-80% |
| FlexPrefill | Medium (last-q attn) | JS divergence | 50-70% |
| MInference | Low (last-q attn) | Fixed vertical+slash | 70-90% |
| AvgPool | Medium (pooled attn) | top-k/top-p | 50-80% |
When to Use Each
- XAttention: General purpose, good balance of accuracy and speed
- FlexPrefill: When pattern varies significantly across heads
- MInference: When vertical (sink) and local (slash) patterns dominate
- AvgPool: When you want simple, interpretable block selection
Integration with HuggingFace Models
All methods integrate via FastPrefillConfig in xattn/src/load_llama.py:
from xattn.src.load_llama import FastPrefillConfig, load_model_and_apply_fastprefill
config = FastPrefillConfig(
metric="xattn", # "xattn", "flexprefill", "minference", "avgpool"
threshold=0.9, # for xattn
stride=16, # for xattn
top_k=64, # for avgpool
top_p=None, # for avgpool nucleus sampling
)
model, tokenizer = load_model_and_apply_fastprefill(
model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
fastprefillconfig=config,
)
The forward method of attention layers is monkey-patched to use the selected sparse attention method during prefill.
Key Files
| File | Purpose |
|---|---|
xattn/src/Xattention.py |
XAttention implementation |
xattn/src/Flexprefill.py |
FlexPrefill implementation |
xattn/src/Minference.py |
MInference implementation |
xattn/src/AvgPool.py |
AvgPool implementation |
xattn/src/load_llama.py |
Model loading and method dispatch |
xattn/src/Compass.py |
Another sparse method (gradient-based) |
Dependencies
Required libraries:
torch: PyTorchtriton: For FlexPrefill Triton kernelsflash_attn: Flash Attention for baselineblock_sparse_attn: MIT-HAN-LAB block sparse kernelminference: For MInference vertical_slash kernel
Docker image tzj/xattn:v0.5 has all dependencies pre-installed.
Quest Sparse Policy
Files: nanovllm/kvcache/sparse/quest.py, nanovllm/kvcache/sparse/policy.py
Core Idea
Quest policy selects Top-K blocks based on query-key similarity bounds using min/max key metadata. This enables efficient block selection for CPU offload scenarios.
Scoring Mechanism
# Compute scores using key metadata bounds
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
score_max = torch.einsum('hd,bhd->bh', q, key_max) # [num_blocks, kv_heads]
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] ← averaged!
Critical Limitation - No Per-Head Scheduling
The .mean(dim=-1) averages scores across all heads, making a unified block selection for all heads:
Block A: head0 needs (+4), head1 doesn't (-4) → avg = 0 → NOT selected
Block B: head0 doesn't (-4), head1 needs (+4) → avg = 0 → NOT selected
Block C: both heads moderately need (+2, +2) → avg = +2 → selected
Why Per-Head Scheduling is Infeasible
-
Memory Layout: GPU cache stores all heads together
[block_size, kv_heads, head_dim] -
FlashAttention: Requires complete heads - partial heads cause dimension mismatch
-
Block Granularity: If any head needs a block, the entire block (all heads) must be loaded
Policy Types
| Policy | supports_prefill | supports_decode | Description |
|---|---|---|---|
FullAttentionPolicy |
True | True | Loads all blocks (no sparsity) |
QuestPolicy |
False | True | Decode-only Top-K selection |
Usage Example
from nanovllm.kvcache.sparse.policy import QuestPolicy
# Create Quest policy for decode-only sparse attention
policy = QuestPolicy(topk=8, threshold=4.0)
# Select blocks based on query and key metadata
selected_blocks = policy.select_blocks(
query, # [num_tokens, num_heads, head_dim]
key_min, # [num_blocks, num_heads, head_dim]
key_max, # [num_blocks, num_heads, head_dim]
)
Key Parameters
| Parameter | Default | Description |
|---|---|---|
topk |
8 | Number of blocks to select |
threshold |
4.0 | Minimum score threshold for selection |
Integration with CPU Offload
The Quest policy is used in conjunction with CPU offload to reduce the number of blocks transferred from CPU to GPU during decode:
- During prefill, all blocks are loaded (full attention)
- During decode, Quest selects only top-K important blocks
- Only selected blocks are transferred from CPU to GPU
- This reduces memory bandwidth requirements for long sequences