[WIP] Before refactor policies.

This commit is contained in:
Zijie Tian
2026-01-06 20:47:55 +08:00
parent 7cc8a394a5
commit 690492e074
6 changed files with 112 additions and 237 deletions

View File

@@ -147,22 +147,40 @@ class QuestPolicy(SparsePolicy):
This upper bound is derived from the fact that for any key k in
the block: min_k <= k <= max_k (element-wise), so the actual
attention score is bounded by the maximum of the two extremes.
Note: This is a decode-only policy. For prefill, use FullAttentionPolicy.
"""
def __init__(
self,
config: QuestConfig,
metadata_manager: BlockMetadataManager,
):
# Quest is decode-only
supports_prefill = False
supports_decode = True
def __init__(self, config: QuestConfig):
"""
Initialize Quest policy.
Args:
config: QuestConfig with selection parameters
metadata_manager: BlockMetadataManager for min/max key storage
"""
self.config = config
self.metadata = metadata_manager
self.metadata: Optional[BlockMetadataManager] = None
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
) -> None:
"""Create BlockMetadataManager for storing min/max keys."""
self.metadata = BlockMetadataManager(
num_blocks=num_cpu_blocks,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
def select_blocks(
self,
@@ -175,6 +193,12 @@ class QuestPolicy(SparsePolicy):
If query is not available (some prefill scenarios), falls back
to loading all blocks.
"""
if self.metadata is None:
raise RuntimeError(
"QuestPolicy not initialized. Call initialize() first or "
"let the framework call it during KV cache allocation."
)
n = len(available_blocks)
# If below threshold or no query, load all
@@ -269,11 +293,13 @@ class QuestPolicy(SparsePolicy):
num_valid_tokens: int,
) -> None:
"""Update min/max key metadata when block is offloaded."""
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
if self.metadata is not None:
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
def reset(self) -> None:
"""Reset metadata."""
self.metadata.reset()
if self.metadata is not None:
self.metadata.reset()
def __repr__(self) -> str:
return (