[refactor] Refactor current gpu and cpu block allocation strategy.

This commit is contained in:
Zijie Tian
2025-12-10 21:23:31 +08:00
parent 0a247ccb1b
commit 190df5f70d
7 changed files with 906 additions and 162 deletions

View File

@@ -10,8 +10,11 @@ from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model
from nanovllm.utils.logger import get_logger
from nanovllm.kvcache import create_kvcache_manager, KVCacheManager
logger = get_logger("model_runner")
class ModelRunner:
@@ -120,9 +123,11 @@ class ModelRunner:
num_gpu_blocks = max_gpu_blocks
if config.enable_cpu_offload:
# Calculate CPU blocks based on cpu_memory_gb
cpu_bytes = int(config.cpu_memory_gb * 1024**3)
num_cpu_blocks = cpu_bytes // block_bytes
# Ping-Pong设计CPU是主存储GPU是工作缓冲区
# CPU blocks = 支持max_model_len所需的全部blocks存储一个最大序列的完整KV
# GPU blocks = Ping-Pong工作缓冲区用户指定或自动
num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
config.num_gpu_kvcache_blocks = num_gpu_blocks
config.num_cpu_kvcache_blocks = num_cpu_blocks
# For backward compatibility
@@ -143,6 +148,27 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Log KV cache allocation info
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
total_memory_mb = gpu_memory_mb + cpu_memory_mb
if config.enable_cpu_offload:
logger.info(
f"KV Cache allocated (Ping-Pong mode): "
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
f"Total={total_memory_mb:.1f}MB, "
f"block_size={self.block_size}, "
f"ping_size={config.num_gpu_kvcache_blocks // 2}"
)
else:
logger.info(
f"KV Cache allocated: "
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
f"block_size={self.block_size}"
)
# Bind layer caches to attention modules and set layer_id
layer_id = 0
for module in self.model.modules():
@@ -328,7 +354,16 @@ class ModelRunner:
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
# Check if chunked prefill is needed
# Check if Ping-Pong mode should be used (all blocks on CPU)
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
use_pingpong = self._should_use_pingpong(seqs, is_prefill)
if use_pingpong:
if is_prefill:
return self.run_pingpong_prefill(seqs)
else:
return self.run_pingpong_decode(seqs)
# Check if chunked prefill is needed (legacy path)
if is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
@@ -338,7 +373,7 @@ class ModelRunner:
if needs_chunked:
return self.run_chunked_prefill(seqs)
# Check if chunked decode is needed
# Check if chunked decode is needed (legacy path)
if not is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
@@ -355,6 +390,36 @@ class ModelRunner:
reset_context()
return token_ids
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""
Check if Ping-Pong mode should be used.
Use Ping-Pong when:
- CPU offload is enabled
- There are blocks on CPU (either allocated there or offloaded)
- Sequence exceeds GPU capacity
"""
if not hasattr(self.kvcache_manager, 'offload_engine'):
return False
for seq in seqs:
if not seq.block_table:
continue # Skip warmup sequences
# Check if any blocks are on CPU
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
if cpu_blocks:
# Has CPU blocks - use Ping-Pong
return True
# Check if sequence needs more blocks than GPU can hold
ping_size = self.kvcache_manager.offload_engine.ping_size
if seq.num_blocks > ping_size:
# Needs chunked processing
return True
return False
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill in chunks when sequences exceed GPU capacity.
@@ -543,6 +608,210 @@ class ModelRunner:
return input_ids, positions
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with Ping-Pong dual buffer (CPU is primary storage).
Flow:
1. All blocks are allocated to CPU (primary storage)
2. Process tokens in chunks using Ping/Pong GPU buffers
3. After each chunk, offload from GPU to CPU
4. Alternate between Ping and Pong buffers
"""
import sys
assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
ping_size = offload_engine.ping_size
tokens_per_chunk = ping_size * self.block_size
total_tokens = len(seq)
print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, "
f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens",
file=sys.stderr)
current_buffer = "ping"
chunk_num = 0
logits = None
processed_tokens = 0
# Get CPU block table for offload targets
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
while processed_tokens < total_tokens:
chunk_num += 1
chunk_start = processed_tokens
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
chunk_tokens = chunk_end - chunk_start
# Calculate which CPU blocks this chunk covers
start_block_idx = chunk_start // self.block_size
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
num_blocks = end_block_idx - start_block_idx
print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}",
file=sys.stderr)
# Get GPU slots for this chunk (Ping or Pong buffer)
if current_buffer == "ping":
gpu_slots = offload_engine.ping_slots[:num_blocks]
else:
gpu_slots = offload_engine.pong_slots[:num_blocks]
# Prepare inputs
input_ids, positions = self._prepare_pingpong_chunk(
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
)
if input_ids.numel() == 0:
break
# Run model forward
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
# Mark blocks as prefilled
for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))):
logical_id = seq.block_table[i]
self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk from GPU to CPU (async)
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks)
# Switch buffer for next chunk
if current_buffer == "ping":
offload_engine.wait_ping_offload_done()
current_buffer = "pong"
else:
offload_engine.wait_pong_offload_done()
current_buffer = "ping"
processed_tokens = chunk_end
# Wait for all offloads to complete
offload_engine.wait_all_offload_done()
print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
# Sample from last logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
else:
token_ids = [0] if self.rank == 0 else None
return token_ids
def _prepare_pingpong_chunk(
self,
seq: Sequence,
chunk_start: int,
chunk_end: int,
gpu_slots: list[int],
start_block_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare inputs for a Ping-Pong prefill chunk."""
# Input tokens for this chunk
input_ids = seq[chunk_start:chunk_end]
positions = list(range(chunk_start, chunk_end))
# Create slot mapping pointing to GPU slots
slot_mapping = []
num_tokens = chunk_end - chunk_start
token_idx = 0
for i, gpu_slot in enumerate(gpu_slots):
block_idx = start_block_idx + i
block_start = block_idx * self.block_size
block_end = min(block_start + self.block_size, len(seq))
# How many tokens in this block for this chunk
overlap_start = max(chunk_start, block_start)
overlap_end = min(chunk_end, block_end)
for pos in range(overlap_start, overlap_end):
pos_in_block = pos % self.block_size
slot = gpu_slot * self.block_size + pos_in_block
slot_mapping.append(slot)
# Convert to tensors
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked prefill
seqlen = num_tokens
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=seqlen,
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
offload_engine=self.kvcache_manager,
chunked_seq=seq,
)
return input_ids, positions
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with Ping-Pong dual buffer.
All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention.
New token's KV is written to GPU then offloaded to CPU.
"""
assert len(seqs) == 1, "Ping-Pong decode only supports single sequence"
seq = seqs[0]
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# Get write slot for new KV (will use last slot of the buffer used for final chunk)
write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq)
# Calculate position in block for slot mapping
last_block_idx = seq.num_blocks - 1
pos_in_block = (len(seq) - 1) % self.block_size
slot = write_slot * self.block_size + pos_in_block
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked decode
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager,
chunked_seq=seq,
)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# Offload new KV from write_slot to CPU
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0:
self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block)
self.kvcache_manager.offload_engine.wait_all_offload_done()
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
@torch.inference_mode()
def capture_cudagraph(self):
config = self.config