[WIP] Before plan execute.

This commit is contained in:
Zijie Tian
2026-01-19 03:30:44 +08:00
parent e6e0dc5d7d
commit 9e6fdc0650
5 changed files with 377 additions and 10 deletions

View File

@@ -7,8 +7,8 @@ import torch
class SparsePolicyType(Enum):
"""Sparse attention policy types."""
FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only)
FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only)
@dataclass

View File

@@ -64,11 +64,16 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
# Create sparse policy from config enum
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
sparse_policy = create_sparse_policy(
sparse_policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
)
# Build policy kwargs based on policy type
policy_kwargs = {}
if sparse_policy_type == SparsePolicyType.QUEST:
policy_kwargs = {
'topk_blocks': getattr(config, 'sparse_topk_blocks', 8),
'threshold_blocks': getattr(config, 'sparse_threshold_blocks', 4),
}
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,

View File

@@ -35,8 +35,8 @@ class PolicyContext:
query: Optional[torch.Tensor]
"""
Query tensor for current chunk.
Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill.
May be None if not available (e.g., some prefill scenarios).
Shape: [1, num_heads, head_dim] for decode, [seq_len, num_heads, head_dim] for prefill.
Available for both prefill and decode phases.
"""
is_prefill: bool

View File

@@ -207,8 +207,10 @@ class Attention(nn.Module):
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
# Apply sparse policy if enabled
sparse_policy = kvcache_manager.sparse_policy
# === Standard sparse policy (Quest, etc.) ===
if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(