[WIP] NEED refactor nanovllm mechenism.
This commit is contained in:
38
.claude/rules/code-analysis.md
Normal file
38
.claude/rules/code-analysis.md
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# Code Analysis
|
||||||
|
|
||||||
|
## Use cclsp MCP for Code Navigation
|
||||||
|
|
||||||
|
When analyzing code, understanding call chains, or exploring the codebase, **prefer using the cclsp MCP tools** over grep/glob-based searches:
|
||||||
|
|
||||||
|
### Available cclsp Tools
|
||||||
|
|
||||||
|
| Tool | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| `mcp__cclsp__find_definition` | Jump to symbol definition |
|
||||||
|
| `mcp__cclsp__find_references` | Find all usages of a symbol |
|
||||||
|
| `mcp__cclsp__rename_symbol` | Rename a symbol across the codebase |
|
||||||
|
| `mcp__cclsp__get_diagnostics` | Get LSP diagnostics (errors, warnings) |
|
||||||
|
| `mcp__cclsp__restart_server` | Restart the LSP server if needed |
|
||||||
|
|
||||||
|
### When to Use cclsp
|
||||||
|
|
||||||
|
1. **Understanding call chains**: Use `find_references` to trace how functions are called
|
||||||
|
2. **Finding implementations**: Use `find_definition` to jump to actual code
|
||||||
|
3. **Refactoring**: Use `rename_symbol` for safe cross-file renames
|
||||||
|
4. **Code quality**: Use `get_diagnostics` to check for issues
|
||||||
|
|
||||||
|
### Example Workflow
|
||||||
|
|
||||||
|
```
|
||||||
|
1. User asks: "How does the prefill flow work?"
|
||||||
|
2. Use find_definition to locate key entry points (e.g., run_chunked_offload_prefill)
|
||||||
|
3. Use find_references to trace the call chain through the codebase
|
||||||
|
4. Read relevant code sections to understand the implementation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benefits over grep/glob
|
||||||
|
|
||||||
|
- **Semantic understanding**: cclsp understands code structure, not just text patterns
|
||||||
|
- **Accurate references**: Finds actual usages, not just text matches
|
||||||
|
- **Cross-file navigation**: Follows imports and definitions across modules
|
||||||
|
- **Type-aware**: Understands Python types and class hierarchies
|
||||||
@@ -1,20 +1,98 @@
|
|||||||
# Testing
|
# Testing
|
||||||
|
|
||||||
## Chunked Attention Test
|
## Test File Guidelines
|
||||||
|
|
||||||
```bash
|
### Naming Convention
|
||||||
CUDA_VISIBLE_DEVICES=4,5 python tests/test_chunked_attention.py 6 2048 64 2
|
|
||||||
# Args: num_gpu_blocks input_len output_len num_prefetch_blocks
|
- All test files must be named `test_*.py`
|
||||||
|
- Example: `test_offload_engine.py`, `test_ring_buffer.py`
|
||||||
|
|
||||||
|
### Purpose
|
||||||
|
|
||||||
|
Tests are **educational scripts** for understanding module behavior, NOT traditional unit tests:
|
||||||
|
- Focus on demonstrating how modules work
|
||||||
|
- Show the flow and interaction between components
|
||||||
|
- Help developers understand implementation details
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
1. **Script-based structure**: Write tests as executable scripts, not pytest-style functions
|
||||||
|
2. **Utility functions**: Extract reusable steps as helper functions at the top of the file
|
||||||
|
3. **Main flow as script**: The actual test/demonstration logic runs as top-level script code
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example structure:
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from nanovllm.kvcache import SomeModule
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Utility Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def verify(tensor, expected, name):
|
||||||
|
actual = tensor.mean().item()
|
||||||
|
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test Script
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 1. Initialize
|
||||||
|
module = SomeModule(param=value)
|
||||||
|
|
||||||
|
# 2. Test feature X
|
||||||
|
result = module.do_something()
|
||||||
|
assert result == expected_value
|
||||||
|
|
||||||
|
# 3. Test feature Y
|
||||||
|
...
|
||||||
|
|
||||||
|
print("test_xxx: PASSED")
|
||||||
```
|
```
|
||||||
|
|
||||||
## CPU Offload Testing
|
### Comments
|
||||||
|
|
||||||
|
- Keep comments concise and clear
|
||||||
|
- Only add comments where the code isn't self-explanatory
|
||||||
|
- Use section headers (`# === Section ===`) to organize logical blocks
|
||||||
|
|
||||||
|
### Output
|
||||||
|
|
||||||
|
- **Minimize print statements** - the code should be self-explanatory
|
||||||
|
- Only print a final "PASSED" message at the end
|
||||||
|
- Use `assert` for verification instead of printing results
|
||||||
|
- If the user needs explanation, they will ask
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Basic test with limited GPU blocks to trigger offload
|
# Run a specific test
|
||||||
CUDA_VISIBLE_DEVICES=4,5 python tests/test_chunked_attention.py 6 2048 64 2
|
python tests/test_offload_engine.py
|
||||||
|
|
||||||
# Verify consistency (run multiple times, output should be identical)
|
# Run with specific GPU
|
||||||
for i in 1 2 3; do
|
CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py
|
||||||
CUDA_VISIBLE_DEVICES=4,5 python tests/test_chunked_attention.py 6 2048 32 2 2>&1 | tail -3
|
```
|
||||||
done
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Standard GPU benchmark
|
||||||
|
python bench.py
|
||||||
|
|
||||||
|
# CPU offload benchmark
|
||||||
|
python bench_offload.py
|
||||||
|
|
||||||
|
# vLLM comparison benchmark
|
||||||
|
python bench_vllm.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Verification
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Import test
|
||||||
|
python -c "from nanovllm import LLM"
|
||||||
|
|
||||||
|
# Run offload benchmark (tests CPU-primary ring buffer mode)
|
||||||
|
python bench_offload.py
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -125,6 +125,11 @@ class OffloadEngine:
|
|||||||
dtype=torch.int64, device="cuda"
|
dtype=torch.int64, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Log memory allocation
|
||||||
|
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
|
||||||
|
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
|
||||||
|
logger.info(f" GPU memory: {gpu_mem_mb:.1f} MB, CPU memory: {cpu_mem_mb:.1f} MB")
|
||||||
|
|
||||||
# ========== Transfer streams for async operations ==========
|
# ========== Transfer streams for async operations ==========
|
||||||
self.transfer_streams = [torch.cuda.Stream() for _ in range(num_streams)]
|
self.transfer_streams = [torch.cuda.Stream() for _ in range(num_streams)]
|
||||||
self.compute_stream = torch.cuda.current_stream()
|
self.compute_stream = torch.cuda.current_stream()
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
"""Test suite for nano-vllm KV cache offload."""
|
# Tests module
|
||||||
|
|||||||
@@ -1,157 +0,0 @@
|
|||||||
"""
|
|
||||||
Test chunked attention with small num_gpu_blocks to trigger CPU offload.
|
|
||||||
|
|
||||||
For 8K tokens with block_size=256:
|
|
||||||
- Total blocks needed: 8192 / 256 = 32 blocks
|
|
||||||
- With num_gpu_blocks=10, 22 blocks go to CPU -> triggers chunked attention
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Enable debug logging before importing nanovllm
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
|
||||||
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
def create_long_context_prompt(target_tokens: int) -> str:
|
|
||||||
"""
|
|
||||||
Create a meaningful long context prompt with a question at the end.
|
|
||||||
The answer depends on information scattered throughout the context.
|
|
||||||
"""
|
|
||||||
# Key facts to embed in the context
|
|
||||||
facts = [
|
|
||||||
"The capital of France is Paris.",
|
|
||||||
"The Eiffel Tower was built in 1889.",
|
|
||||||
"Python was created by Guido van Rossum.",
|
|
||||||
"The speed of light is approximately 299,792 kilometers per second.",
|
|
||||||
"Mount Everest is 8,848 meters tall.",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Padding text to reach target length
|
|
||||||
padding_paragraph = """
|
|
||||||
This is additional context information that helps extend the length of the prompt.
|
|
||||||
Machine learning has revolutionized many fields including computer vision, natural language processing, and robotics.
|
|
||||||
Deep neural networks can learn complex patterns from large amounts of data.
|
|
||||||
The transformer architecture has become the foundation of modern language models.
|
|
||||||
Attention mechanisms allow models to focus on relevant parts of the input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Build the prompt
|
|
||||||
prompt_parts = []
|
|
||||||
|
|
||||||
# Add instruction
|
|
||||||
prompt_parts.append("Please read the following information carefully and answer the question at the end.\n\n")
|
|
||||||
|
|
||||||
# Add facts at different positions
|
|
||||||
current_tokens = 50 # approximate tokens so far
|
|
||||||
tokens_per_padding = 80 # approximate tokens per padding paragraph
|
|
||||||
fact_interval = target_tokens // (len(facts) + 1)
|
|
||||||
|
|
||||||
fact_idx = 0
|
|
||||||
while current_tokens < target_tokens - 100:
|
|
||||||
# Add padding
|
|
||||||
prompt_parts.append(padding_paragraph)
|
|
||||||
current_tokens += tokens_per_padding
|
|
||||||
|
|
||||||
# Add a fact at intervals
|
|
||||||
if fact_idx < len(facts) and current_tokens > fact_interval * (fact_idx + 1):
|
|
||||||
prompt_parts.append(f"\n[Important Fact #{fact_idx + 1}]: {facts[fact_idx]}\n")
|
|
||||||
current_tokens += 20
|
|
||||||
fact_idx += 1
|
|
||||||
|
|
||||||
# Add the question at the end
|
|
||||||
prompt_parts.append("\n\nQuestion: Based on the information above, what is the speed of light?\n\nAnswer:")
|
|
||||||
|
|
||||||
return "".join(prompt_parts)
|
|
||||||
|
|
||||||
|
|
||||||
def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64, num_prefetch_blocks=2):
|
|
||||||
"""Test chunked prefill with limited GPU blocks."""
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
|
||||||
|
|
||||||
print(f"=" * 60)
|
|
||||||
print(f"Chunked Prefill Test (Chunked Offload)")
|
|
||||||
print(f"=" * 60)
|
|
||||||
print(f" target_input_len: ~{input_len} tokens")
|
|
||||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
|
||||||
print(f" num_prefetch_blocks: {num_prefetch_blocks}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
path,
|
|
||||||
enforce_eager=False,
|
|
||||||
max_model_len=128 * 1024,
|
|
||||||
max_num_batched_tokens=128 * 1024,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
|
||||||
num_prefetch_blocks=num_prefetch_blocks,
|
|
||||||
)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Create meaningful prompt
|
|
||||||
prompt = create_long_context_prompt(input_len)
|
|
||||||
|
|
||||||
print(f"Running generation...")
|
|
||||||
outputs = llm.generate(
|
|
||||||
[prompt],
|
|
||||||
SamplingParams(temperature=0.1, max_tokens=output_len), # low temperature for more deterministic output
|
|
||||||
use_tqdm=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
print()
|
|
||||||
print(f"Output tokens: {len(outputs[0]['token_ids'])}")
|
|
||||||
print(f"Output text:\n{outputs[0]['text']}")
|
|
||||||
print()
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128, num_prefetch_blocks=2):
|
|
||||||
"""Test chunked decode with limited GPU blocks."""
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
|
||||||
|
|
||||||
print(f"=" * 60)
|
|
||||||
print(f"Chunked Decode Test (Chunked Offload)")
|
|
||||||
print(f"=" * 60)
|
|
||||||
print(f" target_input_len: ~{input_len} tokens")
|
|
||||||
print(f" output_len: {output_len} tokens")
|
|
||||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
|
||||||
print(f" num_prefetch_blocks: {num_prefetch_blocks}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
path,
|
|
||||||
enforce_eager=False,
|
|
||||||
max_model_len=128 * 1024,
|
|
||||||
max_num_batched_tokens=128 * 1024,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
|
||||||
num_prefetch_blocks=num_prefetch_blocks,
|
|
||||||
)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Create meaningful prompt
|
|
||||||
prompt = create_long_context_prompt(input_len)
|
|
||||||
|
|
||||||
print(f"Running generation...")
|
|
||||||
outputs = llm.generate(
|
|
||||||
[prompt],
|
|
||||||
SamplingParams(temperature=0.1, max_tokens=output_len),
|
|
||||||
use_tqdm=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
print()
|
|
||||||
print(f"Output tokens: {len(outputs[0]['token_ids'])}")
|
|
||||||
print(f"Output text:\n{outputs[0]['text']}")
|
|
||||||
print()
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Parse arguments: num_gpu_blocks input_len output_len [num_prefetch_blocks]
|
|
||||||
num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10
|
|
||||||
input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048
|
|
||||||
output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64
|
|
||||||
num_prefetch_blocks = int(sys.argv[4]) if len(sys.argv) > 4 else 2
|
|
||||||
|
|
||||||
test_chunked_prefill(num_gpu_blocks, input_len, output_len, num_prefetch_blocks)
|
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
"""Tests for Triton gathered copy kernels."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from nanovllm.kvcache.kernels import gathered_copy, gathered_copy_kv
|
|
||||||
|
|
||||||
|
|
||||||
class TestGatheredCopy:
|
|
||||||
"""Tests for gathered copy kernel."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def setup_tensors(self):
|
|
||||||
"""Create test tensors."""
|
|
||||||
torch.cuda.manual_seed(42)
|
|
||||||
num_src_blocks = 16
|
|
||||||
num_dst_blocks = 8
|
|
||||||
block_size = 256
|
|
||||||
kv_dim = 64
|
|
||||||
|
|
||||||
src = torch.randn(num_src_blocks, block_size, kv_dim,
|
|
||||||
dtype=torch.float16, device="cuda")
|
|
||||||
dst = torch.zeros(num_dst_blocks, block_size, kv_dim,
|
|
||||||
dtype=torch.float16, device="cuda")
|
|
||||||
|
|
||||||
# Indices: dst[i] = src[indices[i]]
|
|
||||||
indices = torch.randint(0, num_src_blocks, (num_dst_blocks,),
|
|
||||||
dtype=torch.int64, device="cuda")
|
|
||||||
|
|
||||||
return src, dst, indices
|
|
||||||
|
|
||||||
def test_basic_copy(self, setup_tensors):
|
|
||||||
"""Test basic gathered copy."""
|
|
||||||
src, dst, indices = setup_tensors
|
|
||||||
|
|
||||||
gathered_copy(src, dst, indices)
|
|
||||||
|
|
||||||
# Verify copy
|
|
||||||
for i in range(len(indices)):
|
|
||||||
src_idx = indices[i].item()
|
|
||||||
assert torch.allclose(dst[i], src[src_idx]), f"Mismatch at index {i}"
|
|
||||||
|
|
||||||
def test_skip_negative_indices(self, setup_tensors):
|
|
||||||
"""Test that negative indices are skipped."""
|
|
||||||
src, dst, indices = setup_tensors
|
|
||||||
|
|
||||||
# Set some indices to -1
|
|
||||||
indices[2] = -1
|
|
||||||
indices[5] = -1
|
|
||||||
|
|
||||||
# Fill dst with a known value
|
|
||||||
dst.fill_(999.0)
|
|
||||||
|
|
||||||
gathered_copy(src, dst, indices)
|
|
||||||
|
|
||||||
# Skipped slots should be unchanged
|
|
||||||
assert (dst[2] == 999.0).all()
|
|
||||||
assert (dst[5] == 999.0).all()
|
|
||||||
|
|
||||||
# Non-skipped slots should be copied
|
|
||||||
for i in [0, 1, 3, 4, 6, 7]:
|
|
||||||
src_idx = indices[i].item()
|
|
||||||
assert torch.allclose(dst[i], src[src_idx])
|
|
||||||
|
|
||||||
def test_single_block(self):
|
|
||||||
"""Test copying a single block."""
|
|
||||||
src = torch.randn(4, 256, 64, dtype=torch.float16, device="cuda")
|
|
||||||
dst = torch.zeros(1, 256, 64, dtype=torch.float16, device="cuda")
|
|
||||||
indices = torch.tensor([2], dtype=torch.int64, device="cuda")
|
|
||||||
|
|
||||||
gathered_copy(src, dst, indices)
|
|
||||||
|
|
||||||
assert torch.allclose(dst[0], src[2])
|
|
||||||
|
|
||||||
|
|
||||||
class TestGatheredCopyKV:
|
|
||||||
"""Tests for gathered K/V cache copy kernel."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def setup_kv_tensors(self):
|
|
||||||
"""Create K/V test tensors."""
|
|
||||||
torch.cuda.manual_seed(42)
|
|
||||||
num_src_blocks = 16
|
|
||||||
num_dst_blocks = 8
|
|
||||||
block_size = 256
|
|
||||||
num_kv_heads = 4
|
|
||||||
head_dim = 64
|
|
||||||
|
|
||||||
k_src = torch.randn(num_src_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device="cuda")
|
|
||||||
v_src = torch.randn(num_src_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device="cuda")
|
|
||||||
k_dst = torch.zeros(num_dst_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device="cuda")
|
|
||||||
v_dst = torch.zeros(num_dst_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device="cuda")
|
|
||||||
|
|
||||||
indices = torch.randint(0, num_src_blocks, (num_dst_blocks,),
|
|
||||||
dtype=torch.int64, device="cuda")
|
|
||||||
|
|
||||||
return k_src, v_src, k_dst, v_dst, indices
|
|
||||||
|
|
||||||
def test_kv_copy(self, setup_kv_tensors):
|
|
||||||
"""Test K/V gathered copy."""
|
|
||||||
k_src, v_src, k_dst, v_dst, indices = setup_kv_tensors
|
|
||||||
|
|
||||||
gathered_copy_kv(k_src, v_src, k_dst, v_dst, indices)
|
|
||||||
|
|
||||||
# Verify copy
|
|
||||||
for i in range(len(indices)):
|
|
||||||
src_idx = indices[i].item()
|
|
||||||
assert torch.allclose(k_dst[i], k_src[src_idx]), f"K mismatch at {i}"
|
|
||||||
assert torch.allclose(v_dst[i], v_src[src_idx]), f"V mismatch at {i}"
|
|
||||||
|
|
||||||
def test_kv_skip_negative(self, setup_kv_tensors):
|
|
||||||
"""Test that negative indices are skipped for K/V."""
|
|
||||||
k_src, v_src, k_dst, v_dst, indices = setup_kv_tensors
|
|
||||||
|
|
||||||
indices[0] = -1
|
|
||||||
k_dst.fill_(999.0)
|
|
||||||
v_dst.fill_(999.0)
|
|
||||||
|
|
||||||
gathered_copy_kv(k_src, v_src, k_dst, v_dst, indices)
|
|
||||||
|
|
||||||
assert (k_dst[0] == 999.0).all()
|
|
||||||
assert (v_dst[0] == 999.0).all()
|
|
||||||
|
|
||||||
|
|
||||||
class TestPerformance:
|
|
||||||
"""Performance benchmarks for gathered copy."""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_blocks", [8, 32, 128])
|
|
||||||
def test_throughput(self, num_blocks):
|
|
||||||
"""Benchmark copy throughput."""
|
|
||||||
block_size = 256
|
|
||||||
kv_dim = 64
|
|
||||||
|
|
||||||
src = torch.randn(num_blocks * 2, block_size, kv_dim,
|
|
||||||
dtype=torch.float16, device="cuda")
|
|
||||||
dst = torch.zeros(num_blocks, block_size, kv_dim,
|
|
||||||
dtype=torch.float16, device="cuda")
|
|
||||||
indices = torch.arange(num_blocks, dtype=torch.int64, device="cuda")
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
gathered_copy(src, dst, indices)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
import time
|
|
||||||
start = time.perf_counter()
|
|
||||||
num_iters = 100
|
|
||||||
for _ in range(num_iters):
|
|
||||||
gathered_copy(src, dst, indices)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
bytes_copied = num_blocks * block_size * kv_dim * 2 * num_iters # fp16
|
|
||||||
bandwidth_gbps = bytes_copied / elapsed / 1e9
|
|
||||||
|
|
||||||
print(f"\n{num_blocks} blocks: {bandwidth_gbps:.2f} GB/s")
|
|
||||||
|
|
||||||
# Should achieve reasonable bandwidth (lower threshold for small blocks due to kernel launch overhead)
|
|
||||||
min_bandwidth = 5 if num_blocks <= 16 else 10
|
|
||||||
assert bandwidth_gbps > min_bandwidth, f"Bandwidth too low: {bandwidth_gbps} GB/s"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
"""Tests for KV cache managers."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from nanovllm.engine.sequence import Sequence
|
|
||||||
from nanovllm.kvcache.gpu_manager import GPUOnlyManager
|
|
||||||
|
|
||||||
|
|
||||||
class MockSequence:
|
|
||||||
"""Mock sequence for testing block allocation."""
|
|
||||||
|
|
||||||
def __init__(self, token_ids: list[int], block_size: int = 256):
|
|
||||||
self._token_ids = token_ids
|
|
||||||
self._block_size = block_size
|
|
||||||
self.block_table: list[int] = []
|
|
||||||
self.num_cached_tokens = 0
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._token_ids)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_blocks(self) -> int:
|
|
||||||
return (len(self) + self._block_size - 1) // self._block_size
|
|
||||||
|
|
||||||
def block(self, i: int) -> list[int]:
|
|
||||||
start = i * self._block_size
|
|
||||||
end = min((i + 1) * self._block_size, len(self))
|
|
||||||
return self._token_ids[start:end]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGPUOnlyManager:
|
|
||||||
"""Tests for GPU-only KV cache manager."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def manager(self):
|
|
||||||
"""Create a small manager for testing."""
|
|
||||||
return GPUOnlyManager(num_blocks=16, block_size=256)
|
|
||||||
|
|
||||||
def test_initialization(self, manager):
|
|
||||||
"""Test manager initialization."""
|
|
||||||
assert manager.block_size == 256
|
|
||||||
assert manager.num_free_blocks == 16
|
|
||||||
assert len(manager.blocks) == 16
|
|
||||||
|
|
||||||
def test_allocate_cache(self, manager):
|
|
||||||
"""Test cache allocation."""
|
|
||||||
manager.allocate_cache(
|
|
||||||
num_layers=4,
|
|
||||||
num_kv_heads=8,
|
|
||||||
head_dim=64,
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert manager.kv_cache is not None
|
|
||||||
assert manager.kv_cache.shape == (2, 4, 16, 256, 8, 64)
|
|
||||||
assert manager.kv_cache.device.type == "cuda"
|
|
||||||
|
|
||||||
def test_get_layer_cache(self, manager):
|
|
||||||
"""Test getting layer cache."""
|
|
||||||
manager.allocate_cache(
|
|
||||||
num_layers=4,
|
|
||||||
num_kv_heads=8,
|
|
||||||
head_dim=64,
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
k_cache, v_cache = manager.get_layer_cache(0)
|
|
||||||
assert k_cache.shape == (16, 256, 8, 64)
|
|
||||||
assert v_cache.shape == (16, 256, 8, 64)
|
|
||||||
|
|
||||||
def test_can_allocate(self, manager):
|
|
||||||
"""Test allocation check."""
|
|
||||||
seq = MockSequence([0] * 300) # Needs 2 blocks
|
|
||||||
assert manager.can_allocate(seq)
|
|
||||||
|
|
||||||
# Fill up all blocks with unique tokens to avoid prefix caching
|
|
||||||
for i in range(8):
|
|
||||||
# Each sequence has unique tokens to prevent prefix cache hits
|
|
||||||
s = MockSequence([i * 1000 + j for j in range(300)])
|
|
||||||
manager.allocate(s)
|
|
||||||
|
|
||||||
# Now should not be able to allocate
|
|
||||||
new_seq = MockSequence([9999] * 300)
|
|
||||||
assert not manager.can_allocate(new_seq)
|
|
||||||
|
|
||||||
def test_allocate_and_deallocate(self, manager):
|
|
||||||
"""Test block allocation and deallocation."""
|
|
||||||
seq = MockSequence([0] * 600) # Needs 3 blocks
|
|
||||||
initial_free = manager.num_free_blocks
|
|
||||||
|
|
||||||
manager.allocate(seq)
|
|
||||||
assert len(seq.block_table) == 3
|
|
||||||
assert manager.num_free_blocks == initial_free - 3
|
|
||||||
|
|
||||||
manager.deallocate(seq)
|
|
||||||
assert len(seq.block_table) == 0
|
|
||||||
assert manager.num_free_blocks == initial_free
|
|
||||||
|
|
||||||
def test_can_append(self, manager):
|
|
||||||
"""Test append check."""
|
|
||||||
seq = MockSequence([0] * 256) # Exactly 1 block
|
|
||||||
manager.allocate(seq)
|
|
||||||
|
|
||||||
# Can append without new block (still in same block)
|
|
||||||
seq._token_ids = [0] * 257
|
|
||||||
assert manager.can_append(seq)
|
|
||||||
|
|
||||||
def test_prepare_for_attention_noop(self, manager):
|
|
||||||
"""Test that prepare_for_attention is a no-op for GPU-only."""
|
|
||||||
seq = MockSequence([0] * 100)
|
|
||||||
manager.allocate(seq)
|
|
||||||
|
|
||||||
# Should not raise
|
|
||||||
manager.prepare_for_attention([seq], is_prefill=True)
|
|
||||||
manager.prepare_for_attention([seq], is_prefill=False)
|
|
||||||
|
|
||||||
def test_get_gpu_block_tables(self, manager):
|
|
||||||
"""Test getting GPU block tables."""
|
|
||||||
seq1 = MockSequence([0] * 300)
|
|
||||||
seq2 = MockSequence([0] * 600)
|
|
||||||
|
|
||||||
manager.allocate(seq1)
|
|
||||||
manager.allocate(seq2)
|
|
||||||
|
|
||||||
tables = manager.get_gpu_block_tables([seq1, seq2])
|
|
||||||
|
|
||||||
assert len(tables) == 2
|
|
||||||
assert tables[0] == list(seq1.block_table)
|
|
||||||
assert tables[1] == list(seq2.block_table)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGPUOnlyManagerPrefixCaching:
|
|
||||||
"""Tests for prefix caching in GPU-only manager."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def manager(self):
|
|
||||||
"""Create manager for testing."""
|
|
||||||
return GPUOnlyManager(num_blocks=32, block_size=256)
|
|
||||||
|
|
||||||
def test_prefix_cache_hit(self, manager):
|
|
||||||
"""Test that identical prefixes are cached."""
|
|
||||||
# Create two sequences with same prefix
|
|
||||||
tokens = list(range(512)) # 2 full blocks
|
|
||||||
seq1 = MockSequence(tokens)
|
|
||||||
seq2 = MockSequence(tokens)
|
|
||||||
|
|
||||||
manager.allocate(seq1)
|
|
||||||
initial_free = manager.num_free_blocks
|
|
||||||
|
|
||||||
manager.allocate(seq2)
|
|
||||||
|
|
||||||
# Second sequence should reuse cached blocks
|
|
||||||
assert seq2.num_cached_tokens >= 256 # At least first block cached
|
|
||||||
# Should use fewer new blocks
|
|
||||||
assert manager.num_free_blocks >= initial_free - 2
|
|
||||||
|
|
||||||
def test_prefix_cache_different_suffix(self, manager):
|
|
||||||
"""Test cache with same prefix but different suffix."""
|
|
||||||
prefix = list(range(256)) # 1 full block
|
|
||||||
|
|
||||||
seq1 = MockSequence(prefix + [1000, 1001])
|
|
||||||
seq2 = MockSequence(prefix + [2000, 2001])
|
|
||||||
|
|
||||||
manager.allocate(seq1)
|
|
||||||
manager.allocate(seq2)
|
|
||||||
|
|
||||||
# First block should be shared
|
|
||||||
assert seq1.block_table[0] == seq2.block_table[0]
|
|
||||||
# Second block should be different
|
|
||||||
assert seq1.block_table[1] != seq2.block_table[1]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
@@ -1,196 +1,119 @@
|
|||||||
"""Tests for CPU-GPU offload engine."""
|
"""
|
||||||
|
Test script for OffloadEngine - CPU-GPU KV cache transfer engine.
|
||||||
|
|
||||||
|
Demonstrates: ring buffer, H2D/D2H transfers, CUDA events, KV access.
|
||||||
|
"""
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Utility Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
class TestOffloadEngine:
|
def verify(tensor: torch.Tensor, expected: float, name: str) -> None:
|
||||||
"""Tests for OffloadEngine."""
|
"""Verify tensor contains expected value."""
|
||||||
|
actual = tensor.mean().item()
|
||||||
|
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
|
||||||
|
|
||||||
@pytest.fixture
|
# ============================================================
|
||||||
def engine(self):
|
# Configuration
|
||||||
"""Create a small engine for testing."""
|
# ============================================================
|
||||||
return OffloadEngine(
|
|
||||||
num_layers=2,
|
|
||||||
num_gpu_blocks=4,
|
|
||||||
num_cpu_blocks=8,
|
|
||||||
block_size=256,
|
|
||||||
num_kv_heads=4,
|
|
||||||
head_dim=64,
|
|
||||||
dtype=torch.float16,
|
|
||||||
num_streams=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_initialization(self, engine):
|
NUM_LAYERS = 4
|
||||||
"""Test engine initialization."""
|
NUM_GPU_BLOCKS = 8
|
||||||
# Check GPU cache shape
|
NUM_CPU_BLOCKS = 16
|
||||||
assert engine.k_cache_gpu.shape == (2, 4, 256, 4, 64)
|
BLOCK_SIZE = 64
|
||||||
assert engine.v_cache_gpu.shape == (2, 4, 256, 4, 64)
|
NUM_KV_HEADS = 4
|
||||||
|
HEAD_DIM = 32
|
||||||
|
|
||||||
# Check CPU cache shape
|
# ============================================================
|
||||||
assert engine.k_cache_cpu.shape == (2, 8, 256, 4, 64)
|
# Main Test Script
|
||||||
assert engine.v_cache_cpu.shape == (2, 8, 256, 4, 64)
|
# ============================================================
|
||||||
|
|
||||||
# Check pinned memory
|
# 1. Initialize
|
||||||
assert engine.k_cache_cpu.is_pinned()
|
engine = OffloadEngine(
|
||||||
assert engine.v_cache_cpu.is_pinned()
|
num_layers=NUM_LAYERS,
|
||||||
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||||
|
num_cpu_blocks=NUM_CPU_BLOCKS,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
num_kv_heads=NUM_KV_HEADS,
|
||||||
|
head_dim=HEAD_DIM,
|
||||||
|
dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
# Check gather indices
|
# 2. Ring buffer slot management
|
||||||
assert engine.gather_indices_cpu.shape == (2, 4)
|
for chunk_idx in range(12):
|
||||||
assert engine.gather_indices_gpu.shape == (2, 4)
|
write_slot = engine.get_write_slot_for_prefill(chunk_idx)
|
||||||
|
load_slots = engine.get_load_slots_for_prefill(write_slot)
|
||||||
|
|
||||||
|
print("chunk idx", chunk_idx, "write slots:", write_slot, "load slots:", load_slots)
|
||||||
|
|
||||||
|
assert write_slot == chunk_idx % engine.num_ring_slots
|
||||||
|
assert write_slot not in load_slots
|
||||||
|
|
||||||
def test_get_layer_cache(self, engine):
|
assert engine.decode_slot == 0
|
||||||
"""Test getting layer cache."""
|
assert engine.get_load_slots_for_decode() == list(range(1, NUM_GPU_BLOCKS))
|
||||||
k, v = engine.get_layer_cache(0)
|
|
||||||
assert k.shape == (4, 256, 4, 64)
|
|
||||||
assert v.shape == (4, 256, 4, 64)
|
|
||||||
assert k.device.type == "cuda"
|
|
||||||
assert v.device.type == "cuda"
|
|
||||||
|
|
||||||
def test_prefetch_and_offload(self, engine):
|
# 3. Per-slot per-layer H2D transfer
|
||||||
"""Test async prefetch and offload."""
|
engine.k_cache_cpu[0, 0].fill_(42.0)
|
||||||
# Write some data to CPU block 0
|
engine.v_cache_cpu[0, 0].fill_(42.5)
|
||||||
engine.k_cache_cpu[0, 0].fill_(1.0)
|
|
||||||
engine.v_cache_cpu[0, 0].fill_(2.0)
|
|
||||||
|
|
||||||
# Prefetch to GPU block 2
|
engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=0)
|
||||||
event = engine.prefetch_block_async(
|
engine.wait_slot_layer(slot_idx=1, layer_id=0)
|
||||||
layer_id=0,
|
|
||||||
cpu_block_id=0,
|
|
||||||
gpu_block_id=2,
|
|
||||||
)
|
|
||||||
event.synchronize()
|
|
||||||
|
|
||||||
# Verify data was copied (move GPU to CPU for comparison)
|
verify(engine.k_cache_gpu[0, 1], 42.0, "H2D K")
|
||||||
assert torch.allclose(engine.k_cache_gpu[0, 2].cpu(), engine.k_cache_cpu[0, 0])
|
verify(engine.v_cache_gpu[0, 1], 42.5, "H2D V")
|
||||||
assert torch.allclose(engine.v_cache_gpu[0, 2].cpu(), engine.v_cache_cpu[0, 0])
|
|
||||||
|
|
||||||
# Modify GPU data
|
# 4. Compute-done event (pipeline safety)
|
||||||
engine.k_cache_gpu[0, 2].fill_(3.0)
|
engine.record_slot_compute_done(slot_idx=1, layer_id=0)
|
||||||
engine.v_cache_gpu[0, 2].fill_(4.0)
|
|
||||||
|
|
||||||
# Offload to CPU block 5
|
engine.k_cache_cpu[0, 1].fill_(100.0)
|
||||||
event = engine.offload_block_async(
|
engine.v_cache_cpu[0, 1].fill_(100.5)
|
||||||
layer_id=0,
|
engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=1)
|
||||||
gpu_block_id=2,
|
engine.wait_slot_layer(slot_idx=1, layer_id=0)
|
||||||
cpu_block_id=5,
|
|
||||||
)
|
|
||||||
event.synchronize()
|
|
||||||
|
|
||||||
# Verify data was copied
|
verify(engine.k_cache_gpu[0, 1], 100.0, "Reuse K")
|
||||||
assert torch.allclose(engine.k_cache_cpu[0, 5], engine.k_cache_gpu[0, 2].cpu())
|
verify(engine.v_cache_gpu[0, 1], 100.5, "Reuse V")
|
||||||
assert torch.allclose(engine.v_cache_cpu[0, 5], engine.v_cache_gpu[0, 2].cpu())
|
|
||||||
|
|
||||||
def test_update_gather_indices(self, engine):
|
# 5. D2H offload
|
||||||
"""Test updating gather indices."""
|
engine.k_cache_gpu[1, 2].fill_(77.0)
|
||||||
# Manually set CPU data
|
engine.v_cache_gpu[1, 2].fill_(77.5)
|
||||||
for i in range(8):
|
|
||||||
engine.k_cache_cpu[0, i].fill_(float(i))
|
|
||||||
engine.v_cache_cpu[0, i].fill_(float(i + 100))
|
|
||||||
|
|
||||||
# Update indices for layer 0: (cpu_block_id, gpu_slot)
|
engine.offload_slot_to_cpu(slot_idx=2, cpu_block_id=5)
|
||||||
mappings = [(2, 0), (5, 1), (1, 2), (7, 3)]
|
engine.wait_slot_offload(slot_idx=2)
|
||||||
engine.update_gather_indices(layer_id=0, mappings=mappings)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Verify indices were set
|
verify(engine.k_cache_cpu[1, 5], 77.0, "D2H K")
|
||||||
expected = torch.tensor([2, 5, 1, 7], dtype=torch.int64)
|
verify(engine.v_cache_cpu[1, 5], 77.5, "D2H V")
|
||||||
assert torch.equal(engine.gather_indices_cpu[0], expected)
|
|
||||||
|
|
||||||
def test_gathered_h2d_layer(self, engine):
|
# 6. KV access methods
|
||||||
"""Test gathered H2D copy for a layer."""
|
k, v = engine.get_kv_for_slot(slot_idx=1, layer_id=0)
|
||||||
# Set up CPU data with known values
|
assert k.shape == (1, BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM)
|
||||||
for i in range(8):
|
|
||||||
engine.k_cache_cpu[0, i].fill_(float(i))
|
|
||||||
engine.v_cache_cpu[0, i].fill_(float(i + 100))
|
|
||||||
|
|
||||||
# Set gather indices: (cpu_block_id, gpu_slot)
|
k, v = engine.get_kv_for_slots(layer_id=0, slot_indices=[0, 1, 2])
|
||||||
# GPU slot 0 gets CPU block 3, GPU slot 1 gets CPU block 0, etc.
|
assert k.shape == (1, 3 * BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM)
|
||||||
mappings = [(3, 0), (0, 1), (7, 2), (2, 3)]
|
|
||||||
engine.update_gather_indices(layer_id=0, mappings=mappings)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Execute gathered H2D
|
engine.k_cache_gpu[0, engine.decode_slot].fill_(33.0)
|
||||||
engine.gathered_h2d_layer(layer_id=0)
|
k, v = engine.get_kv_for_decode_slot_accumulated(layer_id=0, num_tokens=10)
|
||||||
torch.cuda.synchronize()
|
assert k.shape == (1, 10, NUM_KV_HEADS, HEAD_DIM)
|
||||||
|
verify(k, 33.0, "Decode slot K")
|
||||||
|
|
||||||
# Verify: GPU slot 0 should have CPU block 3's data
|
# 7. Batch transfer
|
||||||
assert torch.allclose(engine.k_cache_gpu[0, 0],
|
cpu_blocks = [2, 3, 4]
|
||||||
torch.full_like(engine.k_cache_gpu[0, 0], 3.0))
|
gpu_slots = [3, 4, 5]
|
||||||
# GPU slot 1 should have CPU block 0's data
|
for cpu_id in cpu_blocks:
|
||||||
assert torch.allclose(engine.k_cache_gpu[0, 1],
|
engine.k_cache_cpu[0, cpu_id].fill_(50.0 + cpu_id)
|
||||||
torch.full_like(engine.k_cache_gpu[0, 1], 0.0))
|
|
||||||
# GPU slot 2 should have CPU block 7's data
|
|
||||||
assert torch.allclose(engine.k_cache_gpu[0, 2],
|
|
||||||
torch.full_like(engine.k_cache_gpu[0, 2], 7.0))
|
|
||||||
# GPU slot 3 should have CPU block 2's data
|
|
||||||
assert torch.allclose(engine.k_cache_gpu[0, 3],
|
|
||||||
torch.full_like(engine.k_cache_gpu[0, 3], 2.0))
|
|
||||||
|
|
||||||
def test_multi_layer_independence(self, engine):
|
engine.load_cpu_blocks_to_gpu_slots(layer_id=0, cpu_block_ids=cpu_blocks, gpu_slot_ids=gpu_slots)
|
||||||
"""Test that layers are independent."""
|
|
||||||
# Set different data for each layer
|
|
||||||
engine.k_cache_cpu[0, 0].fill_(1.0)
|
|
||||||
engine.k_cache_cpu[1, 0].fill_(2.0)
|
|
||||||
|
|
||||||
# Prefetch layer 0
|
for cpu_id, gpu_slot in zip(cpu_blocks, gpu_slots):
|
||||||
event = engine.prefetch_block_async(0, 0, 0)
|
verify(engine.k_cache_gpu[0, gpu_slot], 50.0 + cpu_id, f"Batch slot {gpu_slot}")
|
||||||
event.synchronize()
|
|
||||||
|
|
||||||
# Verify only layer 0 was affected
|
# 8. Gather indices (CUDA graph compatible)
|
||||||
assert torch.allclose(engine.k_cache_gpu[0, 0],
|
engine.update_gather_indices(layer_id=0, mappings=[(0, 0), (1, 1), (2, 2)])
|
||||||
torch.full_like(engine.k_cache_gpu[0, 0], 1.0))
|
assert engine.gather_indices_gpu[0, :3].tolist() == [0, 1, 2]
|
||||||
# Layer 1 should be zeros (initial state)
|
|
||||||
assert not torch.allclose(engine.k_cache_gpu[1, 0],
|
|
||||||
torch.full_like(engine.k_cache_gpu[1, 0], 2.0))
|
|
||||||
|
|
||||||
|
engine.clear_gather_indices(layer_id=0)
|
||||||
|
assert engine.gather_indices_gpu[0, 0].item() == -1
|
||||||
|
|
||||||
class TestOffloadEngineFixedAddresses:
|
print("test_offload_engine: PASSED")
|
||||||
"""Tests verifying fixed address property for CUDA Graph compatibility."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def engine(self):
|
|
||||||
"""Create engine for address tests."""
|
|
||||||
return OffloadEngine(
|
|
||||||
num_layers=2,
|
|
||||||
num_gpu_blocks=4,
|
|
||||||
num_cpu_blocks=8,
|
|
||||||
block_size=256,
|
|
||||||
num_kv_heads=4,
|
|
||||||
head_dim=64,
|
|
||||||
dtype=torch.float16,
|
|
||||||
num_streams=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_gpu_cache_address_fixed(self, engine):
|
|
||||||
"""Verify GPU cache addresses don't change."""
|
|
||||||
k_ptr_before = engine.k_cache_gpu.data_ptr()
|
|
||||||
v_ptr_before = engine.v_cache_gpu.data_ptr()
|
|
||||||
|
|
||||||
# Perform some operations - mappings is List[(cpu_block_id, gpu_slot)]
|
|
||||||
mappings = [(0, 0), (1, 1), (2, 2), (3, 3)]
|
|
||||||
engine.update_gather_indices(0, mappings)
|
|
||||||
engine.gathered_h2d_layer(0)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Addresses should be the same
|
|
||||||
assert engine.k_cache_gpu.data_ptr() == k_ptr_before
|
|
||||||
assert engine.v_cache_gpu.data_ptr() == v_ptr_before
|
|
||||||
|
|
||||||
def test_gather_indices_gpu_address_fixed(self, engine):
|
|
||||||
"""Verify gather indices GPU tensor address doesn't change."""
|
|
||||||
ptr_before = engine.gather_indices_gpu.data_ptr()
|
|
||||||
|
|
||||||
# Update indices multiple times - mappings is List[(cpu_block_id, gpu_slot)]
|
|
||||||
mappings = [(0, 0), (1, 1), (2, 2), (3, 3)]
|
|
||||||
for _ in range(10):
|
|
||||||
engine.update_gather_indices(0, mappings)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
assert engine.gather_indices_gpu.data_ptr() == ptr_before
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
|
|||||||
@@ -1,167 +0,0 @@
|
|||||||
"""Tests for eviction policies."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
|
|
||||||
from nanovllm.kvcache.policies.fifo_policy import FIFOPolicy
|
|
||||||
from nanovllm.kvcache.policies import get_policy
|
|
||||||
|
|
||||||
|
|
||||||
class TestLRUPolicy:
|
|
||||||
"""Tests for LRU eviction policy."""
|
|
||||||
|
|
||||||
def test_basic_eviction(self):
|
|
||||||
"""Test that LRU evicts least recently used block."""
|
|
||||||
policy = LRUPolicy()
|
|
||||||
|
|
||||||
# Allocate blocks 0, 1, 2 in order
|
|
||||||
policy.on_block_allocated(0, step=1)
|
|
||||||
policy.on_block_allocated(1, step=2)
|
|
||||||
policy.on_block_allocated(2, step=3)
|
|
||||||
|
|
||||||
# Access block 0 (makes it most recently used)
|
|
||||||
policy.on_block_access(0, step=4)
|
|
||||||
|
|
||||||
# Should evict block 1 (least recently used)
|
|
||||||
candidates = {0, 1, 2}
|
|
||||||
victim = policy.select_victim(candidates)
|
|
||||||
assert victim == 1, f"Expected block 1, got {victim}"
|
|
||||||
|
|
||||||
def test_access_updates_order(self):
|
|
||||||
"""Test that access updates LRU order."""
|
|
||||||
policy = LRUPolicy()
|
|
||||||
|
|
||||||
policy.on_block_allocated(0, step=1)
|
|
||||||
policy.on_block_allocated(1, step=2)
|
|
||||||
policy.on_block_allocated(2, step=3)
|
|
||||||
|
|
||||||
# Access all in reverse order
|
|
||||||
policy.on_block_access(2, step=4)
|
|
||||||
policy.on_block_access(1, step=5)
|
|
||||||
policy.on_block_access(0, step=6)
|
|
||||||
|
|
||||||
# Block 2 is now LRU (accessed earliest after allocation update)
|
|
||||||
candidates = {0, 1, 2}
|
|
||||||
victim = policy.select_victim(candidates)
|
|
||||||
assert victim == 2, f"Expected block 2, got {victim}"
|
|
||||||
|
|
||||||
def test_eviction_removes_from_tracking(self):
|
|
||||||
"""Test that evicted blocks are removed from tracking."""
|
|
||||||
policy = LRUPolicy()
|
|
||||||
|
|
||||||
policy.on_block_allocated(0, step=1)
|
|
||||||
policy.on_block_allocated(1, step=2)
|
|
||||||
|
|
||||||
policy.on_block_evicted(0)
|
|
||||||
|
|
||||||
# Only block 1 should be a candidate
|
|
||||||
candidates = {0, 1}
|
|
||||||
victim = policy.select_victim(candidates)
|
|
||||||
assert victim == 1, "Should select block 1 since 0 was evicted"
|
|
||||||
|
|
||||||
def test_batch_eviction_order(self):
|
|
||||||
"""Test get_eviction_order returns blocks in LRU order."""
|
|
||||||
policy = LRUPolicy()
|
|
||||||
|
|
||||||
for i in range(5):
|
|
||||||
policy.on_block_allocated(i, step=i)
|
|
||||||
|
|
||||||
# Access blocks 2 and 4
|
|
||||||
policy.on_block_access(2, step=10)
|
|
||||||
policy.on_block_access(4, step=11)
|
|
||||||
|
|
||||||
candidates = {0, 1, 2, 3, 4}
|
|
||||||
order = policy.get_eviction_order(candidates, count=3)
|
|
||||||
|
|
||||||
# Should be 0, 1, 3 (in that order, skipping 2 and 4 until needed)
|
|
||||||
assert order == [0, 1, 3], f"Expected [0, 1, 3], got {order}"
|
|
||||||
|
|
||||||
|
|
||||||
class TestFIFOPolicy:
|
|
||||||
"""Tests for FIFO eviction policy."""
|
|
||||||
|
|
||||||
def test_basic_eviction(self):
|
|
||||||
"""Test that FIFO evicts oldest allocated block."""
|
|
||||||
policy = FIFOPolicy()
|
|
||||||
|
|
||||||
policy.on_block_allocated(0, step=1)
|
|
||||||
policy.on_block_allocated(1, step=2)
|
|
||||||
policy.on_block_allocated(2, step=3)
|
|
||||||
|
|
||||||
# Access doesn't change FIFO order
|
|
||||||
policy.on_block_access(0, step=4)
|
|
||||||
|
|
||||||
candidates = {0, 1, 2}
|
|
||||||
victim = policy.select_victim(candidates)
|
|
||||||
assert victim == 0, f"Expected block 0 (oldest), got {victim}"
|
|
||||||
|
|
||||||
def test_access_does_not_update_order(self):
|
|
||||||
"""Test that FIFO ignores access patterns."""
|
|
||||||
policy = FIFOPolicy()
|
|
||||||
|
|
||||||
policy.on_block_allocated(0, step=1)
|
|
||||||
policy.on_block_allocated(1, step=2)
|
|
||||||
policy.on_block_allocated(2, step=3)
|
|
||||||
|
|
||||||
# Multiple accesses to block 0
|
|
||||||
for i in range(10):
|
|
||||||
policy.on_block_access(0, step=10 + i)
|
|
||||||
|
|
||||||
# Block 0 should still be evicted first (FIFO order)
|
|
||||||
candidates = {0, 1, 2}
|
|
||||||
victim = policy.select_victim(candidates)
|
|
||||||
assert victim == 0, f"Expected block 0, got {victim}"
|
|
||||||
|
|
||||||
def test_prefetch_resets_order(self):
|
|
||||||
"""Test that prefetch moves block to end of queue."""
|
|
||||||
policy = FIFOPolicy()
|
|
||||||
|
|
||||||
policy.on_block_allocated(0, step=1)
|
|
||||||
policy.on_block_allocated(1, step=2)
|
|
||||||
policy.on_block_allocated(2, step=3)
|
|
||||||
|
|
||||||
# Prefetch block 0 (moves to end)
|
|
||||||
policy.on_block_prefetched(0, step=4)
|
|
||||||
|
|
||||||
candidates = {0, 1, 2}
|
|
||||||
victim = policy.select_victim(candidates)
|
|
||||||
assert victim == 1, f"Expected block 1 (now oldest), got {victim}"
|
|
||||||
|
|
||||||
def test_batch_eviction_order(self):
|
|
||||||
"""Test get_eviction_order returns blocks in FIFO order."""
|
|
||||||
policy = FIFOPolicy()
|
|
||||||
|
|
||||||
for i in range(5):
|
|
||||||
policy.on_block_allocated(i, step=i)
|
|
||||||
|
|
||||||
candidates = {0, 1, 2, 3, 4}
|
|
||||||
order = policy.get_eviction_order(candidates, count=3)
|
|
||||||
|
|
||||||
assert order == [0, 1, 2], f"Expected [0, 1, 2], got {order}"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetPolicy:
|
|
||||||
"""Tests for policy factory function."""
|
|
||||||
|
|
||||||
def test_get_lru(self):
|
|
||||||
"""Test getting LRU policy by name."""
|
|
||||||
policy = get_policy("lru")
|
|
||||||
assert isinstance(policy, LRUPolicy)
|
|
||||||
|
|
||||||
def test_get_fifo(self):
|
|
||||||
"""Test getting FIFO policy by name."""
|
|
||||||
policy = get_policy("fifo")
|
|
||||||
assert isinstance(policy, FIFOPolicy)
|
|
||||||
|
|
||||||
def test_get_by_class_path(self):
|
|
||||||
"""Test getting policy by full class path."""
|
|
||||||
policy = get_policy("nanovllm.kvcache.policies.lru_policy.LRUPolicy")
|
|
||||||
assert isinstance(policy, LRUPolicy)
|
|
||||||
|
|
||||||
def test_invalid_policy_name(self):
|
|
||||||
"""Test that invalid policy name raises error."""
|
|
||||||
with pytest.raises((ValueError, ImportError)):
|
|
||||||
get_policy("invalid_policy")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
@@ -1,252 +0,0 @@
|
|||||||
"""
|
|
||||||
Test sparse attention policies.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
CUDA_VISIBLE_DEVICES=4,5 python tests/test_sparse_policy.py [policy_name]
|
|
||||||
|
|
||||||
Policy names: full, vertical_slash, streaming_llm, quest
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# Test the sparse policy implementations
|
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
|
||||||
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
|
|
||||||
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
|
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
|
||||||
|
|
||||||
|
|
||||||
def test_full_attention_policy():
|
|
||||||
"""Test FullAttentionPolicy returns all blocks."""
|
|
||||||
print("\n=== Testing FullAttentionPolicy ===")
|
|
||||||
policy = FullAttentionPolicy()
|
|
||||||
|
|
||||||
available_blocks = list(range(10))
|
|
||||||
ctx = PolicyContext(
|
|
||||||
query_chunk_idx=5,
|
|
||||||
num_query_chunks=10,
|
|
||||||
layer_id=0,
|
|
||||||
query=None,
|
|
||||||
is_prefill=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
assert selected == available_blocks, f"Expected all blocks, got {selected}"
|
|
||||||
print(f" Prefill: input={available_blocks}, selected={selected} [PASS]")
|
|
||||||
|
|
||||||
# Test decode
|
|
||||||
ctx.is_prefill = False
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
assert selected == available_blocks, f"Expected all blocks, got {selected}"
|
|
||||||
print(f" Decode: input={available_blocks}, selected={selected} [PASS]")
|
|
||||||
|
|
||||||
|
|
||||||
def test_vertical_slash_policy():
|
|
||||||
"""Test VerticalSlashPolicy selects sink + local window."""
|
|
||||||
print("\n=== Testing VerticalSlashPolicy ===")
|
|
||||||
config = VerticalSlashConfig(
|
|
||||||
num_sink_blocks=2,
|
|
||||||
local_window_blocks=3,
|
|
||||||
threshold_blocks=4,
|
|
||||||
)
|
|
||||||
policy = VerticalSlashPolicy(config)
|
|
||||||
|
|
||||||
# Test with 10 blocks, chunk 7 (should select sink[0,1] + local[4,5,6])
|
|
||||||
available_blocks = list(range(10))
|
|
||||||
ctx = PolicyContext(
|
|
||||||
query_chunk_idx=7,
|
|
||||||
num_query_chunks=10,
|
|
||||||
layer_id=0,
|
|
||||||
query=None,
|
|
||||||
is_prefill=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
expected = [0, 1, 4, 5, 6] # sink + local window before chunk 7
|
|
||||||
assert selected == expected, f"Expected {expected}, got {selected}"
|
|
||||||
print(f" Prefill chunk 7: input={available_blocks}, selected={selected} [PASS]")
|
|
||||||
|
|
||||||
# Test with small number of blocks (below threshold)
|
|
||||||
available_blocks = [0, 1, 2]
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
assert selected == [0, 1, 2], f"Expected all blocks for small input, got {selected}"
|
|
||||||
print(f" Below threshold: input={[0,1,2]}, selected={selected} [PASS]")
|
|
||||||
|
|
||||||
# Test decode (local window is last M blocks)
|
|
||||||
available_blocks = list(range(10))
|
|
||||||
ctx.is_prefill = False
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
expected = [0, 1, 7, 8, 9] # sink + last 3 blocks
|
|
||||||
assert selected == expected, f"Expected {expected}, got {selected}"
|
|
||||||
print(f" Decode: input={available_blocks}, selected={selected} [PASS]")
|
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_llm_policy():
|
|
||||||
"""Test StreamingLLMPolicy selects sink + recent only."""
|
|
||||||
print("\n=== Testing StreamingLLMPolicy ===")
|
|
||||||
config = StreamingLLMConfig(
|
|
||||||
num_sink_blocks=1,
|
|
||||||
num_recent_blocks=2,
|
|
||||||
)
|
|
||||||
policy = StreamingLLMPolicy(config)
|
|
||||||
|
|
||||||
available_blocks = list(range(10))
|
|
||||||
ctx = PolicyContext(
|
|
||||||
query_chunk_idx=0,
|
|
||||||
num_query_chunks=1,
|
|
||||||
layer_id=0,
|
|
||||||
query=None,
|
|
||||||
is_prefill=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
expected = [0, 8, 9] # sink[0] + recent[8,9]
|
|
||||||
assert selected == expected, f"Expected {expected}, got {selected}"
|
|
||||||
print(f" 10 blocks: selected={selected} [PASS]")
|
|
||||||
|
|
||||||
# Test with 3 blocks (all fit in sink+recent)
|
|
||||||
available_blocks = [0, 1, 2]
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
assert selected == [0, 1, 2], f"Expected all blocks, got {selected}"
|
|
||||||
print(f" 3 blocks: selected={selected} [PASS]")
|
|
||||||
|
|
||||||
|
|
||||||
def test_quest_policy():
|
|
||||||
"""Test QuestPolicy with mock metadata."""
|
|
||||||
print("\n=== Testing QuestPolicy ===")
|
|
||||||
|
|
||||||
# Create metadata manager
|
|
||||||
num_blocks = 10
|
|
||||||
num_layers = 2
|
|
||||||
num_kv_heads = 4
|
|
||||||
head_dim = 64
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
metadata = BlockMetadataManager(
|
|
||||||
num_blocks=num_blocks,
|
|
||||||
num_layers=num_layers,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate offloading blocks with different key patterns
|
|
||||||
# Blocks 0, 5, 9 will have high scores (keys aligned with query)
|
|
||||||
for block_id in range(num_blocks):
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
k_cache = torch.randn(100, num_kv_heads, head_dim) # 100 tokens per block
|
|
||||||
if block_id in [0, 5, 9]:
|
|
||||||
# Make these blocks have keys that score high
|
|
||||||
k_cache = k_cache.abs() # All positive
|
|
||||||
else:
|
|
||||||
k_cache = -k_cache.abs() # All negative
|
|
||||||
metadata.update_metadata(block_id, layer_id, k_cache, 100)
|
|
||||||
|
|
||||||
config = QuestConfig(
|
|
||||||
topk_blocks=4,
|
|
||||||
threshold_blocks=3,
|
|
||||||
)
|
|
||||||
policy = QuestPolicy(config, metadata)
|
|
||||||
|
|
||||||
available_blocks = list(range(10))
|
|
||||||
|
|
||||||
# Create query that scores high with positive keys
|
|
||||||
query = torch.ones(1, num_kv_heads, head_dim, device='cuda')
|
|
||||||
|
|
||||||
ctx = PolicyContext(
|
|
||||||
query_chunk_idx=0,
|
|
||||||
num_query_chunks=1,
|
|
||||||
layer_id=0,
|
|
||||||
query=query,
|
|
||||||
is_prefill=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
print(f" Top-4 selection: input={available_blocks}, selected={selected}")
|
|
||||||
|
|
||||||
# High-scoring blocks [0, 5, 9] should be in selection
|
|
||||||
for expected_block in [0, 5, 9]:
|
|
||||||
assert expected_block in selected, f"Expected block {expected_block} in selection"
|
|
||||||
print(f" High-score blocks [0, 5, 9] in selection [PASS]")
|
|
||||||
|
|
||||||
# Test below threshold (should return all)
|
|
||||||
available_blocks = [0, 1, 2]
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
assert selected == [0, 1, 2], f"Expected all blocks below threshold, got {selected}"
|
|
||||||
print(f" Below threshold: selected={selected} [PASS]")
|
|
||||||
|
|
||||||
# Test without query (should return all)
|
|
||||||
ctx.query = None
|
|
||||||
available_blocks = list(range(10))
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
assert selected == available_blocks, f"Expected all blocks without query, got {selected}"
|
|
||||||
print(f" No query: selected all [PASS]")
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_policy():
|
|
||||||
"""Test creating a custom policy."""
|
|
||||||
print("\n=== Testing Custom Policy ===")
|
|
||||||
|
|
||||||
class EveryOtherPolicy(SparsePolicy):
|
|
||||||
"""Select every other block."""
|
|
||||||
|
|
||||||
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
|
|
||||||
return [available_blocks[i] for i in range(0, len(available_blocks), 2)]
|
|
||||||
|
|
||||||
policy = EveryOtherPolicy()
|
|
||||||
available_blocks = list(range(10))
|
|
||||||
ctx = PolicyContext(
|
|
||||||
query_chunk_idx=0,
|
|
||||||
num_query_chunks=1,
|
|
||||||
layer_id=0,
|
|
||||||
query=None,
|
|
||||||
is_prefill=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
selected = policy.select_blocks(available_blocks, ctx)
|
|
||||||
expected = [0, 2, 4, 6, 8]
|
|
||||||
assert selected == expected, f"Expected {expected}, got {selected}"
|
|
||||||
print(f" Every other: input={available_blocks}, selected={selected} [PASS]")
|
|
||||||
|
|
||||||
|
|
||||||
def run_all_tests():
|
|
||||||
"""Run all policy tests."""
|
|
||||||
print("Running Sparse Policy Tests...")
|
|
||||||
|
|
||||||
test_full_attention_policy()
|
|
||||||
test_vertical_slash_policy()
|
|
||||||
test_streaming_llm_policy()
|
|
||||||
test_quest_policy()
|
|
||||||
test_custom_policy()
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("All tests passed!")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) > 1:
|
|
||||||
policy_name = sys.argv[1].lower()
|
|
||||||
if policy_name == "full":
|
|
||||||
test_full_attention_policy()
|
|
||||||
elif policy_name == "vertical_slash":
|
|
||||||
test_vertical_slash_policy()
|
|
||||||
elif policy_name == "streaming_llm":
|
|
||||||
test_streaming_llm_policy()
|
|
||||||
elif policy_name == "quest":
|
|
||||||
test_quest_policy()
|
|
||||||
elif policy_name == "custom":
|
|
||||||
test_custom_policy()
|
|
||||||
else:
|
|
||||||
print(f"Unknown policy: {policy_name}")
|
|
||||||
print("Available: full, vertical_slash, streaming_llm, quest, custom")
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
run_all_tests()
|
|
||||||
Reference in New Issue
Block a user