🔧 feat: add density statistics tracking to sparse policies
Add statistics tracking to compare block selection between policies: - XAttentionBSAPolicy: track available/selected blocks per chunk - FullAttentionPolicy: track total blocks (always 100% density) - Add reset_stats(), get_density_stats(), print_density_stats() methods - Use logger.debug for per-chunk density logging Results on 32K niah_single_1: - Full: 100% density across all chunks - XAttn BSA: 90% -> 73% density (saves ~25-30% blocks in later chunks) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize with statistics tracking."""
|
||||
self._stats_total_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
@@ -44,8 +49,33 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Return all blocks - no sparsity."""
|
||||
# Update statistics (only for layer 0 to avoid overcounting)
|
||||
if ctx.layer_id == 0 and available_blocks:
|
||||
self._stats_total_blocks += len(available_blocks)
|
||||
self._stats_num_chunks += 1
|
||||
logger.debug(f"[Full] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, density=100.0%")
|
||||
return available_blocks
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset density statistics."""
|
||||
self._stats_total_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def get_density_stats(self) -> dict:
|
||||
"""Get density statistics."""
|
||||
return {
|
||||
"total_available_blocks": self._stats_total_blocks,
|
||||
"total_selected_blocks": self._stats_total_blocks, # Full = all selected
|
||||
"num_chunks": self._stats_num_chunks,
|
||||
"overall_density": 1.0, # Always 100%
|
||||
}
|
||||
|
||||
def print_density_stats(self) -> None:
|
||||
"""Print density statistics summary."""
|
||||
stats = self.get_density_stats()
|
||||
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
||||
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user