🔧 feat: add density statistics tracking to sparse policies
Add statistics tracking to compare block selection between policies: - XAttentionBSAPolicy: track available/selected blocks per chunk - FullAttentionPolicy: track total blocks (always 100% density) - Add reset_stats(), get_density_stats(), print_density_stats() methods - Use logger.debug for per-chunk density logging Results on 32K niah_single_1: - Full: 100% density across all chunks - XAttn BSA: 90% -> 73% density (saves ~25-30% blocks in later chunks) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize with statistics tracking."""
|
||||
self._stats_total_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
@@ -44,8 +49,33 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Return all blocks - no sparsity."""
|
||||
# Update statistics (only for layer 0 to avoid overcounting)
|
||||
if ctx.layer_id == 0 and available_blocks:
|
||||
self._stats_total_blocks += len(available_blocks)
|
||||
self._stats_num_chunks += 1
|
||||
logger.debug(f"[Full] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, density=100.0%")
|
||||
return available_blocks
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset density statistics."""
|
||||
self._stats_total_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def get_density_stats(self) -> dict:
|
||||
"""Get density statistics."""
|
||||
return {
|
||||
"total_available_blocks": self._stats_total_blocks,
|
||||
"total_selected_blocks": self._stats_total_blocks, # Full = all selected
|
||||
"num_chunks": self._stats_num_chunks,
|
||||
"overall_density": 1.0, # Always 100%
|
||||
}
|
||||
|
||||
def print_density_stats(self) -> None:
|
||||
"""Print density statistics summary."""
|
||||
stats = self.get_density_stats()
|
||||
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
||||
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -117,6 +117,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
# Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]]
|
||||
self.sparse_metadata: dict = {}
|
||||
|
||||
# Statistics for density tracking
|
||||
self._stats_total_available_blocks = 0
|
||||
self._stats_total_selected_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
@@ -298,18 +303,23 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
block_selected = vote_ratio > vote_threshold
|
||||
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
|
||||
|
||||
# Log density for layer 0 only
|
||||
if layer_id == 0:
|
||||
density = len(selected_block_ids) / len(available_blocks) if available_blocks else 1.0
|
||||
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, "
|
||||
f"selected={len(selected_block_ids)}, density={density:.1%}")
|
||||
|
||||
# Always include first block (sink) and last block for safety
|
||||
if available_blocks and available_blocks[0] not in selected_block_ids:
|
||||
selected_block_ids.insert(0, available_blocks[0])
|
||||
if available_blocks and available_blocks[-1] not in selected_block_ids:
|
||||
selected_block_ids.append(available_blocks[-1])
|
||||
|
||||
# Update statistics (only for layer 0 to avoid overcounting)
|
||||
if layer_id == 0 and available_blocks:
|
||||
self._stats_total_available_blocks += len(available_blocks)
|
||||
self._stats_total_selected_blocks += len(selected_block_ids)
|
||||
self._stats_num_chunks += 1
|
||||
|
||||
# Log per-chunk density
|
||||
chunk_density = len(selected_block_ids) / len(available_blocks)
|
||||
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
|
||||
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
|
||||
|
||||
return selected_block_ids
|
||||
|
||||
def compute_chunked_prefill(
|
||||
@@ -460,6 +470,37 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
def reset(self) -> None:
|
||||
"""Reset policy state and clear sparse metadata."""
|
||||
self.sparse_metadata.clear()
|
||||
# Don't reset statistics here - they accumulate across the entire prefill
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset density statistics."""
|
||||
self._stats_total_available_blocks = 0
|
||||
self._stats_total_selected_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def get_density_stats(self) -> dict:
|
||||
"""Get density statistics."""
|
||||
if self._stats_total_available_blocks == 0:
|
||||
return {
|
||||
"total_available_blocks": 0,
|
||||
"total_selected_blocks": 0,
|
||||
"num_chunks": 0,
|
||||
"overall_density": 0.0,
|
||||
}
|
||||
return {
|
||||
"total_available_blocks": self._stats_total_available_blocks,
|
||||
"total_selected_blocks": self._stats_total_selected_blocks,
|
||||
"num_chunks": self._stats_num_chunks,
|
||||
"overall_density": self._stats_total_selected_blocks / self._stats_total_available_blocks,
|
||||
}
|
||||
|
||||
def print_density_stats(self) -> None:
|
||||
"""Print density statistics summary."""
|
||||
stats = self.get_density_stats()
|
||||
logger.info(f"[XAttn BSA] Density Stats: chunks={stats['num_chunks']}, "
|
||||
f"available={stats['total_available_blocks']}, "
|
||||
f"selected={stats['total_selected_blocks']}, "
|
||||
f"density={stats['overall_density']:.1%}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"
|
||||
|
||||
Reference in New Issue
Block a user