♻️ refactor: move select_blocks from policy to attention layer

Move block selection logic from compute_chunked_prefill/decode methods
to attention.py caller. This improves separation of concerns:

- attention.py now calls select_blocks() before compute_chunked_*()
- Policy methods receive pre-selected blocks via selected_blocks parameter
- Enables sparse policies to implement custom block selection without
  modifying the compute path

Changes:
- policy.py: Add selected_blocks parameter to abstract methods
- full_policy.py: Remove internal select_blocks calls, use passed blocks
- xattn_bsa.py: Sync signatures for prefill/decode methods
- attention.py: Add select_blocks calls before policy delegation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-23 05:21:28 +08:00
parent ca32ea6f93
commit a50b4c2ac2
4 changed files with 100 additions and 63 deletions

View File

@@ -136,6 +136,7 @@ class XAttentionBSAPolicy(SparsePolicy):
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute attention for chunked prefill.
@@ -169,7 +170,7 @@ class XAttentionBSAPolicy(SparsePolicy):
# This is temporary until proper sparse implementation is ready
return self._compute_dense_fallback(
q, k, v, layer_id, softmax_scale, offload_engine,
kvcache_manager, current_chunk_idx, seq, num_tokens
kvcache_manager, current_chunk_idx, seq, num_tokens, selected_blocks
)
def _compute_dense_fallback(
@@ -184,22 +185,24 @@ class XAttentionBSAPolicy(SparsePolicy):
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Fallback to dense attention when BSA/XAttn not available.
Uses FullAttentionPolicy's proven pipeline.
Uses FullAttentionPolicy's proven pipeline with pre-selected blocks.
"""
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}")
logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}, "
f"selected_blocks={len(selected_blocks)}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Get historical CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
# Process historical blocks using pipeline
if cpu_block_table:
@@ -282,6 +285,7 @@ class XAttentionBSAPolicy(SparsePolicy):
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor:
"""
XAttention does not support decode phase.