♻️ refactor: move select_blocks from policy to attention layer
Move block selection logic from compute_chunked_prefill/decode methods to attention.py caller. This improves separation of concerns: - attention.py now calls select_blocks() before compute_chunked_*() - Policy methods receive pre-selected blocks via selected_blocks parameter - Enables sparse policies to implement custom block selection without modifying the compute path Changes: - policy.py: Add selected_blocks parameter to abstract methods - full_policy.py: Remove internal select_blocks calls, use passed blocks - xattn_bsa.py: Sync signatures for prefill/decode methods - attention.py: Add select_blocks calls before policy delegation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -58,16 +58,17 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute full attention for chunked prefill.
|
||||
|
||||
This method handles the complete chunked prefill flow:
|
||||
1. Get historical blocks
|
||||
2. Select blocks via select_blocks
|
||||
3. Load and compute attention to historical chunks
|
||||
4. Compute attention to current chunk
|
||||
5. Merge all results
|
||||
This method handles the chunked prefill computation:
|
||||
1. Load and compute attention to historical chunks (using selected_blocks)
|
||||
2. Compute attention to current chunk
|
||||
3. Merge all results
|
||||
|
||||
Note: Block selection is done by the caller before invoking this method.
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
@@ -80,6 +81,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
current_chunk_idx: Current chunk index
|
||||
seq: Sequence object
|
||||
num_tokens: Number of tokens in current chunk
|
||||
selected_blocks: List of CPU block IDs to process (already filtered)
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
@@ -87,30 +89,16 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
|
||||
f"selected_blocks={len(selected_blocks)}")
|
||||
|
||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Step 1: Get historical blocks
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Step 2: Apply select_blocks to filter blocks
|
||||
if cpu_block_table:
|
||||
num_chunks = current_chunk_idx + 1
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
num_query_chunks=num_chunks,
|
||||
layer_id=layer_id,
|
||||
query=None, # Prefill typically doesn't use query for selection
|
||||
is_prefill=True,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
|
||||
# Use the pre-selected blocks directly
|
||||
cpu_block_table = selected_blocks
|
||||
|
||||
if cpu_block_table:
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
@@ -200,16 +188,17 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute full attention for chunked decode.
|
||||
|
||||
This method handles the complete chunked decode flow:
|
||||
1. Get prefilled CPU blocks
|
||||
2. Apply select_blocks for block filtering
|
||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||
4. Read accumulated decode tokens from decode buffer
|
||||
5. Merge all results
|
||||
This method handles the chunked decode computation:
|
||||
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
|
||||
2. Read accumulated decode tokens from decode buffer
|
||||
3. Merge all results
|
||||
|
||||
Note: Block selection is done by the caller before invoking this method.
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch_size, num_heads, head_dim]
|
||||
@@ -218,6 +207,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
seq: Sequence object
|
||||
selected_blocks: List of CPU block IDs to process (already filtered)
|
||||
|
||||
Returns:
|
||||
Attention output [batch_size, 1, num_heads, head_dim]
|
||||
@@ -227,40 +217,35 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
# Use the pre-selected blocks directly
|
||||
cpu_block_table = selected_blocks
|
||||
if layer_id == 0:
|
||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||
logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}")
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last CPU block
|
||||
# CRITICAL: Use original prefill length, not current seq length!
|
||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||
# Note: We need to get all prefilled blocks to determine last_block_valid_tokens
|
||||
block_size = kvcache_manager.block_size
|
||||
num_prefill_blocks = len(cpu_block_table)
|
||||
all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
# Apply sparse policy (self) for block filtering
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=layer_id,
|
||||
query=q_batched,
|
||||
is_prefill=False,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
# Determine if selected_blocks contains the last prefilled block
|
||||
# If not, all selected blocks are full blocks (use block_size as valid tokens)
|
||||
last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None
|
||||
selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block)
|
||||
effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size
|
||||
|
||||
# Use ring buffer pipeline for loading prefilled blocks
|
||||
load_slots = offload_engine.decode_load_slots
|
||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
block_size, last_block_valid_tokens, layer_id, softmax_scale
|
||||
block_size, effective_last_block_tokens, layer_id, softmax_scale
|
||||
)
|
||||
|
||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||
|
||||
Reference in New Issue
Block a user