feat: add DensityObserver for XAttention sparse attention density tracking

- Add DensityObserver class to track per-layer density statistics
- Integrate DensityObserver into compute_prefill for GPU-only mode
- Fix stride parameter not being passed to xattn_estimate
- Add density statistics output to test_ruler.py for XATTN_BSA
- Add comprehensive density benchmark documentation

Key changes:
- nanovllm/utils/density_observer.py: New Observer for density tracking
- xattn_bsa.py: Add stride param to xattn_estimate, integrate DensityObserver
- test_ruler.py: Enable DensityObserver and print summary for XATTN_BSA
- docs/xattn_density_benchmark.md: Benchmark results for 4K-32K contexts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-30 16:26:56 +08:00
parent 4484a1482c
commit f6ac4ccdde
5 changed files with 387 additions and 7 deletions

View File

@@ -17,6 +17,7 @@ import torch.cuda.nvtx as nvtx
from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.utils.density_observer import DensityObserver
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
@@ -258,6 +259,10 @@ class XAttentionBSAPolicy(SparsePolicy):
from nanovllm.ops.xattn import xattn_estimate
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("gpu_only")
# Get dimensions
total_q, num_heads, head_dim = q.shape
total_kv, num_kv_heads, _ = k.shape
@@ -315,6 +320,7 @@ class XAttentionBSAPolicy(SparsePolicy):
Q, K_exp,
chunk_size=self.chunk_size,
block_size=self.BSA_BLOCK_SIZE,
stride=self.stride,
threshold=self.threshold,
use_triton=self.use_triton,
causal=True,
@@ -360,13 +366,8 @@ class XAttentionBSAPolicy(SparsePolicy):
is_causal=True,
)
# Update statistics (layer 0 only to avoid overcounting)
if layer_id == 0:
selected_blocks = mask_trimmed.sum().item()
total_blocks = q_block_num * k_block_num * num_heads
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
f"k_blocks={k_block_num}, density={density:.1%}")
# Record density for all layers via DensityObserver
DensityObserver.record(layer_id, mask_trimmed, causal=True)
return output