WIP: Enhance sparse attention with density tracking and block selection improvements

- Added analysis documentation for xattn density alignment.
- Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration.
- Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection.
- Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks.
- Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling.
- Introduced DensityObserver to track compute and communication density for sparse attention layers.
- Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios.
- Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
Zijie Tian
2026-01-31 14:48:23 +08:00
parent f6ac4ccdde
commit 2e96d1d97d
9 changed files with 490 additions and 152 deletions

View File

@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""
Select Top-K blocks based on query-key similarity bounds.
If query is not available (some prefill scenarios), falls back
to loading all blocks.
Args:
available_blocks: List of CPU block IDs
offload_engine: OffloadEngine for loading KV (unused in Quest)
ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
Returns:
Selected block IDs
"""
if self.metadata is None:
raise RuntimeError(
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
if n <= self.config.threshold_blocks:
return available_blocks
if ctx.query is None:
if q is None:
# No query available - cannot compute scores
return available_blocks
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
)
# Metadata is already on GPU, same device as query
device = ctx.query.device
device = q.device
# Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
q = ctx.query
# query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
if q.dim() == 4:
# Prefill: use mean over sequence length
q = q.mean(dim=1) # [1, num_heads, head_dim]