feat: add XAttention sparse policy integration

Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload
execution path. Uses FlashAttention with native GQA support for
offload mode.

New files:
- nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility
- nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention
- nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation

Modified:
- nanovllm/config.py: Add XATTN configuration parameters
- nanovllm/engine/model_runner.py: Support XATTN policy
- nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy
- tests/test_ruler.py: Add --sparse-policy parameter

Test results (32k ruler):
- NIAH tasks: 12/12 (100%)
- QA/Recall tasks: 11/15 (73%)
- Overall: 23/27 (85%)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-14 10:04:46 +08:00
parent 029894118d
commit ac1ccbceaa
10 changed files with 1001 additions and 813 deletions

View File

@@ -178,19 +178,34 @@ class ModelRunner:
# Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Create sparse prefill policy for GPU-only path
# This is separate from CPU offload sparse policy (which uses select_blocks)
# Create sparse prefill policy
# This is used for both GPU-only and CPU offload modes when policy supports prefill
self.sparse_prefill_policy = None
if not config.enable_cpu_offload and config.sparse_policy != SparsePolicyType.FULL:
if config.sparse_policy != SparsePolicyType.FULL:
from nanovllm.kvcache.sparse import create_sparse_policy
policy = create_sparse_policy(
config.sparse_policy,
vertical_size=config.minference_vertical_size,
slash_size=config.minference_slash_size,
adaptive_budget=config.minference_adaptive_budget,
num_sink_tokens=config.minference_num_sink_tokens,
num_recent_diags=config.minference_num_recent_diags,
)
# Get policy-specific parameters based on type
if config.sparse_policy == SparsePolicyType.XATTN:
policy_kwargs = {
"stride": config.xattn_stride,
"threshold": config.xattn_threshold,
"chunk_size": config.xattn_chunk_size,
"use_triton": config.xattn_use_triton,
"keep_sink": config.xattn_keep_sink,
"keep_recent": config.xattn_keep_recent,
"norm": config.xattn_norm,
}
else: # MINFERENCE or others
policy_kwargs = {
"vertical_size": config.minference_vertical_size,
"slash_size": config.minference_slash_size,
"adaptive_budget": config.minference_adaptive_budget,
"num_sink_tokens": config.minference_num_sink_tokens,
"num_recent_diags": config.minference_num_recent_diags,
}
policy = create_sparse_policy(config.sparse_policy, **policy_kwargs)
# Only use if policy supports sparse prefill
if policy.supports_prefill:
self.sparse_prefill_policy = policy