diff --git a/bench_offload.py b/bench_offload.py index 17c138e..98e0745 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -36,7 +36,7 @@ def main(): path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") llm = LLM( path, - enforce_eager=True, + enforce_eager=False, max_model_len=128 * 1024, max_num_batched_tokens=128 * 1024, enable_cpu_offload=True, diff --git a/nanovllm/config.py b/nanovllm/config.py index 9e01048..124dc1e 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -22,6 +22,7 @@ class Config: cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache offload_policy: str = "lru" # "lru", "fifo", or full class path num_transfer_streams: int = 4 # Number of CUDA streams for async transfers + num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available) # Computed fields for offload (set in __post_init__ or by ModelRunner) num_gpu_kvcache_blocks: int = -1 diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 30371ea..ef64cfc 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -109,9 +109,15 @@ class ModelRunner: head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads) block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize - # Calculate GPU block count - num_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes - assert num_gpu_blocks > 0 + # Calculate max GPU blocks based on available memory + max_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes + assert max_gpu_blocks > 0 + + # Determine final GPU blocks: user-specified or auto (max available) + if config.num_gpu_blocks > 0: + num_gpu_blocks = min(config.num_gpu_blocks, max_gpu_blocks) + else: + num_gpu_blocks = max_gpu_blocks if config.enable_cpu_offload: # Calculate CPU blocks based on cpu_memory_gb @@ -300,7 +306,11 @@ class ModelRunner: @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): - if is_prefill or self.enforce_eager or input_ids.size(0) > 512: + context = get_context() + # Use eager mode for: prefill, enforce_eager, large batch, or chunked attention + # Chunked attention requires dynamic KV loading that can't be captured in CUDA Graph + use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill + if use_eager: return self.model.compute_logits(self.model(input_ids, positions)) else: bs = input_ids.size(0) diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 8f8c2d5..152d64f 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -13,6 +13,9 @@ from typing import Dict, List, Tuple, Optional from dataclasses import dataclass from nanovllm.kvcache.kernels import gathered_copy_kv +from nanovllm.utils.logger import get_logger + +logger = get_logger("offload_engine") @dataclass @@ -216,6 +219,8 @@ class OffloadEngine: stream = self._get_next_stream() event = torch.cuda.Event() + logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]") + with torch.cuda.stream(stream): # K cache self.k_cache_gpu[layer_id, gpu_block_id].copy_( @@ -271,6 +276,8 @@ class OffloadEngine: stream = self._get_next_stream() event = torch.cuda.Event() + logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]") + with torch.cuda.stream(stream): # Wait for any compute using this block stream.wait_stream(self.compute_stream) @@ -329,6 +336,9 @@ class OffloadEngine: """ assert len(cpu_block_ids) == len(gpu_slot_ids) + if cpu_block_ids: + logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}") + stream = self._get_next_stream() with torch.cuda.stream(stream): @@ -365,6 +375,9 @@ class OffloadEngine: """ assert len(cpu_block_ids) == len(gpu_slot_ids) + if cpu_block_ids: + logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}") + stream = self._get_next_stream() event = torch.cuda.Event() @@ -398,6 +411,9 @@ class OffloadEngine: """ assert len(cpu_block_ids) == len(gpu_slot_ids) + if cpu_block_ids: + logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}") + stream = self._get_next_stream() with torch.cuda.stream(stream): diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 1816181..bb00fb7 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -245,5 +245,5 @@ class Attention(nn.Module): if o_acc is None: raise RuntimeError("Chunked decode attention failed: no KV available") - # Output shape: [batch, 1, heads, dim] -> [batch, heads, dim] - return o_acc.squeeze(1) + # Output shape: [batch, 1, heads, dim] (same as normal decode) + return o_acc diff --git a/nanovllm/utils/logger.py b/nanovllm/utils/logger.py index 15700c1..3f13c27 100644 --- a/nanovllm/utils/logger.py +++ b/nanovllm/utils/logger.py @@ -83,9 +83,9 @@ def _setup_logger() -> logging.Logger: # Check if terminal supports colors use_colors = hasattr(sys.stderr, "isatty") and sys.stderr.isatty() - # Format: [TIME] [LEVEL] [MODULE] message + # Format: [TIME] [LEVEL] [FILE:LINE] message formatter = ColoredFormatter( - fmt="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", + fmt="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", datefmt="%H:%M:%S", use_colors=use_colors, ) diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py new file mode 100644 index 0000000..e9922c1 --- /dev/null +++ b/tests/test_chunked_attention.py @@ -0,0 +1,114 @@ +""" +Test chunked attention with small num_gpu_blocks to trigger CPU offload. + +For 8K tokens with block_size=256: +- Total blocks needed: 8192 / 256 = 32 blocks +- With num_gpu_blocks=10, 22 blocks go to CPU -> triggers chunked attention +""" +import os +import sys + +# Enable debug logging before importing nanovllm +os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" + +from nanovllm import LLM, SamplingParams + + +def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=16): + """Test chunked prefill with limited GPU blocks.""" + path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") + + total_blocks = (input_len + 255) // 256 + print(f"=" * 60) + print(f"Chunked Prefill Test") + print(f"=" * 60) + print(f" input_len: {input_len} tokens") + print(f" total_blocks: {total_blocks}") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" blocks_on_cpu: {max(0, total_blocks - num_gpu_blocks)}") + print() + + llm = LLM( + path, + enforce_eager=False, + max_model_len=16 * 1024, # 16K is enough for 8K test + max_num_batched_tokens=16 * 1024, + enable_cpu_offload=True, + cpu_memory_gb=4.0, + num_gpu_blocks=num_gpu_blocks, + ) + + print(f"LLM initialized:") + print(f" num_gpu_kvcache_blocks: {llm.model_runner.config.num_gpu_kvcache_blocks}") + print(f" num_cpu_kvcache_blocks: {llm.model_runner.config.num_cpu_kvcache_blocks}") + print() + + # Create prompt with approximate token count + prompt = "Hello " * (input_len // 2) + + print(f"Running generation...") + outputs = llm.generate( + [prompt], + SamplingParams(temperature=0.6, max_tokens=output_len), + use_tqdm=True, + ) + + print() + print(f"Output tokens: {len(outputs[0]['token_ids'])}") + print(f"Output text (first 100 chars): {outputs[0]['text'][:100]}") + print() + return outputs + + +def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=64): + """Test chunked decode with limited GPU blocks.""" + path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") + + total_blocks = (input_len + 255) // 256 + print(f"=" * 60) + print(f"Chunked Decode Test") + print(f"=" * 60) + print(f" input_len: {input_len} tokens") + print(f" output_len: {output_len} tokens") + print(f" total_blocks: {total_blocks}") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print() + + llm = LLM( + path, + enforce_eager=False, + max_model_len=16 * 1024, + max_num_batched_tokens=16 * 1024, + enable_cpu_offload=True, + cpu_memory_gb=4.0, + num_gpu_blocks=num_gpu_blocks, + ) + + print(f"LLM initialized:") + print(f" num_gpu_kvcache_blocks: {llm.model_runner.config.num_gpu_kvcache_blocks}") + print(f" num_cpu_kvcache_blocks: {llm.model_runner.config.num_cpu_kvcache_blocks}") + print() + + prompt = "Hello " * (input_len // 2) + + print(f"Running generation...") + outputs = llm.generate( + [prompt], + SamplingParams(temperature=0.6, max_tokens=output_len), + use_tqdm=True, + ) + + print() + print(f"Output tokens: {len(outputs[0]['token_ids'])}") + print(f"Output text (first 100 chars): {outputs[0]['text'][:100]}") + print() + return outputs + + +if __name__ == "__main__": + # Parse arguments + num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10 + input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 8192 + output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 32 + + test_chunked_prefill(num_gpu_blocks, input_len, output_len) \ No newline at end of file