♻️ refactor: move select_blocks from policy to attention layer

Move block selection logic from compute_chunked_prefill/decode methods
to attention.py caller. This improves separation of concerns:

- attention.py now calls select_blocks() before compute_chunked_*()
- Policy methods receive pre-selected blocks via selected_blocks parameter
- Enables sparse policies to implement custom block selection without
  modifying the compute path

Changes:
- policy.py: Add selected_blocks parameter to abstract methods
- full_policy.py: Remove internal select_blocks calls, use passed blocks
- xattn_bsa.py: Sync signatures for prefill/decode methods
- attention.py: Add select_blocks calls before policy delegation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-23 05:21:28 +08:00
parent ca32ea6f93
commit a50b4c2ac2
4 changed files with 100 additions and 63 deletions

View File

@@ -5,6 +5,7 @@ from torch import nn
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__)
@@ -197,11 +198,30 @@ class Attention(nn.Module):
if sparse_policy is None:
raise RuntimeError("sparse_policy is required for chunked prefill")
# Step 1: Get historical CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
selected_blocks = []
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
# Delegate all computation to policy (no flash_attn or merge calls here!)
# Delegate computation to policy with pre-selected blocks
final_o = sparse_policy.compute_chunked_prefill(
q, k, v,
self.layer_id,
@@ -211,6 +231,7 @@ class Attention(nn.Module):
current_chunk_idx,
seq,
num_tokens,
selected_blocks,
)
torch.cuda.nvtx.range_pop() # ChunkedPrefill
@@ -265,11 +286,29 @@ class Attention(nn.Module):
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
f"falling back to FullAttentionPolicy")
# Step 1: Get prefilled CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
selected_blocks = []
if cpu_block_table:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
f"policy={sparse_policy}, layer={self.layer_id}")
# Delegate all computation to policy (no flash_attn or merge calls here!)
# Delegate computation to policy with pre-selected blocks
return sparse_policy.compute_chunked_decode(
q,
self.layer_id,
@@ -277,4 +316,5 @@ class Attention(nn.Module):
offload_engine,
kvcache_manager,
seq,
selected_blocks,
)