♻️ refactor: migrate chunked decode attention to SparsePolicy
Move decode attention computation from attention.py to SparsePolicy: - Add compute_chunked_decode abstract method to SparsePolicy base class - Implement compute_chunked_decode in FullAttentionPolicy with: - Ring buffer pipeline (_decode_ring_buffer_pipeline) - Cross-layer pipeline (_decode_with_layer_pipeline) - Decode buffer handling - Simplify _chunked_decode_attention to only validate and delegate - Remove _decode_ring_buffer_pipeline and _decode_with_layer_pipeline from attention.py - Add supports_decode check for policy validation This completes the SparsePolicy v5 refactoring where both prefill and decode paths now delegate all computation to the sparse policy. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -233,5 +233,43 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute chunked decode attention (complete flow).
|
||||
|
||||
This is the main entry point for decode attention computation.
|
||||
It defines the complete decode flow:
|
||||
1. Get prefilled blocks from CPU
|
||||
2. Select blocks (call select_blocks)
|
||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||
4. Read accumulated decode tokens from decode buffer
|
||||
5. Merge all results
|
||||
|
||||
The decode position information can be computed internally:
|
||||
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||
- decode_pos_in_block = (len(seq) - 1) % kvcache_manager.block_size
|
||||
|
||||
Args:
|
||||
q: [batch_size, num_heads, head_dim] query for decode token
|
||||
layer_id: transformer layer index
|
||||
softmax_scale: softmax scaling factor
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
seq: Sequence object
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, num_heads, head_dim] final attention output
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
Reference in New Issue
Block a user