♻️ refactor: move select_blocks from policy to attention layer
Move block selection logic from compute_chunked_prefill/decode methods to attention.py caller. This improves separation of concerns: - attention.py now calls select_blocks() before compute_chunked_*() - Policy methods receive pre-selected blocks via selected_blocks parameter - Enables sparse policies to implement custom block selection without modifying the compute path Changes: - policy.py: Add selected_blocks parameter to abstract methods - full_policy.py: Remove internal select_blocks calls, use passed blocks - xattn_bsa.py: Sync signatures for prefill/decode methods - attention.py: Add select_blocks calls before policy delegation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -204,17 +204,20 @@ class SparsePolicy(ABC):
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute chunked prefill attention (complete flow).
|
||||
|
||||
This is the main entry point for prefill attention computation.
|
||||
It defines the complete prefill flow:
|
||||
1. Get historical blocks
|
||||
2. Select blocks (call select_blocks)
|
||||
3. Load and compute historical blocks via offload_engine
|
||||
4. Get current chunk KV from offload_engine, compute attention
|
||||
5. Merge all results
|
||||
1. Load and compute historical blocks via offload_engine (using selected_blocks)
|
||||
2. Get current chunk KV from offload_engine, compute attention
|
||||
3. Merge all results
|
||||
|
||||
Note: Block selection (select_blocks) is called by the caller (attention.py)
|
||||
before invoking this method. The selected_blocks parameter contains the
|
||||
filtered block IDs to process.
|
||||
|
||||
Args:
|
||||
q: [seq_len, num_heads, head_dim] query for current chunk
|
||||
@@ -227,6 +230,7 @@ class SparsePolicy(ABC):
|
||||
current_chunk_idx: current chunk index
|
||||
seq: Sequence object
|
||||
num_tokens: number of tokens in current chunk
|
||||
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
|
||||
|
||||
Returns:
|
||||
[seq_len, num_heads, head_dim] final attention output
|
||||
@@ -242,17 +246,20 @@ class SparsePolicy(ABC):
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute chunked decode attention (complete flow).
|
||||
|
||||
This is the main entry point for decode attention computation.
|
||||
It defines the complete decode flow:
|
||||
1. Get prefilled blocks from CPU
|
||||
2. Select blocks (call select_blocks)
|
||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||
4. Read accumulated decode tokens from decode buffer
|
||||
5. Merge all results
|
||||
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
|
||||
2. Read accumulated decode tokens from decode buffer
|
||||
3. Merge all results
|
||||
|
||||
Note: Block selection (select_blocks) is called by the caller (attention.py)
|
||||
before invoking this method. The selected_blocks parameter contains the
|
||||
filtered block IDs to process.
|
||||
|
||||
The decode position information can be computed internally:
|
||||
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||
@@ -265,6 +272,7 @@ class SparsePolicy(ABC):
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
seq: Sequence object
|
||||
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, num_heads, head_dim] final attention output
|
||||
|
||||
Reference in New Issue
Block a user