diff --git a/.claude/rules/code-analysis.md b/.claude/rules/code-analysis.md new file mode 100644 index 0000000..53e038c --- /dev/null +++ b/.claude/rules/code-analysis.md @@ -0,0 +1,38 @@ +# Code Analysis + +## Use cclsp MCP for Code Navigation + +When analyzing code, understanding call chains, or exploring the codebase, **prefer using the cclsp MCP tools** over grep/glob-based searches: + +### Available cclsp Tools + +| Tool | Purpose | +|------|---------| +| `mcp__cclsp__find_definition` | Jump to symbol definition | +| `mcp__cclsp__find_references` | Find all usages of a symbol | +| `mcp__cclsp__rename_symbol` | Rename a symbol across the codebase | +| `mcp__cclsp__get_diagnostics` | Get LSP diagnostics (errors, warnings) | +| `mcp__cclsp__restart_server` | Restart the LSP server if needed | + +### When to Use cclsp + +1. **Understanding call chains**: Use `find_references` to trace how functions are called +2. **Finding implementations**: Use `find_definition` to jump to actual code +3. **Refactoring**: Use `rename_symbol` for safe cross-file renames +4. **Code quality**: Use `get_diagnostics` to check for issues + +### Example Workflow + +``` +1. User asks: "How does the prefill flow work?" +2. Use find_definition to locate key entry points (e.g., run_chunked_offload_prefill) +3. Use find_references to trace the call chain through the codebase +4. Read relevant code sections to understand the implementation +``` + +### Benefits over grep/glob + +- **Semantic understanding**: cclsp understands code structure, not just text patterns +- **Accurate references**: Finds actual usages, not just text matches +- **Cross-file navigation**: Follows imports and definitions across modules +- **Type-aware**: Understands Python types and class hierarchies diff --git a/.claude/rules/testing.md b/.claude/rules/testing.md index 85c2999..aa32abc 100644 --- a/.claude/rules/testing.md +++ b/.claude/rules/testing.md @@ -1,20 +1,98 @@ # Testing -## Chunked Attention Test +## Test File Guidelines -```bash -CUDA_VISIBLE_DEVICES=4,5 python tests/test_chunked_attention.py 6 2048 64 2 -# Args: num_gpu_blocks input_len output_len num_prefetch_blocks +### Naming Convention + +- All test files must be named `test_*.py` +- Example: `test_offload_engine.py`, `test_ring_buffer.py` + +### Purpose + +Tests are **educational scripts** for understanding module behavior, NOT traditional unit tests: +- Focus on demonstrating how modules work +- Show the flow and interaction between components +- Help developers understand implementation details + +### Code Style + +1. **Script-based structure**: Write tests as executable scripts, not pytest-style functions +2. **Utility functions**: Extract reusable steps as helper functions at the top of the file +3. **Main flow as script**: The actual test/demonstration logic runs as top-level script code + +```python +# Example structure: + +import torch +from nanovllm.kvcache import SomeModule + +# ============================================================ +# Utility Functions +# ============================================================ + +def verify(tensor, expected, name): + actual = tensor.mean().item() + assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}" + +# ============================================================ +# Main Test Script +# ============================================================ + +# 1. Initialize +module = SomeModule(param=value) + +# 2. Test feature X +result = module.do_something() +assert result == expected_value + +# 3. Test feature Y +... + +print("test_xxx: PASSED") ``` -## CPU Offload Testing +### Comments + +- Keep comments concise and clear +- Only add comments where the code isn't self-explanatory +- Use section headers (`# === Section ===`) to organize logical blocks + +### Output + +- **Minimize print statements** - the code should be self-explanatory +- Only print a final "PASSED" message at the end +- Use `assert` for verification instead of printing results +- If the user needs explanation, they will ask + +## Running Tests ```bash -# Basic test with limited GPU blocks to trigger offload -CUDA_VISIBLE_DEVICES=4,5 python tests/test_chunked_attention.py 6 2048 64 2 +# Run a specific test +python tests/test_offload_engine.py -# Verify consistency (run multiple times, output should be identical) -for i in 1 2 3; do - CUDA_VISIBLE_DEVICES=4,5 python tests/test_chunked_attention.py 6 2048 32 2 2>&1 | tail -3 -done +# Run with specific GPU +CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py +``` + +## Benchmarks + +```bash +# Standard GPU benchmark +python bench.py + +# CPU offload benchmark +python bench_offload.py + +# vLLM comparison benchmark +python bench_vllm.py +``` + +## Quick Verification + +```bash +# Import test +python -c "from nanovllm import LLM" + +# Run offload benchmark (tests CPU-primary ring buffer mode) +python bench_offload.py ``` diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index a7f8ba7..c0946e0 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -125,6 +125,11 @@ class OffloadEngine: dtype=torch.int64, device="cuda" ) + # Log memory allocation + gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024) + cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024) + logger.info(f" GPU memory: {gpu_mem_mb:.1f} MB, CPU memory: {cpu_mem_mb:.1f} MB") + # ========== Transfer streams for async operations ========== self.transfer_streams = [torch.cuda.Stream() for _ in range(num_streams)] self.compute_stream = torch.cuda.current_stream() diff --git a/tests/__init__.py b/tests/__init__.py index aa26d43..11754ee 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -"""Test suite for nano-vllm KV cache offload.""" +# Tests module diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py deleted file mode 100644 index 52ca8e6..0000000 --- a/tests/test_chunked_attention.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Test chunked attention with small num_gpu_blocks to trigger CPU offload. - -For 8K tokens with block_size=256: -- Total blocks needed: 8192 / 256 = 32 blocks -- With num_gpu_blocks=10, 22 blocks go to CPU -> triggers chunked attention -""" -import os -import sys - -# Enable debug logging before importing nanovllm -os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" - -from nanovllm import LLM, SamplingParams - - -def create_long_context_prompt(target_tokens: int) -> str: - """ - Create a meaningful long context prompt with a question at the end. - The answer depends on information scattered throughout the context. - """ - # Key facts to embed in the context - facts = [ - "The capital of France is Paris.", - "The Eiffel Tower was built in 1889.", - "Python was created by Guido van Rossum.", - "The speed of light is approximately 299,792 kilometers per second.", - "Mount Everest is 8,848 meters tall.", - ] - - # Padding text to reach target length - padding_paragraph = """ -This is additional context information that helps extend the length of the prompt. -Machine learning has revolutionized many fields including computer vision, natural language processing, and robotics. -Deep neural networks can learn complex patterns from large amounts of data. -The transformer architecture has become the foundation of modern language models. -Attention mechanisms allow models to focus on relevant parts of the input. -""" - - # Build the prompt - prompt_parts = [] - - # Add instruction - prompt_parts.append("Please read the following information carefully and answer the question at the end.\n\n") - - # Add facts at different positions - current_tokens = 50 # approximate tokens so far - tokens_per_padding = 80 # approximate tokens per padding paragraph - fact_interval = target_tokens // (len(facts) + 1) - - fact_idx = 0 - while current_tokens < target_tokens - 100: - # Add padding - prompt_parts.append(padding_paragraph) - current_tokens += tokens_per_padding - - # Add a fact at intervals - if fact_idx < len(facts) and current_tokens > fact_interval * (fact_idx + 1): - prompt_parts.append(f"\n[Important Fact #{fact_idx + 1}]: {facts[fact_idx]}\n") - current_tokens += 20 - fact_idx += 1 - - # Add the question at the end - prompt_parts.append("\n\nQuestion: Based on the information above, what is the speed of light?\n\nAnswer:") - - return "".join(prompt_parts) - - -def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64, num_prefetch_blocks=2): - """Test chunked prefill with limited GPU blocks.""" - path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") - - print(f"=" * 60) - print(f"Chunked Prefill Test (Chunked Offload)") - print(f"=" * 60) - print(f" target_input_len: ~{input_len} tokens") - print(f" num_gpu_blocks: {num_gpu_blocks}") - print(f" num_prefetch_blocks: {num_prefetch_blocks}") - print() - - llm = LLM( - path, - enforce_eager=False, - max_model_len=128 * 1024, - max_num_batched_tokens=128 * 1024, - enable_cpu_offload=True, - num_gpu_blocks=num_gpu_blocks, - num_prefetch_blocks=num_prefetch_blocks, - ) - print() - - # Create meaningful prompt - prompt = create_long_context_prompt(input_len) - - print(f"Running generation...") - outputs = llm.generate( - [prompt], - SamplingParams(temperature=0.1, max_tokens=output_len), # low temperature for more deterministic output - use_tqdm=False, - ) - - print() - print(f"Output tokens: {len(outputs[0]['token_ids'])}") - print(f"Output text:\n{outputs[0]['text']}") - print() - return outputs - - -def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128, num_prefetch_blocks=2): - """Test chunked decode with limited GPU blocks.""" - path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") - - print(f"=" * 60) - print(f"Chunked Decode Test (Chunked Offload)") - print(f"=" * 60) - print(f" target_input_len: ~{input_len} tokens") - print(f" output_len: {output_len} tokens") - print(f" num_gpu_blocks: {num_gpu_blocks}") - print(f" num_prefetch_blocks: {num_prefetch_blocks}") - print() - - llm = LLM( - path, - enforce_eager=False, - max_model_len=128 * 1024, - max_num_batched_tokens=128 * 1024, - enable_cpu_offload=True, - num_gpu_blocks=num_gpu_blocks, - num_prefetch_blocks=num_prefetch_blocks, - ) - print() - - # Create meaningful prompt - prompt = create_long_context_prompt(input_len) - - print(f"Running generation...") - outputs = llm.generate( - [prompt], - SamplingParams(temperature=0.1, max_tokens=output_len), - use_tqdm=False, - ) - - print() - print(f"Output tokens: {len(outputs[0]['token_ids'])}") - print(f"Output text:\n{outputs[0]['text']}") - print() - return outputs - - -if __name__ == "__main__": - # Parse arguments: num_gpu_blocks input_len output_len [num_prefetch_blocks] - num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10 - input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048 - output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64 - num_prefetch_blocks = int(sys.argv[4]) if len(sys.argv) > 4 else 2 - - test_chunked_prefill(num_gpu_blocks, input_len, output_len, num_prefetch_blocks) \ No newline at end of file diff --git a/tests/test_kernels.py b/tests/test_kernels.py deleted file mode 100644 index af2dc59..0000000 --- a/tests/test_kernels.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Tests for Triton gathered copy kernels.""" - -import pytest -import torch - -from nanovllm.kvcache.kernels import gathered_copy, gathered_copy_kv - - -class TestGatheredCopy: - """Tests for gathered copy kernel.""" - - @pytest.fixture - def setup_tensors(self): - """Create test tensors.""" - torch.cuda.manual_seed(42) - num_src_blocks = 16 - num_dst_blocks = 8 - block_size = 256 - kv_dim = 64 - - src = torch.randn(num_src_blocks, block_size, kv_dim, - dtype=torch.float16, device="cuda") - dst = torch.zeros(num_dst_blocks, block_size, kv_dim, - dtype=torch.float16, device="cuda") - - # Indices: dst[i] = src[indices[i]] - indices = torch.randint(0, num_src_blocks, (num_dst_blocks,), - dtype=torch.int64, device="cuda") - - return src, dst, indices - - def test_basic_copy(self, setup_tensors): - """Test basic gathered copy.""" - src, dst, indices = setup_tensors - - gathered_copy(src, dst, indices) - - # Verify copy - for i in range(len(indices)): - src_idx = indices[i].item() - assert torch.allclose(dst[i], src[src_idx]), f"Mismatch at index {i}" - - def test_skip_negative_indices(self, setup_tensors): - """Test that negative indices are skipped.""" - src, dst, indices = setup_tensors - - # Set some indices to -1 - indices[2] = -1 - indices[5] = -1 - - # Fill dst with a known value - dst.fill_(999.0) - - gathered_copy(src, dst, indices) - - # Skipped slots should be unchanged - assert (dst[2] == 999.0).all() - assert (dst[5] == 999.0).all() - - # Non-skipped slots should be copied - for i in [0, 1, 3, 4, 6, 7]: - src_idx = indices[i].item() - assert torch.allclose(dst[i], src[src_idx]) - - def test_single_block(self): - """Test copying a single block.""" - src = torch.randn(4, 256, 64, dtype=torch.float16, device="cuda") - dst = torch.zeros(1, 256, 64, dtype=torch.float16, device="cuda") - indices = torch.tensor([2], dtype=torch.int64, device="cuda") - - gathered_copy(src, dst, indices) - - assert torch.allclose(dst[0], src[2]) - - -class TestGatheredCopyKV: - """Tests for gathered K/V cache copy kernel.""" - - @pytest.fixture - def setup_kv_tensors(self): - """Create K/V test tensors.""" - torch.cuda.manual_seed(42) - num_src_blocks = 16 - num_dst_blocks = 8 - block_size = 256 - num_kv_heads = 4 - head_dim = 64 - - k_src = torch.randn(num_src_blocks, block_size, num_kv_heads, head_dim, - dtype=torch.float16, device="cuda") - v_src = torch.randn(num_src_blocks, block_size, num_kv_heads, head_dim, - dtype=torch.float16, device="cuda") - k_dst = torch.zeros(num_dst_blocks, block_size, num_kv_heads, head_dim, - dtype=torch.float16, device="cuda") - v_dst = torch.zeros(num_dst_blocks, block_size, num_kv_heads, head_dim, - dtype=torch.float16, device="cuda") - - indices = torch.randint(0, num_src_blocks, (num_dst_blocks,), - dtype=torch.int64, device="cuda") - - return k_src, v_src, k_dst, v_dst, indices - - def test_kv_copy(self, setup_kv_tensors): - """Test K/V gathered copy.""" - k_src, v_src, k_dst, v_dst, indices = setup_kv_tensors - - gathered_copy_kv(k_src, v_src, k_dst, v_dst, indices) - - # Verify copy - for i in range(len(indices)): - src_idx = indices[i].item() - assert torch.allclose(k_dst[i], k_src[src_idx]), f"K mismatch at {i}" - assert torch.allclose(v_dst[i], v_src[src_idx]), f"V mismatch at {i}" - - def test_kv_skip_negative(self, setup_kv_tensors): - """Test that negative indices are skipped for K/V.""" - k_src, v_src, k_dst, v_dst, indices = setup_kv_tensors - - indices[0] = -1 - k_dst.fill_(999.0) - v_dst.fill_(999.0) - - gathered_copy_kv(k_src, v_src, k_dst, v_dst, indices) - - assert (k_dst[0] == 999.0).all() - assert (v_dst[0] == 999.0).all() - - -class TestPerformance: - """Performance benchmarks for gathered copy.""" - - @pytest.mark.parametrize("num_blocks", [8, 32, 128]) - def test_throughput(self, num_blocks): - """Benchmark copy throughput.""" - block_size = 256 - kv_dim = 64 - - src = torch.randn(num_blocks * 2, block_size, kv_dim, - dtype=torch.float16, device="cuda") - dst = torch.zeros(num_blocks, block_size, kv_dim, - dtype=torch.float16, device="cuda") - indices = torch.arange(num_blocks, dtype=torch.int64, device="cuda") - - # Warmup - for _ in range(10): - gathered_copy(src, dst, indices) - torch.cuda.synchronize() - - # Benchmark - import time - start = time.perf_counter() - num_iters = 100 - for _ in range(num_iters): - gathered_copy(src, dst, indices) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - - bytes_copied = num_blocks * block_size * kv_dim * 2 * num_iters # fp16 - bandwidth_gbps = bytes_copied / elapsed / 1e9 - - print(f"\n{num_blocks} blocks: {bandwidth_gbps:.2f} GB/s") - - # Should achieve reasonable bandwidth (lower threshold for small blocks due to kernel launch overhead) - min_bandwidth = 5 if num_blocks <= 16 else 10 - assert bandwidth_gbps > min_bandwidth, f"Bandwidth too low: {bandwidth_gbps} GB/s" - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_kvcache_manager.py b/tests/test_kvcache_manager.py deleted file mode 100644 index e3798b3..0000000 --- a/tests/test_kvcache_manager.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Tests for KV cache managers.""" - -import pytest -import torch - -from nanovllm.engine.sequence import Sequence -from nanovllm.kvcache.gpu_manager import GPUOnlyManager - - -class MockSequence: - """Mock sequence for testing block allocation.""" - - def __init__(self, token_ids: list[int], block_size: int = 256): - self._token_ids = token_ids - self._block_size = block_size - self.block_table: list[int] = [] - self.num_cached_tokens = 0 - - def __len__(self): - return len(self._token_ids) - - @property - def num_blocks(self) -> int: - return (len(self) + self._block_size - 1) // self._block_size - - def block(self, i: int) -> list[int]: - start = i * self._block_size - end = min((i + 1) * self._block_size, len(self)) - return self._token_ids[start:end] - - -class TestGPUOnlyManager: - """Tests for GPU-only KV cache manager.""" - - @pytest.fixture - def manager(self): - """Create a small manager for testing.""" - return GPUOnlyManager(num_blocks=16, block_size=256) - - def test_initialization(self, manager): - """Test manager initialization.""" - assert manager.block_size == 256 - assert manager.num_free_blocks == 16 - assert len(manager.blocks) == 16 - - def test_allocate_cache(self, manager): - """Test cache allocation.""" - manager.allocate_cache( - num_layers=4, - num_kv_heads=8, - head_dim=64, - dtype=torch.float16, - ) - - assert manager.kv_cache is not None - assert manager.kv_cache.shape == (2, 4, 16, 256, 8, 64) - assert manager.kv_cache.device.type == "cuda" - - def test_get_layer_cache(self, manager): - """Test getting layer cache.""" - manager.allocate_cache( - num_layers=4, - num_kv_heads=8, - head_dim=64, - dtype=torch.float16, - ) - - k_cache, v_cache = manager.get_layer_cache(0) - assert k_cache.shape == (16, 256, 8, 64) - assert v_cache.shape == (16, 256, 8, 64) - - def test_can_allocate(self, manager): - """Test allocation check.""" - seq = MockSequence([0] * 300) # Needs 2 blocks - assert manager.can_allocate(seq) - - # Fill up all blocks with unique tokens to avoid prefix caching - for i in range(8): - # Each sequence has unique tokens to prevent prefix cache hits - s = MockSequence([i * 1000 + j for j in range(300)]) - manager.allocate(s) - - # Now should not be able to allocate - new_seq = MockSequence([9999] * 300) - assert not manager.can_allocate(new_seq) - - def test_allocate_and_deallocate(self, manager): - """Test block allocation and deallocation.""" - seq = MockSequence([0] * 600) # Needs 3 blocks - initial_free = manager.num_free_blocks - - manager.allocate(seq) - assert len(seq.block_table) == 3 - assert manager.num_free_blocks == initial_free - 3 - - manager.deallocate(seq) - assert len(seq.block_table) == 0 - assert manager.num_free_blocks == initial_free - - def test_can_append(self, manager): - """Test append check.""" - seq = MockSequence([0] * 256) # Exactly 1 block - manager.allocate(seq) - - # Can append without new block (still in same block) - seq._token_ids = [0] * 257 - assert manager.can_append(seq) - - def test_prepare_for_attention_noop(self, manager): - """Test that prepare_for_attention is a no-op for GPU-only.""" - seq = MockSequence([0] * 100) - manager.allocate(seq) - - # Should not raise - manager.prepare_for_attention([seq], is_prefill=True) - manager.prepare_for_attention([seq], is_prefill=False) - - def test_get_gpu_block_tables(self, manager): - """Test getting GPU block tables.""" - seq1 = MockSequence([0] * 300) - seq2 = MockSequence([0] * 600) - - manager.allocate(seq1) - manager.allocate(seq2) - - tables = manager.get_gpu_block_tables([seq1, seq2]) - - assert len(tables) == 2 - assert tables[0] == list(seq1.block_table) - assert tables[1] == list(seq2.block_table) - - -class TestGPUOnlyManagerPrefixCaching: - """Tests for prefix caching in GPU-only manager.""" - - @pytest.fixture - def manager(self): - """Create manager for testing.""" - return GPUOnlyManager(num_blocks=32, block_size=256) - - def test_prefix_cache_hit(self, manager): - """Test that identical prefixes are cached.""" - # Create two sequences with same prefix - tokens = list(range(512)) # 2 full blocks - seq1 = MockSequence(tokens) - seq2 = MockSequence(tokens) - - manager.allocate(seq1) - initial_free = manager.num_free_blocks - - manager.allocate(seq2) - - # Second sequence should reuse cached blocks - assert seq2.num_cached_tokens >= 256 # At least first block cached - # Should use fewer new blocks - assert manager.num_free_blocks >= initial_free - 2 - - def test_prefix_cache_different_suffix(self, manager): - """Test cache with same prefix but different suffix.""" - prefix = list(range(256)) # 1 full block - - seq1 = MockSequence(prefix + [1000, 1001]) - seq2 = MockSequence(prefix + [2000, 2001]) - - manager.allocate(seq1) - manager.allocate(seq2) - - # First block should be shared - assert seq1.block_table[0] == seq2.block_table[0] - # Second block should be different - assert seq1.block_table[1] != seq2.block_table[1] - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_offload_engine.py b/tests/test_offload_engine.py index 8613ee5..2df77bc 100644 --- a/tests/test_offload_engine.py +++ b/tests/test_offload_engine.py @@ -1,196 +1,119 @@ -"""Tests for CPU-GPU offload engine.""" +""" +Test script for OffloadEngine - CPU-GPU KV cache transfer engine. + +Demonstrates: ring buffer, H2D/D2H transfers, CUDA events, KV access. +""" -import pytest import torch - from nanovllm.kvcache.offload_engine import OffloadEngine +# ============================================================ +# Utility Functions +# ============================================================ -class TestOffloadEngine: - """Tests for OffloadEngine.""" +def verify(tensor: torch.Tensor, expected: float, name: str) -> None: + """Verify tensor contains expected value.""" + actual = tensor.mean().item() + assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}" - @pytest.fixture - def engine(self): - """Create a small engine for testing.""" - return OffloadEngine( - num_layers=2, - num_gpu_blocks=4, - num_cpu_blocks=8, - block_size=256, - num_kv_heads=4, - head_dim=64, - dtype=torch.float16, - num_streams=2, - ) +# ============================================================ +# Configuration +# ============================================================ - def test_initialization(self, engine): - """Test engine initialization.""" - # Check GPU cache shape - assert engine.k_cache_gpu.shape == (2, 4, 256, 4, 64) - assert engine.v_cache_gpu.shape == (2, 4, 256, 4, 64) +NUM_LAYERS = 4 +NUM_GPU_BLOCKS = 8 +NUM_CPU_BLOCKS = 16 +BLOCK_SIZE = 64 +NUM_KV_HEADS = 4 +HEAD_DIM = 32 - # Check CPU cache shape - assert engine.k_cache_cpu.shape == (2, 8, 256, 4, 64) - assert engine.v_cache_cpu.shape == (2, 8, 256, 4, 64) +# ============================================================ +# Main Test Script +# ============================================================ - # Check pinned memory - assert engine.k_cache_cpu.is_pinned() - assert engine.v_cache_cpu.is_pinned() +# 1. Initialize +engine = OffloadEngine( + num_layers=NUM_LAYERS, + num_gpu_blocks=NUM_GPU_BLOCKS, + num_cpu_blocks=NUM_CPU_BLOCKS, + block_size=BLOCK_SIZE, + num_kv_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + dtype=torch.float16, +) - # Check gather indices - assert engine.gather_indices_cpu.shape == (2, 4) - assert engine.gather_indices_gpu.shape == (2, 4) +# 2. Ring buffer slot management +for chunk_idx in range(12): + write_slot = engine.get_write_slot_for_prefill(chunk_idx) + load_slots = engine.get_load_slots_for_prefill(write_slot) + + print("chunk idx", chunk_idx, "write slots:", write_slot, "load slots:", load_slots) + + assert write_slot == chunk_idx % engine.num_ring_slots + assert write_slot not in load_slots - def test_get_layer_cache(self, engine): - """Test getting layer cache.""" - k, v = engine.get_layer_cache(0) - assert k.shape == (4, 256, 4, 64) - assert v.shape == (4, 256, 4, 64) - assert k.device.type == "cuda" - assert v.device.type == "cuda" +assert engine.decode_slot == 0 +assert engine.get_load_slots_for_decode() == list(range(1, NUM_GPU_BLOCKS)) - def test_prefetch_and_offload(self, engine): - """Test async prefetch and offload.""" - # Write some data to CPU block 0 - engine.k_cache_cpu[0, 0].fill_(1.0) - engine.v_cache_cpu[0, 0].fill_(2.0) +# 3. Per-slot per-layer H2D transfer +engine.k_cache_cpu[0, 0].fill_(42.0) +engine.v_cache_cpu[0, 0].fill_(42.5) - # Prefetch to GPU block 2 - event = engine.prefetch_block_async( - layer_id=0, - cpu_block_id=0, - gpu_block_id=2, - ) - event.synchronize() +engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=0) +engine.wait_slot_layer(slot_idx=1, layer_id=0) - # Verify data was copied (move GPU to CPU for comparison) - assert torch.allclose(engine.k_cache_gpu[0, 2].cpu(), engine.k_cache_cpu[0, 0]) - assert torch.allclose(engine.v_cache_gpu[0, 2].cpu(), engine.v_cache_cpu[0, 0]) +verify(engine.k_cache_gpu[0, 1], 42.0, "H2D K") +verify(engine.v_cache_gpu[0, 1], 42.5, "H2D V") - # Modify GPU data - engine.k_cache_gpu[0, 2].fill_(3.0) - engine.v_cache_gpu[0, 2].fill_(4.0) +# 4. Compute-done event (pipeline safety) +engine.record_slot_compute_done(slot_idx=1, layer_id=0) - # Offload to CPU block 5 - event = engine.offload_block_async( - layer_id=0, - gpu_block_id=2, - cpu_block_id=5, - ) - event.synchronize() +engine.k_cache_cpu[0, 1].fill_(100.0) +engine.v_cache_cpu[0, 1].fill_(100.5) +engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=1) +engine.wait_slot_layer(slot_idx=1, layer_id=0) - # Verify data was copied - assert torch.allclose(engine.k_cache_cpu[0, 5], engine.k_cache_gpu[0, 2].cpu()) - assert torch.allclose(engine.v_cache_cpu[0, 5], engine.v_cache_gpu[0, 2].cpu()) +verify(engine.k_cache_gpu[0, 1], 100.0, "Reuse K") +verify(engine.v_cache_gpu[0, 1], 100.5, "Reuse V") - def test_update_gather_indices(self, engine): - """Test updating gather indices.""" - # Manually set CPU data - for i in range(8): - engine.k_cache_cpu[0, i].fill_(float(i)) - engine.v_cache_cpu[0, i].fill_(float(i + 100)) +# 5. D2H offload +engine.k_cache_gpu[1, 2].fill_(77.0) +engine.v_cache_gpu[1, 2].fill_(77.5) - # Update indices for layer 0: (cpu_block_id, gpu_slot) - mappings = [(2, 0), (5, 1), (1, 2), (7, 3)] - engine.update_gather_indices(layer_id=0, mappings=mappings) - torch.cuda.synchronize() +engine.offload_slot_to_cpu(slot_idx=2, cpu_block_id=5) +engine.wait_slot_offload(slot_idx=2) - # Verify indices were set - expected = torch.tensor([2, 5, 1, 7], dtype=torch.int64) - assert torch.equal(engine.gather_indices_cpu[0], expected) +verify(engine.k_cache_cpu[1, 5], 77.0, "D2H K") +verify(engine.v_cache_cpu[1, 5], 77.5, "D2H V") - def test_gathered_h2d_layer(self, engine): - """Test gathered H2D copy for a layer.""" - # Set up CPU data with known values - for i in range(8): - engine.k_cache_cpu[0, i].fill_(float(i)) - engine.v_cache_cpu[0, i].fill_(float(i + 100)) +# 6. KV access methods +k, v = engine.get_kv_for_slot(slot_idx=1, layer_id=0) +assert k.shape == (1, BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM) - # Set gather indices: (cpu_block_id, gpu_slot) - # GPU slot 0 gets CPU block 3, GPU slot 1 gets CPU block 0, etc. - mappings = [(3, 0), (0, 1), (7, 2), (2, 3)] - engine.update_gather_indices(layer_id=0, mappings=mappings) - torch.cuda.synchronize() +k, v = engine.get_kv_for_slots(layer_id=0, slot_indices=[0, 1, 2]) +assert k.shape == (1, 3 * BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM) - # Execute gathered H2D - engine.gathered_h2d_layer(layer_id=0) - torch.cuda.synchronize() +engine.k_cache_gpu[0, engine.decode_slot].fill_(33.0) +k, v = engine.get_kv_for_decode_slot_accumulated(layer_id=0, num_tokens=10) +assert k.shape == (1, 10, NUM_KV_HEADS, HEAD_DIM) +verify(k, 33.0, "Decode slot K") - # Verify: GPU slot 0 should have CPU block 3's data - assert torch.allclose(engine.k_cache_gpu[0, 0], - torch.full_like(engine.k_cache_gpu[0, 0], 3.0)) - # GPU slot 1 should have CPU block 0's data - assert torch.allclose(engine.k_cache_gpu[0, 1], - torch.full_like(engine.k_cache_gpu[0, 1], 0.0)) - # GPU slot 2 should have CPU block 7's data - assert torch.allclose(engine.k_cache_gpu[0, 2], - torch.full_like(engine.k_cache_gpu[0, 2], 7.0)) - # GPU slot 3 should have CPU block 2's data - assert torch.allclose(engine.k_cache_gpu[0, 3], - torch.full_like(engine.k_cache_gpu[0, 3], 2.0)) +# 7. Batch transfer +cpu_blocks = [2, 3, 4] +gpu_slots = [3, 4, 5] +for cpu_id in cpu_blocks: + engine.k_cache_cpu[0, cpu_id].fill_(50.0 + cpu_id) - def test_multi_layer_independence(self, engine): - """Test that layers are independent.""" - # Set different data for each layer - engine.k_cache_cpu[0, 0].fill_(1.0) - engine.k_cache_cpu[1, 0].fill_(2.0) +engine.load_cpu_blocks_to_gpu_slots(layer_id=0, cpu_block_ids=cpu_blocks, gpu_slot_ids=gpu_slots) - # Prefetch layer 0 - event = engine.prefetch_block_async(0, 0, 0) - event.synchronize() +for cpu_id, gpu_slot in zip(cpu_blocks, gpu_slots): + verify(engine.k_cache_gpu[0, gpu_slot], 50.0 + cpu_id, f"Batch slot {gpu_slot}") - # Verify only layer 0 was affected - assert torch.allclose(engine.k_cache_gpu[0, 0], - torch.full_like(engine.k_cache_gpu[0, 0], 1.0)) - # Layer 1 should be zeros (initial state) - assert not torch.allclose(engine.k_cache_gpu[1, 0], - torch.full_like(engine.k_cache_gpu[1, 0], 2.0)) +# 8. Gather indices (CUDA graph compatible) +engine.update_gather_indices(layer_id=0, mappings=[(0, 0), (1, 1), (2, 2)]) +assert engine.gather_indices_gpu[0, :3].tolist() == [0, 1, 2] +engine.clear_gather_indices(layer_id=0) +assert engine.gather_indices_gpu[0, 0].item() == -1 -class TestOffloadEngineFixedAddresses: - """Tests verifying fixed address property for CUDA Graph compatibility.""" - - @pytest.fixture - def engine(self): - """Create engine for address tests.""" - return OffloadEngine( - num_layers=2, - num_gpu_blocks=4, - num_cpu_blocks=8, - block_size=256, - num_kv_heads=4, - head_dim=64, - dtype=torch.float16, - num_streams=2, - ) - - def test_gpu_cache_address_fixed(self, engine): - """Verify GPU cache addresses don't change.""" - k_ptr_before = engine.k_cache_gpu.data_ptr() - v_ptr_before = engine.v_cache_gpu.data_ptr() - - # Perform some operations - mappings is List[(cpu_block_id, gpu_slot)] - mappings = [(0, 0), (1, 1), (2, 2), (3, 3)] - engine.update_gather_indices(0, mappings) - engine.gathered_h2d_layer(0) - torch.cuda.synchronize() - - # Addresses should be the same - assert engine.k_cache_gpu.data_ptr() == k_ptr_before - assert engine.v_cache_gpu.data_ptr() == v_ptr_before - - def test_gather_indices_gpu_address_fixed(self, engine): - """Verify gather indices GPU tensor address doesn't change.""" - ptr_before = engine.gather_indices_gpu.data_ptr() - - # Update indices multiple times - mappings is List[(cpu_block_id, gpu_slot)] - mappings = [(0, 0), (1, 1), (2, 2), (3, 3)] - for _ in range(10): - engine.update_gather_indices(0, mappings) - torch.cuda.synchronize() - - assert engine.gather_indices_gpu.data_ptr() == ptr_before - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +print("test_offload_engine: PASSED") diff --git a/tests/test_policies.py b/tests/test_policies.py deleted file mode 100644 index d241148..0000000 --- a/tests/test_policies.py +++ /dev/null @@ -1,167 +0,0 @@ -"""Tests for eviction policies.""" - -import pytest -from nanovllm.kvcache.policies.lru_policy import LRUPolicy -from nanovllm.kvcache.policies.fifo_policy import FIFOPolicy -from nanovllm.kvcache.policies import get_policy - - -class TestLRUPolicy: - """Tests for LRU eviction policy.""" - - def test_basic_eviction(self): - """Test that LRU evicts least recently used block.""" - policy = LRUPolicy() - - # Allocate blocks 0, 1, 2 in order - policy.on_block_allocated(0, step=1) - policy.on_block_allocated(1, step=2) - policy.on_block_allocated(2, step=3) - - # Access block 0 (makes it most recently used) - policy.on_block_access(0, step=4) - - # Should evict block 1 (least recently used) - candidates = {0, 1, 2} - victim = policy.select_victim(candidates) - assert victim == 1, f"Expected block 1, got {victim}" - - def test_access_updates_order(self): - """Test that access updates LRU order.""" - policy = LRUPolicy() - - policy.on_block_allocated(0, step=1) - policy.on_block_allocated(1, step=2) - policy.on_block_allocated(2, step=3) - - # Access all in reverse order - policy.on_block_access(2, step=4) - policy.on_block_access(1, step=5) - policy.on_block_access(0, step=6) - - # Block 2 is now LRU (accessed earliest after allocation update) - candidates = {0, 1, 2} - victim = policy.select_victim(candidates) - assert victim == 2, f"Expected block 2, got {victim}" - - def test_eviction_removes_from_tracking(self): - """Test that evicted blocks are removed from tracking.""" - policy = LRUPolicy() - - policy.on_block_allocated(0, step=1) - policy.on_block_allocated(1, step=2) - - policy.on_block_evicted(0) - - # Only block 1 should be a candidate - candidates = {0, 1} - victim = policy.select_victim(candidates) - assert victim == 1, "Should select block 1 since 0 was evicted" - - def test_batch_eviction_order(self): - """Test get_eviction_order returns blocks in LRU order.""" - policy = LRUPolicy() - - for i in range(5): - policy.on_block_allocated(i, step=i) - - # Access blocks 2 and 4 - policy.on_block_access(2, step=10) - policy.on_block_access(4, step=11) - - candidates = {0, 1, 2, 3, 4} - order = policy.get_eviction_order(candidates, count=3) - - # Should be 0, 1, 3 (in that order, skipping 2 and 4 until needed) - assert order == [0, 1, 3], f"Expected [0, 1, 3], got {order}" - - -class TestFIFOPolicy: - """Tests for FIFO eviction policy.""" - - def test_basic_eviction(self): - """Test that FIFO evicts oldest allocated block.""" - policy = FIFOPolicy() - - policy.on_block_allocated(0, step=1) - policy.on_block_allocated(1, step=2) - policy.on_block_allocated(2, step=3) - - # Access doesn't change FIFO order - policy.on_block_access(0, step=4) - - candidates = {0, 1, 2} - victim = policy.select_victim(candidates) - assert victim == 0, f"Expected block 0 (oldest), got {victim}" - - def test_access_does_not_update_order(self): - """Test that FIFO ignores access patterns.""" - policy = FIFOPolicy() - - policy.on_block_allocated(0, step=1) - policy.on_block_allocated(1, step=2) - policy.on_block_allocated(2, step=3) - - # Multiple accesses to block 0 - for i in range(10): - policy.on_block_access(0, step=10 + i) - - # Block 0 should still be evicted first (FIFO order) - candidates = {0, 1, 2} - victim = policy.select_victim(candidates) - assert victim == 0, f"Expected block 0, got {victim}" - - def test_prefetch_resets_order(self): - """Test that prefetch moves block to end of queue.""" - policy = FIFOPolicy() - - policy.on_block_allocated(0, step=1) - policy.on_block_allocated(1, step=2) - policy.on_block_allocated(2, step=3) - - # Prefetch block 0 (moves to end) - policy.on_block_prefetched(0, step=4) - - candidates = {0, 1, 2} - victim = policy.select_victim(candidates) - assert victim == 1, f"Expected block 1 (now oldest), got {victim}" - - def test_batch_eviction_order(self): - """Test get_eviction_order returns blocks in FIFO order.""" - policy = FIFOPolicy() - - for i in range(5): - policy.on_block_allocated(i, step=i) - - candidates = {0, 1, 2, 3, 4} - order = policy.get_eviction_order(candidates, count=3) - - assert order == [0, 1, 2], f"Expected [0, 1, 2], got {order}" - - -class TestGetPolicy: - """Tests for policy factory function.""" - - def test_get_lru(self): - """Test getting LRU policy by name.""" - policy = get_policy("lru") - assert isinstance(policy, LRUPolicy) - - def test_get_fifo(self): - """Test getting FIFO policy by name.""" - policy = get_policy("fifo") - assert isinstance(policy, FIFOPolicy) - - def test_get_by_class_path(self): - """Test getting policy by full class path.""" - policy = get_policy("nanovllm.kvcache.policies.lru_policy.LRUPolicy") - assert isinstance(policy, LRUPolicy) - - def test_invalid_policy_name(self): - """Test that invalid policy name raises error.""" - with pytest.raises((ValueError, ImportError)): - get_policy("invalid_policy") - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_sparse_policy.py b/tests/test_sparse_policy.py deleted file mode 100644 index e8c4b51..0000000 --- a/tests/test_sparse_policy.py +++ /dev/null @@ -1,252 +0,0 @@ -""" -Test sparse attention policies. - -Usage: - CUDA_VISIBLE_DEVICES=4,5 python tests/test_sparse_policy.py [policy_name] - -Policy names: full, vertical_slash, streaming_llm, quest -""" - -import sys -import os - -os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" - -import torch -from typing import List - -# Test the sparse policy implementations -from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext -from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy -from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig -from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig -from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager - - -def test_full_attention_policy(): - """Test FullAttentionPolicy returns all blocks.""" - print("\n=== Testing FullAttentionPolicy ===") - policy = FullAttentionPolicy() - - available_blocks = list(range(10)) - ctx = PolicyContext( - query_chunk_idx=5, - num_query_chunks=10, - layer_id=0, - query=None, - is_prefill=True, - ) - - selected = policy.select_blocks(available_blocks, ctx) - assert selected == available_blocks, f"Expected all blocks, got {selected}" - print(f" Prefill: input={available_blocks}, selected={selected} [PASS]") - - # Test decode - ctx.is_prefill = False - selected = policy.select_blocks(available_blocks, ctx) - assert selected == available_blocks, f"Expected all blocks, got {selected}" - print(f" Decode: input={available_blocks}, selected={selected} [PASS]") - - -def test_vertical_slash_policy(): - """Test VerticalSlashPolicy selects sink + local window.""" - print("\n=== Testing VerticalSlashPolicy ===") - config = VerticalSlashConfig( - num_sink_blocks=2, - local_window_blocks=3, - threshold_blocks=4, - ) - policy = VerticalSlashPolicy(config) - - # Test with 10 blocks, chunk 7 (should select sink[0,1] + local[4,5,6]) - available_blocks = list(range(10)) - ctx = PolicyContext( - query_chunk_idx=7, - num_query_chunks=10, - layer_id=0, - query=None, - is_prefill=True, - ) - - selected = policy.select_blocks(available_blocks, ctx) - expected = [0, 1, 4, 5, 6] # sink + local window before chunk 7 - assert selected == expected, f"Expected {expected}, got {selected}" - print(f" Prefill chunk 7: input={available_blocks}, selected={selected} [PASS]") - - # Test with small number of blocks (below threshold) - available_blocks = [0, 1, 2] - selected = policy.select_blocks(available_blocks, ctx) - assert selected == [0, 1, 2], f"Expected all blocks for small input, got {selected}" - print(f" Below threshold: input={[0,1,2]}, selected={selected} [PASS]") - - # Test decode (local window is last M blocks) - available_blocks = list(range(10)) - ctx.is_prefill = False - selected = policy.select_blocks(available_blocks, ctx) - expected = [0, 1, 7, 8, 9] # sink + last 3 blocks - assert selected == expected, f"Expected {expected}, got {selected}" - print(f" Decode: input={available_blocks}, selected={selected} [PASS]") - - -def test_streaming_llm_policy(): - """Test StreamingLLMPolicy selects sink + recent only.""" - print("\n=== Testing StreamingLLMPolicy ===") - config = StreamingLLMConfig( - num_sink_blocks=1, - num_recent_blocks=2, - ) - policy = StreamingLLMPolicy(config) - - available_blocks = list(range(10)) - ctx = PolicyContext( - query_chunk_idx=0, - num_query_chunks=1, - layer_id=0, - query=None, - is_prefill=False, - ) - - selected = policy.select_blocks(available_blocks, ctx) - expected = [0, 8, 9] # sink[0] + recent[8,9] - assert selected == expected, f"Expected {expected}, got {selected}" - print(f" 10 blocks: selected={selected} [PASS]") - - # Test with 3 blocks (all fit in sink+recent) - available_blocks = [0, 1, 2] - selected = policy.select_blocks(available_blocks, ctx) - assert selected == [0, 1, 2], f"Expected all blocks, got {selected}" - print(f" 3 blocks: selected={selected} [PASS]") - - -def test_quest_policy(): - """Test QuestPolicy with mock metadata.""" - print("\n=== Testing QuestPolicy ===") - - # Create metadata manager - num_blocks = 10 - num_layers = 2 - num_kv_heads = 4 - head_dim = 64 - dtype = torch.float32 - - metadata = BlockMetadataManager( - num_blocks=num_blocks, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - dtype=dtype, - ) - - # Simulate offloading blocks with different key patterns - # Blocks 0, 5, 9 will have high scores (keys aligned with query) - for block_id in range(num_blocks): - for layer_id in range(num_layers): - k_cache = torch.randn(100, num_kv_heads, head_dim) # 100 tokens per block - if block_id in [0, 5, 9]: - # Make these blocks have keys that score high - k_cache = k_cache.abs() # All positive - else: - k_cache = -k_cache.abs() # All negative - metadata.update_metadata(block_id, layer_id, k_cache, 100) - - config = QuestConfig( - topk_blocks=4, - threshold_blocks=3, - ) - policy = QuestPolicy(config, metadata) - - available_blocks = list(range(10)) - - # Create query that scores high with positive keys - query = torch.ones(1, num_kv_heads, head_dim, device='cuda') - - ctx = PolicyContext( - query_chunk_idx=0, - num_query_chunks=1, - layer_id=0, - query=query, - is_prefill=False, - ) - - selected = policy.select_blocks(available_blocks, ctx) - print(f" Top-4 selection: input={available_blocks}, selected={selected}") - - # High-scoring blocks [0, 5, 9] should be in selection - for expected_block in [0, 5, 9]: - assert expected_block in selected, f"Expected block {expected_block} in selection" - print(f" High-score blocks [0, 5, 9] in selection [PASS]") - - # Test below threshold (should return all) - available_blocks = [0, 1, 2] - selected = policy.select_blocks(available_blocks, ctx) - assert selected == [0, 1, 2], f"Expected all blocks below threshold, got {selected}" - print(f" Below threshold: selected={selected} [PASS]") - - # Test without query (should return all) - ctx.query = None - available_blocks = list(range(10)) - selected = policy.select_blocks(available_blocks, ctx) - assert selected == available_blocks, f"Expected all blocks without query, got {selected}" - print(f" No query: selected all [PASS]") - - -def test_custom_policy(): - """Test creating a custom policy.""" - print("\n=== Testing Custom Policy ===") - - class EveryOtherPolicy(SparsePolicy): - """Select every other block.""" - - def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]: - return [available_blocks[i] for i in range(0, len(available_blocks), 2)] - - policy = EveryOtherPolicy() - available_blocks = list(range(10)) - ctx = PolicyContext( - query_chunk_idx=0, - num_query_chunks=1, - layer_id=0, - query=None, - is_prefill=True, - ) - - selected = policy.select_blocks(available_blocks, ctx) - expected = [0, 2, 4, 6, 8] - assert selected == expected, f"Expected {expected}, got {selected}" - print(f" Every other: input={available_blocks}, selected={selected} [PASS]") - - -def run_all_tests(): - """Run all policy tests.""" - print("Running Sparse Policy Tests...") - - test_full_attention_policy() - test_vertical_slash_policy() - test_streaming_llm_policy() - test_quest_policy() - test_custom_policy() - - print("\n" + "=" * 50) - print("All tests passed!") - print("=" * 50) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - policy_name = sys.argv[1].lower() - if policy_name == "full": - test_full_attention_policy() - elif policy_name == "vertical_slash": - test_vertical_slash_policy() - elif policy_name == "streaming_llm": - test_streaming_llm_policy() - elif policy_name == "quest": - test_quest_policy() - elif policy_name == "custom": - test_custom_policy() - else: - print(f"Unknown policy: {policy_name}") - print("Available: full, vertical_slash, streaming_llm, quest, custom") - sys.exit(1) - else: - run_all_tests()