[feat] Added num_gpu_blocks limit gpu blocks.
This commit is contained in:
@@ -36,7 +36,7 @@ def main():
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
llm = LLM(
|
||||
path,
|
||||
enforce_eager=True,
|
||||
enforce_eager=False,
|
||||
max_model_len=128 * 1024,
|
||||
max_num_batched_tokens=128 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
|
||||
@@ -22,6 +22,7 @@ class Config:
|
||||
cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache
|
||||
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
||||
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
||||
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
||||
|
||||
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
||||
num_gpu_kvcache_blocks: int = -1
|
||||
|
||||
@@ -109,9 +109,15 @@ class ModelRunner:
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
||||
|
||||
# Calculate GPU block count
|
||||
num_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert num_gpu_blocks > 0
|
||||
# Calculate max GPU blocks based on available memory
|
||||
max_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert max_gpu_blocks > 0
|
||||
|
||||
# Determine final GPU blocks: user-specified or auto (max available)
|
||||
if config.num_gpu_blocks > 0:
|
||||
num_gpu_blocks = min(config.num_gpu_blocks, max_gpu_blocks)
|
||||
else:
|
||||
num_gpu_blocks = max_gpu_blocks
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
# Calculate CPU blocks based on cpu_memory_gb
|
||||
@@ -300,7 +306,11 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
||||
context = get_context()
|
||||
# Use eager mode for: prefill, enforce_eager, large batch, or chunked attention
|
||||
# Chunked attention requires dynamic KV loading that can't be captured in CUDA Graph
|
||||
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
|
||||
if use_eager:
|
||||
return self.model.compute_logits(self.model(input_ids, positions))
|
||||
else:
|
||||
bs = input_ids.size(0)
|
||||
|
||||
@@ -13,6 +13,9 @@ from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanovllm.kvcache.kernels import gathered_copy_kv
|
||||
from nanovllm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("offload_engine")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -216,6 +219,8 @@ class OffloadEngine:
|
||||
stream = self._get_next_stream()
|
||||
event = torch.cuda.Event()
|
||||
|
||||
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
# K cache
|
||||
self.k_cache_gpu[layer_id, gpu_block_id].copy_(
|
||||
@@ -271,6 +276,8 @@ class OffloadEngine:
|
||||
stream = self._get_next_stream()
|
||||
event = torch.cuda.Event()
|
||||
|
||||
logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
# Wait for any compute using this block
|
||||
stream.wait_stream(self.compute_stream)
|
||||
@@ -329,6 +336,9 @@ class OffloadEngine:
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
if cpu_block_ids:
|
||||
logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
||||
|
||||
stream = self._get_next_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
@@ -365,6 +375,9 @@ class OffloadEngine:
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
if cpu_block_ids:
|
||||
logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
||||
|
||||
stream = self._get_next_stream()
|
||||
event = torch.cuda.Event()
|
||||
|
||||
@@ -398,6 +411,9 @@ class OffloadEngine:
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
if cpu_block_ids:
|
||||
logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
||||
|
||||
stream = self._get_next_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
|
||||
@@ -245,5 +245,5 @@ class Attention(nn.Module):
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
|
||||
return o_acc.squeeze(1)
|
||||
# Output shape: [batch, 1, heads, dim] (same as normal decode)
|
||||
return o_acc
|
||||
|
||||
@@ -83,9 +83,9 @@ def _setup_logger() -> logging.Logger:
|
||||
# Check if terminal supports colors
|
||||
use_colors = hasattr(sys.stderr, "isatty") and sys.stderr.isatty()
|
||||
|
||||
# Format: [TIME] [LEVEL] [MODULE] message
|
||||
# Format: [TIME] [LEVEL] [FILE:LINE] message
|
||||
formatter = ColoredFormatter(
|
||||
fmt="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
|
||||
fmt="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
use_colors=use_colors,
|
||||
)
|
||||
|
||||
114
tests/test_chunked_attention.py
Normal file
114
tests/test_chunked_attention.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Test chunked attention with small num_gpu_blocks to trigger CPU offload.
|
||||
|
||||
For 8K tokens with block_size=256:
|
||||
- Total blocks needed: 8192 / 256 = 32 blocks
|
||||
- With num_gpu_blocks=10, 22 blocks go to CPU -> triggers chunked attention
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Enable debug logging before importing nanovllm
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||
|
||||
from nanovllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=16):
|
||||
"""Test chunked prefill with limited GPU blocks."""
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
|
||||
total_blocks = (input_len + 255) // 256
|
||||
print(f"=" * 60)
|
||||
print(f"Chunked Prefill Test")
|
||||
print(f"=" * 60)
|
||||
print(f" input_len: {input_len} tokens")
|
||||
print(f" total_blocks: {total_blocks}")
|
||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
||||
print(f" blocks_on_cpu: {max(0, total_blocks - num_gpu_blocks)}")
|
||||
print()
|
||||
|
||||
llm = LLM(
|
||||
path,
|
||||
enforce_eager=False,
|
||||
max_model_len=16 * 1024, # 16K is enough for 8K test
|
||||
max_num_batched_tokens=16 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
cpu_memory_gb=4.0,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
print(f"LLM initialized:")
|
||||
print(f" num_gpu_kvcache_blocks: {llm.model_runner.config.num_gpu_kvcache_blocks}")
|
||||
print(f" num_cpu_kvcache_blocks: {llm.model_runner.config.num_cpu_kvcache_blocks}")
|
||||
print()
|
||||
|
||||
# Create prompt with approximate token count
|
||||
prompt = "Hello " * (input_len // 2)
|
||||
|
||||
print(f"Running generation...")
|
||||
outputs = llm.generate(
|
||||
[prompt],
|
||||
SamplingParams(temperature=0.6, max_tokens=output_len),
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Output tokens: {len(outputs[0]['token_ids'])}")
|
||||
print(f"Output text (first 100 chars): {outputs[0]['text'][:100]}")
|
||||
print()
|
||||
return outputs
|
||||
|
||||
|
||||
def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=64):
|
||||
"""Test chunked decode with limited GPU blocks."""
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
|
||||
total_blocks = (input_len + 255) // 256
|
||||
print(f"=" * 60)
|
||||
print(f"Chunked Decode Test")
|
||||
print(f"=" * 60)
|
||||
print(f" input_len: {input_len} tokens")
|
||||
print(f" output_len: {output_len} tokens")
|
||||
print(f" total_blocks: {total_blocks}")
|
||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
||||
print()
|
||||
|
||||
llm = LLM(
|
||||
path,
|
||||
enforce_eager=False,
|
||||
max_model_len=16 * 1024,
|
||||
max_num_batched_tokens=16 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
cpu_memory_gb=4.0,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
print(f"LLM initialized:")
|
||||
print(f" num_gpu_kvcache_blocks: {llm.model_runner.config.num_gpu_kvcache_blocks}")
|
||||
print(f" num_cpu_kvcache_blocks: {llm.model_runner.config.num_cpu_kvcache_blocks}")
|
||||
print()
|
||||
|
||||
prompt = "Hello " * (input_len // 2)
|
||||
|
||||
print(f"Running generation...")
|
||||
outputs = llm.generate(
|
||||
[prompt],
|
||||
SamplingParams(temperature=0.6, max_tokens=output_len),
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Output tokens: {len(outputs[0]['token_ids'])}")
|
||||
print(f"Output text (first 100 chars): {outputs[0]['text'][:100]}")
|
||||
print()
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse arguments
|
||||
num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10
|
||||
input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 8192
|
||||
output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 32
|
||||
|
||||
test_chunked_prefill(num_gpu_blocks, input_len, output_len)
|
||||
Reference in New Issue
Block a user