[WIP] Before add Quest policy.
This commit is contained in:
@@ -17,6 +17,11 @@ from nanovllm.kvcache.kernels import gathered_copy_kv
|
||||
from nanovllm.comm import memcpy_2d_async
|
||||
from nanovllm.utils.logger import get_logger
|
||||
|
||||
# Import for type hints only (avoid circular import)
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.sparse import SparsePolicy
|
||||
|
||||
logger = get_logger("offload_engine")
|
||||
|
||||
|
||||
@@ -55,6 +60,8 @@ class OffloadEngine:
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_streams: int = 4,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@@ -210,6 +217,10 @@ class OffloadEngine:
|
||||
self._debug_mode = False
|
||||
self._debug_hooks: List = [] # External hooks for debug events
|
||||
|
||||
# ========== Sparse attention policies (set at construction time) ==========
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
stream = self.transfer_streams[self._stream_idx]
|
||||
@@ -730,7 +741,14 @@ class OffloadEngine:
|
||||
"""Wait for slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||
|
||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
def offload_slot_layer_to_cpu(
|
||||
self,
|
||||
slot_idx: int,
|
||||
layer_id: int,
|
||||
cpu_block_id: int,
|
||||
num_valid_tokens: int = -1,
|
||||
is_prefill: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU for one layer.
|
||||
|
||||
@@ -741,9 +759,27 @@ class OffloadEngine:
|
||||
slot_idx: Source GPU slot index
|
||||
layer_id: Target layer in CPU cache
|
||||
cpu_block_id: Target CPU block ID
|
||||
num_valid_tokens: Number of valid tokens in this block (-1 = use block_size)
|
||||
is_prefill: True if in prefill phase, False if in decode phase
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||
|
||||
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
||||
# Both policies' callbacks are called - each decides whether to respond
|
||||
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||
k_cache = self.k_cache_gpu[slot_idx]
|
||||
|
||||
if is_prefill:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
else:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for both compute_stream and default stream
|
||||
|
||||
Reference in New Issue
Block a user