[WIP] Before add Quest policy.
This commit is contained in:
@@ -134,7 +134,7 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_block_offloaded(
|
||||
def on_prefill_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
@@ -142,15 +142,38 @@ class SparsePolicy(ABC):
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Hook called when a block is offloaded from GPU to CPU.
|
||||
Hook called when a block is offloaded during prefill phase.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to collect metadata about blocks (e.g., min/max keys
|
||||
for Quest-style selection). Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that was written
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_decode_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Hook called when a block is offloaded during decode phase.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to update metadata about blocks. Default implementation
|
||||
does nothing.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
"""
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user