[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -56,36 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
# Need CPU offload: use hybrid manager
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.policies import get_policy
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
from nanovllm.kvcache.sparse import create_sparse_policy
from nanovllm.config import SparsePolicyType
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
# Create sparse policies from config
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
decode_policy_type = getattr(config, 'decode_policy', 'full')
def create_policy(policy_type_str):
"""Create a sparse policy from config string."""
if policy_type_str.lower() == 'full':
return create_sparse_policy(SparsePolicyType.FULL)
policy_type = SparsePolicyType[policy_type_str.upper()]
return create_sparse_policy(
policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
)
prefill_policy = create_policy(prefill_policy_type)
decode_policy = create_policy(decode_policy_type)
# Create sparse policy from config enum
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
sparse_policy = create_sparse_policy(
sparse_policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
)
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=config.kvcache_block_size,
policy=eviction_policy,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
sparse_policy=sparse_policy,
)

View File

@@ -90,8 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
prefill_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
sparse_policy: "SparsePolicy" = None,
):
"""
Initialize hybrid manager with CPU-primary ring buffer design.
@@ -104,8 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU, used for prefix cache management)
prefill_policy: Sparse attention policy for prefill phase
decode_policy: Sparse attention policy for decode phase
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
@@ -117,9 +115,8 @@ class HybridKVCacheManager(KVCacheManager):
# Eviction policy
self.policy = policy or LRUPolicy()
# Sparse attention policies (set at construction time, immutable)
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
# Sparse attention policy (set at construction time, immutable)
self.sparse_policy = sparse_policy
# Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [
@@ -185,8 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
prefill_policy=self.prefill_policy,
decode_policy=self.decode_policy,
sparse_policy=self.sparse_policy,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
@@ -194,18 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id)
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
"""
Get sparse policy for the specified phase.
Args:
is_prefill: True for prefill phase, False for decode phase
Returns:
SparsePolicy for the phase, or None if not set
"""
return self.prefill_policy if is_prefill else self.decode_policy
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence."""
return len(self.free_logical_ids) >= seq.num_blocks

View File

@@ -60,8 +60,7 @@ class OffloadEngine:
head_dim: int,
dtype: torch.dtype = torch.float16,
num_streams: int = 4,
prefill_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
sparse_policy: "SparsePolicy" = None,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
@@ -217,9 +216,8 @@ class OffloadEngine:
self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events
# ========== Sparse attention policies (set at construction time) ==========
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
# ========== Sparse attention policy (set at construction time) ==========
self.sparse_policy = sparse_policy
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
@@ -765,20 +763,14 @@ class OffloadEngine:
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
# Collect metadata BEFORE offload (while k_cache is still on GPU)
# Both policies' callbacks are called - each decides whether to respond
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
k_cache = self.k_cache_gpu[slot_idx]
if is_prefill:
if self.prefill_policy is not None:
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else:
if self.prefill_policy is not None:
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.sparse_policy is not None:
if is_prefill:
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else:
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):

View File

@@ -19,7 +19,8 @@ Usage:
return available_blocks[:5] # Just first 5 blocks
"""
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager

View File

@@ -7,15 +7,11 @@ from CPU for each query chunk during chunked attention computation.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional, Any
import torch
class SparsePolicyType(Enum):
"""Built-in sparse attention policy types."""
FULL = auto() # prefill + decode
QUEST = auto() # decode only
# Import SparsePolicyType from config to avoid circular imports
from nanovllm.config import SparsePolicyType
@dataclass