WIP: Enhance sparse attention with density tracking and block selection improvements

- Added analysis documentation for xattn density alignment.
- Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration.
- Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection.
- Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks.
- Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling.
- Introduced DensityObserver to track compute and communication density for sparse attention layers.
- Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios.
- Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
Zijie Tian
2026-01-31 14:48:23 +08:00
parent f6ac4ccdde
commit 2e96d1d97d
9 changed files with 490 additions and 152 deletions

View File

@@ -221,20 +221,19 @@ class Attention(nn.Module):
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
selected_blocks = []
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# Always call select_blocks even for first chunk (cpu_block_table may be empty)
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size if cpu_block_table else 0,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
@@ -320,7 +319,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path