[opt] optimize nanovllm performance compareable with vllm.

This commit is contained in:
Zijie Tian
2025-12-25 03:47:07 +08:00
parent 16fcf8350b
commit 82ed34fc2d
7 changed files with 450 additions and 208 deletions

View File

@@ -1,13 +1,21 @@
"""
Test Attention layer with KV cache offload in isolation.
Test Attention layer with KV cache offload - N-way Pipeline.
This test demonstrates how to use Attention + HybridKVCacheManager directly
without requiring full LLMEngine/ModelRunner setup.
This test demonstrates and verifies the N-way pipeline with:
- Per-slot transfer streams for parallel H2D
- Dedicated compute stream (avoids CUDA default stream implicit sync)
- Pre-load phase + main loop with immediate slot reuse
Key difference from previous test:
- We first pre-fill many chunks to CPU cache
- Then simulate processing a new chunk that loads ALL previous blocks
- This exercises the full N-way pipeline with many blocks in flight
"""
import torch
from nanovllm.layers.attention import Attention
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
from nanovllm.engine.sequence import Sequence
from nanovllm.utils.context import set_context, reset_context
@@ -16,45 +24,40 @@ from nanovllm.utils.context import set_context, reset_context
# Configuration
# ============================================================
NUM_LAYERS = 8 # Multi-layer for realistic profiling
NUM_LAYERS = 8
NUM_HEADS = 8
NUM_KV_HEADS = 8
HEAD_DIM = 64
BLOCK_SIZE = 1024 # tokens per block
CHUNK_SIZE = 1024 # tokens per chunk (same as block for simplicity)
BLOCK_SIZE = 1024
CHUNK_SIZE = 1024
NUM_GPU_SLOTS = 4
NUM_CPU_BLOCKS = 16
NUM_GPU_SLOTS = 6 # N-way pipeline with 6 slots
NUM_CPU_BLOCKS = 16 # Many blocks to load from CPU
DTYPE = torch.float16
DTYPE = torch.bfloat16
DEVICE = "cuda"
# ============================================================
# Setup: Create Manager and Attention Layers
# Setup
# ============================================================
def create_manager():
"""Create and initialize HybridKVCacheManager with OffloadEngine."""
manager = HybridKVCacheManager(
num_gpu_slots=NUM_GPU_SLOTS,
num_cpu_blocks=NUM_CPU_BLOCKS,
block_size=BLOCK_SIZE,
)
# Initialize offload engine (this creates k_cache_gpu/cpu, v_cache_gpu/cpu)
manager.allocate_cache(
num_layers=NUM_LAYERS,
num_kv_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
dtype=DTYPE,
)
return manager
def create_attention_layers(manager):
"""Create attention layers and bind KV cache."""
layers = []
for layer_id in range(NUM_LAYERS):
attn = Attention(
@@ -64,89 +67,145 @@ def create_attention_layers(manager):
num_kv_heads=NUM_KV_HEADS,
)
attn.layer_id = layer_id
# Bind KV cache from manager
k_cache, v_cache = manager.get_layer_cache(layer_id)
attn.k_cache = k_cache
attn.v_cache = v_cache
layers.append(attn.to(DEVICE))
return layers
def create_test_sequence(manager, num_chunks=3):
"""Create a test sequence and allocate blocks."""
total_tokens = num_chunks * CHUNK_SIZE
# ============================================================
# Pre-fill CPU cache with random data
# ============================================================
# Sequence only takes token_ids
seq = Sequence(token_ids=list(range(total_tokens)))
def prefill_cpu_cache(manager, num_blocks):
"""
Fill CPU cache with random KV data for num_blocks blocks.
This simulates having already processed many chunks.
"""
offload_engine = manager.offload_engine
# Set block_size for this test
seq.block_size = BLOCK_SIZE
for block_id in range(num_blocks):
# Generate random KV data for all layers
for layer_id in range(NUM_LAYERS):
k_data = torch.randn(
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
dtype=DTYPE, device=DEVICE
)
v_data = torch.randn(
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
dtype=DTYPE, device=DEVICE
)
# Allocate blocks (will be on CPU in CPU-primary mode)
manager.allocate(seq)
# Copy to CPU cache
offload_engine.k_cache_cpu[layer_id, block_id].copy_(k_data)
offload_engine.v_cache_cpu[layer_id, block_id].copy_(v_data)
return seq
return list(range(num_blocks))
# ============================================================
# Chunked Prefill Simulation
# Simulate N-way Pipeline (mirrors attention.py logic)
# ============================================================
def simulate_chunk_forward(
layers,
manager,
seq,
chunk_idx,
chunk_size,
def simulate_nway_pipeline(
layer_id: int,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
scale: float,
):
"""
Simulate forward pass for one chunk through all layers.
Returns:
output: Final layer attention output
Simulate N-way pipeline for a single layer.
This mirrors the logic in Attention._ring_buffer_pipeline_load().
"""
# Generate random Q, K, V for this chunk
hidden = torch.randn(chunk_size, NUM_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
num_blocks = len(cpu_block_table)
num_slots = len(load_slots)
# Build slot_mapping: maps token positions to GPU slots
write_slot = manager.offload_engine.get_write_slot_for_prefill(chunk_idx)
slot_mapping = torch.full((chunk_size,), write_slot * BLOCK_SIZE, dtype=torch.long, device=DEVICE)
slot_mapping += torch.arange(chunk_size, device=DEVICE)
o_acc, lse_acc = None, None
# Build cu_seqlens for flash attention
cu_seqlens = torch.tensor([0, chunk_size], dtype=torch.int32, device=DEVICE)
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
torch.cuda.nvtx.range_push(f"Phase1_Preload: L{layer_id}")
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
torch.cuda.nvtx.range_pop()
# Set context for this chunk
set_context(
is_prefill=True,
is_chunked_prefill=True,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=chunk_size,
max_seqlen_k=chunk_size,
slot_mapping=slot_mapping,
kvcache_manager=manager,
chunked_seq=seq,
current_chunk_idx=chunk_idx,
)
# Phase 2: Main loop with compute_stream
compute_stream = offload_engine.compute_stream
# Forward through all layers
output = hidden
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"Block: L{layer_id} B{block_idx}")
current_slot = load_slots[block_idx % num_slots]
# Wait for transfer
offload_engine.wait_slot_layer(current_slot, layer_id)
# Compute on dedicated stream
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{layer_id} B{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
offload_engine.record_slot_compute_done(current_slot, layer_id)
# Start next transfer (reuse current_slot)
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(
current_slot, layer_id, cpu_block_table[next_block_idx]
)
# Merge
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop()
return o_acc, lse_acc
def simulate_full_forward(layers, manager, cpu_block_table, chunk_size):
"""
Simulate forward pass through all layers, loading previous blocks from CPU.
This is the key test: many blocks loaded via N-way pipeline.
"""
offload_engine = manager.offload_engine
# Current chunk index (we're processing the "next" chunk after all prefilled ones)
current_chunk_idx = len(cpu_block_table)
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
# Random query for attention
q = torch.randn(1, chunk_size, NUM_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
outputs = []
for layer in layers:
k = torch.randn(chunk_size, NUM_KV_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
v = torch.randn(chunk_size, NUM_KV_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
output = layer.forward(output, k, v)
torch.cuda.nvtx.range_push(f"Layer: {layer.layer_id}")
# Offload current chunk to CPU
logical_id = seq.block_table[chunk_idx]
cpu_block_id = manager.logical_blocks[logical_id].cpu_block_id
manager.offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
manager.prefilled_blocks.add(logical_id)
o_acc, lse_acc = simulate_nway_pipeline(
layer.layer_id,
q,
cpu_block_table,
load_slots,
offload_engine,
layer.scale,
)
return output
outputs.append(o_acc)
torch.cuda.nvtx.range_pop()
return outputs
# ============================================================
@@ -154,64 +213,81 @@ def simulate_chunk_forward(
# ============================================================
print("=" * 60)
print("Test: Attention Layer with KV Cache Offload")
print("Test: N-way Pipeline with CPU Offload")
print("=" * 60)
# 1. Setup
print("\n[1] Creating manager and attention layers...")
manager = create_manager()
layers = create_attention_layers(manager)
print(f" - Manager: {NUM_GPU_SLOTS} GPU slots, {NUM_CPU_BLOCKS} CPU blocks")
print(f" - Layers: {NUM_LAYERS} layers, {NUM_HEADS} heads, {HEAD_DIM} head_dim")
print(f" - OffloadEngine initialized: {manager.offload_engine is not None}")
offload_engine = manager.offload_engine
# 2. Setup
print("\n[2] Test configuration...")
NUM_CHUNKS = NUM_CPU_BLOCKS # Use all CPU blocks
print(f" - Total tokens: {NUM_CHUNKS * CHUNK_SIZE}")
print(f" - Chunks: {NUM_CHUNKS}")
print(f" - GPU slots: {NUM_GPU_SLOTS}")
print(f" - CPU blocks: {NUM_CPU_BLOCKS}")
print(f" - Per-slot streams: {len(offload_engine.slot_transfer_streams)}")
print(f" - Compute stream: {offload_engine.compute_stream}")
# 3. Warmup runs
print(f"\n[3] Warmup runs (3 iterations)...")
for warmup_iter in range(3):
manager.prefilled_blocks.clear()
seq = create_test_sequence(manager, num_chunks=NUM_CHUNKS)
# 2. Pre-fill CPU cache
NUM_PREV_BLOCKS = 12 # Many blocks to load via N-way pipeline
print(f"\n[2] Pre-filling {NUM_PREV_BLOCKS} blocks to CPU cache...")
cpu_block_table = prefill_cpu_cache(manager, NUM_PREV_BLOCKS)
print(f" - CPU blocks filled: {cpu_block_table}")
for chunk_idx in range(NUM_CHUNKS):
write_slot = manager.offload_engine.get_write_slot_for_prefill(chunk_idx)
output = simulate_chunk_forward(layers, manager, seq, chunk_idx, CHUNK_SIZE)
# 3. Verify pipeline configuration
current_chunk_idx = NUM_PREV_BLOCKS
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
print(f"\n[3] Pipeline configuration for chunk {current_chunk_idx}:")
print(f" - Write slot: {write_slot}")
print(f" - Load slots: {load_slots}")
print(f" - Pipeline depth (N-way): {len(load_slots)}")
assert len(load_slots) == NUM_GPU_SLOTS - 1, f"Expected {NUM_GPU_SLOTS - 1} load slots"
manager.deallocate(seq)
print(f" - Warmup {warmup_iter + 1}/3 completed")
# 4. Warmup
print("\n[4] Warmup (3 iterations)...")
for i in range(3):
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
torch.cuda.synchronize()
print(f" - Warmup {i+1}/3 done")
# 4. Benchmark runs
print(f"\n[4] Benchmark runs (10 iterations)...")
for bench_iter in range(10):
manager.prefilled_blocks.clear()
seq = create_test_sequence(manager, num_chunks=NUM_CHUNKS)
# 5. Benchmark
NUM_ITERS = 10
print(f"\n[5] Benchmark ({NUM_ITERS} iterations)...")
for chunk_idx in range(NUM_CHUNKS):
write_slot = manager.offload_engine.get_write_slot_for_prefill(chunk_idx)
load_slots = manager.offload_engine.get_load_slots_for_prefill(write_slot)
output = simulate_chunk_forward(layers, manager, seq, chunk_idx, CHUNK_SIZE)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
manager.deallocate(seq)
print(f" - Iteration {bench_iter + 1}/10 completed")
start_event.record()
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"Iteration_{i}")
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
torch.cuda.nvtx.range_pop()
end_event.record()
# 5. Verify results (using last iteration's seq)
print("\n[5] Verifying ring buffer and offload...")
for chunk_idx in range(NUM_CHUNKS):
expected_slot = chunk_idx % NUM_GPU_SLOTS
actual_slot = manager.offload_engine.get_write_slot_for_prefill(chunk_idx)
assert actual_slot == expected_slot, f"Chunk {chunk_idx}: expected slot {expected_slot}, got {actual_slot}"
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
cpu_block_table = manager.get_prefilled_cpu_blocks(seq)
assert cpu_block_table == seq.block_table[:NUM_CHUNKS], "CPU block table mismatch"
print(" - Ring buffer cycling verified ✓")
print(" - CPU offload verified ✓")
# Stats
total_blocks_loaded = NUM_PREV_BLOCKS * NUM_LAYERS * NUM_ITERS
blocks_per_sec = total_blocks_loaded / (elapsed_ms / 1000)
total_tokens = NUM_PREV_BLOCKS * BLOCK_SIZE * NUM_LAYERS * NUM_ITERS
tokens_per_sec = total_tokens / (elapsed_ms / 1000)
# Cleanup
manager.deallocate(seq)
print(f"\n[6] Results:")
print(f" - Total time: {elapsed_ms:.2f} ms")
print(f" - Per iteration: {elapsed_ms / NUM_ITERS:.2f} ms")
print(f" - Blocks loaded: {total_blocks_loaded} ({blocks_per_sec:.0f} blocks/s)")
print(f" - Tokens processed: {total_tokens} ({tokens_per_sec:.0f} tok/s)")
# 7. Verification
print("\n[7] Verification:")
assert len(outputs) == NUM_LAYERS, f"Expected {NUM_LAYERS} outputs"
for i, o in enumerate(outputs):
assert o is not None, f"Layer {i} output is None"
assert o.shape == (1, CHUNK_SIZE, NUM_HEADS, HEAD_DIM), f"Layer {i} shape mismatch"
print(" - All layer outputs valid ✓")
print(" - N-way pipeline executed correctly ✓")
# Cleanup
reset_context()