[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -90,8 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
prefill_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
sparse_policy: "SparsePolicy" = None,
):
"""
Initialize hybrid manager with CPU-primary ring buffer design.
@@ -104,8 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU, used for prefix cache management)
prefill_policy: Sparse attention policy for prefill phase
decode_policy: Sparse attention policy for decode phase
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
@@ -117,9 +115,8 @@ class HybridKVCacheManager(KVCacheManager):
# Eviction policy
self.policy = policy or LRUPolicy()
# Sparse attention policies (set at construction time, immutable)
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
# Sparse attention policy (set at construction time, immutable)
self.sparse_policy = sparse_policy
# Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [
@@ -185,8 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
prefill_policy=self.prefill_policy,
decode_policy=self.decode_policy,
sparse_policy=self.sparse_policy,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
@@ -194,18 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id)
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
"""
Get sparse policy for the specified phase.
Args:
is_prefill: True for prefill phase, False for decode phase
Returns:
SparsePolicy for the phase, or None if not set
"""
return self.prefill_policy if is_prefill else self.decode_policy
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence."""
return len(self.free_logical_ids) >= seq.num_blocks