Files
nano-vllm/docs/sparse_attention_guide.md
2025-12-29 19:56:54 +08:00

11 KiB

Sparse Attention Computation Flow Guide

This document explains the computation flow for block sparse attention methods used in long-context LLM inference. The implementations are from the x-attention project.

Overview

All sparse attention methods follow a common pattern:

1. Estimate important blocks (low-cost estimation)
2. Create block mask (True = compute, False = skip)
3. Execute block sparse attention kernel (only compute selected blocks)

The key difference between methods is how they estimate which blocks are important.


Common Concepts

Block Tiling

All methods divide Q, K, V into blocks:

  • Block Size: Typically 64 or 128 tokens per block
  • Num Q Blocks: ceil(seq_len / block_size)
  • Num K Blocks: ceil(seq_len / block_size)
  • Block Mask: Boolean tensor [batch, heads, q_blocks, k_blocks]

Causal Masking

For autoregressive models, block (i, j) is only valid if j <= i (causal constraint).

Block Sparse Attention Kernel

All methods ultimately call block_sparse_attn_func from MIT-HAN-LAB:

from block_sparse_attn import block_sparse_attn_func

output = block_sparse_attn_func(
    query_states,      # [batch, seq, heads, head_dim]
    key_states,        # [batch, seq, heads, head_dim]
    value_states,      # [batch, seq, heads, head_dim]
    block_mask,        # [batch, heads, q_blocks, k_blocks] bool
    block_size=64,     # tokens per block
    causal=True,
)

Method 1: XAttention (xattn_estimate)

Source: xattn/src/Xattention.py

Core Idea

Use strided Q/K reshaping to create coarse-grained representations, compute block-level attention scores, and select blocks above a threshold.

Algorithm

def xattn_estimate(query, key, block_size=64, stride=16):
    """
    Estimate block importance using strided attention.

    1. Reshape Q: [batch, seq, heads, dim] -> [batch, num_blocks, stride, heads, dim]
       Then take mean over stride dimension to get block-level Q

    2. Reshape K: Same process to get block-level K

    3. Compute block attention: softmax(block_Q @ block_K.T / sqrt(d))
       Result shape: [batch, heads, q_blocks, k_blocks]

    4. Apply causal mask (upper triangle = 0)

    5. Threshold: blocks with score > threshold are selected
    """

Key Parameters

Parameter Default Description
block_size 64 Tokens per block
stride 16 Stride for coarse Q/K computation
threshold 0.9 Selection threshold (cumulative or direct)

Computation Flow

query [B, S, H, D]
    |
    v
Reshape to [B, num_blocks, stride, H, D]
    |
    v
Mean over stride -> block_q [B, num_blocks, H, D]
    |
    v
Compute block attention scores [B, H, q_blocks, k_blocks]
    |
    v
Apply threshold -> block_mask [B, H, q_blocks, k_blocks]
    |
    v
block_sparse_attn_func(q, k, v, block_mask)
    |
    v
output [B, S, H, D]

Usage

from xattn.src.Xattention import Xattention_prefill

output = Xattention_prefill(
    query_states, key_states, value_states,
    threshold=0.9,
    stride=16,
)

Method 2: FlexPrefill

Source: xattn/src/Flexprefill.py

Core Idea

Use last-q attention pattern to detect vertical (column-wise important) and slash (diagonal-like) patterns. Adaptively adjust budget based on JS divergence between estimated and uniform distribution.

Algorithm

def Flexprefill_prefill(query, key, value, gamma=0.9, tau=0.1):
    """
    1. Compute attention using only last 64 queries (last_q_attn)
       This reveals which K positions are globally important

    2. Detect vertical pattern: columns with high attention across all last-q
       These are "sink tokens" that all queries attend to

    3. Detect slash pattern: diagonal bands that capture local attention

    4. Compute JS divergence between estimated pattern and uniform
       - High divergence = sparse pattern detected, use fewer blocks
       - Low divergence = dense pattern needed, use more blocks

    5. Adjust budget per head based on divergence

    6. Select top-scoring blocks up to budget
    """

Key Parameters

Parameter Default Description
gamma 0.9 Base coverage ratio (fraction of blocks to keep)
tau 0.1 JS divergence threshold for adaptive budget
min_budget 0.5 Minimum coverage even for sparse patterns
block_size 64 Tokens per block

Patterns Detected

  1. Vertical Pattern: Columns where many queries attend heavily

    • Detected by summing attention across query dimension
    • Captures "attention sinks" (e.g., BOS token, punctuation)
  2. Slash Pattern: Diagonal bands

    • Captures local context attention
    • Width determined by slash_size parameter

Computation Flow

query [B, S, H, D]
    |
    v
Take last 64 queries -> last_q [B, 64, H, D]
    |
    v
Compute last_q attention to all K -> attn [B, H, 64, S]
    |
    v
Analyze pattern:
  - Vertical: sum over query dim, find high columns
  - Slash: diagonal bands
    |
    v
Compute JS divergence per head
    |
    v
Adaptive budget = gamma * (1 - divergence/tau)
    |
    v
Select blocks up to budget -> block_mask
    |
    v
block_sparse_attn_func(q, k, v, block_mask)
    |
    v
output [B, S, H, D]

Triton Kernels

FlexPrefill uses custom Triton kernels for efficiency:

  • flex_prefill_attention_kernel: Block-wise attention with pattern masking
  • flex_vertical_slash_kernel: Combined vertical + slash pattern attention

Method 3: MInference

Source: xattn/src/Minference.py

Core Idea

Simple and direct: use vertical_slash_sparse_attention kernel with pre-computed vertical and slash indices.

