♻️ refactor: move select_blocks from policy to attention layer

Move block selection logic from compute_chunked_prefill/decode methods
to attention.py caller. This improves separation of concerns:

- attention.py now calls select_blocks() before compute_chunked_*()
- Policy methods receive pre-selected blocks via selected_blocks parameter
- Enables sparse policies to implement custom block selection without
  modifying the compute path

Changes:
- policy.py: Add selected_blocks parameter to abstract methods
- full_policy.py: Remove internal select_blocks calls, use passed blocks
- xattn_bsa.py: Sync signatures for prefill/decode methods
- attention.py: Add select_blocks calls before policy delegation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-23 05:21:28 +08:00
parent ca32ea6f93
commit a50b4c2ac2
4 changed files with 100 additions and 63 deletions

View File

@@ -58,16 +58,17 @@ class FullAttentionPolicy(SparsePolicy):
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute full attention for chunked prefill.
This method handles the complete chunked prefill flow:
1. Get historical blocks
2. Select blocks via select_blocks
3. Load and compute attention to historical chunks
4. Compute attention to current chunk
5. Merge all results
This method handles the chunked prefill computation:
1. Load and compute attention to historical chunks (using selected_blocks)
2. Compute attention to current chunk
3. Merge all results
Note: Block selection is done by the caller before invoking this method.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
@@ -80,6 +81,7 @@ class FullAttentionPolicy(SparsePolicy):
current_chunk_idx: Current chunk index
seq: Sequence object
num_tokens: Number of tokens in current chunk
selected_blocks: List of CPU block IDs to process (already filtered)
Returns:
Attention output [seq_len, num_heads, head_dim]
@@ -87,30 +89,16 @@ class FullAttentionPolicy(SparsePolicy):
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
f"selected_blocks={len(selected_blocks)}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Step 1: Get historical blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots))
@@ -200,16 +188,17 @@ class FullAttentionPolicy(SparsePolicy):
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute full attention for chunked decode.
This method handles the complete chunked decode flow:
1. Get prefilled CPU blocks
2. Apply select_blocks for block filtering
3. Load blocks via pipeline (ring buffer or cross-layer)
4. Read accumulated decode tokens from decode buffer
5. Merge all results
This method handles the chunked decode computation:
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
2. Read accumulated decode tokens from decode buffer
3. Merge all results
Note: Block selection is done by the caller before invoking this method.
Args:
q: Query tensor [batch_size, num_heads, head_dim]
@@ -218,6 +207,7 @@ class FullAttentionPolicy(SparsePolicy):
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
seq: Sequence object
selected_blocks: List of CPU block IDs to process (already filtered)
Returns:
Attention output [batch_size, 1, num_heads, head_dim]
@@ -227,40 +217,35 @@ class FullAttentionPolicy(SparsePolicy):
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Get only PREFILLED CPU blocks (exclude the current decode block)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
if layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last CPU block
# CRITICAL: Use original prefill length, not current seq length!
# CPU blocks are fixed after prefill, their content doesn't change during decode.
# Note: We need to get all prefilled blocks to determine last_block_valid_tokens
block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table)
all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq)
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy (self) for block filtering
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=layer_id,
query=q_batched,
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
# Determine if selected_blocks contains the last prefilled block
# If not, all selected blocks are full blocks (use block_size as valid tokens)
last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None
selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block)
effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size
# Use ring buffer pipeline for loading prefilled blocks
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens, layer_id, softmax_scale
block_size, effective_last_block_tokens, layer_id, softmax_scale
)
# Now attend to accumulated decode tokens from per-layer decode buffer