From a50b4c2ac28385f684f516b1b3c874a3824b9446 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Fri, 23 Jan 2026 05:21:28 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20move=20select?= =?UTF-8?q?=5Fblocks=20from=20policy=20to=20attention=20layer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move block selection logic from compute_chunked_prefill/decode methods to attention.py caller. This improves separation of concerns: - attention.py now calls select_blocks() before compute_chunked_*() - Policy methods receive pre-selected blocks via selected_blocks parameter - Enables sparse policies to implement custom block selection without modifying the compute path Changes: - policy.py: Add selected_blocks parameter to abstract methods - full_policy.py: Remove internal select_blocks calls, use passed blocks - xattn_bsa.py: Sync signatures for prefill/decode methods - attention.py: Add select_blocks calls before policy delegation Co-Authored-By: Claude Opus 4.5 --- nanovllm/kvcache/sparse/full_policy.py | 77 +++++++++++--------------- nanovllm/kvcache/sparse/policy.py | 28 ++++++---- nanovllm/kvcache/sparse/xattn_bsa.py | 14 +++-- nanovllm/layers/attention.py | 44 ++++++++++++++- 4 files changed, 100 insertions(+), 63 deletions(-) diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index ff99133..52b846c 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -58,16 +58,17 @@ class FullAttentionPolicy(SparsePolicy): current_chunk_idx: int, seq: "Sequence", num_tokens: int, + selected_blocks: List[int], ) -> torch.Tensor: """ Compute full attention for chunked prefill. - This method handles the complete chunked prefill flow: - 1. Get historical blocks - 2. Select blocks via select_blocks - 3. Load and compute attention to historical chunks - 4. Compute attention to current chunk - 5. Merge all results + This method handles the chunked prefill computation: + 1. Load and compute attention to historical chunks (using selected_blocks) + 2. Compute attention to current chunk + 3. Merge all results + + Note: Block selection is done by the caller before invoking this method. Args: q: Query tensor [seq_len, num_heads, head_dim] @@ -80,6 +81,7 @@ class FullAttentionPolicy(SparsePolicy): current_chunk_idx: Current chunk index seq: Sequence object num_tokens: Number of tokens in current chunk + selected_blocks: List of CPU block IDs to process (already filtered) Returns: Attention output [seq_len, num_heads, head_dim] @@ -87,30 +89,16 @@ class FullAttentionPolicy(SparsePolicy): from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, " - f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}") + f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, " + f"selected_blocks={len(selected_blocks)}") q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] o_acc = None lse_acc = None compute_stream = offload_engine.compute_stream - # Step 1: Get historical blocks - cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) - - # Step 2: Apply select_blocks to filter blocks - if cpu_block_table: - num_chunks = current_chunk_idx + 1 - policy_ctx = PolicyContext( - query_chunk_idx=current_chunk_idx, - num_query_chunks=num_chunks, - layer_id=layer_id, - query=None, # Prefill typically doesn't use query for selection - is_prefill=True, - block_size=kvcache_manager.block_size, - total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, - ) - cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx) - logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks") + # Use the pre-selected blocks directly + cpu_block_table = selected_blocks if cpu_block_table: load_slots = list(range(offload_engine.num_ring_slots)) @@ -200,16 +188,17 @@ class FullAttentionPolicy(SparsePolicy): offload_engine: "OffloadEngine", kvcache_manager: "KVCacheManager", seq: "Sequence", + selected_blocks: List[int], ) -> torch.Tensor: """ Compute full attention for chunked decode. - This method handles the complete chunked decode flow: - 1. Get prefilled CPU blocks - 2. Apply select_blocks for block filtering - 3. Load blocks via pipeline (ring buffer or cross-layer) - 4. Read accumulated decode tokens from decode buffer - 5. Merge all results + This method handles the chunked decode computation: + 1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer) + 2. Read accumulated decode tokens from decode buffer + 3. Merge all results + + Note: Block selection is done by the caller before invoking this method. Args: q: Query tensor [batch_size, num_heads, head_dim] @@ -218,6 +207,7 @@ class FullAttentionPolicy(SparsePolicy): offload_engine: OffloadEngine for loading blocks kvcache_manager: KVCacheManager for block management seq: Sequence object + selected_blocks: List of CPU block IDs to process (already filtered) Returns: Attention output [batch_size, 1, num_heads, head_dim] @@ -227,40 +217,35 @@ class FullAttentionPolicy(SparsePolicy): # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] - # Get only PREFILLED CPU blocks (exclude the current decode block) - cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) + # Use the pre-selected blocks directly + cpu_block_table = selected_blocks if layer_id == 0: - logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}") + logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}") if not cpu_block_table: raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") # Calculate valid tokens in the last CPU block # CRITICAL: Use original prefill length, not current seq length! # CPU blocks are fixed after prefill, their content doesn't change during decode. + # Note: We need to get all prefilled blocks to determine last_block_valid_tokens block_size = kvcache_manager.block_size - num_prefill_blocks = len(cpu_block_table) + all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq) total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length last_block_valid_tokens = total_prefill_tokens % block_size if last_block_valid_tokens == 0 and total_prefill_tokens > 0: last_block_valid_tokens = block_size # Last block was exactly full - # Apply sparse policy (self) for block filtering - policy_ctx = PolicyContext( - query_chunk_idx=0, - num_query_chunks=1, - layer_id=layer_id, - query=q_batched, - is_prefill=False, - block_size=kvcache_manager.block_size, - total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, - ) - cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx) + # Determine if selected_blocks contains the last prefilled block + # If not, all selected blocks are full blocks (use block_size as valid tokens) + last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None + selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block) + effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size # Use ring buffer pipeline for loading prefilled blocks load_slots = offload_engine.decode_load_slots o_acc, lse_acc = self._decode_ring_buffer_pipeline( q_batched, cpu_block_table, load_slots, offload_engine, - block_size, last_block_valid_tokens, layer_id, softmax_scale + block_size, effective_last_block_tokens, layer_id, softmax_scale ) # Now attend to accumulated decode tokens from per-layer decode buffer diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index e56d266..b80a723 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -204,17 +204,20 @@ class SparsePolicy(ABC): current_chunk_idx: int, seq: "Sequence", num_tokens: int, + selected_blocks: List[int], ) -> torch.Tensor: """ Compute chunked prefill attention (complete flow). This is the main entry point for prefill attention computation. It defines the complete prefill flow: - 1. Get historical blocks - 2. Select blocks (call select_blocks) - 3. Load and compute historical blocks via offload_engine - 4. Get current chunk KV from offload_engine, compute attention - 5. Merge all results + 1. Load and compute historical blocks via offload_engine (using selected_blocks) + 2. Get current chunk KV from offload_engine, compute attention + 3. Merge all results + + Note: Block selection (select_blocks) is called by the caller (attention.py) + before invoking this method. The selected_blocks parameter contains the + filtered block IDs to process. Args: q: [seq_len, num_heads, head_dim] query for current chunk @@ -227,6 +230,7 @@ class SparsePolicy(ABC): current_chunk_idx: current chunk index seq: Sequence object num_tokens: number of tokens in current chunk + selected_blocks: list of CPU block IDs to process (already filtered by select_blocks) Returns: [seq_len, num_heads, head_dim] final attention output @@ -242,17 +246,20 @@ class SparsePolicy(ABC): offload_engine: "OffloadEngine", kvcache_manager: "KVCacheManager", seq: "Sequence", + selected_blocks: List[int], ) -> torch.Tensor: """ Compute chunked decode attention (complete flow). This is the main entry point for decode attention computation. It defines the complete decode flow: - 1. Get prefilled blocks from CPU - 2. Select blocks (call select_blocks) - 3. Load blocks via pipeline (ring buffer or cross-layer) - 4. Read accumulated decode tokens from decode buffer - 5. Merge all results + 1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer) + 2. Read accumulated decode tokens from decode buffer + 3. Merge all results + + Note: Block selection (select_blocks) is called by the caller (attention.py) + before invoking this method. The selected_blocks parameter contains the + filtered block IDs to process. The decode position information can be computed internally: - decode_start_pos = kvcache_manager.get_decode_start_pos(seq) @@ -265,6 +272,7 @@ class SparsePolicy(ABC): offload_engine: OffloadEngine for loading blocks kvcache_manager: KVCacheManager for block management seq: Sequence object + selected_blocks: list of CPU block IDs to process (already filtered by select_blocks) Returns: [batch_size, 1, num_heads, head_dim] final attention output diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index ad1fa2e..daa7ff3 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -136,6 +136,7 @@ class XAttentionBSAPolicy(SparsePolicy): current_chunk_idx: int, seq: "Sequence", num_tokens: int, + selected_blocks: List[int], ) -> torch.Tensor: """ Compute attention for chunked prefill. @@ -169,7 +170,7 @@ class XAttentionBSAPolicy(SparsePolicy): # This is temporary until proper sparse implementation is ready return self._compute_dense_fallback( q, k, v, layer_id, softmax_scale, offload_engine, - kvcache_manager, current_chunk_idx, seq, num_tokens + kvcache_manager, current_chunk_idx, seq, num_tokens, selected_blocks ) def _compute_dense_fallback( @@ -184,22 +185,24 @@ class XAttentionBSAPolicy(SparsePolicy): current_chunk_idx: int, seq: "Sequence", num_tokens: int, + selected_blocks: List[int], ) -> torch.Tensor: """ Fallback to dense attention when BSA/XAttn not available. - Uses FullAttentionPolicy's proven pipeline. + Uses FullAttentionPolicy's proven pipeline with pre-selected blocks. """ from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs - logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}") + logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}, " + f"selected_blocks={len(selected_blocks)}") q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] o_acc = None lse_acc = None compute_stream = offload_engine.compute_stream - # Get historical CPU blocks - cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) + # Use the pre-selected blocks directly + cpu_block_table = selected_blocks # Process historical blocks using pipeline if cpu_block_table: @@ -282,6 +285,7 @@ class XAttentionBSAPolicy(SparsePolicy): offload_engine: "OffloadEngine", kvcache_manager: "KVCacheManager", seq: "Sequence", + selected_blocks: List[int], ) -> torch.Tensor: """ XAttention does not support decode phase. diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 5a22416..515bd10 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -5,6 +5,7 @@ from torch import nn from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context +from nanovllm.kvcache.sparse.policy import PolicyContext logger = logging.getLogger(__name__) @@ -197,11 +198,30 @@ class Attention(nn.Module): if sparse_policy is None: raise RuntimeError("sparse_policy is required for chunked prefill") + # Step 1: Get historical CPU blocks + cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) + + # Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill) + selected_blocks = [] + if cpu_block_table: + num_chunks = current_chunk_idx + 1 + policy_ctx = PolicyContext( + query_chunk_idx=current_chunk_idx, + num_query_chunks=num_chunks, + layer_id=self.layer_id, + query=q, # Pass query for sparse policies that need it + is_prefill=True, + block_size=kvcache_manager.block_size, + total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, + ) + selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx) + logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks") + # [DEBUG] Verify execution path logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, " f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}") - # Delegate all computation to policy (no flash_attn or merge calls here!) + # Delegate computation to policy with pre-selected blocks final_o = sparse_policy.compute_chunked_prefill( q, k, v, self.layer_id, @@ -211,6 +231,7 @@ class Attention(nn.Module): current_chunk_idx, seq, num_tokens, + selected_blocks, ) torch.cuda.nvtx.range_pop() # ChunkedPrefill @@ -265,11 +286,29 @@ class Attention(nn.Module): logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, " f"falling back to FullAttentionPolicy") + # Step 1: Get prefilled CPU blocks + cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) + + # Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode) + selected_blocks = [] + if cpu_block_table: + policy_ctx = PolicyContext( + query_chunk_idx=0, + num_query_chunks=1, + layer_id=self.layer_id, + query=q, # Pass query for sparse policies that need it + is_prefill=False, + block_size=kvcache_manager.block_size, + total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, + ) + selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx) + logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks") + # [DEBUG] Verify execution path logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, " f"policy={sparse_policy}, layer={self.layer_id}") - # Delegate all computation to policy (no flash_attn or merge calls here!) + # Delegate computation to policy with pre-selected blocks return sparse_policy.compute_chunked_decode( q, self.layer_id, @@ -277,4 +316,5 @@ class Attention(nn.Module): offload_engine, kvcache_manager, seq, + selected_blocks, )