[refactor] Refactor current gpu and cpu block allocation strategy.
This commit is contained in:
@@ -10,8 +10,11 @@ from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.layers.sampler import Sampler
|
||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||
from nanovllm.utils.loader import load_model
|
||||
from nanovllm.utils.logger import get_logger
|
||||
from nanovllm.kvcache import create_kvcache_manager, KVCacheManager
|
||||
|
||||
logger = get_logger("model_runner")
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
@@ -120,9 +123,11 @@ class ModelRunner:
|
||||
num_gpu_blocks = max_gpu_blocks
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
# Calculate CPU blocks based on cpu_memory_gb
|
||||
cpu_bytes = int(config.cpu_memory_gb * 1024**3)
|
||||
num_cpu_blocks = cpu_bytes // block_bytes
|
||||
# Ping-Pong设计:CPU是主存储,GPU是工作缓冲区
|
||||
# CPU blocks = 支持max_model_len所需的全部blocks(存储一个最大序列的完整KV)
|
||||
# GPU blocks = Ping-Pong工作缓冲区(用户指定或自动)
|
||||
num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
||||
|
||||
config.num_gpu_kvcache_blocks = num_gpu_blocks
|
||||
config.num_cpu_kvcache_blocks = num_cpu_blocks
|
||||
# For backward compatibility
|
||||
@@ -143,6 +148,27 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Log KV cache allocation info
|
||||
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
total_memory_mb = gpu_memory_mb + cpu_memory_mb
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
logger.info(
|
||||
f"KV Cache allocated (Ping-Pong mode): "
|
||||
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
|
||||
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
|
||||
f"Total={total_memory_mb:.1f}MB, "
|
||||
f"block_size={self.block_size}, "
|
||||
f"ping_size={config.num_gpu_kvcache_blocks // 2}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"KV Cache allocated: "
|
||||
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
|
||||
f"block_size={self.block_size}"
|
||||
)
|
||||
|
||||
# Bind layer caches to attention modules and set layer_id
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
@@ -328,7 +354,16 @@ class ModelRunner:
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
# Check if chunked prefill is needed
|
||||
# Check if Ping-Pong mode should be used (all blocks on CPU)
|
||||
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
|
||||
use_pingpong = self._should_use_pingpong(seqs, is_prefill)
|
||||
if use_pingpong:
|
||||
if is_prefill:
|
||||
return self.run_pingpong_prefill(seqs)
|
||||
else:
|
||||
return self.run_pingpong_decode(seqs)
|
||||
|
||||
# Check if chunked prefill is needed (legacy path)
|
||||
if is_prefill and hasattr(self, 'kvcache_manager'):
|
||||
needs_chunked = any(
|
||||
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
|
||||
@@ -338,7 +373,7 @@ class ModelRunner:
|
||||
if needs_chunked:
|
||||
return self.run_chunked_prefill(seqs)
|
||||
|
||||
# Check if chunked decode is needed
|
||||
# Check if chunked decode is needed (legacy path)
|
||||
if not is_prefill and hasattr(self, 'kvcache_manager'):
|
||||
needs_chunked = any(
|
||||
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
|
||||
@@ -355,6 +390,36 @@ class ModelRunner:
|
||||
reset_context()
|
||||
return token_ids
|
||||
|
||||
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
|
||||
"""
|
||||
Check if Ping-Pong mode should be used.
|
||||
|
||||
Use Ping-Pong when:
|
||||
- CPU offload is enabled
|
||||
- There are blocks on CPU (either allocated there or offloaded)
|
||||
- Sequence exceeds GPU capacity
|
||||
"""
|
||||
if not hasattr(self.kvcache_manager, 'offload_engine'):
|
||||
return False
|
||||
|
||||
for seq in seqs:
|
||||
if not seq.block_table:
|
||||
continue # Skip warmup sequences
|
||||
|
||||
# Check if any blocks are on CPU
|
||||
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if cpu_blocks:
|
||||
# Has CPU blocks - use Ping-Pong
|
||||
return True
|
||||
|
||||
# Check if sequence needs more blocks than GPU can hold
|
||||
ping_size = self.kvcache_manager.offload_engine.ping_size
|
||||
if seq.num_blocks > ping_size:
|
||||
# Needs chunked processing
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill in chunks when sequences exceed GPU capacity.
|
||||
@@ -543,6 +608,210 @@ class ModelRunner:
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill with Ping-Pong dual buffer (CPU is primary storage).
|
||||
|
||||
Flow:
|
||||
1. All blocks are allocated to CPU (primary storage)
|
||||
2. Process tokens in chunks using Ping/Pong GPU buffers
|
||||
3. After each chunk, offload from GPU to CPU
|
||||
4. Alternate between Ping and Pong buffers
|
||||
"""
|
||||
import sys
|
||||
|
||||
assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
ping_size = offload_engine.ping_size
|
||||
tokens_per_chunk = ping_size * self.block_size
|
||||
|
||||
total_tokens = len(seq)
|
||||
print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, "
|
||||
f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||
file=sys.stderr)
|
||||
|
||||
current_buffer = "ping"
|
||||
chunk_num = 0
|
||||
logits = None
|
||||
processed_tokens = 0
|
||||
|
||||
# Get CPU block table for offload targets
|
||||
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||
|
||||
while processed_tokens < total_tokens:
|
||||
chunk_num += 1
|
||||
chunk_start = processed_tokens
|
||||
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
|
||||
chunk_tokens = chunk_end - chunk_start
|
||||
|
||||
# Calculate which CPU blocks this chunk covers
|
||||
start_block_idx = chunk_start // self.block_size
|
||||
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
|
||||
num_blocks = end_block_idx - start_block_idx
|
||||
|
||||
print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}",
|
||||
file=sys.stderr)
|
||||
|
||||
# Get GPU slots for this chunk (Ping or Pong buffer)
|
||||
if current_buffer == "ping":
|
||||
gpu_slots = offload_engine.ping_slots[:num_blocks]
|
||||
else:
|
||||
gpu_slots = offload_engine.pong_slots[:num_blocks]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, positions = self._prepare_pingpong_chunk(
|
||||
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
|
||||
)
|
||||
|
||||
if input_ids.numel() == 0:
|
||||
break
|
||||
|
||||
# Run model forward
|
||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||
reset_context()
|
||||
|
||||
# Mark blocks as prefilled
|
||||
for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))):
|
||||
logical_id = seq.block_table[i]
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# Offload this chunk from GPU to CPU (async)
|
||||
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
|
||||
offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks)
|
||||
|
||||
# Switch buffer for next chunk
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping_offload_done()
|
||||
current_buffer = "pong"
|
||||
else:
|
||||
offload_engine.wait_pong_offload_done()
|
||||
current_buffer = "ping"
|
||||
|
||||
processed_tokens = chunk_end
|
||||
|
||||
# Wait for all offloads to complete
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||
|
||||
# Sample from last logits
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
if logits is not None:
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
else:
|
||||
token_ids = [0] if self.rank == 0 else None
|
||||
|
||||
return token_ids
|
||||
|
||||
def _prepare_pingpong_chunk(
|
||||
self,
|
||||
seq: Sequence,
|
||||
chunk_start: int,
|
||||
chunk_end: int,
|
||||
gpu_slots: list[int],
|
||||
start_block_idx: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare inputs for a Ping-Pong prefill chunk."""
|
||||
# Input tokens for this chunk
|
||||
input_ids = seq[chunk_start:chunk_end]
|
||||
positions = list(range(chunk_start, chunk_end))
|
||||
|
||||
# Create slot mapping pointing to GPU slots
|
||||
slot_mapping = []
|
||||
num_tokens = chunk_end - chunk_start
|
||||
|
||||
token_idx = 0
|
||||
for i, gpu_slot in enumerate(gpu_slots):
|
||||
block_idx = start_block_idx + i
|
||||
block_start = block_idx * self.block_size
|
||||
block_end = min(block_start + self.block_size, len(seq))
|
||||
|
||||
# How many tokens in this block for this chunk
|
||||
overlap_start = max(chunk_start, block_start)
|
||||
overlap_end = min(chunk_end, block_end)
|
||||
|
||||
for pos in range(overlap_start, overlap_end):
|
||||
pos_in_block = pos % self.block_size
|
||||
slot = gpu_slot * self.block_size + pos_in_block
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# Convert to tensors
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Set up context for chunked prefill
|
||||
seqlen = num_tokens
|
||||
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=seqlen,
|
||||
max_seqlen_k=seqlen,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=True,
|
||||
offload_engine=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with Ping-Pong dual buffer.
|
||||
|
||||
All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention.
|
||||
New token's KV is written to GPU then offloaded to CPU.
|
||||
"""
|
||||
assert len(seqs) == 1, "Ping-Pong decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Get write slot for new KV (will use last slot of the buffer used for final chunk)
|
||||
write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq)
|
||||
|
||||
# Calculate position in block for slot mapping
|
||||
last_block_idx = seq.num_blocks - 1
|
||||
pos_in_block = (len(seq) - 1) % self.block_size
|
||||
slot = write_slot * self.block_size + pos_in_block
|
||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Set up context for chunked decode
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_len,
|
||||
is_chunked_prefill=True, # Use chunked attention path
|
||||
offload_engine=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
)
|
||||
|
||||
# Run model forward pass
|
||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||
reset_context()
|
||||
|
||||
# Offload new KV from write_slot to CPU
|
||||
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
||||
if last_cpu_block >= 0:
|
||||
self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block)
|
||||
self.kvcache_manager.offload_engine.wait_all_offload_done()
|
||||
|
||||
# Sample
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
|
||||
return token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_cudagraph(self):
|
||||
config = self.config
|
||||
|
||||
Reference in New Issue
Block a user