♻️ refactor: move select_blocks from policy to attention layer
Move block selection logic from compute_chunked_prefill/decode methods to attention.py caller. This improves separation of concerns: - attention.py now calls select_blocks() before compute_chunked_*() - Policy methods receive pre-selected blocks via selected_blocks parameter - Enables sparse policies to implement custom block selection without modifying the compute path Changes: - policy.py: Add selected_blocks parameter to abstract methods - full_policy.py: Remove internal select_blocks calls, use passed blocks - xattn_bsa.py: Sync signatures for prefill/decode methods - attention.py: Add select_blocks calls before policy delegation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from torch import nn
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -197,11 +198,30 @@ class Attention(nn.Module):
|
||||
if sparse_policy is None:
|
||||
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||||
|
||||
# Step 1: Get historical CPU blocks
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
|
||||
selected_blocks = []
|
||||
if cpu_block_table:
|
||||
num_chunks = current_chunk_idx + 1
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
num_query_chunks=num_chunks,
|
||||
layer_id=self.layer_id,
|
||||
query=q, # Pass query for sparse policies that need it
|
||||
is_prefill=True,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||
|
||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||
# Delegate computation to policy with pre-selected blocks
|
||||
final_o = sparse_policy.compute_chunked_prefill(
|
||||
q, k, v,
|
||||
self.layer_id,
|
||||
@@ -211,6 +231,7 @@ class Attention(nn.Module):
|
||||
current_chunk_idx,
|
||||
seq,
|
||||
num_tokens,
|
||||
selected_blocks,
|
||||
)
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
@@ -265,11 +286,29 @@ class Attention(nn.Module):
|
||||
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
|
||||
f"falling back to FullAttentionPolicy")
|
||||
|
||||
# Step 1: Get prefilled CPU blocks
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
|
||||
selected_blocks = []
|
||||
if cpu_block_table:
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=self.layer_id,
|
||||
query=q, # Pass query for sparse policies that need it
|
||||
is_prefill=False,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||||
f"policy={sparse_policy}, layer={self.layer_id}")
|
||||
|
||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||
# Delegate computation to policy with pre-selected blocks
|
||||
return sparse_policy.compute_chunked_decode(
|
||||
q,
|
||||
self.layer_id,
|
||||
@@ -277,4 +316,5 @@ class Attention(nn.Module):
|
||||
offload_engine,
|
||||
kvcache_manager,
|
||||
seq,
|
||||
selected_blocks,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user