11 KiB
Sparse Attention Computation Flow Guide
This document explains the computation flow for block sparse attention methods used in long-context LLM inference. The implementations are from the x-attention project.
Overview
All sparse attention methods follow a common pattern:
1. Estimate important blocks (low-cost estimation)
2. Create block mask (True = compute, False = skip)
3. Execute block sparse attention kernel (only compute selected blocks)
The key difference between methods is how they estimate which blocks are important.
Common Concepts
Block Tiling
All methods divide Q, K, V into blocks:
- Block Size: Typically 64 or 128 tokens per block
- Num Q Blocks:
ceil(seq_len / block_size) - Num K Blocks:
ceil(seq_len / block_size) - Block Mask: Boolean tensor
[batch, heads, q_blocks, k_blocks]
Causal Masking
For autoregressive models, block (i, j) is only valid if j <= i (causal constraint).
Block Sparse Attention Kernel
All methods ultimately call block_sparse_attn_func from MIT-HAN-LAB:
from block_sparse_attn import block_sparse_attn_func
output = block_sparse_attn_func(
query_states, # [batch, seq, heads, head_dim]
key_states, # [batch, seq, heads, head_dim]
value_states, # [batch, seq, heads, head_dim]
block_mask, # [batch, heads, q_blocks, k_blocks] bool
block_size=64, # tokens per block
causal=True,
)
Method 1: XAttention (xattn_estimate)
Source: xattn/src/Xattention.py
Core Idea
Use strided Q/K reshaping to create coarse-grained representations, compute block-level attention scores, and select blocks above a threshold.
Algorithm
def xattn_estimate(query, key, block_size=64, stride=16):
"""
Estimate block importance using strided attention.
1. Reshape Q: [batch, seq, heads, dim] -> [batch, num_blocks, stride, heads, dim]
Then take mean over stride dimension to get block-level Q
2. Reshape K: Same process to get block-level K
3. Compute block attention: softmax(block_Q @ block_K.T / sqrt(d))
Result shape: [batch, heads, q_blocks, k_blocks]
4. Apply causal mask (upper triangle = 0)
5. Threshold: blocks with score > threshold are selected
"""
Key Parameters
| Parameter | Default | Description |
|---|---|---|
block_size |
64 | Tokens per block |
stride |
16 | Stride for coarse Q/K computation |
threshold |
0.9 | Selection threshold (cumulative or direct) |
Computation Flow
query [B, S, H, D]
|
v
Reshape to [B, num_blocks, stride, H, D]
|
v
Mean over stride -> block_q [B, num_blocks, H, D]
|
v
Compute block attention scores [B, H, q_blocks, k_blocks]
|
v
Apply threshold -> block_mask [B, H, q_blocks, k_blocks]
|
v
block_sparse_attn_func(q, k, v, block_mask)
|
v
output [B, S, H, D]
Usage
from xattn.src.Xattention import Xattention_prefill
output = Xattention_prefill(
query_states, key_states, value_states,
threshold=0.9,
stride=16,
)
Method 2: FlexPrefill
Source: xattn/src/Flexprefill.py
Core Idea
Use last-q attention pattern to detect vertical (column-wise important) and slash (diagonal-like) patterns. Adaptively adjust budget based on JS divergence between estimated and uniform distribution.
Algorithm
def Flexprefill_prefill(query, key, value, gamma=0.9, tau=0.1):
"""
1. Compute attention using only last 64 queries (last_q_attn)
This reveals which K positions are globally important
2. Detect vertical pattern: columns with high attention across all last-q
These are "sink tokens" that all queries attend to
3. Detect slash pattern: diagonal bands that capture local attention
4. Compute JS divergence between estimated pattern and uniform
- High divergence = sparse pattern detected, use fewer blocks
- Low divergence = dense pattern needed, use more blocks
5. Adjust budget per head based on divergence
6. Select top-scoring blocks up to budget
"""
Key Parameters
| Parameter | Default | Description |
|---|---|---|
gamma |
0.9 | Base coverage ratio (fraction of blocks to keep) |
tau |
0.1 | JS divergence threshold for adaptive budget |
min_budget |
0.5 | Minimum coverage even for sparse patterns |
block_size |
64 | Tokens per block |
Patterns Detected
-
Vertical Pattern: Columns where many queries attend heavily
- Detected by summing attention across query dimension
- Captures "attention sinks" (e.g., BOS token, punctuation)
-
Slash Pattern: Diagonal bands
- Captures local context attention
- Width determined by
slash_sizeparameter
Computation Flow
query [B, S, H, D]
|
v
Take last 64 queries -> last_q [B, 64, H, D]
|
v
Compute last_q attention to all K -> attn [B, H, 64, S]
|
v
Analyze pattern:
- Vertical: sum over query dim, find high columns
- Slash: diagonal bands
|
v
Compute JS divergence per head
|
v
Adaptive budget = gamma * (1 - divergence/tau)
|
v
Select blocks up to budget -> block_mask
|
v
block_sparse_attn_func(q, k, v, block_mask)
|
v
output [B, S, H, D]
Triton Kernels
FlexPrefill uses custom Triton kernels for efficiency:
flex_prefill_attention_kernel: Block-wise attention with pattern maskingflex_vertical_slash_kernel: Combined vertical + slash pattern attention
Method 3: MInference
Source: xattn/src/Minference.py
Core Idea
Simple and direct: use vertical_slash_sparse_attention kernel with pre-computed vertical and slash indices.
Algorithm
def Minference_prefill(query, key, value, vertical_topk=100):
"""
1. Compute attention from last 64 queries to all K
2. For each head, identify:
- vertical_indices: top-k columns with highest total attention
- slash_indices: diagonal band positions
3. Call vertical_slash_sparse_attention kernel
- Computes attention only at selected positions
- Returns output with zeros elsewhere
"""
Key Parameters
| Parameter | Default | Description |
|---|---|---|
vertical_topk |
100 | Number of important columns to select |
slash_n |
64 | Size of diagonal band |
Computation Flow
query [B, S, H, D]
|
v
Take last 64 queries
|
v
Compute attention scores to all K
|
v
Sum across queries -> column importance [H, S]
|
v
Select top-k columns per head -> vertical_indices [H, topk]
|
v
Generate slash indices (diagonal positions) -> slash_indices [H, slash_n]
|
v
vertical_slash_sparse_attention(q, k, v, vertical_indices, slash_indices)
|
v
output [B, S, H, D]
CUDA Kernel
MInference uses a specialized CUDA kernel:
from minference.ops.block_sparse_attn import vertical_slash_sparse_attention
output = vertical_slash_sparse_attention(
query, key, value,
vertical_indices, # [heads, topk]
slash_indices, # [heads, slash_n]
)
Method 4: AvgPool
Source: xattn/src/AvgPool.py
Core Idea
Pool Q and K within each block using average pooling, compute block-level softmax attention, then select blocks using top-k or top-p (nucleus sampling).
Algorithm
def AvgPool_prefill(query, key, value, block_size=128, top_k=64, top_p=None):
"""
1. Divide Q into blocks, apply average pooling -> block_q [B, q_blocks, H, D]
2. Divide K into blocks, apply average pooling -> block_k [B, k_blocks, H, D]
3. Compute block attention: softmax(block_q @ block_k.T / sqrt(d))
Result: [B, H, q_blocks, k_blocks]
4. Select blocks:
- top_k: Select k highest scoring blocks per row
- top_p: Sort descending, accumulate until sum > p, select all accumulated
5. Create block_mask from selection
6. Execute block_sparse_attn_func
"""
Key Parameters
| Parameter | Default | Description |
|---|---|---|
block_size |
128 | Tokens per block |
chunk_size |
16384 | Processing chunk size |
top_k |
64 | Fixed number of blocks per row |
top_p |
None | Cumulative probability threshold (0.0-1.0) |
pool_method |
"avg" | Pooling method (avg, max) |
Top-K vs Top-P Selection
Top-K: Select exactly k blocks with highest scores per row.
scores = [0.4, 0.3, 0.2, 0.1], k=2
selected = [0.4, 0.3] # indices 0, 1
Top-P (Nucleus): Sort descending, accumulate until exceeding threshold.
scores = [0.4, 0.3, 0.2, 0.1], p=0.8
cumsum = [0.4, 0.7, 0.9, 1.0]
selected = [0.4, 0.3, 0.2] # first 3 (cumsum exceeds 0.8 at index 2)
Computation Flow
query [B, S, H, D]
|
v
Reshape to blocks [B, num_blocks, block_size, H, D]
|
v
Average pool -> block_q [B, num_blocks, H, D]
|
v
Same for K -> block_k [B, num_blocks, H, D]
|
v
Compute block scores [B, H, q_blocks, k_blocks]
|
v
Apply causal mask
|
v
top_k or top_p selection -> block_mask
|
v
block_sparse_attn_func(q, k, v, block_mask)
|
v
output [B, S, H, D]
Comparison
| Method | Estimation Cost | Adaptivity | Typical Sparsity |
|---|---|---|---|
| XAttention | Medium (strided attn) | Threshold-based | 60-80% |
| FlexPrefill | Medium (last-q attn) | JS divergence | 50-70% |
| MInference | Low (last-q attn) | Fixed vertical+slash | 70-90% |
| AvgPool | Medium (pooled attn) | top-k/top-p | 50-80% |
When to Use Each
- XAttention: General purpose, good balance of accuracy and speed
- FlexPrefill: When pattern varies significantly across heads
- MInference: When vertical (sink) and local (slash) patterns dominate
- AvgPool: When you want simple, interpretable block selection
Integration with HuggingFace Models
All methods integrate via FastPrefillConfig in xattn/src/load_llama.py:
from xattn.src.load_llama import FastPrefillConfig, load_model_and_apply_fastprefill
config = FastPrefillConfig(
metric="xattn", # "xattn", "flexprefill", "minference", "avgpool"
threshold=0.9, # for xattn
stride=16, # for xattn
top_k=64, # for avgpool
top_p=None, # for avgpool nucleus sampling
)
model, tokenizer = load_model_and_apply_fastprefill(
model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
fastprefillconfig=config,
)
The forward method of attention layers is monkey-patched to use the selected sparse attention method during prefill.
Key Files
| File | Purpose |
|---|---|
xattn/src/Xattention.py |
XAttention implementation |
xattn/src/Flexprefill.py |
FlexPrefill implementation |
xattn/src/Minference.py |
MInference implementation |
xattn/src/AvgPool.py |
AvgPool implementation |
xattn/src/load_llama.py |
Model loading and method dispatch |
xattn/src/Compass.py |
Another sparse method (gradient-based) |
Dependencies
Required libraries:
torch: PyTorchtriton: For FlexPrefill Triton kernelsflash_attn: Flash Attention for baselineblock_sparse_attn: MIT-HAN-LAB block sparse kernelminference: For MInference vertical_slash kernel
Docker image tzj/xattn:v0.5 has all dependencies pre-installed.