298 lines
9.6 KiB
Python
298 lines
9.6 KiB
Python
"""
|
|
Test Attention layer with KV cache offload - N-way Pipeline.
|
|
|
|
This test demonstrates and verifies the N-way pipeline with:
|
|
- Per-slot transfer streams for parallel H2D
|
|
- Dedicated compute stream (avoids CUDA default stream implicit sync)
|
|
- Pre-load phase + main loop with immediate slot reuse
|
|
|
|
Key difference from previous test:
|
|
- We first pre-fill many chunks to CPU cache
|
|
- Then simulate processing a new chunk that loads ALL previous blocks
|
|
- This exercises the full N-way pipeline with many blocks in flight
|
|
"""
|
|
|
|
import torch
|
|
from nanovllm.layers.attention import Attention
|
|
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
from nanovllm.engine.sequence import Sequence
|
|
from nanovllm.utils.context import set_context, reset_context
|
|
|
|
|
|
# ============================================================
|
|
# Configuration
|
|
# ============================================================
|
|
|
|
NUM_LAYERS = 8
|
|
NUM_HEADS = 8
|
|
NUM_KV_HEADS = 8
|
|
HEAD_DIM = 64
|
|
BLOCK_SIZE = 1024
|
|
CHUNK_SIZE = 1024
|
|
|
|
NUM_GPU_SLOTS = 6 # N-way pipeline with 6 slots
|
|
NUM_CPU_BLOCKS = 16 # Many blocks to load from CPU
|
|
|
|
DTYPE = torch.bfloat16
|
|
DEVICE = "cuda"
|
|
|
|
|
|
# ============================================================
|
|
# Setup
|
|
# ============================================================
|
|
|
|
def create_manager():
|
|
manager = HybridKVCacheManager(
|
|
num_gpu_slots=NUM_GPU_SLOTS,
|
|
num_cpu_blocks=NUM_CPU_BLOCKS,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
manager.allocate_cache(
|
|
num_layers=NUM_LAYERS,
|
|
num_kv_heads=NUM_KV_HEADS,
|
|
head_dim=HEAD_DIM,
|
|
dtype=DTYPE,
|
|
)
|
|
return manager
|
|
|
|
|
|
def create_attention_layers(manager):
|
|
layers = []
|
|
for layer_id in range(NUM_LAYERS):
|
|
attn = Attention(
|
|
num_heads=NUM_HEADS,
|
|
head_dim=HEAD_DIM,
|
|
scale=HEAD_DIM ** -0.5,
|
|
num_kv_heads=NUM_KV_HEADS,
|
|
)
|
|
attn.layer_id = layer_id
|
|
k_cache, v_cache = manager.get_layer_cache(layer_id)
|
|
attn.k_cache = k_cache
|
|
attn.v_cache = v_cache
|
|
layers.append(attn.to(DEVICE))
|
|
return layers
|
|
|
|
|
|
# ============================================================
|
|
# Pre-fill CPU cache with random data
|
|
# ============================================================
|
|
|
|
def prefill_cpu_cache(manager, num_blocks):
|
|
"""
|
|
Fill CPU cache with random KV data for num_blocks blocks.
|
|
This simulates having already processed many chunks.
|
|
"""
|
|
offload_engine = manager.offload_engine
|
|
|
|
for block_id in range(num_blocks):
|
|
# Generate random KV data for all layers
|
|
for layer_id in range(NUM_LAYERS):
|
|
k_data = torch.randn(
|
|
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
|
|
dtype=DTYPE, device=DEVICE
|
|
)
|
|
v_data = torch.randn(
|
|
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
|
|
dtype=DTYPE, device=DEVICE
|
|
)
|
|
|
|
# Copy to CPU cache
|
|
offload_engine.k_cache_cpu[layer_id, block_id].copy_(k_data)
|
|
offload_engine.v_cache_cpu[layer_id, block_id].copy_(v_data)
|
|
|
|
return list(range(num_blocks))
|
|
|
|
|
|
# ============================================================
|
|
# Simulate N-way Pipeline (mirrors attention.py logic)
|
|
# ============================================================
|
|
|
|
def simulate_nway_pipeline(
|
|
layer_id: int,
|
|
q_batched: torch.Tensor,
|
|
cpu_block_table: list,
|
|
load_slots: list,
|
|
offload_engine,
|
|
scale: float,
|
|
):
|
|
"""
|
|
Simulate N-way pipeline for a single layer.
|
|
This mirrors the logic in Attention._ring_buffer_pipeline_load().
|
|
"""
|
|
num_blocks = len(cpu_block_table)
|
|
num_slots = len(load_slots)
|
|
|
|
o_acc, lse_acc = None, None
|
|
|
|
# Phase 1: Pre-load up to num_slots blocks
|
|
num_preload = min(num_slots, num_blocks)
|
|
torch.cuda.nvtx.range_push(f"Phase1_Preload: L{layer_id}")
|
|
for i in range(num_preload):
|
|
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
|
torch.cuda.nvtx.range_pop()
|
|
|
|
# Phase 2: Main loop with compute_stream
|
|
compute_stream = offload_engine.compute_stream
|
|
|
|
for block_idx in range(num_blocks):
|
|
torch.cuda.nvtx.range_push(f"Block: L{layer_id} B{block_idx}")
|
|
|
|
current_slot = load_slots[block_idx % num_slots]
|
|
|
|
# Wait for transfer
|
|
offload_engine.wait_slot_layer(current_slot, layer_id)
|
|
|
|
# Compute on dedicated stream
|
|
with torch.cuda.stream(compute_stream):
|
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{layer_id} B{block_idx}")
|
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
|
|
prev_o, prev_lse = flash_attn_with_lse(
|
|
q_batched, prev_k, prev_v,
|
|
softmax_scale=scale,
|
|
causal=False,
|
|
)
|
|
torch.cuda.nvtx.range_pop()
|
|
offload_engine.record_slot_compute_done(current_slot, layer_id)
|
|
|
|
# Start next transfer (reuse current_slot)
|
|
next_block_idx = block_idx + num_slots
|
|
if next_block_idx < num_blocks:
|
|
offload_engine.load_to_slot_layer(
|
|
current_slot, layer_id, cpu_block_table[next_block_idx]
|
|
)
|
|
|
|
# Merge
|
|
with torch.cuda.stream(compute_stream):
|
|
if o_acc is None:
|
|
o_acc, lse_acc = prev_o, prev_lse
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
|
|
torch.cuda.nvtx.range_pop()
|
|
|
|
return o_acc, lse_acc
|
|
|
|
|
|
def simulate_full_forward(layers, manager, cpu_block_table, chunk_size):
|
|
"""
|
|
Simulate forward pass through all layers, loading previous blocks from CPU.
|
|
This is the key test: many blocks loaded via N-way pipeline.
|
|
"""
|
|
offload_engine = manager.offload_engine
|
|
|
|
# Current chunk index (we're processing the "next" chunk after all prefilled ones)
|
|
current_chunk_idx = len(cpu_block_table)
|
|
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
|
|
# Random query for attention
|
|
q = torch.randn(1, chunk_size, NUM_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
|
|
|
|
outputs = []
|
|
for layer in layers:
|
|
torch.cuda.nvtx.range_push(f"Layer: {layer.layer_id}")
|
|
|
|
o_acc, lse_acc = simulate_nway_pipeline(
|
|
layer.layer_id,
|
|
q,
|
|
cpu_block_table,
|
|
load_slots,
|
|
offload_engine,
|
|
layer.scale,
|
|
)
|
|
|
|
outputs.append(o_acc)
|
|
torch.cuda.nvtx.range_pop()
|
|
|
|
return outputs
|
|
|
|
|
|
# ============================================================
|
|
# Main Test
|
|
# ============================================================
|
|
|
|
print("=" * 60)
|
|
print("Test: N-way Pipeline with CPU Offload")
|
|
print("=" * 60)
|
|
|
|
# 1. Setup
|
|
print("\n[1] Creating manager and attention layers...")
|
|
manager = create_manager()
|
|
layers = create_attention_layers(manager)
|
|
offload_engine = manager.offload_engine
|
|
|
|
print(f" - GPU slots: {NUM_GPU_SLOTS}")
|
|
print(f" - CPU blocks: {NUM_CPU_BLOCKS}")
|
|
print(f" - Per-slot streams: {len(offload_engine.slot_transfer_streams)}")
|
|
print(f" - Compute stream: {offload_engine.compute_stream}")
|
|
|
|
# 2. Pre-fill CPU cache
|
|
NUM_PREV_BLOCKS = 12 # Many blocks to load via N-way pipeline
|
|
print(f"\n[2] Pre-filling {NUM_PREV_BLOCKS} blocks to CPU cache...")
|
|
cpu_block_table = prefill_cpu_cache(manager, NUM_PREV_BLOCKS)
|
|
print(f" - CPU blocks filled: {cpu_block_table}")
|
|
|
|
# 3. Verify pipeline configuration
|
|
current_chunk_idx = NUM_PREV_BLOCKS
|
|
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
print(f"\n[3] Pipeline configuration for chunk {current_chunk_idx}:")
|
|
print(f" - Write slot: {write_slot}")
|
|
print(f" - Load slots: {load_slots}")
|
|
print(f" - Pipeline depth (N-way): {len(load_slots)}")
|
|
assert len(load_slots) == NUM_GPU_SLOTS - 1, f"Expected {NUM_GPU_SLOTS - 1} load slots"
|
|
|
|
# 4. Warmup
|
|
print("\n[4] Warmup (3 iterations)...")
|
|
for i in range(3):
|
|
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
|
|
torch.cuda.synchronize()
|
|
print(f" - Warmup {i+1}/3 done")
|
|
|
|
# 5. Benchmark
|
|
NUM_ITERS = 10
|
|
print(f"\n[5] Benchmark ({NUM_ITERS} iterations)...")
|
|
|
|
torch.cuda.synchronize()
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
start_event.record()
|
|
for i in range(NUM_ITERS):
|
|
torch.cuda.nvtx.range_push(f"Iteration_{i}")
|
|
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
|
|
torch.cuda.nvtx.range_pop()
|
|
end_event.record()
|
|
|
|
torch.cuda.synchronize()
|
|
elapsed_ms = start_event.elapsed_time(end_event)
|
|
|
|
# Stats
|
|
total_blocks_loaded = NUM_PREV_BLOCKS * NUM_LAYERS * NUM_ITERS
|
|
blocks_per_sec = total_blocks_loaded / (elapsed_ms / 1000)
|
|
total_tokens = NUM_PREV_BLOCKS * BLOCK_SIZE * NUM_LAYERS * NUM_ITERS
|
|
tokens_per_sec = total_tokens / (elapsed_ms / 1000)
|
|
|
|
print(f"\n[6] Results:")
|
|
print(f" - Total time: {elapsed_ms:.2f} ms")
|
|
print(f" - Per iteration: {elapsed_ms / NUM_ITERS:.2f} ms")
|
|
print(f" - Blocks loaded: {total_blocks_loaded} ({blocks_per_sec:.0f} blocks/s)")
|
|
print(f" - Tokens processed: {total_tokens} ({tokens_per_sec:.0f} tok/s)")
|
|
|
|
# 7. Verification
|
|
print("\n[7] Verification:")
|
|
assert len(outputs) == NUM_LAYERS, f"Expected {NUM_LAYERS} outputs"
|
|
for i, o in enumerate(outputs):
|
|
assert o is not None, f"Layer {i} output is None"
|
|
assert o.shape == (1, CHUNK_SIZE, NUM_HEADS, HEAD_DIM), f"Layer {i} shape mismatch"
|
|
print(" - All layer outputs valid ✓")
|
|
print(" - N-way pipeline executed correctly ✓")
|
|
|
|
# Cleanup
|
|
reset_context()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("test_attention_offload: PASSED")
|
|
print("=" * 60)
|