[WIP] Before refactor policies.
This commit is contained in:
@@ -147,22 +147,40 @@ class QuestPolicy(SparsePolicy):
|
||||
This upper bound is derived from the fact that for any key k in
|
||||
the block: min_k <= k <= max_k (element-wise), so the actual
|
||||
attention score is bounded by the maximum of the two extremes.
|
||||
|
||||
Note: This is a decode-only policy. For prefill, use FullAttentionPolicy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: QuestConfig,
|
||||
metadata_manager: BlockMetadataManager,
|
||||
):
|
||||
# Quest is decode-only
|
||||
supports_prefill = False
|
||||
supports_decode = True
|
||||
|
||||
def __init__(self, config: QuestConfig):
|
||||
"""
|
||||
Initialize Quest policy.
|
||||
|
||||
Args:
|
||||
config: QuestConfig with selection parameters
|
||||
metadata_manager: BlockMetadataManager for min/max key storage
|
||||
"""
|
||||
self.config = config
|
||||
self.metadata = metadata_manager
|
||||
self.metadata: Optional[BlockMetadataManager] = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Create BlockMetadataManager for storing min/max keys."""
|
||||
self.metadata = BlockMetadataManager(
|
||||
num_blocks=num_cpu_blocks,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
@@ -175,6 +193,12 @@ class QuestPolicy(SparsePolicy):
|
||||
If query is not available (some prefill scenarios), falls back
|
||||
to loading all blocks.
|
||||
"""
|
||||
if self.metadata is None:
|
||||
raise RuntimeError(
|
||||
"QuestPolicy not initialized. Call initialize() first or "
|
||||
"let the framework call it during KV cache allocation."
|
||||
)
|
||||
|
||||
n = len(available_blocks)
|
||||
|
||||
# If below threshold or no query, load all
|
||||
@@ -269,11 +293,13 @@ class QuestPolicy(SparsePolicy):
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Update min/max key metadata when block is offloaded."""
|
||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
if self.metadata is not None:
|
||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset metadata."""
|
||||
self.metadata.reset()
|
||||
if self.metadata is not None:
|
||||
self.metadata.reset()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user