feat: add XAttention sparse policy integration

Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload
execution path. Uses FlashAttention with native GQA support for
offload mode.

New files:
- nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility
- nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention
- nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation

Modified:
- nanovllm/config.py: Add XATTN configuration parameters
- nanovllm/engine/model_runner.py: Support XATTN policy
- nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy
- tests/test_ruler.py: Add --sparse-policy parameter

Test results (32k ruler):
- NIAH tasks: 12/12 (100%)
- QA/Recall tasks: 11/15 (73%)
- Overall: 23/27 (85%)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-14 10:04:46 +08:00
parent 029894118d
commit ac1ccbceaa
10 changed files with 1001 additions and 813 deletions

View File

@@ -10,6 +10,7 @@ class SparsePolicyType(Enum):
FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only)
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
XATTN = auto() # XAttention chunked estimation + block-sparse attention
@dataclass
@@ -53,6 +54,15 @@ class Config:
minference_num_sink_tokens: int = 30 # Sink tokens to always keep
minference_num_recent_diags: int = 100 # Recent diagonals to always keep
# XAttention configuration (used when sparse_policy == XATTN)
xattn_stride: int = 8 # Stride for reorganizing Q/K
xattn_threshold: float = 0.9 # Block selection threshold (0-1)
xattn_chunk_size: int = 16384 # Chunk size for estimation (auto if None)
xattn_use_triton: bool = True # Use Triton kernels (requires SM 80+)
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
xattn_norm: float = 1.0 # Normalization factor for attention scores
def __post_init__(self):
assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0