WIP: Enhance sparse attention with density tracking and block selection improvements
- Added analysis documentation for xattn density alignment. - Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration. - Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection. - Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks. - Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling. - Introduced DensityObserver to track compute and communication density for sparse attention layers. - Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios. - Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
@@ -142,6 +142,8 @@ class SparsePolicy(ABC):
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select which KV blocks to load for the current query chunk.
|
||||
@@ -158,6 +160,8 @@ class SparsePolicy(ABC):
|
||||
to load KV to make selection decisions).
|
||||
ctx: PolicyContext with information about the current query
|
||||
chunk, layer, phase (prefill/decode), etc.
|
||||
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
|
||||
|
||||
Returns:
|
||||
List of block IDs to load (must be a subset of available_blocks).
|
||||
|
||||
Reference in New Issue
Block a user