[feat] Added Quest Sparsity Policy.
This commit is contained in:
@@ -56,36 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
# Need CPU offload: use hybrid manager
|
||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||
from nanovllm.kvcache.policies import get_policy
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||
from nanovllm.config import SparsePolicyType
|
||||
|
||||
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
|
||||
# Create sparse policies from config
|
||||
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
|
||||
decode_policy_type = getattr(config, 'decode_policy', 'full')
|
||||
|
||||
def create_policy(policy_type_str):
|
||||
"""Create a sparse policy from config string."""
|
||||
if policy_type_str.lower() == 'full':
|
||||
return create_sparse_policy(SparsePolicyType.FULL)
|
||||
policy_type = SparsePolicyType[policy_type_str.upper()]
|
||||
return create_sparse_policy(
|
||||
policy_type,
|
||||
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
||||
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
|
||||
)
|
||||
|
||||
prefill_policy = create_policy(prefill_policy_type)
|
||||
decode_policy = create_policy(decode_policy_type)
|
||||
# Create sparse policy from config enum
|
||||
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
|
||||
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
|
||||
sparse_policy = create_sparse_policy(
|
||||
sparse_policy_type,
|
||||
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
||||
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||
)
|
||||
|
||||
return HybridKVCacheManager(
|
||||
num_gpu_slots=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
policy=eviction_policy,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
sparse_policy=sparse_policy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -90,8 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
sparse_policy: "SparsePolicy" = None,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||
@@ -104,8 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||
prefill_policy: Sparse attention policy for prefill phase
|
||||
decode_policy: Sparse attention policy for decode phase
|
||||
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
@@ -117,9 +115,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
|
||||
# Sparse attention policies (set at construction time, immutable)
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
# Sparse attention policy (set at construction time, immutable)
|
||||
self.sparse_policy = sparse_policy
|
||||
|
||||
# Logical blocks (what sequences reference) - one per CPU block
|
||||
self.logical_blocks: List[LogicalBlock] = [
|
||||
@@ -185,8 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
prefill_policy=self.prefill_policy,
|
||||
decode_policy=self.decode_policy,
|
||||
sparse_policy=self.sparse_policy,
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
@@ -194,18 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
assert self.offload_engine is not None
|
||||
return self.offload_engine.get_layer_cache(layer_id)
|
||||
|
||||
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
|
||||
"""
|
||||
Get sparse policy for the specified phase.
|
||||
|
||||
Args:
|
||||
is_prefill: True for prefill phase, False for decode phase
|
||||
|
||||
Returns:
|
||||
SparsePolicy for the phase, or None if not set
|
||||
"""
|
||||
return self.prefill_policy if is_prefill else self.decode_policy
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""Check if we can allocate blocks for a new sequence."""
|
||||
return len(self.free_logical_ids) >= seq.num_blocks
|
||||
|
||||
@@ -60,8 +60,7 @@ class OffloadEngine:
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_streams: int = 4,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
sparse_policy: "SparsePolicy" = None,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@@ -217,9 +216,8 @@ class OffloadEngine:
|
||||
self._debug_mode = False
|
||||
self._debug_hooks: List = [] # External hooks for debug events
|
||||
|
||||
# ========== Sparse attention policies (set at construction time) ==========
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
# ========== Sparse attention policy (set at construction time) ==========
|
||||
self.sparse_policy = sparse_policy
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
@@ -765,20 +763,14 @@ class OffloadEngine:
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||
|
||||
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
||||
# Both policies' callbacks are called - each decides whether to respond
|
||||
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||
k_cache = self.k_cache_gpu[slot_idx]
|
||||
|
||||
if is_prefill:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
else:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.sparse_policy is not None:
|
||||
if is_prefill:
|
||||
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
else:
|
||||
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
|
||||
@@ -19,7 +19,8 @@ Usage:
|
||||
return available_blocks[:5] # Just first 5 blocks
|
||||
"""
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
||||
from nanovllm.config import SparsePolicyType
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
|
||||
|
||||
@@ -7,15 +7,11 @@ from CPU for each query chunk during chunked attention computation.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List, Optional, Any
|
||||
import torch
|
||||
|
||||
|
||||
class SparsePolicyType(Enum):
|
||||
"""Built-in sparse attention policy types."""
|
||||
FULL = auto() # prefill + decode
|
||||
QUEST = auto() # decode only
|
||||
# Import SparsePolicyType from config to avoid circular imports
|
||||
from nanovllm.config import SparsePolicyType
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user