✨ feat: add DensityObserver for XAttention sparse attention density tracking
- Add DensityObserver class to track per-layer density statistics - Integrate DensityObserver into compute_prefill for GPU-only mode - Fix stride parameter not being passed to xattn_estimate - Add density statistics output to test_ruler.py for XATTN_BSA - Add comprehensive density benchmark documentation Key changes: - nanovllm/utils/density_observer.py: New Observer for density tracking - xattn_bsa.py: Add stride param to xattn_estimate, integrate DensityObserver - test_ruler.py: Enable DensityObserver and print summary for XATTN_BSA - docs/xattn_density_benchmark.md: Benchmark results for 4K-32K contexts Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import torch.cuda.nvtx as nvtx
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.utils.density_observer import DensityObserver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
@@ -258,6 +259,10 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
# Set DensityObserver mode on first layer
|
||||
if layer_id == 0:
|
||||
DensityObserver.set_mode("gpu_only")
|
||||
|
||||
# Get dimensions
|
||||
total_q, num_heads, head_dim = q.shape
|
||||
total_kv, num_kv_heads, _ = k.shape
|
||||
@@ -315,6 +320,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
Q, K_exp,
|
||||
chunk_size=self.chunk_size,
|
||||
block_size=self.BSA_BLOCK_SIZE,
|
||||
stride=self.stride,
|
||||
threshold=self.threshold,
|
||||
use_triton=self.use_triton,
|
||||
causal=True,
|
||||
@@ -360,13 +366,8 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Update statistics (layer 0 only to avoid overcounting)
|
||||
if layer_id == 0:
|
||||
selected_blocks = mask_trimmed.sum().item()
|
||||
total_blocks = q_block_num * k_block_num * num_heads
|
||||
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
|
||||
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
|
||||
f"k_blocks={k_block_num}, density={density:.1%}")
|
||||
# Record density for all layers via DensityObserver
|
||||
DensityObserver.record(layer_id, mask_trimmed, causal=True)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user