Files
nano-vllm/tests/test_offload_engine.py

197 lines
7.0 KiB
Python

"""Tests for CPU-GPU offload engine."""
import pytest
import torch
from nanovllm.kvcache.offload_engine import OffloadEngine
class TestOffloadEngine:
"""Tests for OffloadEngine."""
@pytest.fixture
def engine(self):
"""Create a small engine for testing."""
return OffloadEngine(
num_layers=2,
num_gpu_blocks=4,
num_cpu_blocks=8,
block_size=256,
num_kv_heads=4,
head_dim=64,
dtype=torch.float16,
num_streams=2,
)
def test_initialization(self, engine):
"""Test engine initialization."""
# Check GPU cache shape
assert engine.k_cache_gpu.shape == (2, 4, 256, 4, 64)
assert engine.v_cache_gpu.shape == (2, 4, 256, 4, 64)
# Check CPU cache shape
assert engine.k_cache_cpu.shape == (2, 8, 256, 4, 64)
assert engine.v_cache_cpu.shape == (2, 8, 256, 4, 64)
# Check pinned memory
assert engine.k_cache_cpu.is_pinned()
assert engine.v_cache_cpu.is_pinned()
# Check gather indices
assert engine.gather_indices_cpu.shape == (2, 4)
assert engine.gather_indices_gpu.shape == (2, 4)
def test_get_layer_cache(self, engine):
"""Test getting layer cache."""
k, v = engine.get_layer_cache(0)
assert k.shape == (4, 256, 4, 64)
assert v.shape == (4, 256, 4, 64)
assert k.device.type == "cuda"
assert v.device.type == "cuda"
def test_prefetch_and_offload(self, engine):
"""Test async prefetch and offload."""
# Write some data to CPU block 0
engine.k_cache_cpu[0, 0].fill_(1.0)
engine.v_cache_cpu[0, 0].fill_(2.0)
# Prefetch to GPU block 2
event = engine.prefetch_block_async(
layer_id=0,
cpu_block_id=0,
gpu_block_id=2,
)
event.synchronize()
# Verify data was copied (move GPU to CPU for comparison)
assert torch.allclose(engine.k_cache_gpu[0, 2].cpu(), engine.k_cache_cpu[0, 0])
assert torch.allclose(engine.v_cache_gpu[0, 2].cpu(), engine.v_cache_cpu[0, 0])
# Modify GPU data
engine.k_cache_gpu[0, 2].fill_(3.0)
engine.v_cache_gpu[0, 2].fill_(4.0)
# Offload to CPU block 5
event = engine.offload_block_async(
layer_id=0,
gpu_block_id=2,
cpu_block_id=5,
)
event.synchronize()
# Verify data was copied
assert torch.allclose(engine.k_cache_cpu[0, 5], engine.k_cache_gpu[0, 2].cpu())
assert torch.allclose(engine.v_cache_cpu[0, 5], engine.v_cache_gpu[0, 2].cpu())
def test_update_gather_indices(self, engine):
"""Test updating gather indices."""
# Manually set CPU data
for i in range(8):
engine.k_cache_cpu[0, i].fill_(float(i))
engine.v_cache_cpu[0, i].fill_(float(i + 100))
# Update indices for layer 0: (cpu_block_id, gpu_slot)
mappings = [(2, 0), (5, 1), (1, 2), (7, 3)]
engine.update_gather_indices(layer_id=0, mappings=mappings)
torch.cuda.synchronize()
# Verify indices were set
expected = torch.tensor([2, 5, 1, 7], dtype=torch.int64)
assert torch.equal(engine.gather_indices_cpu[0], expected)
def test_gathered_h2d_layer(self, engine):
"""Test gathered H2D copy for a layer."""
# Set up CPU data with known values
for i in range(8):
engine.k_cache_cpu[0, i].fill_(float(i))
engine.v_cache_cpu[0, i].fill_(float(i + 100))
# Set gather indices: (cpu_block_id, gpu_slot)
# GPU slot 0 gets CPU block 3, GPU slot 1 gets CPU block 0, etc.
mappings = [(3, 0), (0, 1), (7, 2), (2, 3)]
engine.update_gather_indices(layer_id=0, mappings=mappings)
torch.cuda.synchronize()
# Execute gathered H2D
engine.gathered_h2d_layer(layer_id=0)
torch.cuda.synchronize()
# Verify: GPU slot 0 should have CPU block 3's data
assert torch.allclose(engine.k_cache_gpu[0, 0],
torch.full_like(engine.k_cache_gpu[0, 0], 3.0))
# GPU slot 1 should have CPU block 0's data
assert torch.allclose(engine.k_cache_gpu[0, 1],
torch.full_like(engine.k_cache_gpu[0, 1], 0.0))
# GPU slot 2 should have CPU block 7's data
assert torch.allclose(engine.k_cache_gpu[0, 2],
torch.full_like(engine.k_cache_gpu[0, 2], 7.0))
# GPU slot 3 should have CPU block 2's data
assert torch.allclose(engine.k_cache_gpu[0, 3],
torch.full_like(engine.k_cache_gpu[0, 3], 2.0))
def test_multi_layer_independence(self, engine):
"""Test that layers are independent."""
# Set different data for each layer
engine.k_cache_cpu[0, 0].fill_(1.0)
engine.k_cache_cpu[1, 0].fill_(2.0)
# Prefetch layer 0
event = engine.prefetch_block_async(0, 0, 0)
event.synchronize()
# Verify only layer 0 was affected
assert torch.allclose(engine.k_cache_gpu[0, 0],
torch.full_like(engine.k_cache_gpu[0, 0], 1.0))
# Layer 1 should be zeros (initial state)
assert not torch.allclose(engine.k_cache_gpu[1, 0],
torch.full_like(engine.k_cache_gpu[1, 0], 2.0))
class TestOffloadEngineFixedAddresses:
"""Tests verifying fixed address property for CUDA Graph compatibility."""
@pytest.fixture
def engine(self):
"""Create engine for address tests."""
return OffloadEngine(
num_layers=2,
num_gpu_blocks=4,
num_cpu_blocks=8,
block_size=256,
num_kv_heads=4,
head_dim=64,
dtype=torch.float16,
num_streams=2,
)
def test_gpu_cache_address_fixed(self, engine):
"""Verify GPU cache addresses don't change."""
k_ptr_before = engine.k_cache_gpu.data_ptr()
v_ptr_before = engine.v_cache_gpu.data_ptr()
# Perform some operations - mappings is List[(cpu_block_id, gpu_slot)]
mappings = [(0, 0), (1, 1), (2, 2), (3, 3)]
engine.update_gather_indices(0, mappings)
engine.gathered_h2d_layer(0)
torch.cuda.synchronize()
# Addresses should be the same
assert engine.k_cache_gpu.data_ptr() == k_ptr_before
assert engine.v_cache_gpu.data_ptr() == v_ptr_before
def test_gather_indices_gpu_address_fixed(self, engine):
"""Verify gather indices GPU tensor address doesn't change."""
ptr_before = engine.gather_indices_gpu.data_ptr()
# Update indices multiple times - mappings is List[(cpu_block_id, gpu_slot)]
mappings = [(0, 0), (1, 1), (2, 2), (3, 3)]
for _ in range(10):
engine.update_gather_indices(0, mappings)
torch.cuda.synchronize()
assert engine.gather_indices_gpu.data_ptr() == ptr_before
if __name__ == "__main__":
pytest.main([__file__, "-v"])