feat: add DensityObserver for XAttention sparse attention density tracking

- Add DensityObserver class to track per-layer density statistics
- Integrate DensityObserver into compute_prefill for GPU-only mode
- Fix stride parameter not being passed to xattn_estimate
- Add density statistics output to test_ruler.py for XATTN_BSA
- Add comprehensive density benchmark documentation

Key changes:
- nanovllm/utils/density_observer.py: New Observer for density tracking
- xattn_bsa.py: Add stride param to xattn_estimate, integrate DensityObserver
- test_ruler.py: Enable DensityObserver and print summary for XATTN_BSA
- docs/xattn_density_benchmark.md: Benchmark results for 4K-32K contexts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-30 16:26:56 +08:00
parent 4484a1482c
commit f6ac4ccdde
5 changed files with 387 additions and 7 deletions

View File

@@ -41,6 +41,7 @@ from pathlib import Path
from typing import List, Dict, Tuple, Optional
from nanovllm import LLM, SamplingParams
from nanovllm.utils.density_observer import DensityObserver
# ============================================================
@@ -381,6 +382,13 @@ def run_ruler_benchmark(
print(f"Fresh LLM mode: {fresh_llm}")
print(f"{'='*60}")
# Enable DensityObserver for XAttention BSA
if sparse_policy and sparse_policy.upper() == "XATTN_BSA":
DensityObserver.enable()
DensityObserver.complete_reset()
if not json_output:
print("[DensityObserver] Enabled for XAttention BSA")
# LLM initialization kwargs
llm_kwargs = {
"max_model_len": max_model_len,
@@ -471,6 +479,14 @@ def run_ruler_benchmark(
print(f"{'-'*54}")
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
print(f"\nTime: {total_time:.1f}s")
# Print DensityObserver summary if enabled
if sparse_policy and sparse_policy.upper() == "XATTN_BSA" and DensityObserver.is_enabled():
print(f"\n{'='*60}")
print("Density Statistics (XAttention BSA)")
print(f"{'='*60}")
DensityObserver.print_summary()
print(f"{'='*60}\n")
results = {