♻️ refactor: move select_blocks from policy to attention layer
Move block selection logic from compute_chunked_prefill/decode methods to attention.py caller. This improves separation of concerns: - attention.py now calls select_blocks() before compute_chunked_*() - Policy methods receive pre-selected blocks via selected_blocks parameter - Enables sparse policies to implement custom block selection without modifying the compute path Changes: - policy.py: Add selected_blocks parameter to abstract methods - full_policy.py: Remove internal select_blocks calls, use passed blocks - xattn_bsa.py: Sync signatures for prefill/decode methods - attention.py: Add select_blocks calls before policy delegation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -58,16 +58,17 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
current_chunk_idx: int,
|
current_chunk_idx: int,
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute full attention for chunked prefill.
|
Compute full attention for chunked prefill.
|
||||||
|
|
||||||
This method handles the complete chunked prefill flow:
|
This method handles the chunked prefill computation:
|
||||||
1. Get historical blocks
|
1. Load and compute attention to historical chunks (using selected_blocks)
|
||||||
2. Select blocks via select_blocks
|
2. Compute attention to current chunk
|
||||||
3. Load and compute attention to historical chunks
|
3. Merge all results
|
||||||
4. Compute attention to current chunk
|
|
||||||
5. Merge all results
|
Note: Block selection is done by the caller before invoking this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
@@ -80,6 +81,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
current_chunk_idx: Current chunk index
|
current_chunk_idx: Current chunk index
|
||||||
seq: Sequence object
|
seq: Sequence object
|
||||||
num_tokens: Number of tokens in current chunk
|
num_tokens: Number of tokens in current chunk
|
||||||
|
selected_blocks: List of CPU block IDs to process (already filtered)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
@@ -87,30 +89,16 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
|
||||||
|
f"selected_blocks={len(selected_blocks)}")
|
||||||
|
|
||||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
compute_stream = offload_engine.compute_stream
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
# Step 1: Get historical blocks
|
# Use the pre-selected blocks directly
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = selected_blocks
|
||||||
|
|
||||||
# Step 2: Apply select_blocks to filter blocks
|
|
||||||
if cpu_block_table:
|
|
||||||
num_chunks = current_chunk_idx + 1
|
|
||||||
policy_ctx = PolicyContext(
|
|
||||||
query_chunk_idx=current_chunk_idx,
|
|
||||||
num_query_chunks=num_chunks,
|
|
||||||
layer_id=layer_id,
|
|
||||||
query=None, # Prefill typically doesn't use query for selection
|
|
||||||
is_prefill=True,
|
|
||||||
block_size=kvcache_manager.block_size,
|
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
|
||||||
)
|
|
||||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
|
||||||
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
|
|
||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
load_slots = list(range(offload_engine.num_ring_slots))
|
load_slots = list(range(offload_engine.num_ring_slots))
|
||||||
@@ -200,16 +188,17 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
kvcache_manager: "KVCacheManager",
|
kvcache_manager: "KVCacheManager",
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute full attention for chunked decode.
|
Compute full attention for chunked decode.
|
||||||
|
|
||||||
This method handles the complete chunked decode flow:
|
This method handles the chunked decode computation:
|
||||||
1. Get prefilled CPU blocks
|
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
|
||||||
2. Apply select_blocks for block filtering
|
2. Read accumulated decode tokens from decode buffer
|
||||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
3. Merge all results
|
||||||
4. Read accumulated decode tokens from decode buffer
|
|
||||||
5. Merge all results
|
Note: Block selection is done by the caller before invoking this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [batch_size, num_heads, head_dim]
|
q: Query tensor [batch_size, num_heads, head_dim]
|
||||||
@@ -218,6 +207,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
offload_engine: OffloadEngine for loading blocks
|
offload_engine: OffloadEngine for loading blocks
|
||||||
kvcache_manager: KVCacheManager for block management
|
kvcache_manager: KVCacheManager for block management
|
||||||
seq: Sequence object
|
seq: Sequence object
|
||||||
|
selected_blocks: List of CPU block IDs to process (already filtered)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention output [batch_size, 1, num_heads, head_dim]
|
Attention output [batch_size, 1, num_heads, head_dim]
|
||||||
@@ -227,40 +217,35 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||||
|
|
||||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
# Use the pre-selected blocks directly
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = selected_blocks
|
||||||
if layer_id == 0:
|
if layer_id == 0:
|
||||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}")
|
||||||
if not cpu_block_table:
|
if not cpu_block_table:
|
||||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||||
|
|
||||||
# Calculate valid tokens in the last CPU block
|
# Calculate valid tokens in the last CPU block
|
||||||
# CRITICAL: Use original prefill length, not current seq length!
|
# CRITICAL: Use original prefill length, not current seq length!
|
||||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||||
|
# Note: We need to get all prefilled blocks to determine last_block_valid_tokens
|
||||||
block_size = kvcache_manager.block_size
|
block_size = kvcache_manager.block_size
|
||||||
num_prefill_blocks = len(cpu_block_table)
|
all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||||
last_block_valid_tokens = block_size # Last block was exactly full
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|
||||||
# Apply sparse policy (self) for block filtering
|
# Determine if selected_blocks contains the last prefilled block
|
||||||
policy_ctx = PolicyContext(
|
# If not, all selected blocks are full blocks (use block_size as valid tokens)
|
||||||
query_chunk_idx=0,
|
last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None
|
||||||
num_query_chunks=1,
|
selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block)
|
||||||
layer_id=layer_id,
|
effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size
|
||||||
query=q_batched,
|
|
||||||
is_prefill=False,
|
|
||||||
block_size=kvcache_manager.block_size,
|
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
|
||||||
)
|
|
||||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
|
||||||
|
|
||||||
# Use ring buffer pipeline for loading prefilled blocks
|
# Use ring buffer pipeline for loading prefilled blocks
|
||||||
load_slots = offload_engine.decode_load_slots
|
load_slots = offload_engine.decode_load_slots
|
||||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
block_size, last_block_valid_tokens, layer_id, softmax_scale
|
block_size, effective_last_block_tokens, layer_id, softmax_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||||
|
|||||||
@@ -204,17 +204,20 @@ class SparsePolicy(ABC):
|
|||||||
current_chunk_idx: int,
|
current_chunk_idx: int,
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute chunked prefill attention (complete flow).
|
Compute chunked prefill attention (complete flow).
|
||||||
|
|
||||||
This is the main entry point for prefill attention computation.
|
This is the main entry point for prefill attention computation.
|
||||||
It defines the complete prefill flow:
|
It defines the complete prefill flow:
|
||||||
1. Get historical blocks
|
1. Load and compute historical blocks via offload_engine (using selected_blocks)
|
||||||
2. Select blocks (call select_blocks)
|
2. Get current chunk KV from offload_engine, compute attention
|
||||||
3. Load and compute historical blocks via offload_engine
|
3. Merge all results
|
||||||
4. Get current chunk KV from offload_engine, compute attention
|
|
||||||
5. Merge all results
|
Note: Block selection (select_blocks) is called by the caller (attention.py)
|
||||||
|
before invoking this method. The selected_blocks parameter contains the
|
||||||
|
filtered block IDs to process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: [seq_len, num_heads, head_dim] query for current chunk
|
q: [seq_len, num_heads, head_dim] query for current chunk
|
||||||
@@ -227,6 +230,7 @@ class SparsePolicy(ABC):
|
|||||||
current_chunk_idx: current chunk index
|
current_chunk_idx: current chunk index
|
||||||
seq: Sequence object
|
seq: Sequence object
|
||||||
num_tokens: number of tokens in current chunk
|
num_tokens: number of tokens in current chunk
|
||||||
|
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[seq_len, num_heads, head_dim] final attention output
|
[seq_len, num_heads, head_dim] final attention output
|
||||||
@@ -242,17 +246,20 @@ class SparsePolicy(ABC):
|
|||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
kvcache_manager: "KVCacheManager",
|
kvcache_manager: "KVCacheManager",
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute chunked decode attention (complete flow).
|
Compute chunked decode attention (complete flow).
|
||||||
|
|
||||||
This is the main entry point for decode attention computation.
|
This is the main entry point for decode attention computation.
|
||||||
It defines the complete decode flow:
|
It defines the complete decode flow:
|
||||||
1. Get prefilled blocks from CPU
|
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
|
||||||
2. Select blocks (call select_blocks)
|
2. Read accumulated decode tokens from decode buffer
|
||||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
3. Merge all results
|
||||||
4. Read accumulated decode tokens from decode buffer
|
|
||||||
5. Merge all results
|
Note: Block selection (select_blocks) is called by the caller (attention.py)
|
||||||
|
before invoking this method. The selected_blocks parameter contains the
|
||||||
|
filtered block IDs to process.
|
||||||
|
|
||||||
The decode position information can be computed internally:
|
The decode position information can be computed internally:
|
||||||
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||||
@@ -265,6 +272,7 @@ class SparsePolicy(ABC):
|
|||||||
offload_engine: OffloadEngine for loading blocks
|
offload_engine: OffloadEngine for loading blocks
|
||||||
kvcache_manager: KVCacheManager for block management
|
kvcache_manager: KVCacheManager for block management
|
||||||
seq: Sequence object
|
seq: Sequence object
|
||||||
|
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[batch_size, 1, num_heads, head_dim] final attention output
|
[batch_size, 1, num_heads, head_dim] final attention output
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
current_chunk_idx: int,
|
current_chunk_idx: int,
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute attention for chunked prefill.
|
Compute attention for chunked prefill.
|
||||||
@@ -169,7 +170,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
# This is temporary until proper sparse implementation is ready
|
# This is temporary until proper sparse implementation is ready
|
||||||
return self._compute_dense_fallback(
|
return self._compute_dense_fallback(
|
||||||
q, k, v, layer_id, softmax_scale, offload_engine,
|
q, k, v, layer_id, softmax_scale, offload_engine,
|
||||||
kvcache_manager, current_chunk_idx, seq, num_tokens
|
kvcache_manager, current_chunk_idx, seq, num_tokens, selected_blocks
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_dense_fallback(
|
def _compute_dense_fallback(
|
||||||
@@ -184,22 +185,24 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
current_chunk_idx: int,
|
current_chunk_idx: int,
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Fallback to dense attention when BSA/XAttn not available.
|
Fallback to dense attention when BSA/XAttn not available.
|
||||||
Uses FullAttentionPolicy's proven pipeline.
|
Uses FullAttentionPolicy's proven pipeline with pre-selected blocks.
|
||||||
"""
|
"""
|
||||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}")
|
logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}, "
|
||||||
|
f"selected_blocks={len(selected_blocks)}")
|
||||||
|
|
||||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
compute_stream = offload_engine.compute_stream
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
# Get historical CPU blocks
|
# Use the pre-selected blocks directly
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = selected_blocks
|
||||||
|
|
||||||
# Process historical blocks using pipeline
|
# Process historical blocks using pipeline
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
@@ -282,6 +285,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
kvcache_manager: "KVCacheManager",
|
kvcache_manager: "KVCacheManager",
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
XAttention does not support decode phase.
|
XAttention does not support decode phase.
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from torch import nn
|
|||||||
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -197,11 +198,30 @@ class Attention(nn.Module):
|
|||||||
if sparse_policy is None:
|
if sparse_policy is None:
|
||||||
raise RuntimeError("sparse_policy is required for chunked prefill")
|
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||||||
|
|
||||||
|
# Step 1: Get historical CPU blocks
|
||||||
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
|
||||||
|
selected_blocks = []
|
||||||
|
if cpu_block_table:
|
||||||
|
num_chunks = current_chunk_idx + 1
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=current_chunk_idx,
|
||||||
|
num_query_chunks=num_chunks,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
query=q, # Pass query for sparse policies that need it
|
||||||
|
is_prefill=True,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||||
|
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||||
|
|
||||||
# [DEBUG] Verify execution path
|
# [DEBUG] Verify execution path
|
||||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||||||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||||
|
|
||||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
# Delegate computation to policy with pre-selected blocks
|
||||||
final_o = sparse_policy.compute_chunked_prefill(
|
final_o = sparse_policy.compute_chunked_prefill(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
self.layer_id,
|
self.layer_id,
|
||||||
@@ -211,6 +231,7 @@ class Attention(nn.Module):
|
|||||||
current_chunk_idx,
|
current_chunk_idx,
|
||||||
seq,
|
seq,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
|
selected_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
@@ -265,11 +286,29 @@ class Attention(nn.Module):
|
|||||||
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
|
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
|
||||||
f"falling back to FullAttentionPolicy")
|
f"falling back to FullAttentionPolicy")
|
||||||
|
|
||||||
|
# Step 1: Get prefilled CPU blocks
|
||||||
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
|
||||||
|
selected_blocks = []
|
||||||
|
if cpu_block_table:
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
query=q, # Pass query for sparse policies that need it
|
||||||
|
is_prefill=False,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||||
|
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||||
|
|
||||||
# [DEBUG] Verify execution path
|
# [DEBUG] Verify execution path
|
||||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||||||
f"policy={sparse_policy}, layer={self.layer_id}")
|
f"policy={sparse_policy}, layer={self.layer_id}")
|
||||||
|
|
||||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
# Delegate computation to policy with pre-selected blocks
|
||||||
return sparse_policy.compute_chunked_decode(
|
return sparse_policy.compute_chunked_decode(
|
||||||
q,
|
q,
|
||||||
self.layer_id,
|
self.layer_id,
|
||||||
@@ -277,4 +316,5 @@ class Attention(nn.Module):
|
|||||||
offload_engine,
|
offload_engine,
|
||||||
kvcache_manager,
|
kvcache_manager,
|
||||||
seq,
|
seq,
|
||||||
|
selected_blocks,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user