♻️ refactor: move select_blocks from policy to attention layer

Move block selection logic from compute_chunked_prefill/decode methods
to attention.py caller. This improves separation of concerns:

- attention.py now calls select_blocks() before compute_chunked_*()
- Policy methods receive pre-selected blocks via selected_blocks parameter
- Enables sparse policies to implement custom block selection without
  modifying the compute path

Changes:
- policy.py: Add selected_blocks parameter to abstract methods
- full_policy.py: Remove internal select_blocks calls, use passed blocks
- xattn_bsa.py: Sync signatures for prefill/decode methods
- attention.py: Add select_blocks calls before policy delegation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-23 05:21:28 +08:00
parent ca32ea6f93
commit a50b4c2ac2
4 changed files with 100 additions and 63 deletions

View File

@@ -58,16 +58,17 @@ class FullAttentionPolicy(SparsePolicy):
current_chunk_idx: int, current_chunk_idx: int,
seq: "Sequence", seq: "Sequence",
num_tokens: int, num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute full attention for chunked prefill. Compute full attention for chunked prefill.
This method handles the complete chunked prefill flow: This method handles the chunked prefill computation:
1. Get historical blocks 1. Load and compute attention to historical chunks (using selected_blocks)
2. Select blocks via select_blocks 2. Compute attention to current chunk
3. Load and compute attention to historical chunks 3. Merge all results
4. Compute attention to current chunk
5. Merge all results Note: Block selection is done by the caller before invoking this method.
Args: Args:
q: Query tensor [seq_len, num_heads, head_dim] q: Query tensor [seq_len, num_heads, head_dim]
@@ -80,6 +81,7 @@ class FullAttentionPolicy(SparsePolicy):
current_chunk_idx: Current chunk index current_chunk_idx: Current chunk index
seq: Sequence object seq: Sequence object
num_tokens: Number of tokens in current chunk num_tokens: Number of tokens in current chunk
selected_blocks: List of CPU block IDs to process (already filtered)
Returns: Returns:
Attention output [seq_len, num_heads, head_dim] Attention output [seq_len, num_heads, head_dim]
@@ -87,30 +89,16 @@ class FullAttentionPolicy(SparsePolicy):
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, " logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}") f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
f"selected_blocks={len(selected_blocks)}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None o_acc = None
lse_acc = None lse_acc = None
compute_stream = offload_engine.compute_stream compute_stream = offload_engine.compute_stream
# Step 1: Get historical blocks # Use the pre-selected blocks directly
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = selected_blocks
# Step 2: Apply select_blocks to filter blocks
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
if cpu_block_table: if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots)) load_slots = list(range(offload_engine.num_ring_slots))
@@ -200,16 +188,17 @@ class FullAttentionPolicy(SparsePolicy):
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager", kvcache_manager: "KVCacheManager",
seq: "Sequence", seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute full attention for chunked decode. Compute full attention for chunked decode.
This method handles the complete chunked decode flow: This method handles the chunked decode computation:
1. Get prefilled CPU blocks 1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
2. Apply select_blocks for block filtering 2. Read accumulated decode tokens from decode buffer
3. Load blocks via pipeline (ring buffer or cross-layer) 3. Merge all results
4. Read accumulated decode tokens from decode buffer
5. Merge all results Note: Block selection is done by the caller before invoking this method.
Args: Args:
q: Query tensor [batch_size, num_heads, head_dim] q: Query tensor [batch_size, num_heads, head_dim]
@@ -218,6 +207,7 @@ class FullAttentionPolicy(SparsePolicy):
offload_engine: OffloadEngine for loading blocks offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management kvcache_manager: KVCacheManager for block management
seq: Sequence object seq: Sequence object
selected_blocks: List of CPU block IDs to process (already filtered)
Returns: Returns:
Attention output [batch_size, 1, num_heads, head_dim] Attention output [batch_size, 1, num_heads, head_dim]
@@ -227,40 +217,35 @@ class FullAttentionPolicy(SparsePolicy):
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Get only PREFILLED CPU blocks (exclude the current decode block) # Use the pre-selected blocks directly
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = selected_blocks
if layer_id == 0: if layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}") logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table: if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last CPU block # Calculate valid tokens in the last CPU block
# CRITICAL: Use original prefill length, not current seq length! # CRITICAL: Use original prefill length, not current seq length!
# CPU blocks are fixed after prefill, their content doesn't change during decode. # CPU blocks are fixed after prefill, their content doesn't change during decode.
# Note: We need to get all prefilled blocks to determine last_block_valid_tokens
block_size = kvcache_manager.block_size block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table) all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq)
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
last_block_valid_tokens = total_prefill_tokens % block_size last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0: if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy (self) for block filtering # Determine if selected_blocks contains the last prefilled block
policy_ctx = PolicyContext( # If not, all selected blocks are full blocks (use block_size as valid tokens)
query_chunk_idx=0, last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None
num_query_chunks=1, selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block)
layer_id=layer_id, effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size
query=q_batched,
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
# Use ring buffer pipeline for loading prefilled blocks # Use ring buffer pipeline for loading prefilled blocks
load_slots = offload_engine.decode_load_slots load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline( o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine, q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens, layer_id, softmax_scale block_size, effective_last_block_tokens, layer_id, softmax_scale
) )
# Now attend to accumulated decode tokens from per-layer decode buffer # Now attend to accumulated decode tokens from per-layer decode buffer