Algorithm

def Minference_prefill(query, key, value, vertical_topk=100):
    """
    1. Compute attention from last 64 queries to all K

    2. For each head, identify:
       - vertical_indices: top-k columns with highest total attention
       - slash_indices: diagonal band positions

    3. Call vertical_slash_sparse_attention kernel
       - Computes attention only at selected positions
       - Returns output with zeros elsewhere
    """

Key Parameters

Parameter Default Description
vertical_topk 100 Number of important columns to select
slash_n 64 Size of diagonal band

Computation Flow

query [B, S, H, D]
    |
    v
Take last 64 queries
    |
    v
Compute attention scores to all K
    |
    v
Sum across queries -> column importance [H, S]
    |
    v
Select top-k columns per head -> vertical_indices [H, topk]
    |
    v
Generate slash indices (diagonal positions) -> slash_indices [H, slash_n]
    |
    v
vertical_slash_sparse_attention(q, k, v, vertical_indices, slash_indices)
    |
    v
output [B, S, H, D]

CUDA Kernel

MInference uses a specialized CUDA kernel:

from minference.ops.block_sparse_attn import vertical_slash_sparse_attention

output = vertical_slash_sparse_attention(
    query, key, value,
    vertical_indices,  # [heads, topk]
    slash_indices,     # [heads, slash_n]
)

Method 4: AvgPool

Source: xattn/src/AvgPool.py

Core Idea

Pool Q and K within each block using average pooling, compute block-level softmax attention, then select blocks using top-k or top-p (nucleus sampling).

Algorithm

def AvgPool_prefill(query, key, value, block_size=128, top_k=64, top_p=None):
    """
    1. Divide Q into blocks, apply average pooling -> block_q [B, q_blocks, H, D]

    2. Divide K into blocks, apply average pooling -> block_k [B, k_blocks, H, D]

    3. Compute block attention: softmax(block_q @ block_k.T / sqrt(d))
       Result: [B, H, q_blocks, k_blocks]

    4. Select blocks:
       - top_k: Select k highest scoring blocks per row
       - top_p: Sort descending, accumulate until sum > p, select all accumulated

    5. Create block_mask from selection

    6. Execute block_sparse_attn_func
    """

Key Parameters

Parameter Default Description
block_size 128 Tokens per block
chunk_size 16384 Processing chunk size
top_k 64 Fixed number of blocks per row
top_p None Cumulative probability threshold (0.0-1.0)
pool_method "avg" Pooling method (avg, max)

Top-K vs Top-P Selection

Top-K: Select exactly k blocks with highest scores per row.

scores = [0.4, 0.3, 0.2, 0.1], k=2
selected = [0.4, 0.3]  # indices 0, 1

Top-P (Nucleus): Sort descending, accumulate until exceeding threshold.

scores = [0.4, 0.3, 0.2, 0.1], p=0.8
cumsum = [0.4, 0.7, 0.9, 1.0]
selected = [0.4, 0.3, 0.2]  # first 3 (cumsum exceeds 0.8 at index 2)

Computation Flow

query [B, S, H, D]
    |
    v
Reshape to blocks [B, num_blocks, block_size, H, D]
    |
    v
Average pool -> block_q [B, num_blocks, H, D]
    |
    v
Same for K -> block_k [B, num_blocks, H, D]
    |
    v
Compute block scores [B, H, q_blocks, k_blocks]
    |
    v
Apply causal mask
    |
    v
top_k or top_p selection -> block_mask
    |
    v
block_sparse_attn_func(q, k, v, block_mask)
    |
    v
output [B, S, H, D]

Comparison

Method Estimation Cost Adaptivity Typical Sparsity
XAttention Medium (strided attn) Threshold-based 60-80%
FlexPrefill Medium (last-q attn) JS divergence 50-70%
MInference Low (last-q attn) Fixed vertical+slash 70-90%
AvgPool Medium (pooled attn) top-k/top-p 50-80%

When to Use Each

  • XAttention: General purpose, good balance of accuracy and speed
  • FlexPrefill: When pattern varies significantly across heads
  • MInference: When vertical (sink) and local (slash) patterns dominate
  • AvgPool: When you want simple, interpretable block selection

Integration with HuggingFace Models

All methods integrate via FastPrefillConfig in xattn/src/load_llama.py:

from xattn.src.load_llama import FastPrefillConfig, load_model_and_apply_fastprefill

config = FastPrefillConfig(
    metric="xattn",      # "xattn", "flexprefill", "minference", "avgpool"
    threshold=0.9,       # for xattn
    stride=16,           # for xattn
    top_k=64,            # for avgpool
    top_p=None,          # for avgpool nucleus sampling
)

model, tokenizer = load_model_and_apply_fastprefill(
    model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
    fastprefillconfig=config,
)

The forward method of attention layers is monkey-patched to use the selected sparse attention method during prefill.


Key Files

File Purpose
xattn/src/Xattention.py XAttention implementation
xattn/src/Flexprefill.py FlexPrefill implementation
xattn/src/Minference.py MInference implementation
xattn/src/AvgPool.py AvgPool implementation
xattn/src/load_llama.py Model loading and method dispatch
xattn/src/Compass.py Another sparse method (gradient-based)

Dependencies

Required libraries:

  • torch: PyTorch
  • triton: For FlexPrefill Triton kernels
  • flash_attn: Flash Attention for baseline
  • block_sparse_attn: MIT-HAN-LAB block sparse kernel
  • minference: For MInference vertical_slash kernel

Docker image tzj/xattn:v0.5 has all dependencies pre-installed.