574 lines
18 KiB
Python
574 lines
18 KiB
Python
"""
|
|
Correctness test for chunked attention with CPU offload.
|
|
|
|
Validates that the offload pipeline (CPU -> GPU transfer + chunked attention)
|
|
produces the same result as direct GPU computation.
|
|
|
|
Test scenario:
|
|
1. Generate Q, K, V data
|
|
2. Reference: Compute full causal attention on GPU
|
|
3. Offload: Store K, V in CPU cache, load via pipeline, compute chunked attention
|
|
4. Compare results
|
|
|
|
This test is designed to identify bugs in:
|
|
- CPU <-> GPU data transfer (sgDMA)
|
|
- Ring buffer slot management
|
|
- N-way pipeline ordering
|
|
- Triton merge kernel correctness
|
|
"""
|
|
|
|
import torch
|
|
from flash_attn.flash_attn_interface import flash_attn_func
|
|
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
|
|
|
|
# ============================================================
|
|
# Configuration
|
|
# ============================================================
|
|
|
|
NUM_LAYERS = 4
|
|
NUM_HEADS = 8
|
|
NUM_KV_HEADS = 8
|
|
HEAD_DIM = 64
|
|
BLOCK_SIZE = 256 # Smaller for faster testing
|
|
DTYPE = torch.bfloat16
|
|
DEVICE = "cuda"
|
|
|
|
|
|
# ============================================================
|
|
# Reference Implementation (GPU only, no offload)
|
|
# ============================================================
|
|
|
|
def compute_reference_causal(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute reference causal attention using flash_attn_func.
|
|
|
|
Args:
|
|
q, k, v: [batch, seqlen, nheads, headdim]
|
|
|
|
Returns:
|
|
out: [batch, seqlen, nheads, headdim]
|
|
"""
|
|
return flash_attn_func(q, k, v, causal=True)
|
|
|
|
|
|
def compute_reference_chunked(
|
|
q_chunks: list,
|
|
kv_chunks: list,
|
|
scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute chunked prefill attention directly on GPU (no offload).
|
|
|
|
This is the "gold standard" for chunked attention correctness.
|
|
|
|
Args:
|
|
q_chunks: List of [batch, chunk_size, nheads, headdim]
|
|
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
|
|
scale: Softmax scale
|
|
|
|
Returns:
|
|
out: [batch, total_seqlen, nheads, headdim]
|
|
"""
|
|
out_chunks = []
|
|
|
|
for chunk_idx, q_chunk in enumerate(q_chunks):
|
|
o_acc, lse_acc = None, None
|
|
|
|
# Attend to all previous chunks (no causal mask)
|
|
for i in range(chunk_idx):
|
|
k_chunk, v_chunk = kv_chunks[i]
|
|
chunk_o, chunk_lse = flash_attn_with_lse(
|
|
q_chunk, k_chunk, v_chunk,
|
|
softmax_scale=scale,
|
|
causal=False,
|
|
)
|
|
if o_acc is None:
|
|
o_acc, lse_acc = chunk_o, chunk_lse
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)
|
|
|
|
# Attend to current chunk (with causal mask)
|
|
k_chunk, v_chunk = kv_chunks[chunk_idx]
|
|
current_o, current_lse = flash_attn_with_lse(
|
|
q_chunk, k_chunk, v_chunk,
|
|
softmax_scale=scale,
|
|
causal=True,
|
|
)
|
|
|
|
if o_acc is None:
|
|
final_o = current_o
|
|
else:
|
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
|
|
|
out_chunks.append(final_o)
|
|
|
|
return torch.cat(out_chunks, dim=1)
|
|
|
|
|
|
# ============================================================
|
|
# Offload Implementation
|
|
# ============================================================
|
|
|
|
def create_manager(num_gpu_slots: int, num_cpu_blocks: int):
|
|
"""Create HybridKVCacheManager with specified configuration."""
|
|
manager = HybridKVCacheManager(
|
|
num_gpu_slots=num_gpu_slots,
|
|
num_cpu_blocks=num_cpu_blocks,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
manager.allocate_cache(
|
|
num_layers=NUM_LAYERS,
|
|
num_kv_heads=NUM_KV_HEADS,
|
|
head_dim=HEAD_DIM,
|
|
dtype=DTYPE,
|
|
)
|
|
return manager
|
|
|
|
|
|
def store_kv_to_cpu_cache(manager, kv_chunks: list, layer_id: int):
|
|
"""
|
|
Store K, V chunks to CPU cache.
|
|
|
|
Args:
|
|
manager: HybridKVCacheManager
|
|
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
|
|
layer_id: Layer index
|
|
|
|
Returns:
|
|
cpu_block_ids: List of CPU block IDs
|
|
"""
|
|
offload_engine = manager.offload_engine
|
|
cpu_block_ids = []
|
|
|
|
for block_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
|
|
# k_chunk, v_chunk: [batch, chunk_size, nheads, headdim]
|
|
# CPU cache layout: [num_layers, num_blocks, block_size, nheads, headdim]
|
|
k_data = k_chunk.squeeze(0) # [chunk_size, nheads, headdim]
|
|
v_data = v_chunk.squeeze(0)
|
|
|
|
offload_engine.k_cache_cpu[layer_id, block_idx, :k_data.shape[0]].copy_(k_data)
|
|
offload_engine.v_cache_cpu[layer_id, block_idx, :v_data.shape[0]].copy_(v_data)
|
|
|
|
cpu_block_ids.append(block_idx)
|
|
|
|
return cpu_block_ids
|
|
|
|
|
|
def compute_offload_chunked_single_layer(
|
|
manager,
|
|
q_chunks: list,
|
|
cpu_block_ids: list,
|
|
layer_id: int,
|
|
scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute chunked attention for a single layer using offload pipeline.
|
|
|
|
This mimics the behavior of Attention._ring_buffer_pipeline_load().
|
|
|
|
Args:
|
|
manager: HybridKVCacheManager
|
|
q_chunks: List of [batch, chunk_size, nheads, headdim]
|
|
cpu_block_ids: List of CPU block IDs containing K, V data
|
|
layer_id: Layer index
|
|
scale: Softmax scale
|
|
|
|
Returns:
|
|
out: [batch, total_seqlen, nheads, headdim]
|
|
"""
|
|
offload_engine = manager.offload_engine
|
|
out_chunks = []
|
|
|
|
for chunk_idx, q_chunk in enumerate(q_chunks):
|
|
# CPU blocks to load: all blocks before current chunk
|
|
blocks_to_load = cpu_block_ids[:chunk_idx]
|
|
|
|
# Get slots for this chunk
|
|
write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
|
|
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
|
|
# Load and compute attention for previous chunks
|
|
o_acc, lse_acc = None, None
|
|
|
|
if len(blocks_to_load) > 0 and len(load_slots) > 0:
|
|
o_acc, lse_acc = _pipeline_load_and_compute(
|
|
offload_engine,
|
|
q_chunk,
|
|
blocks_to_load,
|
|
load_slots,
|
|
layer_id,
|
|
scale,
|
|
)
|
|
|
|
# Current chunk's K, V (load from CPU to GPU slot)
|
|
current_cpu_block = cpu_block_ids[chunk_idx]
|
|
offload_engine.load_to_slot_layer(write_slot, layer_id, current_cpu_block)
|
|
offload_engine.wait_slot_layer(write_slot, layer_id)
|
|
|
|
current_k, current_v = offload_engine.get_kv_for_slot(write_slot, layer_id)
|
|
|
|
# Compute attention with causal mask
|
|
current_o, current_lse = flash_attn_with_lse(
|
|
q_chunk, current_k, current_v,
|
|
softmax_scale=scale,
|
|
causal=True,
|
|
)
|
|
|
|
# Merge
|
|
if o_acc is None:
|
|
final_o = current_o
|
|
else:
|
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
|
|
|
out_chunks.append(final_o)
|
|
|
|
return torch.cat(out_chunks, dim=1)
|
|
|
|
|
|
def _pipeline_load_and_compute(
|
|
offload_engine,
|
|
q_chunk: torch.Tensor,
|
|
cpu_block_table: list,
|
|
load_slots: list,
|
|
layer_id: int,
|
|
scale: float,
|
|
):
|
|
"""
|
|
Pipeline loading from CPU and computing attention.
|
|
|
|
Mirrors Attention._ring_buffer_pipeline_load() logic.
|
|
"""
|
|
num_blocks = len(cpu_block_table)
|
|
num_slots = len(load_slots)
|
|
|
|
o_acc, lse_acc = None, None
|
|
|
|
# Phase 1: Pre-load up to num_slots blocks
|
|
num_preload = min(num_slots, num_blocks)
|
|
for i in range(num_preload):
|
|
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
|
|
|
# Phase 2: Main loop
|
|
compute_stream = offload_engine.compute_stream
|
|
|
|
for block_idx in range(num_blocks):
|
|
current_slot = load_slots[block_idx % num_slots]
|
|
|
|
# Wait for transfer
|
|
offload_engine.wait_slot_layer(current_slot, layer_id)
|
|
|
|
# Compute on dedicated stream
|
|
with torch.cuda.stream(compute_stream):
|
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
|
|
prev_o, prev_lse = flash_attn_with_lse(
|
|
q_chunk, prev_k, prev_v,
|
|
softmax_scale=scale,
|
|
causal=False,
|
|
)
|
|
offload_engine.record_slot_compute_done(current_slot, layer_id)
|
|
|
|
# Start next transfer
|
|
next_block_idx = block_idx + num_slots
|
|
if next_block_idx < num_blocks:
|
|
offload_engine.load_to_slot_layer(
|
|
current_slot, layer_id, cpu_block_table[next_block_idx]
|
|
)
|
|
|
|
# Merge
|
|
with torch.cuda.stream(compute_stream):
|
|
if o_acc is None:
|
|
o_acc, lse_acc = prev_o, prev_lse
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
|
|
# Sync compute stream
|
|
compute_stream.synchronize()
|
|
|
|
return o_acc, lse_acc
|
|
|
|
|
|
# ============================================================
|
|
# Test Runner
|
|
# ============================================================
|
|
|
|
def run_correctness_test(
|
|
num_chunks: int,
|
|
num_gpu_slots: int,
|
|
verbose: bool = True,
|
|
) -> tuple[bool, float, float]:
|
|
"""
|
|
Run a single correctness test.
|
|
|
|
Args:
|
|
num_chunks: Number of chunks (= number of CPU blocks)
|
|
num_gpu_slots: Number of GPU ring buffer slots
|
|
verbose: Print detailed info
|
|
|
|
Returns:
|
|
(passed, max_diff, mean_diff)
|
|
"""
|
|
torch.manual_seed(42)
|
|
|
|
seqlen = num_chunks * BLOCK_SIZE
|
|
scale = HEAD_DIM ** -0.5
|
|
|
|
# Generate Q, K, V
|
|
q_full = torch.randn(1, seqlen, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
|
k_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
|
v_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
|
|
|
# Split into chunks
|
|
q_chunks = [q_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE] for i in range(num_chunks)]
|
|
kv_chunks = [
|
|
(k_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE],
|
|
v_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE])
|
|
for i in range(num_chunks)
|
|
]
|
|
|
|
# Reference: chunked attention on GPU (no offload)
|
|
out_ref = compute_reference_chunked(q_chunks, kv_chunks, scale)
|
|
|
|
# Create manager with enough CPU blocks
|
|
manager = create_manager(num_gpu_slots, num_chunks)
|
|
|
|
# Test each layer
|
|
all_passed = True
|
|
max_diff_all = 0.0
|
|
mean_diff_all = 0.0
|
|
|
|
for layer_id in range(NUM_LAYERS):
|
|
# Store K, V to CPU cache
|
|
cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id)
|
|
|
|
# Compute with offload
|
|
out_offload = compute_offload_chunked_single_layer(
|
|
manager, q_chunks, cpu_block_ids, layer_id, scale
|
|
)
|
|
|
|
# Compare
|
|
diff = (out_ref - out_offload).abs()
|
|
max_diff = diff.max().item()
|
|
mean_diff = diff.mean().item()
|
|
|
|
max_diff_all = max(max_diff_all, max_diff)
|
|
mean_diff_all = max(mean_diff_all, mean_diff)
|
|
|
|
tol = 1e-2
|
|
passed = max_diff < tol
|
|
all_passed = all_passed and passed
|
|
|
|
if verbose and not passed:
|
|
print(f" Layer {layer_id}: FAIL max_diff={max_diff:.6f}")
|
|
|
|
return all_passed, max_diff_all, mean_diff_all
|
|
|
|
|
|
# ============================================================
|
|
# Decode Phase Test
|
|
# ============================================================
|
|
|
|
def run_decode_correctness_test(
|
|
num_prefill_chunks: int,
|
|
num_gpu_slots: int,
|
|
num_decode_steps: int = 4,
|
|
verbose: bool = True,
|
|
) -> tuple[bool, float, float]:
|
|
"""
|
|
Test decode phase correctness with CPU offload.
|
|
|
|
Simulates:
|
|
1. Prefill: Store K, V for multiple chunks in CPU cache
|
|
2. Decode: Single token queries against all prefilled K, V
|
|
|
|
This tests the scenario in needle test where decode reads all previous KV.
|
|
"""
|
|
torch.manual_seed(42)
|
|
|
|
scale = HEAD_DIM ** -0.5
|
|
prefill_len = num_prefill_chunks * BLOCK_SIZE
|
|
|
|
# Generate prefill K, V (store in CPU)
|
|
k_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
|
v_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
|
|
|
# Split into chunks for CPU storage
|
|
kv_chunks = [
|
|
(k_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE],
|
|
v_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE])
|
|
for i in range(num_prefill_chunks)
|
|
]
|
|
|
|
# Create manager
|
|
manager = create_manager(num_gpu_slots, num_prefill_chunks)
|
|
offload_engine = manager.offload_engine
|
|
|
|
all_passed = True
|
|
max_diff_all = 0.0
|
|
mean_diff_all = 0.0
|
|
|
|
for layer_id in range(NUM_LAYERS):
|
|
# Store prefilled K, V to CPU cache
|
|
cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id)
|
|
|
|
for decode_step in range(num_decode_steps):
|
|
# Decode query: single token
|
|
q_decode = torch.randn(1, 1, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
|
|
|
# Reference: direct attention on GPU
|
|
# Concat all prefilled K, V and compute attention
|
|
out_ref = flash_attn_func(
|
|
q_decode,
|
|
k_prefill,
|
|
v_prefill,
|
|
causal=False, # Decode query can attend to all prefilled tokens
|
|
)
|
|
|
|
# Offload: load from CPU and compute
|
|
load_slots = offload_engine.get_load_slots_for_prefill(0) # Use all slots except decode slot
|
|
|
|
if len(load_slots) == 0 or len(cpu_block_ids) == 0:
|
|
# No previous chunks to load
|
|
out_offload = out_ref # Trivially equal
|
|
else:
|
|
o_acc, lse_acc = _pipeline_load_and_compute(
|
|
offload_engine,
|
|
q_decode,
|
|
cpu_block_ids,
|
|
load_slots,
|
|
layer_id,
|
|
scale,
|
|
)
|
|
out_offload = o_acc
|
|
|
|
# Compare
|
|
diff = (out_ref - out_offload).abs()
|
|
max_diff = diff.max().item()
|
|
mean_diff = diff.mean().item()
|
|
|
|
max_diff_all = max(max_diff_all, max_diff)
|
|
mean_diff_all = max(mean_diff_all, mean_diff)
|
|
|
|
tol = 1e-2
|
|
passed = max_diff < tol
|
|
all_passed = all_passed and passed
|
|
|
|
if verbose and not passed:
|
|
print(f" Layer {layer_id} Step {decode_step}: FAIL max_diff={max_diff:.6f}")
|
|
|
|
return all_passed, max_diff_all, mean_diff_all
|
|
|
|
|
|
# ============================================================
|
|
# Main Test Script
|
|
# ============================================================
|
|
|
|
if __name__ == "__main__":
|
|
print("=" * 70)
|
|
print("Test: Offload Chunked Attention Correctness")
|
|
print("=" * 70)
|
|
print(f"Config: layers={NUM_LAYERS}, heads={NUM_HEADS}, kv_heads={NUM_KV_HEADS}, "
|
|
f"head_dim={HEAD_DIM}, block_size={BLOCK_SIZE}, dtype={DTYPE}")
|
|
print()
|
|
print("Comparing: Reference (GPU chunked) vs Offload (CPU->GPU pipeline)")
|
|
print()
|
|
|
|
# Test configurations: (num_chunks, num_gpu_slots)
|
|
TEST_CASES = [
|
|
# Basic tests
|
|
(2, 2), # Minimal: 2 chunks, 2 slots (no pipeline)
|
|
(2, 3), # 2 chunks, 3 slots (1-slot pipeline)
|
|
(4, 2), # 4 chunks, 2 slots (heavy slot reuse)
|
|
(4, 3), # 4 chunks, 3 slots
|
|
(4, 4), # 4 chunks, 4 slots
|
|
# Stress tests
|
|
(8, 3), # Many chunks, few slots
|
|
(8, 4), # Many chunks, moderate slots
|
|
(8, 6), # Many chunks, many slots (like bench_offload)
|
|
# Edge cases
|
|
(1, 2), # Single chunk
|
|
(3, 5), # Fewer chunks than slots
|
|
]
|
|
|
|
all_passed = True
|
|
results = []
|
|
|
|
for num_chunks, num_gpu_slots in TEST_CASES:
|
|
seqlen = num_chunks * BLOCK_SIZE
|
|
passed, max_diff, mean_diff = run_correctness_test(
|
|
num_chunks, num_gpu_slots, verbose=False
|
|
)
|
|
|
|
all_passed = all_passed and passed
|
|
status = "PASS" if passed else "FAIL"
|
|
|
|
results.append((num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff))
|
|
|
|
print(f"[{status}] chunks={num_chunks:2d} slots={num_gpu_slots:2d} "
|
|
f"seqlen={seqlen:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
|
|
|
print()
|
|
|
|
# ================================================================
|
|
# Part 2: Decode Phase Tests
|
|
# ================================================================
|
|
print("=" * 70)
|
|
print("Part 2: Decode Phase Correctness")
|
|
print("=" * 70)
|
|
print("Testing: Decode query (single token) against all prefilled K, V")
|
|
print()
|
|
|
|
DECODE_TEST_CASES = [
|
|
# (num_prefill_chunks, num_gpu_slots)
|
|
(2, 2),
|
|
(4, 3),
|
|
(4, 4),
|
|
(8, 4),
|
|
(8, 6),
|
|
]
|
|
|
|
decode_results = []
|
|
|
|
for num_prefill_chunks, num_gpu_slots in DECODE_TEST_CASES:
|
|
prefill_len = num_prefill_chunks * BLOCK_SIZE
|
|
passed, max_diff, mean_diff = run_decode_correctness_test(
|
|
num_prefill_chunks, num_gpu_slots, num_decode_steps=4, verbose=False
|
|
)
|
|
|
|
all_passed = all_passed and passed
|
|
status = "PASS" if passed else "FAIL"
|
|
|
|
decode_results.append((num_prefill_chunks, num_gpu_slots, prefill_len, passed, max_diff, mean_diff))
|
|
|
|
print(f"[{status}] prefill_chunks={num_prefill_chunks:2d} slots={num_gpu_slots:2d} "
|
|
f"prefill_len={prefill_len:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
|
|
|
print()
|
|
print("=" * 70)
|
|
|
|
# Summary
|
|
prefill_passed = sum(1 for r in results if r[3])
|
|
decode_passed = sum(1 for r in decode_results if r[3])
|
|
total_tests = len(results) + len(decode_results)
|
|
total_passed = prefill_passed + decode_passed
|
|
|
|
print(f"Results: {total_passed}/{total_tests} tests passed")
|
|
print(f" - Prefill: {prefill_passed}/{len(results)}")
|
|
print(f" - Decode: {decode_passed}/{len(decode_results)}")
|
|
|
|
if not all_passed:
|
|
print("\nFailed tests:")
|
|
for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in results:
|
|
if not passed:
|
|
print(f" - [Prefill] chunks={num_chunks}, slots={num_gpu_slots}, "
|
|
f"seqlen={seqlen}, max_diff={max_diff:.6f}")
|
|
for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in decode_results:
|
|
if not passed:
|
|
print(f" - [Decode] prefill_chunks={num_chunks}, slots={num_gpu_slots}, "
|
|
f"prefill_len={seqlen}, max_diff={max_diff:.6f}")
|
|
|
|
print()
|
|
assert all_passed, "Some correctness tests failed!"
|
|
print("test_offload_correctness: PASSED")
|