♻️ refactor: move select_blocks from policy to attention layer
Move block selection logic from compute_chunked_prefill/decode methods to attention.py caller. This improves separation of concerns: - attention.py now calls select_blocks() before compute_chunked_*() - Policy methods receive pre-selected blocks via selected_blocks parameter - Enables sparse policies to implement custom block selection without modifying the compute path Changes: - policy.py: Add selected_blocks parameter to abstract methods - full_policy.py: Remove internal select_blocks calls, use passed blocks - xattn_bsa.py: Sync signatures for prefill/decode methods - attention.py: Add select_blocks calls before policy delegation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -136,6 +136,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention for chunked prefill.
|
||||
@@ -169,7 +170,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
# This is temporary until proper sparse implementation is ready
|
||||
return self._compute_dense_fallback(
|
||||
q, k, v, layer_id, softmax_scale, offload_engine,
|
||||
kvcache_manager, current_chunk_idx, seq, num_tokens
|
||||
kvcache_manager, current_chunk_idx, seq, num_tokens, selected_blocks
|
||||
)
|
||||
|
||||
def _compute_dense_fallback(
|
||||
@@ -184,22 +185,24 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fallback to dense attention when BSA/XAttn not available.
|
||||
Uses FullAttentionPolicy's proven pipeline.
|
||||
Uses FullAttentionPolicy's proven pipeline with pre-selected blocks.
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}")
|
||||
logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}, "
|
||||
f"selected_blocks={len(selected_blocks)}")
|
||||
|
||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Get historical CPU blocks
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
# Use the pre-selected blocks directly
|
||||
cpu_block_table = selected_blocks
|
||||
|
||||
# Process historical blocks using pipeline
|
||||
if cpu_block_table:
|
||||
@@ -282,6 +285,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
XAttention does not support decode phase.
|
||||
|
||||
Reference in New Issue
Block a user