Files
nano-vllm/tests/test_offload_correctness.py
2025-12-30 00:31:48 +08:00

574 lines
18 KiB
Python

"""
Correctness test for chunked attention with CPU offload.
Validates that the offload pipeline (CPU -> GPU transfer + chunked attention)
produces the same result as direct GPU computation.
Test scenario:
1. Generate Q, K, V data
2. Reference: Compute full causal attention on GPU
3. Offload: Store K, V in CPU cache, load via pipeline, compute chunked attention
4. Compare results
This test is designed to identify bugs in:
- CPU <-> GPU data transfer (sgDMA)
- Ring buffer slot management
- N-way pipeline ordering
- Triton merge kernel correctness
"""
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# ============================================================
# Configuration
# ============================================================
NUM_LAYERS = 4
NUM_HEADS = 8
NUM_KV_HEADS = 8
HEAD_DIM = 64
BLOCK_SIZE = 256 # Smaller for faster testing
DTYPE = torch.bfloat16
DEVICE = "cuda"
# ============================================================
# Reference Implementation (GPU only, no offload)
# ============================================================
def compute_reference_causal(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Compute reference causal attention using flash_attn_func.
Args:
q, k, v: [batch, seqlen, nheads, headdim]
Returns:
out: [batch, seqlen, nheads, headdim]
"""
return flash_attn_func(q, k, v, causal=True)
def compute_reference_chunked(
q_chunks: list,
kv_chunks: list,
scale: float,
) -> torch.Tensor:
"""
Compute chunked prefill attention directly on GPU (no offload).
This is the "gold standard" for chunked attention correctness.
Args:
q_chunks: List of [batch, chunk_size, nheads, headdim]
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
scale: Softmax scale
Returns:
out: [batch, total_seqlen, nheads, headdim]
"""
out_chunks = []
for chunk_idx, q_chunk in enumerate(q_chunks):
o_acc, lse_acc = None, None
# Attend to all previous chunks (no causal mask)
for i in range(chunk_idx):
k_chunk, v_chunk = kv_chunks[i]
chunk_o, chunk_lse = flash_attn_with_lse(
q_chunk, k_chunk, v_chunk,
softmax_scale=scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = chunk_o, chunk_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)
# Attend to current chunk (with causal mask)
k_chunk, v_chunk = kv_chunks[chunk_idx]
current_o, current_lse = flash_attn_with_lse(
q_chunk, k_chunk, v_chunk,
softmax_scale=scale,
causal=True,
)
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
out_chunks.append(final_o)
return torch.cat(out_chunks, dim=1)
# ============================================================
# Offload Implementation
# ============================================================
def create_manager(num_gpu_slots: int, num_cpu_blocks: int):
"""Create HybridKVCacheManager with specified configuration."""
manager = HybridKVCacheManager(
num_gpu_slots=num_gpu_slots,
num_cpu_blocks=num_cpu_blocks,
block_size=BLOCK_SIZE,
)
manager.allocate_cache(
num_layers=NUM_LAYERS,
num_kv_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
dtype=DTYPE,
)
return manager
def store_kv_to_cpu_cache(manager, kv_chunks: list, layer_id: int):
"""
Store K, V chunks to CPU cache.
Args:
manager: HybridKVCacheManager
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
layer_id: Layer index
Returns:
cpu_block_ids: List of CPU block IDs
"""
offload_engine = manager.offload_engine
cpu_block_ids = []
for block_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
# k_chunk, v_chunk: [batch, chunk_size, nheads, headdim]
# CPU cache layout: [num_layers, num_blocks, block_size, nheads, headdim]
k_data = k_chunk.squeeze(0) # [chunk_size, nheads, headdim]
v_data = v_chunk.squeeze(0)
offload_engine.k_cache_cpu[layer_id, block_idx, :k_data.shape[0]].copy_(k_data)
offload_engine.v_cache_cpu[layer_id, block_idx, :v_data.shape[0]].copy_(v_data)
cpu_block_ids.append(block_idx)
return cpu_block_ids
def compute_offload_chunked_single_layer(
manager,
q_chunks: list,
cpu_block_ids: list,
layer_id: int,
scale: float,
) -> torch.Tensor:
"""
Compute chunked attention for a single layer using offload pipeline.
This mimics the behavior of Attention._ring_buffer_pipeline_load().
Args:
manager: HybridKVCacheManager
q_chunks: List of [batch, chunk_size, nheads, headdim]
cpu_block_ids: List of CPU block IDs containing K, V data
layer_id: Layer index
scale: Softmax scale
Returns:
out: [batch, total_seqlen, nheads, headdim]
"""
offload_engine = manager.offload_engine
out_chunks = []
for chunk_idx, q_chunk in enumerate(q_chunks):
# CPU blocks to load: all blocks before current chunk
blocks_to_load = cpu_block_ids[:chunk_idx]
# Get slots for this chunk
write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
# Load and compute attention for previous chunks
o_acc, lse_acc = None, None
if len(blocks_to_load) > 0 and len(load_slots) > 0:
o_acc, lse_acc = _pipeline_load_and_compute(
offload_engine,
q_chunk,
blocks_to_load,
load_slots,
layer_id,
scale,
)
# Current chunk's K, V (load from CPU to GPU slot)
current_cpu_block = cpu_block_ids[chunk_idx]
offload_engine.load_to_slot_layer(write_slot, layer_id, current_cpu_block)
offload_engine.wait_slot_layer(write_slot, layer_id)
current_k, current_v = offload_engine.get_kv_for_slot(write_slot, layer_id)
# Compute attention with causal mask
current_o, current_lse = flash_attn_with_lse(
q_chunk, current_k, current_v,
softmax_scale=scale,
causal=True,
)
# Merge
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
out_chunks.append(final_o)
return torch.cat(out_chunks, dim=1)
def _pipeline_load_and_compute(
offload_engine,
q_chunk: torch.Tensor,
cpu_block_table: list,
load_slots: list,
layer_id: int,
scale: float,
):
"""
Pipeline loading from CPU and computing attention.
Mirrors Attention._ring_buffer_pipeline_load() logic.
"""
num_blocks = len(cpu_block_table)
num_slots = len(load_slots)
o_acc, lse_acc = None, None
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
# Phase 2: Main loop
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
# Wait for transfer
offload_engine.wait_slot_layer(current_slot, layer_id)
# Compute on dedicated stream
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_chunk, prev_k, prev_v,
softmax_scale=scale,
causal=False,
)
offload_engine.record_slot_compute_done(current_slot, layer_id)
# Start next transfer
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(
current_slot, layer_id, cpu_block_table[next_block_idx]
)
# Merge
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Sync compute stream
compute_stream.synchronize()
return o_acc, lse_acc
# ============================================================
# Test Runner
# ============================================================
def run_correctness_test(
num_chunks: int,
num_gpu_slots: int,
verbose: bool = True,
) -> tuple[bool, float, float]:
"""
Run a single correctness test.
Args:
num_chunks: Number of chunks (= number of CPU blocks)
num_gpu_slots: Number of GPU ring buffer slots
verbose: Print detailed info
Returns:
(passed, max_diff, mean_diff)
"""
torch.manual_seed(42)
seqlen = num_chunks * BLOCK_SIZE
scale = HEAD_DIM ** -0.5
# Generate Q, K, V
q_full = torch.randn(1, seqlen, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
k_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
v_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
# Split into chunks
q_chunks = [q_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE] for i in range(num_chunks)]
kv_chunks = [
(k_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE],
v_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE])
for i in range(num_chunks)
]
# Reference: chunked attention on GPU (no offload)
out_ref = compute_reference_chunked(q_chunks, kv_chunks, scale)
# Create manager with enough CPU blocks
manager = create_manager(num_gpu_slots, num_chunks)
# Test each layer
all_passed = True
max_diff_all = 0.0
mean_diff_all = 0.0
for layer_id in range(NUM_LAYERS):
# Store K, V to CPU cache
cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id)
# Compute with offload
out_offload = compute_offload_chunked_single_layer(
manager, q_chunks, cpu_block_ids, layer_id, scale
)
# Compare
diff = (out_ref - out_offload).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
max_diff_all = max(max_diff_all, max_diff)
mean_diff_all = max(mean_diff_all, mean_diff)
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
if verbose and not passed:
print(f" Layer {layer_id}: FAIL max_diff={max_diff:.6f}")
return all_passed, max_diff_all, mean_diff_all
# ============================================================
# Decode Phase Test
# ============================================================
def run_decode_correctness_test(
num_prefill_chunks: int,
num_gpu_slots: int,
num_decode_steps: int = 4,
verbose: bool = True,
) -> tuple[bool, float, float]:
"""
Test decode phase correctness with CPU offload.
Simulates:
1. Prefill: Store K, V for multiple chunks in CPU cache
2. Decode: Single token queries against all prefilled K, V
This tests the scenario in needle test where decode reads all previous KV.
"""
torch.manual_seed(42)
scale = HEAD_DIM ** -0.5
prefill_len = num_prefill_chunks * BLOCK_SIZE
# Generate prefill K, V (store in CPU)
k_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
v_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
# Split into chunks for CPU storage
kv_chunks = [
(k_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE],
v_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE])
for i in range(num_prefill_chunks)
]
# Create manager
manager = create_manager(num_gpu_slots, num_prefill_chunks)
offload_engine = manager.offload_engine
all_passed = True
max_diff_all = 0.0
mean_diff_all = 0.0
for layer_id in range(NUM_LAYERS):
# Store prefilled K, V to CPU cache
cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id)
for decode_step in range(num_decode_steps):
# Decode query: single token
q_decode = torch.randn(1, 1, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
# Reference: direct attention on GPU
# Concat all prefilled K, V and compute attention
out_ref = flash_attn_func(
q_decode,
k_prefill,
v_prefill,
causal=False, # Decode query can attend to all prefilled tokens
)
# Offload: load from CPU and compute
load_slots = offload_engine.get_load_slots_for_prefill(0) # Use all slots except decode slot
if len(load_slots) == 0 or len(cpu_block_ids) == 0:
# No previous chunks to load
out_offload = out_ref # Trivially equal
else:
o_acc, lse_acc = _pipeline_load_and_compute(
offload_engine,
q_decode,
cpu_block_ids,
load_slots,
layer_id,
scale,
)
out_offload = o_acc
# Compare
diff = (out_ref - out_offload).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
max_diff_all = max(max_diff_all, max_diff)
mean_diff_all = max(mean_diff_all, mean_diff)
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
if verbose and not passed:
print(f" Layer {layer_id} Step {decode_step}: FAIL max_diff={max_diff:.6f}")
return all_passed, max_diff_all, mean_diff_all
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 70)
print("Test: Offload Chunked Attention Correctness")
print("=" * 70)
print(f"Config: layers={NUM_LAYERS}, heads={NUM_HEADS}, kv_heads={NUM_KV_HEADS}, "
f"head_dim={HEAD_DIM}, block_size={BLOCK_SIZE}, dtype={DTYPE}")
print()
print("Comparing: Reference (GPU chunked) vs Offload (CPU->GPU pipeline)")
print()
# Test configurations: (num_chunks, num_gpu_slots)
TEST_CASES = [
# Basic tests
(2, 2), # Minimal: 2 chunks, 2 slots (no pipeline)
(2, 3), # 2 chunks, 3 slots (1-slot pipeline)
(4, 2), # 4 chunks, 2 slots (heavy slot reuse)
(4, 3), # 4 chunks, 3 slots
(4, 4), # 4 chunks, 4 slots
# Stress tests
(8, 3), # Many chunks, few slots
(8, 4), # Many chunks, moderate slots
(8, 6), # Many chunks, many slots (like bench_offload)
# Edge cases
(1, 2), # Single chunk
(3, 5), # Fewer chunks than slots
]
all_passed = True
results = []
for num_chunks, num_gpu_slots in TEST_CASES:
seqlen = num_chunks * BLOCK_SIZE
passed, max_diff, mean_diff = run_correctness_test(
num_chunks, num_gpu_slots, verbose=False
)
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
results.append((num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff))
print(f"[{status}] chunks={num_chunks:2d} slots={num_gpu_slots:2d} "
f"seqlen={seqlen:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
print()
# ================================================================
# Part 2: Decode Phase Tests
# ================================================================
print("=" * 70)
print("Part 2: Decode Phase Correctness")
print("=" * 70)
print("Testing: Decode query (single token) against all prefilled K, V")
print()
DECODE_TEST_CASES = [
# (num_prefill_chunks, num_gpu_slots)
(2, 2),
(4, 3),
(4, 4),
(8, 4),
(8, 6),
]
decode_results = []
for num_prefill_chunks, num_gpu_slots in DECODE_TEST_CASES:
prefill_len = num_prefill_chunks * BLOCK_SIZE
passed, max_diff, mean_diff = run_decode_correctness_test(
num_prefill_chunks, num_gpu_slots, num_decode_steps=4, verbose=False
)
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
decode_results.append((num_prefill_chunks, num_gpu_slots, prefill_len, passed, max_diff, mean_diff))
print(f"[{status}] prefill_chunks={num_prefill_chunks:2d} slots={num_gpu_slots:2d} "
f"prefill_len={prefill_len:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
print()
print("=" * 70)
# Summary
prefill_passed = sum(1 for r in results if r[3])
decode_passed = sum(1 for r in decode_results if r[3])
total_tests = len(results) + len(decode_results)
total_passed = prefill_passed + decode_passed
print(f"Results: {total_passed}/{total_tests} tests passed")
print(f" - Prefill: {prefill_passed}/{len(results)}")
print(f" - Decode: {decode_passed}/{len(decode_results)}")
if not all_passed:
print("\nFailed tests:")
for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in results:
if not passed:
print(f" - [Prefill] chunks={num_chunks}, slots={num_gpu_slots}, "
f"seqlen={seqlen}, max_diff={max_diff:.6f}")
for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in decode_results:
if not passed:
print(f" - [Decode] prefill_chunks={num_chunks}, slots={num_gpu_slots}, "
f"prefill_len={seqlen}, max_diff={max_diff:.6f}")
print()
assert all_passed, "Some correctness tests failed!"
print("test_offload_correctness: PASSED")