[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -60,8 +60,7 @@ class OffloadEngine:
head_dim: int,
dtype: torch.dtype = torch.float16,
num_streams: int = 4,
prefill_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
sparse_policy: "SparsePolicy" = None,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
@@ -217,9 +216,8 @@ class OffloadEngine:
self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events
# ========== Sparse attention policies (set at construction time) ==========
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
# ========== Sparse attention policy (set at construction time) ==========
self.sparse_policy = sparse_policy
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
@@ -765,20 +763,14 @@ class OffloadEngine:
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
# Collect metadata BEFORE offload (while k_cache is still on GPU)
# Both policies' callbacks are called - each decides whether to respond
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
k_cache = self.k_cache_gpu[slot_idx]
if is_prefill:
if self.prefill_policy is not None:
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else:
if self.prefill_policy is not None:
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.sparse_policy is not None:
if is_prefill:
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else:
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):