[refactor] Refactor offload code to multi-chunk.

This commit is contained in:
Zijie Tian
2025-12-15 01:13:58 +08:00
parent 5949537faf
commit 1081ab51ea
7 changed files with 36 additions and 233 deletions

View File

@@ -169,17 +169,17 @@ class ModelRunner:
)
if config.enable_cpu_offload:
ping_size = config.num_gpu_kvcache_blocks // 2
tokens_per_ping = ping_size * self.block_size
compute_size = config.num_gpu_kvcache_blocks // 2
tokens_per_chunk = compute_size * self.block_size
logger.info(
f"KV Cache allocated (Ping-Pong mode): "
f"KV Cache allocated (Chunked Offload mode): "
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
f"Total={total_memory_mb:.1f}MB"
)
logger.info(
f"Ping-Pong config: ping_size={ping_size} blocks, "
f"tokens_per_chunk={tokens_per_ping}, "
f"Chunked Offload config: compute_size={compute_size} blocks, "
f"tokens_per_chunk={tokens_per_chunk}, "
f"block_size={self.block_size}"
)
else:
@@ -374,14 +374,14 @@ class ModelRunner:
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
# Check if Ping-Pong mode should be used (all blocks on CPU)
# Check if Chunked Offload mode should be used (all blocks on CPU)
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
use_pingpong = self._should_use_pingpong(seqs, is_prefill)
if use_pingpong:
use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill)
if use_chunked_offload:
if is_prefill:
return self.run_pingpong_prefill(seqs)
return self.run_chunked_offload_prefill(seqs)
else:
return self.run_pingpong_decode(seqs)
return self.run_chunked_offload_decode(seqs)
# Check if chunked prefill is needed (legacy path)
if is_prefill and hasattr(self, 'kvcache_manager'):
@@ -410,7 +410,7 @@ class ModelRunner:
reset_context()
return token_ids
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
def _should_use_chunked_offload(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""
Check if three-region mode should be used.
@@ -553,7 +553,7 @@ class ModelRunner:
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager,
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
)
@@ -622,13 +622,13 @@ class ModelRunner:
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
offload_engine=self.kvcache_manager, # Pass manager for loading previous KV
kvcache_manager=self.kvcache_manager, # Pass manager for loading previous KV
chunked_seq=seq, # Pass sequence for loading previous KV
)
return input_ids, positions
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with three-region GPU buffer (CPU is primary storage).
@@ -679,7 +679,7 @@ class ModelRunner:
gpu_slots = offload_engine.compute_slots[:num_blocks]
# Prepare inputs
input_ids, positions = self._prepare_pingpong_chunk(
input_ids, positions = self._prepare_chunked_offload_chunk(
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
)
@@ -718,7 +718,7 @@ class ModelRunner:
return token_ids
def _prepare_pingpong_chunk(
def _prepare_chunked_offload_chunk(
self,
seq: Sequence,
chunk_start: int,
@@ -726,7 +726,7 @@ class ModelRunner:
gpu_slots: list[int],
start_block_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare inputs for a Ping-Pong prefill chunk."""
"""Prepare inputs for a chunked offload prefill chunk."""
# Input tokens for this chunk
input_ids = seq[chunk_start:chunk_end]
positions = list(range(chunk_start, chunk_end))
@@ -768,13 +768,13 @@ class ModelRunner:
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
offload_engine=self.kvcache_manager,
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
)
return input_ids, positions
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with three-region GPU buffer.
@@ -809,7 +809,7 @@ class ModelRunner:
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager,
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
decode_pos_in_block=pos_in_block,
decode_start_pos_in_block=decode_start_pos,