WIP: Enhance sparse attention with density tracking and block selection improvements
- Added analysis documentation for xattn density alignment. - Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration. - Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection. - Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks. - Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling. - Introduced DensityObserver to track compute and communication density for sparse attention layers. - Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios. - Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select Top-K blocks based on query-key similarity bounds.
|
||||
|
||||
If query is not available (some prefill scenarios), falls back
|
||||
to loading all blocks.
|
||||
|
||||
Args:
|
||||
available_blocks: List of CPU block IDs
|
||||
offload_engine: OffloadEngine for loading KV (unused in Quest)
|
||||
ctx: PolicyContext with metadata
|
||||
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
|
||||
|
||||
Returns:
|
||||
Selected block IDs
|
||||
"""
|
||||
if self.metadata is None:
|
||||
raise RuntimeError(
|
||||
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
|
||||
if n <= self.config.threshold_blocks:
|
||||
return available_blocks
|
||||
|
||||
if ctx.query is None:
|
||||
if q is None:
|
||||
# No query available - cannot compute scores
|
||||
return available_blocks
|
||||
|
||||
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
|
||||
)
|
||||
|
||||
# Metadata is already on GPU, same device as query
|
||||
device = ctx.query.device
|
||||
device = q.device
|
||||
|
||||
# Compute upper bound scores
|
||||
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
||||
q = ctx.query
|
||||
# query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
|
||||
if q.dim() == 4:
|
||||
# Prefill: use mean over sequence length
|
||||
q = q.mean(dim=1) # [1, num_heads, head_dim]
|
||||
|
||||
Reference in New Issue
Block a user