[WIP] fixing attention compute error.

This commit is contained in:
Zijie Tian
2025-12-30 00:31:48 +08:00
parent bf4c63c7ec
commit 89f8020d38
12 changed files with 2175 additions and 103 deletions

View File

@@ -31,6 +31,8 @@ class LLMEngine:
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
# Set Sequence.block_size to match the KV cache block size
Sequence.block_size = config.kvcache_block_size
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
atexit.register(self.exit)

View File

@@ -521,6 +521,7 @@ class ModelRunner:
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
# Sample from last logits
# For chunked prefill, ParallelLMHead automatically selects last position's logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None

View File

@@ -281,7 +281,11 @@ def _merge_lse_kernel(
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values."""
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
@@ -289,21 +293,21 @@ def _merge_lse_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@@ -313,7 +317,11 @@ def _merge_output_kernel(
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs."""
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
@@ -322,11 +330,11 @@ def _merge_output_kernel(
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values
lse1 = tl.load(lse1_ptr + lse_idx)
lse2 = tl.load(lse2_ptr + lse_idx)
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
@@ -343,14 +351,14 @@ def _merge_output_kernel(
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)

View File

@@ -337,10 +337,10 @@ class HybridKVCacheManager(KVCacheManager):
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
logger.debug(
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
f"returned cpu_blocks={cpu_blocks}"
)
# logger.debug(
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
# f"returned cpu_blocks={cpu_blocks}"
# )
return cpu_blocks
# ========== Ring Buffer CPU-primary support ==========

View File

@@ -538,7 +538,7 @@ class OffloadEngine:
def sync_indices(self) -> None:
"""Synchronize to ensure all index updates are complete."""
torch.cuda.current_stream().synchronize()
torch.cuda.default_stream().synchronize()
# ========== Cache access methods ==========
@@ -682,8 +682,9 @@ class OffloadEngine:
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Before starting the transfer, waits for any previous compute on this slot
to complete (using compute_done event).
Before starting the transfer, waits for:
1. Any previous compute on this slot to complete
2. Any pending offload of this slot to complete
Args:
slot_idx: Target GPU slot index
@@ -701,6 +702,10 @@ class OffloadEngine:
# This prevents data race: transfer must not start until attention finishes reading
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
# Also wait for any pending offload of this slot to complete
# This prevents race: load must not write GPU slot while offload is reading from it
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
@@ -763,7 +768,11 @@ class OffloadEngine:
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations
# - default_stream: for store_kvcache which runs on default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, slot_idx],
@@ -793,7 +802,9 @@ class OffloadEngine:
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
)

View File

@@ -169,9 +169,11 @@ class Attention(nn.Module):
else:
# Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine
q_batched, cpu_block_table, load_slots, offload_engine,
current_chunk_idx
)
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
@@ -187,11 +189,18 @@ class Attention(nn.Module):
if o_acc is None:
final_o = current_o
else:
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
# reading it on the default stream for the merge operation.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -205,12 +214,15 @@ class Attention(nn.Module):
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
compute_stream = offload_engine.compute_stream
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0, self.layer_id)
# IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
@@ -232,6 +244,7 @@ class Attention(nn.Module):
cpu_block_table: list,
load_slots: list,
offload_engine,
current_chunk_idx: int = -1,
):
"""
Ring buffer async pipeline loading with double buffering.
@@ -269,10 +282,14 @@ class Attention(nn.Module):
if pipeline_depth == 1:
# Only 1 slot available, cannot pipeline - use synchronous mode
# IMPORTANT: Must use compute_stream to match synchronization in
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
slot = load_slots[0]
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
offload_engine.wait_slot_layer(slot, self.layer_id)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
@@ -378,12 +395,13 @@ class Attention(nn.Module):
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
# Get only PREFILLED CPU blocks (exclude the current decode block)
# The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
@@ -401,12 +419,17 @@ class Attention(nn.Module):
)
offload_engine = kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
# Chunk size = capacity of each double buffer region (compute/prefetch)
# Each region uses half of decode_load_slots
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
# Check if double buffering is possible (need at least 2 separate regions)
# With only 1 load slot, compute and prefetch regions overlap -> can't double buffer
can_double_buffer = len(offload_engine.decode_load_slots) >= 2
o_acc = None
lse_acc = None
@@ -422,26 +445,13 @@ class Attention(nn.Module):
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Wait for current buffer to be ready
if use_compute:
offload_engine.wait_compute_layer(self.layer_id)
else:
offload_engine.wait_prefetch_layer(self.layer_id)
# Wait for current buffer to be ready on compute_stream
# The load runs on transfer_stream_main, compute runs on compute_stream
compute_stream.wait_stream(offload_engine.transfer_stream_main)
# Trigger async prefetch of next chunk to the OTHER buffer
# This overlaps transfer with current chunk's computation
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Get KV from current buffer
# All computation on explicit compute_stream
with torch.cuda.stream(compute_stream):
# Get KV from current buffer FIRST, before prefetching overwrites it
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
@@ -464,7 +474,24 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Swap buffers for next iteration
# Trigger async prefetch/load of next chunk to the OTHER buffer
# This happens AFTER attention completes, so the data is no longer needed
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if can_double_buffer:
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
else:
# Sync fallback: load next chunk to same slot (always compute region)
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Swap buffers for next iteration (only matters if can_double_buffer)
use_compute = not use_compute
# Now attend to Decode region (contains accumulated decode tokens)
@@ -472,6 +499,7 @@ class Attention(nn.Module):
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
with torch.cuda.stream(compute_stream):
if num_accumulated > 0:
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
@@ -492,4 +520,8 @@ class Attention(nn.Module):
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Sync back to default stream before returning
# Caller expects result to be ready on default stream
torch.cuda.default_stream().wait_stream(compute_stream)
return o_acc

View File

@@ -93,9 +93,9 @@ TEST_CASES = [
(1, 4, 256, 8, 128),
(1, 4, 512, 8, 128),
(1, 8, 512, 8, 128),
(1, 4, 1024, 8, 128),
(1, 4, 1024, 32, 128), # More heads
(1, 8, 256, 8, 64), # Smaller head dim
(1, 32, 1024, 8, 128),
(1, 32, 1024, 32, 128), # More heads
(1, 32, 256, 8, 64), # Smaller head dim
]
DTYPES = [torch.float16, torch.bfloat16]

View File

@@ -0,0 +1,374 @@
"""
Hook-based correctness test for chunked decode attention.
Uses PyTorch register_forward_hook() to capture real inference I/O,
then compares against reference computation to locate bugs.
This test targets the decode phase with CPU offload - after prefill,
the model generates tokens one by one while attending to all previous context.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import torch
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 8 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 2 * 1024 # 2K tokens for prefill
NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode
BLOCK_SIZE = 1024
# ============================================================
# Global capture storage
# ============================================================
captures = []
prefill_kv = {} # Store prefill k,v for reference computation
# ============================================================
# Hook Functions
# ============================================================
def make_hook(layer_id):
"""Create a forward hook for a specific layer."""
def hook(module, inputs, output):
q, k, v = inputs
ctx = get_context()
is_prefill = ctx.is_prefill
capture_entry = {
'layer_id': layer_id,
'is_prefill': is_prefill,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
'is_chunked_prefill': ctx.is_chunked_prefill,
}
if is_prefill:
# Store prefill k,v for reference computation
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry['chunk_idx'] = chunk_idx
if layer_id not in prefill_kv:
prefill_kv[layer_id] = []
prefill_kv[layer_id].append({
'chunk_idx': chunk_idx,
'k': k.clone().cpu(),
'v': v.clone().cpu(),
})
else:
# Decode phase - capture decode token info
capture_entry['decode_step'] = len([c for c in captures
if c['layer_id'] == layer_id and not c['is_prefill']])
captures.append(capture_entry)
return hook
def register_hooks(llm):
"""Register forward hooks on all Attention modules."""
hooks = []
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
hook = attn_module.register_forward_hook(make_hook(layer_idx))
hooks.append(hook)
return hooks
# ============================================================
# Reference Computation
# ============================================================
def compute_decode_reference(layer_id, decode_step, scale, debug=False):
"""
Compute reference decode attention output for a specific layer.
For decode, the query is a single token that attends to:
1. All prefill KV (from CPU cache)
2. All previous decode tokens (stored in GPU decode slot)
"""
# Get the decode capture
decode_captures = [c for c in captures
if c['layer_id'] == layer_id and not c['is_prefill']]
if decode_step >= len(decode_captures):
return None
decode_capture = decode_captures[decode_step]
q = decode_capture['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim]
if debug:
print(f" Reference for L{layer_id} D{decode_step}:")
print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}")
o_acc, lse_acc = None, None
# Attend to all prefill chunks
if layer_id in prefill_kv:
for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']):
k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim]
v = chunk_data['v'].cuda().unsqueeze(0)
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False)
if debug:
print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}")
if o_acc is None:
o_acc, lse_acc = o, lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
# Attend to previous decode tokens (including current)
# In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens
# For step 0, we just have the current token's k,v
# For step 1, we have tokens 0 and 1's k,v
# etc.
# Collect k,v from all decode steps up to and including current
decode_kv = []
for i in range(decode_step + 1):
if i < len(decode_captures):
decode_kv.append({
'k': decode_captures[i]['k'].cuda(),
'v': decode_captures[i]['v'].cuda(),
})
if decode_kv:
# Stack decode k,v into a single tensor
decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim]
decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0)
if debug:
print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}")
# For decode, we use causal=False since we're attending to all decode tokens
# (the causal masking was already handled by only including tokens up to current)
o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v,
softmax_scale=scale, causal=False)
if debug:
print(f" Decode attention: o.mean={o_decode.mean().item():.6f}")
if o_acc is None:
o_acc, lse_acc = o_decode, lse_decode
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode)
if o_acc is None:
return None
if debug:
print(f" Final: o.mean={o_acc.mean().item():.6f}")
return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
# ============================================================
# Test Runner
# ============================================================
def run_test(verbose=True):
"""Run the hook-based chunked decode correctness test."""
global captures, prefill_kv
captures = []
prefill_kv = {}
if verbose:
print("=" * 70)
print("Test: Hook-Based Chunked Decode Correctness")
print("=" * 70)
print(f"Model: {MODEL_PATH}")
print(f"Input length: {INPUT_LEN} tokens")
print(f"Decode tokens: {NUM_DECODE_TOKENS}")
print(f"Block size: {BLOCK_SIZE}")
print()
# Initialize LLM with CPU offload
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
if verbose:
print(f"Num layers: {num_layers}")
print(f"Head dim: {head_dim}")
print()
# Register hooks
hooks = register_hooks(llm)
if verbose:
print(f"Registered {len(hooks)} hooks")
# Generate random prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
# Run prefill and decode
if verbose:
print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...")
sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# =========== VERIFICATION: Check CPU cache after prefill ===========
# Verify that CPU cache data matches captured prefill k,v
if verbose:
print("\n--- CPU Cache Verification (After Prefill) ---")
offload_engine = llm.model_runner.kvcache_manager.offload_engine
# For each prefill capture, check if CPU cache matches
for layer_id in [0]: # Only check layer 0 for brevity
if layer_id not in prefill_kv:
continue
for chunk_data in prefill_kv[layer_id]:
chunk_idx = chunk_data['chunk_idx']
captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim]
# CPU block ID should be chunk_idx (based on allocation order)
cpu_block_id = chunk_idx
cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu()
diff = (captured_k - cpu_k).abs().max().item()
print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}")
if diff > 1e-3:
print(f" WARNING: CPU cache doesn't match captured k!")
print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}")
print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}")
print()
# Analyze captures
prefill_count = sum(1 for c in captures if c['is_prefill'])
decode_count = sum(1 for c in captures if not c['is_prefill'])
if verbose:
print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls")
# Count decode steps per layer
decode_per_layer = {}
for c in captures:
if not c['is_prefill']:
layer_id = c['layer_id']
if layer_id not in decode_per_layer:
decode_per_layer[layer_id] = 0
decode_per_layer[layer_id] += 1
if verbose:
print(f"Decode calls per layer: {decode_per_layer}")
print()
# Verify decode correctness
all_passed = True
results = []
first_fail_debug = True
for c in captures:
if c['is_prefill']:
continue # Skip prefill (already tested in test_chunked_prefill_hook.py)
layer_id = c['layer_id']
decode_step = c['decode_step']
# Only test first decode step for now (simpler reference computation)
if decode_step > 0:
continue
# Compute reference (debug first failure)
debug_this = (layer_id == 0 and first_fail_debug)
ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this)
if ref_output is None:
continue
# Compare
actual_output = c['output'].squeeze(0) # Remove seq dim for decode
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
results.append((layer_id, decode_step, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
# Debug first failure
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
# Find where max diff is
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
print("=" * 70)
# Summary
total_tests = len(results)
passed_count = sum(1 for r in results if r[2])
print(f"Results: {passed_count}/{total_tests} tests passed")
if not all_passed:
print("\nFailed tests:")
for layer_id, decode_step, passed, max_diff, mean_diff in results:
if not passed:
print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print()
return all_passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
passed = run_test(verbose=True)
if passed:
print("test_chunked_decode_hook: PASSED")
else:
print("test_chunked_decode_hook: FAILED")
exit(1)

View File

@@ -0,0 +1,473 @@
"""
Hook-based correctness test for chunked prefill attention.
Uses PyTorch register_forward_hook() to capture real inference I/O,
then compares against reference computation to locate bugs.
This test targets the integration layer (context setup, cpu_block_table management)
which is where the needle test fails despite isolated attention tests passing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
import torch
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024 # 4K tokens = 4 chunks with 1K block size
BLOCK_SIZE = 1024
# ============================================================
# Global capture storage
# ============================================================
captures = []
# ============================================================
# Hook Functions
# ============================================================
def make_hook(layer_id):
"""Create a forward hook for a specific layer."""
def hook(module, inputs, output):
q, k, v = inputs
ctx = get_context()
# Only capture prefill phase
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry = {
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
'is_chunked_prefill': ctx.is_chunked_prefill,
}
# For debugging: also capture CPU cache state for layer 0
if layer_id == 0 and chunk_idx >= 2:
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
oe = kvcache_manager.offload_engine
# Get what should have been loaded from CPU
cpu_k0 = oe.k_cache_cpu[0, 0].clone().cpu() # Layer 0, CPU block 0
cpu_k1 = oe.k_cache_cpu[0, 1].clone().cpu() # Layer 0, CPU block 1
capture_entry['cpu_k0'] = cpu_k0
capture_entry['cpu_k1'] = cpu_k1
captures.append(capture_entry)
return hook
def register_hooks(llm):
"""Register forward hooks on all Attention modules."""
hooks = []
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
hook = attn_module.register_forward_hook(make_hook(layer_idx))
hooks.append(hook)
return hooks
# ============================================================
# Reference Computation
# ============================================================
def compute_reference(layer_id, chunk_idx, scale, debug=False):
"""
Compute reference attention output for a specific layer and chunk.
Uses the captured k, v from all chunks up to and including chunk_idx.
"""
# Filter captures for this layer
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
if not layer_captures:
return None
# Get current chunk's q
current_capture = [c for c in layer_captures if c['chunk_idx'] == chunk_idx][0]
q = current_capture['q'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
# Collect all k, v up to current chunk
kv_list = []
for c in sorted(layer_captures, key=lambda x: x['chunk_idx']):
k = c['k'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
v = c['v'].cuda().unsqueeze(0)
kv_list.append((k, v, c['chunk_idx']))
if debug:
print(f" Reference for L{layer_id} C{chunk_idx}:")
print(f" q shape: {q.shape}, mean={q.mean().item():.4f}")
print(f" kv_list: {len(kv_list)} chunks")
for i, (k, v, cidx) in enumerate(kv_list):
print(f" chunk {cidx}: k.mean={k.mean().item():.4f}, v.mean={v.mean().item():.4f}")
o_acc, lse_acc = None, None
# Previous chunks: non-causal attention
for i in range(len(kv_list) - 1):
k, v, _ = kv_list[i]
o, lse = flash_attn_with_lse(q, k, v, softmax_scale=scale, causal=False)
if o_acc is None:
o_acc, lse_acc = o, lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
# Current chunk: causal attention
k_cur, v_cur, _ = kv_list[-1]
o_cur, lse_cur = flash_attn_with_lse(q, k_cur, v_cur, softmax_scale=scale, causal=True)
if o_acc is None:
return o_cur.squeeze(0).cpu()
final_o, _ = merge_attention_outputs(o_acc, lse_acc, o_cur, lse_cur)
return final_o.squeeze(0).cpu()
def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
"""
Compute reference using standard flash attention (single pass with all K, V).
This simulates what standard (non-chunked) prefill would produce.
Concatenates all Q, K, V from chunks 0 to chunk_idx and runs a single
causal attention pass, then extracts the output for the current chunk.
"""
# Filter captures for this layer
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
if not layer_captures:
return None
# Sort by chunk index
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
# Concatenate all Q, K, V
all_q = []
all_k = []
all_v = []
chunk_lengths = []
for c in layer_captures:
q = c['q'].cuda() # [seqlen, nheads, headdim]
k = c['k'].cuda()
v = c['v'].cuda()
all_q.append(q)
all_k.append(k)
all_v.append(v)
chunk_lengths.append(q.shape[0])
# Concatenate along sequence dimension
full_q = torch.cat(all_q, dim=0) # [total_seqlen, nheads, headdim]
full_k = torch.cat(all_k, dim=0)
full_v = torch.cat(all_v, dim=0)
total_len = full_q.shape[0]
if debug:
print(f" Standard Reference for L{layer_id} C{chunk_idx}:")
print(f" full_q shape: {full_q.shape}, mean={full_q.mean().item():.4f}")
print(f" full_k shape: {full_k.shape}, mean={full_k.mean().item():.4f}")
print(f" chunk_lengths: {chunk_lengths}")
# Run standard causal flash attention
# flash_attn_varlen_func expects: q, k, v with shape [total_seqlen, nheads, headdim]
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
full_o = flash_attn_varlen_func(
full_q, full_k, full_v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_len,
max_seqlen_k=total_len,
softmax_scale=scale,
causal=True,
)
# Extract output for current chunk only
start_pos = sum(chunk_lengths[:-1])
end_pos = sum(chunk_lengths)
chunk_output = full_o[start_pos:end_pos]
if debug:
print(f" full_o shape: {full_o.shape}")
print(f" extracting positions [{start_pos}:{end_pos}]")
print(f" chunk_output shape: {chunk_output.shape}, mean={chunk_output.mean().item():.4f}")
return chunk_output.cpu()
# ============================================================
# Test Runner
# ============================================================
def run_test(verbose=True):
"""Run the hook-based chunked prefill correctness test."""
global captures
captures = []
if verbose:
print("=" * 70)
print("Test: Hook-Based Chunked Prefill Correctness")
print("=" * 70)
print(f"Model: {MODEL_PATH}")
print(f"Input length: {INPUT_LEN} tokens")
print(f"Block size: {BLOCK_SIZE}")
print(f"Expected chunks: {INPUT_LEN // BLOCK_SIZE}")
print()
# Initialize LLM with CPU offload
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
if verbose:
print(f"Num layers: {num_layers}")
print(f"Head dim: {head_dim}")
print()
# Register hooks
hooks = register_hooks(llm)
if verbose:
print(f"Registered {len(hooks)} hooks")
# Generate random prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
# Run prefill only (max_tokens=1)
if verbose:
print("Running inference...")
sampling_params = SamplingParams(temperature=0.6, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Analyze captures
if verbose:
print(f"\nCaptured {len(captures)} attention calls")
# Group by layer and chunk
chunks_per_layer = {}
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if layer_id not in chunks_per_layer:
chunks_per_layer[layer_id] = set()
chunks_per_layer[layer_id].add(chunk_idx)
if verbose:
print("Chunks per layer:", {k: sorted(v) for k, v in chunks_per_layer.items()})
print()
# First, verify CPU cache data integrity
if verbose:
print("\n--- CPU Cache Verification (Layer 0) ---")
# Get original k from chunk 0 and chunk 1 captures
chunk0_k = None
chunk1_k = None
chunk2_capture = None
for c in captures:
if c['layer_id'] == 0:
if c['chunk_idx'] == 0:
chunk0_k = c['k']
elif c['chunk_idx'] == 1:
chunk1_k = c['k']
elif c['chunk_idx'] == 2:
chunk2_capture = c
if chunk0_k is not None and chunk2_capture is not None and 'cpu_k0' in chunk2_capture:
cpu_k0 = chunk2_capture['cpu_k0']
diff_k0 = (chunk0_k - cpu_k0).abs().max().item()
print(f"Chunk 0 k vs CPU cache block 0: max_diff={diff_k0:.6f}")
if diff_k0 > 1e-3:
print(f" WARNING: CPU cache block 0 differs from original chunk 0 k!")
print(f" Original k[0,0,:5] = {chunk0_k[0,0,:5].tolist()}")
print(f" CPU k0[0,0,:5] = {cpu_k0[0,0,:5].tolist()}")
if chunk1_k is not None and chunk2_capture is not None and 'cpu_k1' in chunk2_capture:
cpu_k1 = chunk2_capture['cpu_k1']
diff_k1 = (chunk1_k - cpu_k1).abs().max().item()
print(f"Chunk 1 k vs CPU cache block 1: max_diff={diff_k1:.6f}")
if diff_k1 > 1e-3:
print(f" WARNING: CPU cache block 1 differs from original chunk 1 k!")
print(f" Original k[0,0,:5] = {chunk1_k[0,0,:5].tolist()}")
print(f" CPU k1[0,0,:5] = {cpu_k1[0,0,:5].tolist()}")
print()
# ================================================================
# Test 1: Verify against merge-based reference (same algorithm)
# ================================================================
if verbose:
print("--- Test 1: Merge-based Reference (verifies merge algorithm) ---")
all_passed_merge = True
results_merge = []
first_fail_debug = True
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if chunk_idx == 0:
continue
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
ref_output = compute_reference(layer_id, chunk_idx, scale, debug=debug_this)
if ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed_merge = all_passed_merge and passed
status = "PASS" if passed else "FAIL"
results_merge.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
# ================================================================
# Test 2: Verify against standard flash attention (single pass)
# ================================================================
if verbose:
print("--- Test 2: Standard FlashAttn Reference (verifies correctness vs non-chunked) ---")
all_passed_standard = True
results_standard = []
first_fail_debug = True
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if chunk_idx == 0:
continue
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
std_ref_output = compute_standard_reference(layer_id, chunk_idx, scale, debug=debug_this)
if std_ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - std_ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed_standard = all_passed_standard and passed
status = "PASS" if passed else "FAIL"
results_standard.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: std_ref_output shape={std_ref_output.shape}, mean={std_ref_output.mean().item():.4f}")
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = std_ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
print("=" * 70)
# Summary
total_merge = len(results_merge)
passed_merge = sum(1 for r in results_merge if r[2])
total_standard = len(results_standard)
passed_standard = sum(1 for r in results_standard if r[2])
print(f"Merge-based reference: {passed_merge}/{total_merge} tests passed")
print(f"Standard FlashAttn ref: {passed_standard}/{total_standard} tests passed")
all_passed = all_passed_merge and all_passed_standard
if not all_passed_merge:
print("\nFailed merge-based tests:")
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_merge:
if not passed:
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
if not all_passed_standard:
print("\nFailed standard FlashAttn tests:")
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_standard:
if not passed:
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
print()
return all_passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
passed = run_test(verbose=True)
if passed:
print("test_chunked_prefill_hook: PASSED")
else:
print("test_chunked_prefill_hook: FAILED")
exit(1)

View File

@@ -0,0 +1,276 @@
"""
Test script for flash_attn_with_kvcache based chunked prefill.
Verifies that chunked prefill produces identical results to full attention.
"""
import torch
from flash_attn import flash_attn_func, flash_attn_with_kvcache
def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size):
"""
Chunked prefill using flash_attn_with_kvcache.
Args:
q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim]
k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim]
cache_seqlens: [batch] - current cache lengths
chunk_size: size of each chunk
Returns:
output: [batch, total_seq_len, heads, head_dim]
"""
total_len = q_full.shape[1]
outputs = []
for start in range(0, total_len, chunk_size):
end = min(start + chunk_size, total_len)
q_chunk = q_full[:, start:end]
k_chunk = k_full[:, start:end]
v_chunk = v_full[:, start:end]
out = flash_attn_with_kvcache(
q_chunk,
k_cache,
v_cache,
k=k_chunk,
v=v_chunk,
cache_seqlens=cache_seqlens,
causal=True,
)
outputs.append(out)
cache_seqlens += (end - start)
return torch.cat(outputs, dim=1)
def reference_attention(q, k, v):
"""Standard flash attention as reference."""
return flash_attn_func(q, k, v, causal=True)
def test_chunked_prefill_correctness():
"""Test that chunked prefill matches full attention."""
batch_size = 1
num_heads = 32
num_kv_heads = 8 # GQA
head_dim = 128
max_seq_len = 131072 # 128K
test_configs = [
(1024, 256), # 1K tokens, 256 chunk
(2048, 512), # 2K tokens, 512 chunk
(4096, 1024), # 4K tokens, 1K chunk
(4096, 2048), # 4K tokens, 2K chunk (2 chunks)
(8192, 2048), # 8K tokens, 2K chunk (4 chunks)
(16384, 4096), # 16K tokens, 4K chunk
(32768, 4096), # 32K tokens, 4K chunk
(65536, 8192), # 64K tokens, 8K chunk
(131072, 8192), # 128K tokens, 8K chunk (16 chunks)
]
for seq_len, chunk_size in test_configs:
print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...")
# Generate random input
torch.manual_seed(42)
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
# Expand K/V for non-GQA reference
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
# Reference: full attention
ref_out = reference_attention(q, k_expanded, v_expanded)
# Chunked prefill with KV cache
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
# Compare
max_diff = (ref_out - chunked_out).abs().max().item()
mean_diff = (ref_out - chunked_out).abs().mean().item()
# Verify cache was filled correctly
assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}"
# Check K/V cache content
k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item()
v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item()
print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}")
print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}")
# Tolerance for fp16
tolerance = 1e-2
if max_diff < tolerance:
print(f" PASSED")
else:
print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})")
return False
return True
def test_incremental_decode():
"""Test that decode after chunked prefill works correctly."""
batch_size = 1
num_heads = 32
num_kv_heads = 8
head_dim = 128
max_seq_len = 8192
prefill_len = 2048
chunk_size = 512
num_decode_steps = 10
print(f"\nTesting incremental decode after chunked prefill...")
print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}")
print(f" Decode: {num_decode_steps} steps")
torch.manual_seed(42)
# Prefill phase
q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
# Run chunked prefill
prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill,
k_cache, v_cache, cache_seqlens, chunk_size)
print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}")
# Decode phase - one token at a time
for step in range(num_decode_steps):
q_decode = torch.randn(batch_size, 1, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
decode_out = flash_attn_with_kvcache(
q_decode,
k_cache,
v_cache,
k=k_decode,
v=v_decode,
cache_seqlens=cache_seqlens,
causal=True,
)
cache_seqlens += 1
assert decode_out.shape == (batch_size, 1, num_heads, head_dim)
expected_len = prefill_len + num_decode_steps
actual_len = cache_seqlens[0].item()
print(f" After decode: cache_seqlens={actual_len}")
if actual_len == expected_len:
print(f" PASSED")
return True
else:
print(f" FAILED: expected {expected_len}, got {actual_len}")
return False
def test_batch_processing():
"""Test chunked prefill with batch > 1."""
batch_size = 4
num_heads = 32
num_kv_heads = 8
head_dim = 128
max_seq_len = 4096
seq_len = 2048
chunk_size = 512
print(f"\nTesting batch processing (batch_size={batch_size})...")
torch.manual_seed(42)
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
# Verify all batches have correct cache length
assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}"
assert out.shape == (batch_size, seq_len, num_heads, head_dim)
# Compare with reference for each batch item
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
ref_out = reference_attention(q, k_expanded, v_expanded)
max_diff = (ref_out - out).abs().max().item()
print(f" Output shape: {out.shape}")
print(f" Max diff vs reference: {max_diff:.6f}")
if max_diff < 1e-2:
print(f" PASSED")
return True
else:
print(f" FAILED")
return False
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Testing flash_attn_with_kvcache chunked prefill")
print("=" * 60)
all_passed = True
all_passed &= test_chunked_prefill_correctness()
all_passed &= test_incremental_decode()
all_passed &= test_batch_processing()
print("\n" + "=" * 60)
if all_passed:
print("test_flash_attn_kvcache: ALL TESTS PASSED")
else:
print("test_flash_attn_kvcache: SOME TESTS FAILED")
print("=" * 60)

322
tests/test_needle.py Normal file
View File

@@ -0,0 +1,322 @@
"""
Needle-in-a-haystack test for LLM.
Tests: Long context retrieval capability with configurable sequence length.
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
import argparse
from nanovllm import LLM, SamplingParams
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question at the end
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
# Also try to find it as a standalone number
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test.
Args:
model_path: Path to model
max_model_len: Maximum model context length
input_len: Target input sequence length
num_gpu_blocks: Number of GPU blocks for offload
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
print(f"{'='*60}\n")
# 1. Initialize LLM
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt
prompt, expected = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Generate output
sampling_params = SamplingParams(
temperature=0.6, # Moderate temperature
max_tokens=max_new_tokens,
)
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
# 4. Check result
output_text = outputs[0]["text"]
output_token_ids = outputs[0]["token_ids"]
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=32 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload (has known bug for long sequences)"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
verbose=True,
)
if passed:
print("test_needle: PASSED")
else:
print("test_needle: FAILED")
exit(1)

