🔧 feat: add density statistics tracking to sparse policies
Add statistics tracking to compare block selection between policies: - XAttentionBSAPolicy: track available/selected blocks per chunk - FullAttentionPolicy: track total blocks (always 100% density) - Add reset_stats(), get_density_stats(), print_density_stats() methods - Use logger.debug for per-chunk density logging Results on 32K niah_single_1: - Full: 100% density across all chunks - XAttn BSA: 90% -> 73% density (saves ~25-30% blocks in later chunks) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,11 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize with statistics tracking."""
|
||||||
|
self._stats_total_blocks = 0
|
||||||
|
self._stats_num_chunks = 0
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
@@ -44,8 +49,33 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Return all blocks - no sparsity."""
|
"""Return all blocks - no sparsity."""
|
||||||
|
# Update statistics (only for layer 0 to avoid overcounting)
|
||||||
|
if ctx.layer_id == 0 and available_blocks:
|
||||||
|
self._stats_total_blocks += len(available_blocks)
|
||||||
|
self._stats_num_chunks += 1
|
||||||
|
logger.debug(f"[Full] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, density=100.0%")
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
|
def reset_stats(self) -> None:
|
||||||
|
"""Reset density statistics."""
|
||||||
|
self._stats_total_blocks = 0
|
||||||
|
self._stats_num_chunks = 0
|
||||||
|
|
||||||
|
def get_density_stats(self) -> dict:
|
||||||
|
"""Get density statistics."""
|
||||||
|
return {
|
||||||
|
"total_available_blocks": self._stats_total_blocks,
|
||||||
|
"total_selected_blocks": self._stats_total_blocks, # Full = all selected
|
||||||
|
"num_chunks": self._stats_num_chunks,
|
||||||
|
"overall_density": 1.0, # Always 100%
|
||||||
|
}
|
||||||
|
|
||||||
|
def print_density_stats(self) -> None:
|
||||||
|
"""Print density statistics summary."""
|
||||||
|
stats = self.get_density_stats()
|
||||||
|
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
||||||
|
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
||||||
|
|
||||||
def compute_chunked_prefill(
|
def compute_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
|
|||||||
@@ -117,6 +117,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
# Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]]
|
# Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]]
|
||||||
self.sparse_metadata: dict = {}
|
self.sparse_metadata: dict = {}
|
||||||
|
|
||||||
|
# Statistics for density tracking
|
||||||
|
self._stats_total_available_blocks = 0
|
||||||
|
self._stats_total_selected_blocks = 0
|
||||||
|
self._stats_num_chunks = 0
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
@@ -298,18 +303,23 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
block_selected = vote_ratio > vote_threshold
|
block_selected = vote_ratio > vote_threshold
|
||||||
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
|
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
|
||||||
|
|
||||||
# Log density for layer 0 only
|
|
||||||
if layer_id == 0:
|
|
||||||
density = len(selected_block_ids) / len(available_blocks) if available_blocks else 1.0
|
|
||||||
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, "
|
|
||||||
f"selected={len(selected_block_ids)}, density={density:.1%}")
|
|
||||||
|
|
||||||
# Always include first block (sink) and last block for safety
|
# Always include first block (sink) and last block for safety
|
||||||
if available_blocks and available_blocks[0] not in selected_block_ids:
|
if available_blocks and available_blocks[0] not in selected_block_ids:
|
||||||
selected_block_ids.insert(0, available_blocks[0])
|
selected_block_ids.insert(0, available_blocks[0])
|
||||||
if available_blocks and available_blocks[-1] not in selected_block_ids:
|
if available_blocks and available_blocks[-1] not in selected_block_ids:
|
||||||
selected_block_ids.append(available_blocks[-1])
|
selected_block_ids.append(available_blocks[-1])
|
||||||
|
|
||||||
|
# Update statistics (only for layer 0 to avoid overcounting)
|
||||||
|
if layer_id == 0 and available_blocks:
|
||||||
|
self._stats_total_available_blocks += len(available_blocks)
|
||||||
|
self._stats_total_selected_blocks += len(selected_block_ids)
|
||||||
|
self._stats_num_chunks += 1
|
||||||
|
|
||||||
|
# Log per-chunk density
|
||||||
|
chunk_density = len(selected_block_ids) / len(available_blocks)
|
||||||
|
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
|
||||||
|
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
|
||||||
|
|
||||||
return selected_block_ids
|
return selected_block_ids
|
||||||
|
|
||||||
def compute_chunked_prefill(
|
def compute_chunked_prefill(
|
||||||
@@ -460,6 +470,37 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset policy state and clear sparse metadata."""
|
"""Reset policy state and clear sparse metadata."""
|
||||||
self.sparse_metadata.clear()
|
self.sparse_metadata.clear()
|
||||||
|
# Don't reset statistics here - they accumulate across the entire prefill
|
||||||
|
|
||||||
|
def reset_stats(self) -> None:
|
||||||
|
"""Reset density statistics."""
|
||||||
|
self._stats_total_available_blocks = 0
|
||||||
|
self._stats_total_selected_blocks = 0
|
||||||
|
self._stats_num_chunks = 0
|
||||||
|
|
||||||
|
def get_density_stats(self) -> dict:
|
||||||
|
"""Get density statistics."""
|
||||||
|
if self._stats_total_available_blocks == 0:
|
||||||
|
return {
|
||||||
|
"total_available_blocks": 0,
|
||||||
|
"total_selected_blocks": 0,
|
||||||
|
"num_chunks": 0,
|
||||||
|
"overall_density": 0.0,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"total_available_blocks": self._stats_total_available_blocks,
|
||||||
|
"total_selected_blocks": self._stats_total_selected_blocks,
|
||||||
|
"num_chunks": self._stats_num_chunks,
|
||||||
|
"overall_density": self._stats_total_selected_blocks / self._stats_total_available_blocks,
|
||||||
|
}
|
||||||
|
|
||||||
|
def print_density_stats(self) -> None:
|
||||||
|
"""Print density statistics summary."""
|
||||||
|
stats = self.get_density_stats()
|
||||||
|
logger.info(f"[XAttn BSA] Density Stats: chunks={stats['num_chunks']}, "
|
||||||
|
f"available={stats['total_available_blocks']}, "
|
||||||
|
f"selected={stats['total_selected_blocks']}, "
|
||||||
|
f"density={stats['overall_density']:.1%}")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"
|
return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"
|
||||||
|
|||||||
Reference in New Issue
Block a user