View File

@@ -204,17 +204,20 @@ class SparsePolicy(ABC):
current_chunk_idx: int, current_chunk_idx: int,
seq: "Sequence", seq: "Sequence",
num_tokens: int, num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute chunked prefill attention (complete flow). Compute chunked prefill attention (complete flow).
This is the main entry point for prefill attention computation. This is the main entry point for prefill attention computation.
It defines the complete prefill flow: It defines the complete prefill flow:
1. Get historical blocks 1. Load and compute historical blocks via offload_engine (using selected_blocks)
2. Select blocks (call select_blocks) 2. Get current chunk KV from offload_engine, compute attention
3. Load and compute historical blocks via offload_engine 3. Merge all results
4. Get current chunk KV from offload_engine, compute attention
5. Merge all results Note: Block selection (select_blocks) is called by the caller (attention.py)
before invoking this method. The selected_blocks parameter contains the
filtered block IDs to process.
Args: Args:
q: [seq_len, num_heads, head_dim] query for current chunk q: [seq_len, num_heads, head_dim] query for current chunk
@@ -227,6 +230,7 @@ class SparsePolicy(ABC):
current_chunk_idx: current chunk index current_chunk_idx: current chunk index
seq: Sequence object seq: Sequence object
num_tokens: number of tokens in current chunk num_tokens: number of tokens in current chunk
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
Returns: Returns:
[seq_len, num_heads, head_dim] final attention output [seq_len, num_heads, head_dim] final attention output
@@ -242,17 +246,20 @@ class SparsePolicy(ABC):
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager", kvcache_manager: "KVCacheManager",
seq: "Sequence", seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute chunked decode attention (complete flow). Compute chunked decode attention (complete flow).
This is the main entry point for decode attention computation. This is the main entry point for decode attention computation.
It defines the complete decode flow: It defines the complete decode flow:
1. Get prefilled blocks from CPU 1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
2. Select blocks (call select_blocks) 2. Read accumulated decode tokens from decode buffer
3. Load blocks via pipeline (ring buffer or cross-layer) 3. Merge all results
4. Read accumulated decode tokens from decode buffer
5. Merge all results Note: Block selection (select_blocks) is called by the caller (attention.py)
before invoking this method. The selected_blocks parameter contains the
filtered block IDs to process.
The decode position information can be computed internally: The decode position information can be computed internally:
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq) - decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
@@ -265,6 +272,7 @@ class SparsePolicy(ABC):
offload_engine: OffloadEngine for loading blocks offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management kvcache_manager: KVCacheManager for block management
seq: Sequence object seq: Sequence object
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
Returns: Returns:
[batch_size, 1, num_heads, head_dim] final attention output [batch_size, 1, num_heads, head_dim] final attention output

