222 lines
6.8 KiB
Python
222 lines
6.8 KiB
Python
"""
|
|
Test Attention layer with KV cache offload in isolation.
|
|
|
|
This test demonstrates how to use Attention + HybridKVCacheManager directly
|
|
without requiring full LLMEngine/ModelRunner setup.
|
|
"""
|
|
|
|
import torch
|
|
from nanovllm.layers.attention import Attention
|
|
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
|
from nanovllm.engine.sequence import Sequence
|
|
from nanovllm.utils.context import set_context, reset_context
|
|
|
|
|
|
# ============================================================
|
|
# Configuration
|
|
# ============================================================
|
|
|
|
NUM_LAYERS = 8 # Multi-layer for realistic profiling
|
|
NUM_HEADS = 8
|
|
NUM_KV_HEADS = 8
|
|
HEAD_DIM = 64
|
|
BLOCK_SIZE = 1024 # tokens per block
|
|
CHUNK_SIZE = 1024 # tokens per chunk (same as block for simplicity)
|
|
|
|
NUM_GPU_SLOTS = 4
|
|
NUM_CPU_BLOCKS = 16
|
|
|
|
DTYPE = torch.float16
|
|
DEVICE = "cuda"
|
|
|
|
|
|
# ============================================================
|
|
# Setup: Create Manager and Attention Layers
|
|
# ============================================================
|
|
|
|
def create_manager():
|
|
"""Create and initialize HybridKVCacheManager with OffloadEngine."""
|
|
manager = HybridKVCacheManager(
|
|
num_gpu_slots=NUM_GPU_SLOTS,
|
|
num_cpu_blocks=NUM_CPU_BLOCKS,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
|
|
# Initialize offload engine (this creates k_cache_gpu/cpu, v_cache_gpu/cpu)
|
|
manager.allocate_cache(
|
|
num_layers=NUM_LAYERS,
|
|
num_kv_heads=NUM_KV_HEADS,
|
|
head_dim=HEAD_DIM,
|
|
dtype=DTYPE,
|
|
)
|
|
|
|
return manager
|
|
|
|
|
|
def create_attention_layers(manager):
|
|
"""Create attention layers and bind KV cache."""
|
|
layers = []
|
|
for layer_id in range(NUM_LAYERS):
|
|
attn = Attention(
|
|
num_heads=NUM_HEADS,
|
|
head_dim=HEAD_DIM,
|
|
scale=HEAD_DIM ** -0.5,
|
|
num_kv_heads=NUM_KV_HEADS,
|
|
)
|
|
attn.layer_id = layer_id
|
|
|
|
# Bind KV cache from manager
|
|
k_cache, v_cache = manager.get_layer_cache(layer_id)
|
|
attn.k_cache = k_cache
|
|
attn.v_cache = v_cache
|
|
|
|
layers.append(attn.to(DEVICE))
|
|
|
|
return layers
|
|
|
|
|
|
def create_test_sequence(manager, num_chunks=3):
|
|
"""Create a test sequence and allocate blocks."""
|
|
total_tokens = num_chunks * CHUNK_SIZE
|
|
|
|
# Sequence only takes token_ids
|
|
seq = Sequence(token_ids=list(range(total_tokens)))
|
|
|
|
# Set block_size for this test
|
|
seq.block_size = BLOCK_SIZE
|
|
|
|
# Allocate blocks (will be on CPU in CPU-primary mode)
|
|
manager.allocate(seq)
|
|
|
|
return seq
|
|
|
|
|
|
# ============================================================
|
|
# Chunked Prefill Simulation
|
|
# ============================================================
|
|
|
|
def simulate_chunk_forward(
|
|
layers,
|
|
manager,
|
|
seq,
|
|
chunk_idx,
|
|
chunk_size,
|
|
):
|
|
"""
|
|
Simulate forward pass for one chunk through all layers.
|
|
|
|
Returns:
|
|
output: Final layer attention output
|
|
"""
|
|
# Generate random Q, K, V for this chunk
|
|
hidden = torch.randn(chunk_size, NUM_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
|
|
|
|
# Build slot_mapping: maps token positions to GPU slots
|
|
write_slot = manager.offload_engine.get_write_slot_for_prefill(chunk_idx)
|
|
slot_mapping = torch.full((chunk_size,), write_slot * BLOCK_SIZE, dtype=torch.long, device=DEVICE)
|
|
slot_mapping += torch.arange(chunk_size, device=DEVICE)
|
|
|
|
# Build cu_seqlens for flash attention
|
|
cu_seqlens = torch.tensor([0, chunk_size], dtype=torch.int32, device=DEVICE)
|
|
|
|
# Set context for this chunk
|
|
set_context(
|
|
is_prefill=True,
|
|
is_chunked_prefill=True,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=chunk_size,
|
|
max_seqlen_k=chunk_size,
|
|
slot_mapping=slot_mapping,
|
|
kvcache_manager=manager,
|
|
chunked_seq=seq,
|
|
current_chunk_idx=chunk_idx,
|
|
)
|
|
|
|
# Forward through all layers
|
|
output = hidden
|
|
for layer in layers:
|
|
k = torch.randn(chunk_size, NUM_KV_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
|
|
v = torch.randn(chunk_size, NUM_KV_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
|
|
output = layer.forward(output, k, v)
|
|
|
|
# Offload current chunk to CPU
|
|
logical_id = seq.block_table[chunk_idx]
|
|
cpu_block_id = manager.logical_blocks[logical_id].cpu_block_id
|
|
manager.offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
|
|
manager.prefilled_blocks.add(logical_id)
|
|
|
|
return output
|
|
|
|
|
|
# ============================================================
|
|
# Main Test
|
|
# ============================================================
|
|
|
|
print("=" * 60)
|
|
print("Test: Attention Layer with KV Cache Offload")
|
|
print("=" * 60)
|
|
|
|
# 1. Setup
|
|
print("\n[1] Creating manager and attention layers...")
|
|
manager = create_manager()
|
|
layers = create_attention_layers(manager)
|
|
print(f" - Manager: {NUM_GPU_SLOTS} GPU slots, {NUM_CPU_BLOCKS} CPU blocks")
|
|
print(f" - Layers: {NUM_LAYERS} layers, {NUM_HEADS} heads, {HEAD_DIM} head_dim")
|
|
print(f" - OffloadEngine initialized: {manager.offload_engine is not None}")
|
|
|
|
# 2. Setup
|
|
print("\n[2] Test configuration...")
|
|
NUM_CHUNKS = NUM_CPU_BLOCKS # Use all CPU blocks
|
|
print(f" - Total tokens: {NUM_CHUNKS * CHUNK_SIZE}")
|
|
print(f" - Chunks: {NUM_CHUNKS}")
|
|
|
|
# 3. Warmup runs
|
|
print(f"\n[3] Warmup runs (3 iterations)...")
|
|
for warmup_iter in range(3):
|
|
manager.prefilled_blocks.clear()
|
|
seq = create_test_sequence(manager, num_chunks=NUM_CHUNKS)
|
|
|
|
for chunk_idx in range(NUM_CHUNKS):
|
|
write_slot = manager.offload_engine.get_write_slot_for_prefill(chunk_idx)
|
|
output = simulate_chunk_forward(layers, manager, seq, chunk_idx, CHUNK_SIZE)
|
|
|
|
manager.deallocate(seq)
|
|
print(f" - Warmup {warmup_iter + 1}/3 completed")
|
|
|
|
# 4. Benchmark runs
|
|
print(f"\n[4] Benchmark runs (10 iterations)...")
|
|
for bench_iter in range(10):
|
|
manager.prefilled_blocks.clear()
|
|
seq = create_test_sequence(manager, num_chunks=NUM_CHUNKS)
|
|
|
|
for chunk_idx in range(NUM_CHUNKS):
|
|
write_slot = manager.offload_engine.get_write_slot_for_prefill(chunk_idx)
|
|
load_slots = manager.offload_engine.get_load_slots_for_prefill(write_slot)
|
|
output = simulate_chunk_forward(layers, manager, seq, chunk_idx, CHUNK_SIZE)
|
|
|
|
manager.deallocate(seq)
|
|
print(f" - Iteration {bench_iter + 1}/10 completed")
|
|
|
|
# 5. Verify results (using last iteration's seq)
|
|
print("\n[5] Verifying ring buffer and offload...")
|
|
for chunk_idx in range(NUM_CHUNKS):
|
|
expected_slot = chunk_idx % NUM_GPU_SLOTS
|
|
actual_slot = manager.offload_engine.get_write_slot_for_prefill(chunk_idx)
|
|
assert actual_slot == expected_slot, f"Chunk {chunk_idx}: expected slot {expected_slot}, got {actual_slot}"
|
|
|
|
cpu_block_table = manager.get_prefilled_cpu_blocks(seq)
|
|
assert cpu_block_table == seq.block_table[:NUM_CHUNKS], "CPU block table mismatch"
|
|
print(" - Ring buffer cycling verified ✓")
|
|
print(" - CPU offload verified ✓")
|
|
|
|
# Cleanup
|
|
manager.deallocate(seq)
|
|
|
|
# Cleanup
|
|
reset_context()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("test_attention_offload: PASSED")
|
|
print("=" * 60)
|