[feat] Added num_gpu_blocks limit gpu blocks.

This commit is contained in:
Zijie Tian
2025-12-10 20:17:42 +08:00
parent 01f19ee4a6
commit 0a247ccb1b
7 changed files with 150 additions and 9 deletions

View File

@@ -22,6 +22,7 @@ class Config:
cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache
offload_policy: str = "lru" # "lru", "fifo", or full class path
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
# Computed fields for offload (set in __post_init__ or by ModelRunner)
num_gpu_kvcache_blocks: int = -1

View File

@@ -109,9 +109,15 @@ class ModelRunner:
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
# Calculate GPU block count
num_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert num_gpu_blocks > 0
# Calculate max GPU blocks based on available memory
max_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert max_gpu_blocks > 0
# Determine final GPU blocks: user-specified or auto (max available)
if config.num_gpu_blocks > 0:
num_gpu_blocks = min(config.num_gpu_blocks, max_gpu_blocks)
else:
num_gpu_blocks = max_gpu_blocks
if config.enable_cpu_offload:
# Calculate CPU blocks based on cpu_memory_gb
@@ -300,7 +306,11 @@ class ModelRunner:
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
context = get_context()
# Use eager mode for: prefill, enforce_eager, large batch, or chunked attention
# Chunked attention requires dynamic KV loading that can't be captured in CUDA Graph
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
if use_eager:
return self.model.compute_logits(self.model(input_ids, positions))
else:
bs = input_ids.size(0)

View File

@@ -13,6 +13,9 @@ from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from nanovllm.kvcache.kernels import gathered_copy_kv
from nanovllm.utils.logger import get_logger
logger = get_logger("offload_engine")
@dataclass
@@ -216,6 +219,8 @@ class OffloadEngine:
stream = self._get_next_stream()
event = torch.cuda.Event()
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
with torch.cuda.stream(stream):
# K cache
self.k_cache_gpu[layer_id, gpu_block_id].copy_(
@@ -271,6 +276,8 @@ class OffloadEngine:
stream = self._get_next_stream()
event = torch.cuda.Event()
logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(stream):
# Wait for any compute using this block
stream.wait_stream(self.compute_stream)
@@ -329,6 +336,9 @@ class OffloadEngine:
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
with torch.cuda.stream(stream):
@@ -365,6 +375,9 @@ class OffloadEngine:
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
event = torch.cuda.Event()
@@ -398,6 +411,9 @@ class OffloadEngine:
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
with torch.cuda.stream(stream):

View File

@@ -245,5 +245,5 @@ class Attention(nn.Module):
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
return o_acc.squeeze(1)
# Output shape: [batch, 1, heads, dim] (same as normal decode)
return o_acc

View File

@@ -83,9 +83,9 @@ def _setup_logger() -> logging.Logger:
# Check if terminal supports colors
use_colors = hasattr(sys.stderr, "isatty") and sys.stderr.isatty()
# Format: [TIME] [LEVEL] [MODULE] message
# Format: [TIME] [LEVEL] [FILE:LINE] message
formatter = ColoredFormatter(
fmt="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
fmt="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
datefmt="%H:%M:%S",
use_colors=use_colors,
)