View File

@@ -136,6 +136,7 @@ class XAttentionBSAPolicy(SparsePolicy):
current_chunk_idx: int, current_chunk_idx: int,
seq: "Sequence", seq: "Sequence",
num_tokens: int, num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute attention for chunked prefill. Compute attention for chunked prefill.
@@ -169,7 +170,7 @@ class XAttentionBSAPolicy(SparsePolicy):
# This is temporary until proper sparse implementation is ready # This is temporary until proper sparse implementation is ready
return self._compute_dense_fallback( return self._compute_dense_fallback(
q, k, v, layer_id, softmax_scale, offload_engine, q, k, v, layer_id, softmax_scale, offload_engine,
kvcache_manager, current_chunk_idx, seq, num_tokens kvcache_manager, current_chunk_idx, seq, num_tokens, selected_blocks
) )
def _compute_dense_fallback( def _compute_dense_fallback(
@@ -184,22 +185,24 @@ class XAttentionBSAPolicy(SparsePolicy):
current_chunk_idx: int, current_chunk_idx: int,
seq: "Sequence", seq: "Sequence",
num_tokens: int, num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Fallback to dense attention when BSA/XAttn not available. Fallback to dense attention when BSA/XAttn not available.
Uses FullAttentionPolicy's proven pipeline. Uses FullAttentionPolicy's proven pipeline with pre-selected blocks.
""" """
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}") logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}, "
f"selected_blocks={len(selected_blocks)}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None o_acc = None
lse_acc = None lse_acc = None
compute_stream = offload_engine.compute_stream compute_stream = offload_engine.compute_stream
# Get historical CPU blocks # Use the pre-selected blocks directly
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = selected_blocks
# Process historical blocks using pipeline # Process historical blocks using pipeline
if cpu_block_table: if cpu_block_table:
@@ -282,6 +285,7 @@ class XAttentionBSAPolicy(SparsePolicy):
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager", kvcache_manager: "KVCacheManager",
seq: "Sequence", seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
XAttention does not support decode phase. XAttention does not support decode phase.

View File

@@ -5,6 +5,7 @@ from torch import nn
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -197,11 +198,30 @@ class Attention(nn.Module):
if sparse_policy is None: if sparse_policy is None:
raise RuntimeError("sparse_policy is required for chunked prefill") raise RuntimeError("sparse_policy is required for chunked prefill")
# Step 1: Get historical CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
selected_blocks = []
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path # [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, " logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}") f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
# Delegate all computation to policy (no flash_attn or merge calls here!) # Delegate computation to policy with pre-selected blocks
final_o = sparse_policy.compute_chunked_prefill( final_o = sparse_policy.compute_chunked_prefill(
q, k, v, q, k, v,
self.layer_id, self.layer_id,
@@ -211,6 +231,7 @@ class Attention(nn.Module):
current_chunk_idx, current_chunk_idx,
seq, seq,
num_tokens, num_tokens,
selected_blocks,
) )
torch.cuda.nvtx.range_pop() # ChunkedPrefill torch.cuda.nvtx.range_pop() # ChunkedPrefill
@@ -265,11 +286,29 @@ class Attention(nn.Module):
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, " logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
f"falling back to FullAttentionPolicy") f"falling back to FullAttentionPolicy")
# Step 1: Get prefilled CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
selected_blocks = []
if cpu_block_table:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path # [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, " logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
f"policy={sparse_policy}, layer={self.layer_id}") f"policy={sparse_policy}, layer={self.layer_id}")
# Delegate all computation to policy (no flash_attn or merge calls here!) # Delegate computation to policy with pre-selected blocks
return sparse_policy.compute_chunked_decode( return sparse_policy.compute_chunked_decode(
q, q,
self.layer_id, self.layer_id,
@@ -277,4 +316,5 @@ class Attention(nn.Module):
offload_engine, offload_engine,
kvcache_manager, kvcache_manager,
seq, seq,
selected_blocks,
) )