♻️ refactor: move select_blocks from policy to attention layer

Move block selection logic from compute_chunked_prefill/decode methods
to attention.py caller. This improves separation of concerns:

- attention.py now calls select_blocks() before compute_chunked_*()
- Policy methods receive pre-selected blocks via selected_blocks parameter
- Enables sparse policies to implement custom block selection without
  modifying the compute path

Changes:
- policy.py: Add selected_blocks parameter to abstract methods
- full_policy.py: Remove internal select_blocks calls, use passed blocks
- xattn_bsa.py: Sync signatures for prefill/decode methods
- attention.py: Add select_blocks calls before policy delegation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-23 05:21:28 +08:00
parent ca32ea6f93
commit a50b4c2ac2
4 changed files with 100 additions and 63 deletions

View File

@@ -204,17 +204,20 @@ class SparsePolicy(ABC):
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute chunked prefill attention (complete flow).
This is the main entry point for prefill attention computation.
It defines the complete prefill flow:
1. Get historical blocks
2. Select blocks (call select_blocks)
3. Load and compute historical blocks via offload_engine
4. Get current chunk KV from offload_engine, compute attention
5. Merge all results
1. Load and compute historical blocks via offload_engine (using selected_blocks)
2. Get current chunk KV from offload_engine, compute attention
3. Merge all results
Note: Block selection (select_blocks) is called by the caller (attention.py)
before invoking this method. The selected_blocks parameter contains the
filtered block IDs to process.
Args:
q: [seq_len, num_heads, head_dim] query for current chunk
@@ -227,6 +230,7 @@ class SparsePolicy(ABC):
current_chunk_idx: current chunk index
seq: Sequence object
num_tokens: number of tokens in current chunk
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
Returns:
[seq_len, num_heads, head_dim] final attention output
@@ -242,17 +246,20 @@ class SparsePolicy(ABC):
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute chunked decode attention (complete flow).
This is the main entry point for decode attention computation.
It defines the complete decode flow:
1. Get prefilled blocks from CPU
2. Select blocks (call select_blocks)
3. Load blocks via pipeline (ring buffer or cross-layer)
4. Read accumulated decode tokens from decode buffer
5. Merge all results
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
2. Read accumulated decode tokens from decode buffer
3. Merge all results
Note: Block selection (select_blocks) is called by the caller (attention.py)
before invoking this method. The selected_blocks parameter contains the
filtered block IDs to process.
The decode position information can be computed internally:
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
@@ -265,6 +272,7 @@ class SparsePolicy(ABC):
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
seq: Sequence object
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
Returns:
[batch_size, 1, num_heads, head_dim] final attention output