[refactor] Refactor offload code to multi-chunk.

This commit is contained in:
Zijie Tian
2025-12-15 01:13:58 +08:00
parent 5949537faf
commit 1081ab51ea
7 changed files with 36 additions and 233 deletions

View File

@@ -17,6 +17,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
**ModelRunner** (`nanovllm/engine/model_runner.py`):
- Loads model weights, allocates KV cache, captures CUDA graphs
- Rank 0 is main process; ranks 1+ run via `loop()` with shared memory events
- Chunked offload methods: `run_chunked_offload_prefill()`, `run_chunked_offload_decode()`
**Scheduler** (`nanovllm/engine/scheduler.py`):
- Two-phase scheduling: prefill (waiting queue) then decode (running queue)
@@ -34,7 +35,8 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
**Global Context** (`nanovllm/utils/context.py`):
- Stores attention metadata via `get_context()`/`set_context()`
- Key fields: `cu_seqlens`, `slot_mapping`, `block_tables`, `chunked_seq`
- Key fields: `cu_seqlens`, `slot_mapping`, `block_tables`, `chunked_seq`, `kvcache_manager`
- `kvcache_manager`: Reference to HybridKVCacheManager for chunked attention (set when `is_chunked_prefill=True`)
## CPU Offload System
@@ -118,9 +120,10 @@ Compute: [C0] [C1] [C2]
Manages both GPU and CPU blocks:
- `allocate()`: Allocate GPU block first, fallback to CPU
- `allocate_cpu_only()`: Force CPU allocation (for chunked prefill)
- `allocate_cpu_only()`: Force CPU allocation (for chunked offload mode)
- `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence
- `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks
- `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot)
- `may_offload()`: Offload GPU blocks to CPU when decode slot fills
### Online Softmax Merge

View File

@@ -169,17 +169,17 @@ class ModelRunner:
)
if config.enable_cpu_offload:
ping_size = config.num_gpu_kvcache_blocks // 2
tokens_per_ping = ping_size * self.block_size
compute_size = config.num_gpu_kvcache_blocks // 2
tokens_per_chunk = compute_size * self.block_size
logger.info(
f"KV Cache allocated (Ping-Pong mode): "
f"KV Cache allocated (Chunked Offload mode): "
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
f"Total={total_memory_mb:.1f}MB"
)
logger.info(
f"Ping-Pong config: ping_size={ping_size} blocks, "
f"tokens_per_chunk={tokens_per_ping}, "
f"Chunked Offload config: compute_size={compute_size} blocks, "
f"tokens_per_chunk={tokens_per_chunk}, "
f"block_size={self.block_size}"
)
else:
@@ -374,14 +374,14 @@ class ModelRunner:
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
# Check if Ping-Pong mode should be used (all blocks on CPU)
# Check if Chunked Offload mode should be used (all blocks on CPU)
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
use_pingpong = self._should_use_pingpong(seqs, is_prefill)
if use_pingpong:
use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill)
if use_chunked_offload:
if is_prefill:
return self.run_pingpong_prefill(seqs)
return self.run_chunked_offload_prefill(seqs)
else:
return self.run_pingpong_decode(seqs)
return self.run_chunked_offload_decode(seqs)
# Check if chunked prefill is needed (legacy path)
if is_prefill and hasattr(self, 'kvcache_manager'):
@@ -410,7 +410,7 @@ class ModelRunner:
reset_context()
return token_ids
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
def _should_use_chunked_offload(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""
Check if three-region mode should be used.
@@ -553,7 +553,7 @@ class ModelRunner:
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager,
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
)
@@ -622,13 +622,13 @@ class ModelRunner:
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
offload_engine=self.kvcache_manager, # Pass manager for loading previous KV
kvcache_manager=self.kvcache_manager, # Pass manager for loading previous KV
chunked_seq=seq, # Pass sequence for loading previous KV
)
return input_ids, positions
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with three-region GPU buffer (CPU is primary storage).
@@ -679,7 +679,7 @@ class ModelRunner:
gpu_slots = offload_engine.compute_slots[:num_blocks]
# Prepare inputs
input_ids, positions = self._prepare_pingpong_chunk(
input_ids, positions = self._prepare_chunked_offload_chunk(
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
)
@@ -718,7 +718,7 @@ class ModelRunner:
return token_ids
def _prepare_pingpong_chunk(
def _prepare_chunked_offload_chunk(
self,
seq: Sequence,
chunk_start: int,
@@ -726,7 +726,7 @@ class ModelRunner:
gpu_slots: list[int],
start_block_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare inputs for a Ping-Pong prefill chunk."""
"""Prepare inputs for a chunked offload prefill chunk."""
# Input tokens for this chunk
input_ids = seq[chunk_start:chunk_end]
positions = list(range(chunk_start, chunk_end))
@@ -768,13 +768,13 @@ class ModelRunner:
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
offload_engine=self.kvcache_manager,
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
)
return input_ids, positions
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with three-region GPU buffer.
@@ -809,7 +809,7 @@ class ModelRunner:
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager,
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
decode_pos_in_block=pos_in_block,
decode_start_pos_in_block=decode_start_pos,

View File

@@ -336,7 +336,7 @@ class HybridKVCacheManager(KVCacheManager):
"""
Allocate logical blocks for prefill.
In cpu_primary mode (Ping-Pong): All blocks are allocated to CPU.
In cpu_primary mode (Chunked Offload): All blocks are allocated to CPU.
In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU.
"""
assert not seq.block_table, "Sequence already has blocks"
@@ -1167,9 +1167,9 @@ class HybridKVCacheManager(KVCacheManager):
return block.cpu_block_id
return -1
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
def get_write_slot_for_chunked_offload(self, seq: Sequence) -> int:
"""
Get GPU slot for writing new KV during three-region decode.
Get GPU slot for writing new KV during chunked offload decode.
In three-region design, always use Decode region (slot 0) to write new KV.
This avoids conflicts with Compute/Prefetch region loading operations.

View File

@@ -91,12 +91,6 @@ class OffloadEngine:
self.num_gpu_slots = num_gpu_blocks # alias
# Keep old ping/pong attributes for compatibility (will be removed later)
self.ping_size = self.num_compute_blocks
self.pong_size = self.num_prefetch_blocks
self.ping_slots = self.compute_slots.copy()
self.pong_slots = self.prefetch_slots.copy()
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
@@ -148,13 +142,6 @@ class OffloadEngine:
self.prefetch_ready = torch.cuda.Event()
self.decode_offload_done = torch.cuda.Event()
# Keep old ping/pong events for compatibility (will be removed later)
self.pingpong_stream = self.transfer_stream_main
self.ping_ready = self.compute_ready
self.pong_ready = self.prefetch_ready
self.ping_offload_done = torch.cuda.Event()
self.pong_offload_done = torch.cuda.Event()
# ========== Per-layer events for chunked attention ==========
# Each layer has its own event for synchronization
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
@@ -579,185 +566,9 @@ class OffloadEngine:
f")"
)
# ========== Ping-Pong double buffering methods ==========
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Ping buffer.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.ping_ready.record(self.pingpong_stream)
return
num_to_load = min(len(cpu_block_ids), self.ping_size)
logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}")
with torch.cuda.stream(self.pingpong_stream):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.ping_slots[i]
# Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.ping_ready.record(self.pingpong_stream)
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Pong buffer.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.pong_ready.record(self.pingpong_stream)
return
num_to_load = min(len(cpu_block_ids), self.pong_size)
logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}")
with torch.cuda.stream(self.pingpong_stream):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.pong_slots[i]
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.pong_ready.record(self.pingpong_stream)
def wait_ping(self) -> None:
"""Wait for Ping buffer loading to complete."""
self.compute_stream.wait_event(self.ping_ready)
def wait_pong(self) -> None:
"""Wait for Pong buffer loading to complete."""
self.compute_stream.wait_event(self.pong_ready)
def offload_buffer_to_cpu(
self,
buffer: str,
cpu_block_ids: List[int],
) -> None:
"""
Async offload KV from buffer to CPU.
Args:
buffer: "ping" or "pong"
cpu_block_ids: Target CPU block IDs list
"""
slots = self.ping_slots if buffer == "ping" else self.pong_slots
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
if not cpu_block_ids:
event.record(self.pingpong_stream)
return
num_to_offload = min(len(cpu_block_ids), len(slots))
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.pingpong_stream):
# Wait for compute to complete
self.pingpong_stream.wait_stream(self.compute_stream)
for i in range(num_to_offload):
gpu_slot = slots[i]
cpu_id = cpu_block_ids[i]
self.k_cache_cpu[:, cpu_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
event.record(self.pingpong_stream)
def offload_slot_to_cpu(
self,
gpu_slot: int,
cpu_block_id: int,
) -> None:
"""
Async offload a single GPU slot's KV to CPU.
Args:
gpu_slot: GPU slot ID
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.pingpong_stream):
self.pingpong_stream.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_block_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
def wait_ping_offload_done(self) -> None:
"""Wait for Ping buffer offload to complete."""
self.compute_stream.wait_event(self.ping_offload_done)
def wait_pong_offload_done(self) -> None:
"""Wait for Pong buffer offload to complete."""
self.compute_stream.wait_event(self.pong_offload_done)
def wait_all_offload_done(self) -> None:
"""Wait for all offload operations to complete."""
self.pingpong_stream.synchronize()
def get_kv_for_ping_slots(
self,
layer_id: int,
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of slots in Ping buffer.
Args:
layer_id: Layer ID
num_slots: Number of slots needed
Returns:
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.ping_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slots]
# Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_pong_slots(
self,
layer_id: int,
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of slots in Pong buffer.
Args:
layer_id: Layer ID
num_slots: Number of slots needed
Returns:
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.pong_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots]
v = self.v_cache_gpu[layer_id, slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
self.transfer_stream_main.synchronize()
def get_kv_for_slots(
self,
@@ -918,8 +729,6 @@ class OffloadEngine:
def swap_compute_prefetch(self) -> None:
"""Swap roles of Compute region and Prefetch region."""
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
# Also update old ping/pong slots for compatibility
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
def offload_decode_slot(self, cpu_block_id: int) -> None:
"""

View File

@@ -123,8 +123,7 @@ class Attention(nn.Module):
lse_acc = None
# Load previous KV from CPU using Compute/Prefetch region
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
@@ -215,7 +214,7 @@ class Attention(nn.Module):
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
kvcache_manager = context.offload_engine
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
# Get all CPU blocks for this sequence

View File

@@ -21,7 +21,7 @@ class Context:
# Current chunk's position offset (for causal mask)
chunk_offset: int = 0
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
offload_engine: Any = None
kvcache_manager: Any = None
# Current layer's previous K/V chunks (loaded from CPU)
# Set by model_runner before each layer's forward
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
@@ -33,14 +33,6 @@ class Context:
# Used when batching decode offloads - we need to attend to all accumulated tokens
decode_start_pos_in_block: int = 0
# ========== Per-layer chunked attention state ==========
# Whether chunked decode/prefill is currently active (for hooks to check)
chunked_decode_active: bool = False
# CPU block IDs for the current chunk being processed
chunked_decode_chunk_ids: List[int] = field(default_factory=list)
# Current chunk index being processed
chunked_decode_current_chunk: int = 0
_CONTEXT = Context()
@@ -61,7 +53,7 @@ def set_context(
is_chunked_prefill=False,
prev_kv_ranges=None,
chunk_offset=0,
offload_engine=None,
kvcache_manager=None,
chunked_seq=None,
decode_pos_in_block=0,
decode_start_pos_in_block=0,
@@ -79,7 +71,7 @@ def set_context(
is_chunked_prefill=is_chunked_prefill,
prev_kv_ranges=prev_kv_ranges or [],
chunk_offset=chunk_offset,
offload_engine=offload_engine,
kvcache_manager=kvcache_manager,
chunked_seq=chunked_seq,
decode_pos_in_block=decode_pos_in_block,
decode_start_pos_in_block=decode_start_pos_in_block,

View File

@@ -71,7 +71,7 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64, num_p
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
print(f"=" * 60)
print(f"Chunked Prefill Test (Ping-Pong)")
print(f"Chunked Prefill Test (Chunked Offload)")
print(f"=" * 60)
print(f" target_input_len: ~{input_len} tokens")
print(f" num_gpu_blocks: {num_gpu_blocks}")
@@ -111,7 +111,7 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128, num_p
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
print(f"=" * 60)
print(f"Chunked Decode Test (Ping-Pong)")
print(f"Chunked Decode Test (Chunked Offload)")
print(f"=" * 60)
print(f" target_input_len: ~{input_len} tokens")
print(f" output_len: {output_len} tokens")