From bf4c63c7ec93372933e0fcc7fd71f8ea31c55f81 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Mon, 29 Dec 2025 19:56:54 +0800 Subject: [PATCH] [docs] Added Sparse Attn. --- CLAUDE.md | 4 + docs/sparse_attention_guide.md | 442 +++++++++++++++++++++++++++++++++ 2 files changed, 446 insertions(+) create mode 100644 docs/sparse_attention_guide.md diff --git a/CLAUDE.md b/CLAUDE.md index fffd286..bdbb4d1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -6,6 +6,10 @@ This file provides guidance to Claude Code when working with this repository. Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference. +## Sparse Attention + +For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md). + ## Architecture ### Core Components diff --git a/docs/sparse_attention_guide.md b/docs/sparse_attention_guide.md new file mode 100644 index 0000000..5d441a6 --- /dev/null +++ b/docs/sparse_attention_guide.md @@ -0,0 +1,442 @@ +# Sparse Attention Computation Flow Guide + +This document explains the computation flow for block sparse attention methods used in long-context LLM inference. The implementations are from the x-attention project. + +## Overview + +All sparse attention methods follow a common pattern: + +``` +1. Estimate important blocks (low-cost estimation) +2. Create block mask (True = compute, False = skip) +3. Execute block sparse attention kernel (only compute selected blocks) +``` + +The key difference between methods is **how they estimate which blocks are important**. + +--- + +## Common Concepts + +### Block Tiling + +All methods divide Q, K, V into blocks: +- **Block Size**: Typically 64 or 128 tokens per block +- **Num Q Blocks**: `ceil(seq_len / block_size)` +- **Num K Blocks**: `ceil(seq_len / block_size)` +- **Block Mask**: Boolean tensor `[batch, heads, q_blocks, k_blocks]` + +### Causal Masking + +For autoregressive models, block `(i, j)` is only valid if `j <= i` (causal constraint). + +### Block Sparse Attention Kernel + +All methods ultimately call `block_sparse_attn_func` from MIT-HAN-LAB: +```python +from block_sparse_attn import block_sparse_attn_func + +output = block_sparse_attn_func( + query_states, # [batch, seq, heads, head_dim] + key_states, # [batch, seq, heads, head_dim] + value_states, # [batch, seq, heads, head_dim] + block_mask, # [batch, heads, q_blocks, k_blocks] bool + block_size=64, # tokens per block + causal=True, +) +``` + +--- + +## Method 1: XAttention (xattn_estimate) + +**Source**: `xattn/src/Xattention.py` + +### Core Idea + +Use **strided Q/K reshaping** to create coarse-grained representations, compute block-level attention scores, and select blocks above a threshold. + +### Algorithm + +```python +def xattn_estimate(query, key, block_size=64, stride=16): + """ + Estimate block importance using strided attention. + + 1. Reshape Q: [batch, seq, heads, dim] -> [batch, num_blocks, stride, heads, dim] + Then take mean over stride dimension to get block-level Q + + 2. Reshape K: Same process to get block-level K + + 3. Compute block attention: softmax(block_Q @ block_K.T / sqrt(d)) + Result shape: [batch, heads, q_blocks, k_blocks] + + 4. Apply causal mask (upper triangle = 0) + + 5. Threshold: blocks with score > threshold are selected + """ +``` + +### Key Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `block_size` | 64 | Tokens per block | +| `stride` | 16 | Stride for coarse Q/K computation | +| `threshold` | 0.9 | Selection threshold (cumulative or direct) | + +### Computation Flow + +``` +query [B, S, H, D] + | + v +Reshape to [B, num_blocks, stride, H, D] + | + v +Mean over stride -> block_q [B, num_blocks, H, D] + | + v +Compute block attention scores [B, H, q_blocks, k_blocks] + | + v +Apply threshold -> block_mask [B, H, q_blocks, k_blocks] + | + v +block_sparse_attn_func(q, k, v, block_mask) + | + v +output [B, S, H, D] +``` + +### Usage + +```python +from xattn.src.Xattention import Xattention_prefill + +output = Xattention_prefill( + query_states, key_states, value_states, + threshold=0.9, + stride=16, +) +``` + +--- + +## Method 2: FlexPrefill + +**Source**: `xattn/src/Flexprefill.py` + +### Core Idea + +Use **last-q attention pattern** to detect vertical (column-wise important) and slash (diagonal-like) patterns. Adaptively adjust budget based on **JS divergence** between estimated and uniform distribution. + +### Algorithm + +```python +def Flexprefill_prefill(query, key, value, gamma=0.9, tau=0.1): + """ + 1. Compute attention using only last 64 queries (last_q_attn) + This reveals which K positions are globally important + + 2. Detect vertical pattern: columns with high attention across all last-q + These are "sink tokens" that all queries attend to + + 3. Detect slash pattern: diagonal bands that capture local attention + + 4. Compute JS divergence between estimated pattern and uniform + - High divergence = sparse pattern detected, use fewer blocks + - Low divergence = dense pattern needed, use more blocks + + 5. Adjust budget per head based on divergence + + 6. Select top-scoring blocks up to budget + """ +``` + +### Key Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `gamma` | 0.9 | Base coverage ratio (fraction of blocks to keep) | +| `tau` | 0.1 | JS divergence threshold for adaptive budget | +| `min_budget` | 0.5 | Minimum coverage even for sparse patterns | +| `block_size` | 64 | Tokens per block | + +### Patterns Detected + +1. **Vertical Pattern**: Columns where many queries attend heavily + - Detected by summing attention across query dimension + - Captures "attention sinks" (e.g., BOS token, punctuation) + +2. **Slash Pattern**: Diagonal bands + - Captures local context attention + - Width determined by `slash_size` parameter + +### Computation Flow + +``` +query [B, S, H, D] + | + v +Take last 64 queries -> last_q [B, 64, H, D] + | + v +Compute last_q attention to all K -> attn [B, H, 64, S] + | + v +Analyze pattern: + - Vertical: sum over query dim, find high columns + - Slash: diagonal bands + | + v +Compute JS divergence per head + | + v +Adaptive budget = gamma * (1 - divergence/tau) + | + v +Select blocks up to budget -> block_mask + | + v +block_sparse_attn_func(q, k, v, block_mask) + | + v +output [B, S, H, D] +``` + +### Triton Kernels + +FlexPrefill uses custom Triton kernels for efficiency: +- `flex_prefill_attention_kernel`: Block-wise attention with pattern masking +- `flex_vertical_slash_kernel`: Combined vertical + slash pattern attention + +--- + +## Method 3: MInference + +**Source**: `xattn/src/Minference.py` + +### Core Idea + +Simple and direct: use **vertical_slash_sparse_attention** kernel with pre-computed vertical and slash indices. + +### Algorithm + +```python +def Minference_prefill(query, key, value, vertical_topk=100): + """ + 1. Compute attention from last 64 queries to all K + + 2. For each head, identify: + - vertical_indices: top-k columns with highest total attention + - slash_indices: diagonal band positions + + 3. Call vertical_slash_sparse_attention kernel + - Computes attention only at selected positions + - Returns output with zeros elsewhere + """ +``` + +### Key Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `vertical_topk` | 100 | Number of important columns to select | +| `slash_n` | 64 | Size of diagonal band | + +### Computation Flow + +``` +query [B, S, H, D] + | + v +Take last 64 queries + | + v +Compute attention scores to all K + | + v +Sum across queries -> column importance [H, S] + | + v +Select top-k columns per head -> vertical_indices [H, topk] + | + v +Generate slash indices (diagonal positions) -> slash_indices [H, slash_n] + | + v +vertical_slash_sparse_attention(q, k, v, vertical_indices, slash_indices) + | + v +output [B, S, H, D] +``` + +### CUDA Kernel + +MInference uses a specialized CUDA kernel: +```python +from minference.ops.block_sparse_attn import vertical_slash_sparse_attention + +output = vertical_slash_sparse_attention( + query, key, value, + vertical_indices, # [heads, topk] + slash_indices, # [heads, slash_n] +) +``` + +--- + +## Method 4: AvgPool + +**Source**: `xattn/src/AvgPool.py` + +### Core Idea + +Pool Q and K within each block using average pooling, compute block-level softmax attention, then select blocks using **top-k** or **top-p (nucleus sampling)**. + +### Algorithm + +```python +def AvgPool_prefill(query, key, value, block_size=128, top_k=64, top_p=None): + """ + 1. Divide Q into blocks, apply average pooling -> block_q [B, q_blocks, H, D] + + 2. Divide K into blocks, apply average pooling -> block_k [B, k_blocks, H, D] + + 3. Compute block attention: softmax(block_q @ block_k.T / sqrt(d)) + Result: [B, H, q_blocks, k_blocks] + + 4. Select blocks: + - top_k: Select k highest scoring blocks per row + - top_p: Sort descending, accumulate until sum > p, select all accumulated + + 5. Create block_mask from selection + + 6. Execute block_sparse_attn_func + """ +``` + +### Key Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `block_size` | 128 | Tokens per block | +| `chunk_size` | 16384 | Processing chunk size | +| `top_k` | 64 | Fixed number of blocks per row | +| `top_p` | None | Cumulative probability threshold (0.0-1.0) | +| `pool_method` | "avg" | Pooling method (avg, max) | + +### Top-K vs Top-P Selection + +**Top-K**: Select exactly k blocks with highest scores per row. +``` +scores = [0.4, 0.3, 0.2, 0.1], k=2 +selected = [0.4, 0.3] # indices 0, 1 +``` + +**Top-P (Nucleus)**: Sort descending, accumulate until exceeding threshold. +``` +scores = [0.4, 0.3, 0.2, 0.1], p=0.8 +cumsum = [0.4, 0.7, 0.9, 1.0] +selected = [0.4, 0.3, 0.2] # first 3 (cumsum exceeds 0.8 at index 2) +``` + +### Computation Flow + +``` +query [B, S, H, D] + | + v +Reshape to blocks [B, num_blocks, block_size, H, D] + | + v +Average pool -> block_q [B, num_blocks, H, D] + | + v +Same for K -> block_k [B, num_blocks, H, D] + | + v +Compute block scores [B, H, q_blocks, k_blocks] + | + v +Apply causal mask + | + v +top_k or top_p selection -> block_mask + | + v +block_sparse_attn_func(q, k, v, block_mask) + | + v +output [B, S, H, D] +``` + +--- + +## Comparison + +| Method | Estimation Cost | Adaptivity | Typical Sparsity | +|--------|-----------------|------------|------------------| +| XAttention | Medium (strided attn) | Threshold-based | 60-80% | +| FlexPrefill | Medium (last-q attn) | JS divergence | 50-70% | +| MInference | Low (last-q attn) | Fixed vertical+slash | 70-90% | +| AvgPool | Medium (pooled attn) | top-k/top-p | 50-80% | + +### When to Use Each + +- **XAttention**: General purpose, good balance of accuracy and speed +- **FlexPrefill**: When pattern varies significantly across heads +- **MInference**: When vertical (sink) and local (slash) patterns dominate +- **AvgPool**: When you want simple, interpretable block selection + +--- + +## Integration with HuggingFace Models + +All methods integrate via `FastPrefillConfig` in `xattn/src/load_llama.py`: + +```python +from xattn.src.load_llama import FastPrefillConfig, load_model_and_apply_fastprefill + +config = FastPrefillConfig( + metric="xattn", # "xattn", "flexprefill", "minference", "avgpool" + threshold=0.9, # for xattn + stride=16, # for xattn + top_k=64, # for avgpool + top_p=None, # for avgpool nucleus sampling +) + +model, tokenizer = load_model_and_apply_fastprefill( + model_name_or_path="meta-llama/Llama-3.1-8B-Instruct", + fastprefillconfig=config, +) +``` + +The `forward` method of attention layers is monkey-patched to use the selected sparse attention method during prefill. + +--- + +## Key Files + +| File | Purpose | +|------|---------| +| `xattn/src/Xattention.py` | XAttention implementation | +| `xattn/src/Flexprefill.py` | FlexPrefill implementation | +| `xattn/src/Minference.py` | MInference implementation | +| `xattn/src/AvgPool.py` | AvgPool implementation | +| `xattn/src/load_llama.py` | Model loading and method dispatch | +| `xattn/src/Compass.py` | Another sparse method (gradient-based) | + +--- + +## Dependencies + +Required libraries: +- `torch`: PyTorch +- `triton`: For FlexPrefill Triton kernels +- `flash_attn`: Flash Attention for baseline +- `block_sparse_attn`: MIT-HAN-LAB block sparse kernel +- `minference`: For MInference vertical_slash kernel + +Docker image `tzj/xattn:v0.5` has all dependencies pre-installed.