From 5eb35982bfad9bf6cf691757b864233c98e1e9b9 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Fri, 23 Jan 2026 08:53:22 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20feat:=20add=20density=20statisti?= =?UTF-8?q?cs=20tracking=20to=20sparse=20policies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add statistics tracking to compare block selection between policies: - XAttentionBSAPolicy: track available/selected blocks per chunk - FullAttentionPolicy: track total blocks (always 100% density) - Add reset_stats(), get_density_stats(), print_density_stats() methods - Use logger.debug for per-chunk density logging Results on 32K niah_single_1: - Full: 100% density across all chunks - XAttn BSA: 90% -> 73% density (saves ~25-30% blocks in later chunks) Co-Authored-By: Claude Opus 4.5 --- nanovllm/kvcache/sparse/full_policy.py | 30 +++++++++++++++ nanovllm/kvcache/sparse/xattn_bsa.py | 53 +++++++++++++++++++++++--- 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index 52b846c..9cfd061 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -37,6 +37,11 @@ class FullAttentionPolicy(SparsePolicy): supports_prefill = True supports_decode = True + def __init__(self): + """Initialize with statistics tracking.""" + self._stats_total_blocks = 0 + self._stats_num_chunks = 0 + def select_blocks( self, available_blocks: List[int], @@ -44,8 +49,33 @@ class FullAttentionPolicy(SparsePolicy): ctx: PolicyContext, ) -> List[int]: """Return all blocks - no sparsity.""" + # Update statistics (only for layer 0 to avoid overcounting) + if ctx.layer_id == 0 and available_blocks: + self._stats_total_blocks += len(available_blocks) + self._stats_num_chunks += 1 + logger.debug(f"[Full] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, density=100.0%") return available_blocks + def reset_stats(self) -> None: + """Reset density statistics.""" + self._stats_total_blocks = 0 + self._stats_num_chunks = 0 + + def get_density_stats(self) -> dict: + """Get density statistics.""" + return { + "total_available_blocks": self._stats_total_blocks, + "total_selected_blocks": self._stats_total_blocks, # Full = all selected + "num_chunks": self._stats_num_chunks, + "overall_density": 1.0, # Always 100% + } + + def print_density_stats(self) -> None: + """Print density statistics summary.""" + stats = self.get_density_stats() + logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, " + f"blocks={stats['total_available_blocks']}, density=100.0%") + def compute_chunked_prefill( self, q: torch.Tensor, diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 7749a9d..1fedf89 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -117,6 +117,11 @@ class XAttentionBSAPolicy(SparsePolicy): # Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]] self.sparse_metadata: dict = {} + # Statistics for density tracking + self._stats_total_available_blocks = 0 + self._stats_total_selected_blocks = 0 + self._stats_num_chunks = 0 + def select_blocks( self, available_blocks: List[int], @@ -298,18 +303,23 @@ class XAttentionBSAPolicy(SparsePolicy): block_selected = vote_ratio > vote_threshold selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel] - # Log density for layer 0 only - if layer_id == 0: - density = len(selected_block_ids) / len(available_blocks) if available_blocks else 1.0 - logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, " - f"selected={len(selected_block_ids)}, density={density:.1%}") - # Always include first block (sink) and last block for safety if available_blocks and available_blocks[0] not in selected_block_ids: selected_block_ids.insert(0, available_blocks[0]) if available_blocks and available_blocks[-1] not in selected_block_ids: selected_block_ids.append(available_blocks[-1]) + # Update statistics (only for layer 0 to avoid overcounting) + if layer_id == 0 and available_blocks: + self._stats_total_available_blocks += len(available_blocks) + self._stats_total_selected_blocks += len(selected_block_ids) + self._stats_num_chunks += 1 + + # Log per-chunk density + chunk_density = len(selected_block_ids) / len(available_blocks) + logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, " + f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}") + return selected_block_ids def compute_chunked_prefill( @@ -460,6 +470,37 @@ class XAttentionBSAPolicy(SparsePolicy): def reset(self) -> None: """Reset policy state and clear sparse metadata.""" self.sparse_metadata.clear() + # Don't reset statistics here - they accumulate across the entire prefill + + def reset_stats(self) -> None: + """Reset density statistics.""" + self._stats_total_available_blocks = 0 + self._stats_total_selected_blocks = 0 + self._stats_num_chunks = 0 + + def get_density_stats(self) -> dict: + """Get density statistics.""" + if self._stats_total_available_blocks == 0: + return { + "total_available_blocks": 0, + "total_selected_blocks": 0, + "num_chunks": 0, + "overall_density": 0.0, + } + return { + "total_available_blocks": self._stats_total_available_blocks, + "total_selected_blocks": self._stats_total_selected_blocks, + "num_chunks": self._stats_num_chunks, + "overall_density": self._stats_total_selected_blocks / self._stats_total_available_blocks, + } + + def print_density_stats(self) -> None: + """Print density statistics summary.""" + stats = self.get_density_stats() + logger.info(f"[XAttn BSA] Density Stats: chunks={stats['num_chunks']}, " + f"available={stats['total_available_blocks']}, " + f"selected={stats['total_selected_blocks']}, " + f"density={stats['overall_density']:.1%}") def __repr__(self) -> str: return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"