View File

@@ -0,0 +1,573 @@
"""
Correctness test for chunked attention with CPU offload.
Validates that the offload pipeline (CPU -> GPU transfer + chunked attention)
produces the same result as direct GPU computation.
Test scenario:
1. Generate Q, K, V data
2. Reference: Compute full causal attention on GPU
3. Offload: Store K, V in CPU cache, load via pipeline, compute chunked attention
4. Compare results
This test is designed to identify bugs in:
- CPU <-> GPU data transfer (sgDMA)
- Ring buffer slot management
- N-way pipeline ordering
- Triton merge kernel correctness
"""
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# ============================================================
# Configuration
# ============================================================
NUM_LAYERS = 4
NUM_HEADS = 8
NUM_KV_HEADS = 8
HEAD_DIM = 64
BLOCK_SIZE = 256 # Smaller for faster testing
DTYPE = torch.bfloat16
DEVICE = "cuda"
# ============================================================
# Reference Implementation (GPU only, no offload)
# ============================================================
def compute_reference_causal(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Compute reference causal attention using flash_attn_func.
Args:
q, k, v: [batch, seqlen, nheads, headdim]
Returns:
out: [batch, seqlen, nheads, headdim]
"""
return flash_attn_func(q, k, v, causal=True)
def compute_reference_chunked(
q_chunks: list,
kv_chunks: list,
scale: float,
) -> torch.Tensor:
"""
Compute chunked prefill attention directly on GPU (no offload).
This is the "gold standard" for chunked attention correctness.
Args:
q_chunks: List of [batch, chunk_size, nheads, headdim]
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
scale: Softmax scale
Returns:
out: [batch, total_seqlen, nheads, headdim]
"""
out_chunks = []
for chunk_idx, q_chunk in enumerate(q_chunks):
o_acc, lse_acc = None, None
# Attend to all previous chunks (no causal mask)
for i in range(chunk_idx):
k_chunk, v_chunk = kv_chunks[i]
chunk_o, chunk_lse = flash_attn_with_lse(
q_chunk, k_chunk, v_chunk,
softmax_scale=scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = chunk_o, chunk_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)
# Attend to current chunk (with causal mask)
k_chunk, v_chunk = kv_chunks[chunk_idx]
current_o, current_lse = flash_attn_with_lse(
q_chunk, k_chunk, v_chunk,
softmax_scale=scale,
causal=True,
)
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
out_chunks.append(final_o)
return torch.cat(out_chunks, dim=1)
# ============================================================
# Offload Implementation
# ============================================================
def create_manager(num_gpu_slots: int, num_cpu_blocks: int):
"""Create HybridKVCacheManager with specified configuration."""
manager = HybridKVCacheManager(
num_gpu_slots=num_gpu_slots,
num_cpu_blocks=num_cpu_blocks,
block_size=BLOCK_SIZE,
)
manager.allocate_cache(
num_layers=NUM_LAYERS,
num_kv_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
dtype=DTYPE,
)
return manager
def store_kv_to_cpu_cache(manager, kv_chunks: list, layer_id: int):
"""
Store K, V chunks to CPU cache.
Args:
manager: HybridKVCacheManager
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
layer_id: Layer index
Returns:
cpu_block_ids: List of CPU block IDs
"""
offload_engine = manager.offload_engine
cpu_block_ids = []
for block_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
# k_chunk, v_chunk: [batch, chunk_size, nheads, headdim]
# CPU cache layout: [num_layers, num_blocks, block_size, nheads, headdim]
k_data = k_chunk.squeeze(0) # [chunk_size, nheads, headdim]
v_data = v_chunk.squeeze(0)
offload_engine.k_cache_cpu[layer_id, block_idx, :k_data.shape[0]].copy_(k_data)
offload_engine.v_cache_cpu[layer_id, block_idx, :v_data.shape[0]].copy_(v_data)
cpu_block_ids.append(block_idx)
return cpu_block_ids
def compute_offload_chunked_single_layer(
manager,
q_chunks: list,
cpu_block_ids: list,
layer_id: int,
scale: float,
) -> torch.Tensor:
"""
Compute chunked attention for a single layer using offload pipeline.
This mimics the behavior of Attention._ring_buffer_pipeline_load().
Args:
manager: HybridKVCacheManager
q_chunks: List of [batch, chunk_size, nheads, headdim]
cpu_block_ids: List of CPU block IDs containing K, V data
layer_id: Layer index
scale: Softmax scale
Returns:
out: [batch, total_seqlen, nheads, headdim]
"""
offload_engine = manager.offload_engine
out_chunks = []
for chunk_idx, q_chunk in enumerate(q_chunks):
# CPU blocks to load: all blocks before current chunk
blocks_to_load = cpu_block_ids[:chunk_idx]
# Get slots for this chunk
write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
# Load and compute attention for previous chunks
o_acc, lse_acc = None, None
if len(blocks_to_load) > 0 and len(load_slots) > 0:
o_acc, lse_acc = _pipeline_load_and_compute(
offload_engine,
q_chunk,
blocks_to_load,
load_slots,
layer_id,
scale,
)
# Current chunk's K, V (load from CPU to GPU slot)
current_cpu_block = cpu_block_ids[chunk_idx]
offload_engine.load_to_slot_layer(write_slot, layer_id, current_cpu_block)
offload_engine.wait_slot_layer(write_slot, layer_id)
current_k, current_v = offload_engine.get_kv_for_slot(write_slot, layer_id)
# Compute attention with causal mask
current_o, current_lse = flash_attn_with_lse(
q_chunk, current_k, current_v,
softmax_scale=scale,
causal=True,
)
# Merge
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
out_chunks.append(final_o)
return torch.cat(out_chunks, dim=1)
def _pipeline_load_and_compute(
offload_engine,
q_chunk: torch.Tensor,
cpu_block_table: list,
load_slots: list,
layer_id: int,
scale: float,
):
"""
Pipeline loading from CPU and computing attention.
Mirrors Attention._ring_buffer_pipeline_load() logic.
"""
num_blocks = len(cpu_block_table)
num_slots = len(load_slots)
o_acc, lse_acc = None, None
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
# Phase 2: Main loop
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
# Wait for transfer
offload_engine.wait_slot_layer(current_slot, layer_id)
# Compute on dedicated stream
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_chunk, prev_k, prev_v,
softmax_scale=scale,
causal=False,
)
offload_engine.record_slot_compute_done(current_slot, layer_id)
# Start next transfer
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(
current_slot, layer_id, cpu_block_table[next_block_idx]
)
# Merge
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Sync compute stream
compute_stream.synchronize()
return o_acc, lse_acc
# ============================================================
# Test Runner
# ============================================================
def run_correctness_test(
num_chunks: int,
num_gpu_slots: int,
verbose: bool = True,
) -> tuple[bool, float, float]:
"""
Run a single correctness test.
Args:
num_chunks: Number of chunks (= number of CPU blocks)
num_gpu_slots: Number of GPU ring buffer slots
verbose: Print detailed info
Returns:
(passed, max_diff, mean_diff)
"""
torch.manual_seed(42)
seqlen = num_chunks * BLOCK_SIZE
scale = HEAD_DIM ** -0.5
# Generate Q, K, V
q_full = torch.randn(1, seqlen, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
k_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
v_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
# Split into chunks
q_chunks = [q_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE] for i in range(num_chunks)]
kv_chunks = [
(k_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE],
v_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE])
for i in range(num_chunks)
]
# Reference: chunked attention on GPU (no offload)
out_ref = compute_reference_chunked(q_chunks, kv_chunks, scale)
# Create manager with enough CPU blocks
manager = create_manager(num_gpu_slots, num_chunks)
# Test each layer
all_passed = True
max_diff_all = 0.0
mean_diff_all = 0.0
for layer_id in range(NUM_LAYERS):
# Store K, V to CPU cache
cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id)
# Compute with offload
out_offload = compute_offload_chunked_single_layer(
manager, q_chunks, cpu_block_ids, layer_id, scale
)
# Compare
diff = (out_ref - out_offload).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
max_diff_all = max(max_diff_all, max_diff)
mean_diff_all = max(mean_diff_all, mean_diff)
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
if verbose and not passed:
print(f" Layer {layer_id}: FAIL max_diff={max_diff:.6f}")
return all_passed, max_diff_all, mean_diff_all
# ============================================================
# Decode Phase Test
# ============================================================
def run_decode_correctness_test(
num_prefill_chunks: int,
num_gpu_slots: int,
num_decode_steps: int = 4,
verbose: bool = True,
) -> tuple[bool, float, float]:
"""
Test decode phase correctness with CPU offload.
Simulates:
1. Prefill: Store K, V for multiple chunks in CPU cache
2. Decode: Single token queries against all prefilled K, V
This tests the scenario in needle test where decode reads all previous KV.
"""
torch.manual_seed(42)
scale = HEAD_DIM ** -0.5
prefill_len = num_prefill_chunks * BLOCK_SIZE
# Generate prefill K, V (store in CPU)
k_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
v_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
# Split into chunks for CPU storage
kv_chunks = [
(k_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE],
v_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE])
for i in range(num_prefill_chunks)
]
# Create manager
manager = create_manager(num_gpu_slots, num_prefill_chunks)
offload_engine = manager.offload_engine
all_passed = True
max_diff_all = 0.0
mean_diff_all = 0.0
for layer_id in range(NUM_LAYERS):
# Store prefilled K, V to CPU cache
cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id)
for decode_step in range(num_decode_steps):
# Decode query: single token
q_decode = torch.randn(1, 1, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
# Reference: direct attention on GPU
# Concat all prefilled K, V and compute attention
out_ref = flash_attn_func(
q_decode,
k_prefill,
v_prefill,
causal=False, # Decode query can attend to all prefilled tokens
)
# Offload: load from CPU and compute
load_slots = offload_engine.get_load_slots_for_prefill(0) # Use all slots except decode slot
if len(load_slots) == 0 or len(cpu_block_ids) == 0:
# No previous chunks to load
out_offload = out_ref # Trivially equal
else:
o_acc, lse_acc = _pipeline_load_and_compute(
offload_engine,
q_decode,
cpu_block_ids,
load_slots,
layer_id,
scale,
)
out_offload = o_acc
# Compare
diff = (out_ref - out_offload).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
max_diff_all = max(max_diff_all, max_diff)
mean_diff_all = max(mean_diff_all, mean_diff)
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
if verbose and not passed:
print(f" Layer {layer_id} Step {decode_step}: FAIL max_diff={max_diff:.6f}")
return all_passed, max_diff_all, mean_diff_all
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 70)
print("Test: Offload Chunked Attention Correctness")
print("=" * 70)
print(f"Config: layers={NUM_LAYERS}, heads={NUM_HEADS}, kv_heads={NUM_KV_HEADS}, "
f"head_dim={HEAD_DIM}, block_size={BLOCK_SIZE}, dtype={DTYPE}")
print()
print("Comparing: Reference (GPU chunked) vs Offload (CPU->GPU pipeline)")
print()
# Test configurations: (num_chunks, num_gpu_slots)
TEST_CASES = [
# Basic tests
(2, 2), # Minimal: 2 chunks, 2 slots (no pipeline)
(2, 3), # 2 chunks, 3 slots (1-slot pipeline)
(4, 2), # 4 chunks, 2 slots (heavy slot reuse)
(4, 3), # 4 chunks, 3 slots
(4, 4), # 4 chunks, 4 slots
# Stress tests
(8, 3), # Many chunks, few slots
(8, 4), # Many chunks, moderate slots
(8, 6), # Many chunks, many slots (like bench_offload)
# Edge cases
(1, 2), # Single chunk
(3, 5), # Fewer chunks than slots
]
all_passed = True
results = []
for num_chunks, num_gpu_slots in TEST_CASES:
seqlen = num_chunks * BLOCK_SIZE
passed, max_diff, mean_diff = run_correctness_test(
num_chunks, num_gpu_slots, verbose=False
)
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
results.append((num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff))
print(f"[{status}] chunks={num_chunks:2d} slots={num_gpu_slots:2d} "
f"seqlen={seqlen:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
print()
# ================================================================
# Part 2: Decode Phase Tests
# ================================================================
print("=" * 70)
print("Part 2: Decode Phase Correctness")
print("=" * 70)
print("Testing: Decode query (single token) against all prefilled K, V")
print()
DECODE_TEST_CASES = [
# (num_prefill_chunks, num_gpu_slots)
(2, 2),
(4, 3),
(4, 4),
(8, 4),
(8, 6),
]
decode_results = []
for num_prefill_chunks, num_gpu_slots in DECODE_TEST_CASES:
prefill_len = num_prefill_chunks * BLOCK_SIZE
passed, max_diff, mean_diff = run_decode_correctness_test(
num_prefill_chunks, num_gpu_slots, num_decode_steps=4, verbose=False
)
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
decode_results.append((num_prefill_chunks, num_gpu_slots, prefill_len, passed, max_diff, mean_diff))
print(f"[{status}] prefill_chunks={num_prefill_chunks:2d} slots={num_gpu_slots:2d} "
f"prefill_len={prefill_len:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
print()
print("=" * 70)
# Summary
prefill_passed = sum(1 for r in results if r[3])
decode_passed = sum(1 for r in decode_results if r[3])
total_tests = len(results) + len(decode_results)
total_passed = prefill_passed + decode_passed
print(f"Results: {total_passed}/{total_tests} tests passed")
print(f" - Prefill: {prefill_passed}/{len(results)}")
print(f" - Decode: {decode_passed}/{len(decode_results)}")
if not all_passed:
print("\nFailed tests:")
for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in results:
if not passed:
print(f" - [Prefill] chunks={num_chunks}, slots={num_gpu_slots}, "
f"seqlen={seqlen}, max_diff={max_diff:.6f}")
for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in decode_results:
if not passed:
print(f" - [Decode] prefill_chunks={num_chunks}, slots={num_gpu_slots}, "
f"prefill_len={seqlen}, max_diff={max_diff:.6f}")
print()
assert all_passed, "Some correctness tests failed!"
print("test_offload_correctness: PASSED")