[WIP] Before add Quest policy.
This commit is contained in:
@@ -56,14 +56,36 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
# Need CPU offload: use hybrid manager
|
||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||
from nanovllm.kvcache.policies import get_policy
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||
|
||||
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
|
||||
# Create sparse policies from config
|
||||
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
|
||||
decode_policy_type = getattr(config, 'decode_policy', 'full')
|
||||
|
||||
def create_policy(policy_type_str):
|
||||
"""Create a sparse policy from config string."""
|
||||
if policy_type_str.lower() == 'full':
|
||||
return create_sparse_policy(SparsePolicyType.FULL)
|
||||
policy_type = SparsePolicyType[policy_type_str.upper()]
|
||||
return create_sparse_policy(
|
||||
policy_type,
|
||||
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
||||
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
|
||||
)
|
||||
|
||||
prefill_policy = create_policy(prefill_policy_type)
|
||||
decode_policy = create_policy(decode_policy_type)
|
||||
|
||||
return HybridKVCacheManager(
|
||||
num_gpu_slots=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
policy=policy,
|
||||
policy=eviction_policy,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||
@@ -102,6 +104,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||
prefill_policy: Sparse attention policy for prefill phase
|
||||
decode_policy: Sparse attention policy for decode phase
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
@@ -113,6 +117,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
|
||||
# Sparse attention policies (set at construction time, immutable)
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
|
||||
# Logical blocks (what sequences reference) - one per CPU block
|
||||
self.logical_blocks: List[LogicalBlock] = [
|
||||
LogicalBlock(i) for i in range(self.total_blocks)
|
||||
@@ -153,9 +161,6 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
||||
self._prefill_len: Dict[int, int] = {}
|
||||
|
||||
# Sparse attention policy (optional)
|
||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
@@ -180,6 +185,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
prefill_policy=self.prefill_policy,
|
||||
decode_policy=self.decode_policy,
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
@@ -187,23 +194,17 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
assert self.offload_engine is not None
|
||||
return self.offload_engine.get_layer_cache(layer_id)
|
||||
|
||||
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
|
||||
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
|
||||
"""
|
||||
Set sparse attention policy for block selection.
|
||||
|
||||
The sparse policy determines which KV blocks to load from CPU
|
||||
for each query chunk during chunked attention computation.
|
||||
Get sparse policy for the specified phase.
|
||||
|
||||
Args:
|
||||
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
|
||||
is_prefill: True for prefill phase, False for decode phase
|
||||
|
||||
Example:
|
||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
|
||||
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
|
||||
manager.set_sparse_policy(policy)
|
||||
Returns:
|
||||
SparsePolicy for the phase, or None if not set
|
||||
"""
|
||||
self.sparse_policy = policy
|
||||
logger.info(f"Sparse attention policy set: {policy}")
|
||||
return self.prefill_policy if is_prefill else self.decode_policy
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""Check if we can allocate blocks for a new sequence."""
|
||||
|
||||
@@ -17,6 +17,11 @@ from nanovllm.kvcache.kernels import gathered_copy_kv
|
||||
from nanovllm.comm import memcpy_2d_async
|
||||
from nanovllm.utils.logger import get_logger
|
||||
|
||||
# Import for type hints only (avoid circular import)
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.sparse import SparsePolicy
|
||||
|
||||
logger = get_logger("offload_engine")
|
||||
|
||||
|
||||
@@ -55,6 +60,8 @@ class OffloadEngine:
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_streams: int = 4,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@@ -210,6 +217,10 @@ class OffloadEngine:
|
||||
self._debug_mode = False
|
||||
self._debug_hooks: List = [] # External hooks for debug events
|
||||
|
||||
# ========== Sparse attention policies (set at construction time) ==========
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
stream = self.transfer_streams[self._stream_idx]
|
||||
@@ -730,7 +741,14 @@ class OffloadEngine:
|
||||
"""Wait for slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||
|
||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
def offload_slot_layer_to_cpu(
|
||||
self,
|
||||
slot_idx: int,
|
||||
layer_id: int,
|
||||
cpu_block_id: int,
|
||||
num_valid_tokens: int = -1,
|
||||
is_prefill: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU for one layer.
|
||||
|
||||
@@ -741,9 +759,27 @@ class OffloadEngine:
|
||||
slot_idx: Source GPU slot index
|
||||
layer_id: Target layer in CPU cache
|
||||
cpu_block_id: Target CPU block ID
|
||||
num_valid_tokens: Number of valid tokens in this block (-1 = use block_size)
|
||||
is_prefill: True if in prefill phase, False if in decode phase
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||
|
||||
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
||||
# Both policies' callbacks are called - each decides whether to respond
|
||||
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||
k_cache = self.k_cache_gpu[slot_idx]
|
||||
|
||||
if is_prefill:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
else:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for both compute_stream and default stream
|
||||
|
||||
@@ -22,7 +22,6 @@ Usage:
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
||||
|
||||
|
||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||
@@ -67,6 +66,5 @@ __all__ = [
|
||||
"QuestPolicy",
|
||||
"QuestConfig",
|
||||
"BlockMetadataManager",
|
||||
"HybridPolicy",
|
||||
"create_sparse_policy",
|
||||
]
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
"""
|
||||
Hybrid sparse attention policy.
|
||||
|
||||
Allows using different policies for prefill vs decode phases.
|
||||
This is useful because optimal sparsity patterns often differ:
|
||||
- Prefill: fixed patterns work well (e.g., VerticalSlash)
|
||||
- Decode: query-aware selection helps (e.g., Quest)
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
|
||||
|
||||
class HybridPolicy(SparsePolicy):
|
||||
"""
|
||||
Hybrid policy that uses different policies for prefill and decode.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from nanovllm.kvcache.sparse import (
|
||||
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
|
||||
VerticalSlashConfig, QuestConfig, BlockMetadataManager
|
||||
)
|
||||
|
||||
# Prefill: use fast fixed pattern
|
||||
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
|
||||
num_sink_blocks=1,
|
||||
local_window_blocks=3,
|
||||
))
|
||||
|
||||
# Decode: use query-aware selection
|
||||
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
|
||||
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
|
||||
|
||||
# Combine
|
||||
policy = HybridPolicy(prefill_policy, decode_policy)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefill_policy: SparsePolicy,
|
||||
decode_policy: SparsePolicy,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid policy.
|
||||
|
||||
Args:
|
||||
prefill_policy: Policy to use during prefill phase
|
||||
decode_policy: Policy to use during decode phase
|
||||
"""
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Delegate to appropriate policy based on phase."""
|
||||
if ctx.is_prefill:
|
||||
return self.prefill_policy.select_blocks(available_blocks, ctx)
|
||||
else:
|
||||
return self.decode_policy.select_blocks(available_blocks, ctx)
|
||||
|
||||
def on_block_offloaded(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Forward to both policies (both may need metadata updates)."""
|
||||
self.prefill_policy.on_block_offloaded(
|
||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
||||
)
|
||||
self.decode_policy.on_block_offloaded(
|
||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset both policies."""
|
||||
self.prefill_policy.reset()
|
||||
self.decode_policy.reset()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"HybridPolicy(\n"
|
||||
f" prefill={self.prefill_policy},\n"
|
||||
f" decode={self.decode_policy}\n"
|
||||
f")"
|
||||
)
|
||||
@@ -134,7 +134,7 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_block_offloaded(
|
||||
def on_prefill_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
@@ -142,15 +142,38 @@ class SparsePolicy(ABC):
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Hook called when a block is offloaded from GPU to CPU.
|
||||
Hook called when a block is offloaded during prefill phase.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to collect metadata about blocks (e.g., min/max keys
|
||||
for Quest-style selection). Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that was written
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_decode_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Hook called when a block is offloaded during decode phase.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to update metadata about blocks. Default implementation
|
||||
does nothing.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -289,14 +289,25 @@ class QuestPolicy(SparsePolicy):
|
||||
|
||||
return result
|
||||
|
||||
def on_block_offloaded(
|
||||
def on_prefill_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Update min/max key metadata when block is offloaded."""
|
||||
"""Update min/max key metadata during prefill offload."""
|
||||
if self.metadata is not None:
|
||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
|
||||
def on_decode_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Update min/max key metadata during decode offload (for new blocks)."""
|
||||
if self.metadata is not None:
|
||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user