[WIP] Before add Quest policy.

This commit is contained in:
Zijie Tian
2026-01-07 02:32:30 +08:00
parent f240903013
commit c99a6f3d3f
11 changed files with 166 additions and 191 deletions

View File

@@ -134,7 +134,7 @@ class SparsePolicy(ABC):
"""
pass
def on_block_offloaded(
def on_prefill_offload(
self,
cpu_block_id: int,
layer_id: int,
@@ -142,15 +142,38 @@ class SparsePolicy(ABC):
num_valid_tokens: int,
) -> None:
"""
Hook called when a block is offloaded from GPU to CPU.
Hook called when a block is offloaded during prefill phase.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to collect metadata about blocks (e.g., min/max keys
for Quest-style selection). Default implementation does nothing.
Args:
cpu_block_id: The CPU block ID that was written
cpu_block_id: The CPU block ID that will be written
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
"""
pass
def on_decode_offload(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""
Hook called when a block is offloaded during decode phase.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to update metadata about blocks. Default implementation
does nothing.
Args:
cpu_block_id: The CPU block ID that will be written
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
"""
pass