diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 637b5f6..d91326a 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -31,6 +31,8 @@ class LLMEngine: self.model_runner = ModelRunner(config, 0, self.events) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) config.eos = self.tokenizer.eos_token_id + # Set Sequence.block_size to match the KV cache block size + Sequence.block_size = config.kvcache_block_size self.scheduler = Scheduler(config, self.model_runner.kvcache_manager) atexit.register(self.exit) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 212036a..3e281e6 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -521,6 +521,7 @@ class ModelRunner: print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr) # Sample from last logits + # For chunked prefill, ParallelLMHead automatically selects last position's logits temperatures = self.prepare_sample(seqs) if self.rank == 0 else None if logits is not None: token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None diff --git a/nanovllm/kvcache/chunked_attention.py b/nanovllm/kvcache/chunked_attention.py index 862fd5a..6f92c33 100644 --- a/nanovllm/kvcache/chunked_attention.py +++ b/nanovllm/kvcache/chunked_attention.py @@ -281,7 +281,11 @@ def _merge_lse_kernel( num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - """Fused kernel for merging LSE values.""" + """Fused kernel for merging LSE values. + + IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss. + bf16 has only 7 bits of mantissa, causing significant errors in exp/log. + """ # Each program handles BLOCK_SIZE elements pid = tl.program_id(0) block_start = pid * BLOCK_SIZE @@ -289,21 +293,21 @@ def _merge_lse_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < num_elements - # Load lse values - lse1 = tl.load(lse1_ptr + offsets, mask=mask) - lse2 = tl.load(lse2_ptr + offsets, mask=mask) + # Load lse values and convert to fp32 for precision + lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32) + lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32) - # Compute max for numerical stability + # Compute max for numerical stability (in fp32) max_lse = tl.maximum(lse1, lse2) - # Compute exp(lse - max_lse) + # Compute exp(lse - max_lse) in fp32 exp1 = tl.exp(lse1 - max_lse) exp2 = tl.exp(lse2 - max_lse) - # Compute merged LSE: max_lse + log(exp1 + exp2) + # Compute merged LSE: max_lse + log(exp1 + exp2) in fp32 lse_merged = max_lse + tl.log(exp1 + exp2) - # Store result + # Store result (convert back to original dtype) tl.store(lse_out_ptr + offsets, lse_merged, mask=mask) @@ -313,7 +317,11 @@ def _merge_output_kernel( batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - """Fused kernel for merging attention outputs.""" + """Fused kernel for merging attention outputs. + + IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss. + This is critical for numerical accuracy in chunked attention. + """ # Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position pid_batch = tl.program_id(0) pid_seq = tl.program_id(1) @@ -322,11 +330,11 @@ def _merge_output_kernel( # Compute LSE index: [batch, nheads, seqlen_q] lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq - # Load LSE values - lse1 = tl.load(lse1_ptr + lse_idx) - lse2 = tl.load(lse2_ptr + lse_idx) + # Load LSE values and convert to fp32 for precision + lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32) + lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32) - # Compute max and scaling factors + # Compute max and scaling factors in fp32 max_lse = tl.maximum(lse1, lse2) exp1 = tl.exp(lse1 - max_lse) exp2 = tl.exp(lse2 - max_lse) @@ -343,14 +351,14 @@ def _merge_output_kernel( pid_head * headdim) o_idx = base_idx + d_idx - # Load o1, o2 - o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0) - o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0) + # Load o1, o2 and convert to fp32 for weighted sum + o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32) + o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32) - # Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp + # Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp - # Store result + # Store result (Triton will convert back to original dtype) tl.store(o_out_ptr + o_idx, o_merged, mask=mask) diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index a4004b6..937c626 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -337,10 +337,10 @@ class HybridKVCacheManager(KVCacheManager): block = self.logical_blocks[logical_id] if block.location == BlockLocation.CPU: cpu_blocks.append(block.cpu_block_id) - logger.debug( - f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, " - f"returned cpu_blocks={cpu_blocks}" - ) + # logger.debug( + # f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, " + # f"returned cpu_blocks={cpu_blocks}" + # ) return cpu_blocks # ========== Ring Buffer CPU-primary support ========== diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 46011b0..bf8cc16 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -538,7 +538,7 @@ class OffloadEngine: def sync_indices(self) -> None: """Synchronize to ensure all index updates are complete.""" - torch.cuda.current_stream().synchronize() + torch.cuda.default_stream().synchronize() # ========== Cache access methods ========== @@ -682,8 +682,9 @@ class OffloadEngine: Async load a single CPU block to a ring buffer slot for one layer. This is the core building block for ring buffer pipelining. - Before starting the transfer, waits for any previous compute on this slot - to complete (using compute_done event). + Before starting the transfer, waits for: + 1. Any previous compute on this slot to complete + 2. Any pending offload of this slot to complete Args: slot_idx: Target GPU slot index @@ -701,6 +702,10 @@ class OffloadEngine: # This prevents data race: transfer must not start until attention finishes reading stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id]) + # Also wait for any pending offload of this slot to complete + # This prevents race: load must not write GPU slot while offload is reading from it + stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx]) + self.k_cache_gpu[layer_id, slot_idx].copy_( self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) @@ -763,7 +768,11 @@ class OffloadEngine: torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]") with torch.cuda.stream(self.transfer_stream_main): + # Wait for both compute_stream and default stream + # - compute_stream: for flash attention operations + # - default_stream: for store_kvcache which runs on default stream self.transfer_stream_main.wait_stream(self.compute_stream) + self.transfer_stream_main.wait_stream(torch.cuda.default_stream()) memcpy_2d_async( self.k_cache_cpu[:, cpu_block_id], self.k_cache_gpu[:, slot_idx], @@ -793,7 +802,9 @@ class OffloadEngine: cpu_block_id: Target CPU block ID """ with torch.cuda.stream(self.transfer_stream_main): + # Wait for both compute_stream and default stream self.transfer_stream_main.wait_stream(self.compute_stream) + self.transfer_stream_main.wait_stream(torch.cuda.default_stream()) self.k_cache_cpu[layer_id, cpu_block_id].copy_( self.k_cache_gpu[layer_id, slot_idx], non_blocking=True ) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 33e8d2a..2caac7e 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -169,9 +169,11 @@ class Attention(nn.Module): else: # Use ring buffer pipeline o_acc, lse_acc = self._ring_buffer_pipeline_load( - q_batched, cpu_block_table, load_slots, offload_engine + q_batched, cpu_block_table, load_slots, offload_engine, + current_chunk_idx ) + # Compute attention against current chunk's KV (with causal mask) torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") current_o, current_lse = flash_attn_with_lse( @@ -187,11 +189,18 @@ class Attention(nn.Module): if o_acc is None: final_o = current_o else: + # IMPORTANT: o_acc was computed on compute_stream. We need to sync before + # reading it on the default stream for the merge operation. + if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'): + offload_engine = kvcache_manager.offload_engine + torch.cuda.default_stream().wait_stream(offload_engine.compute_stream) + torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop() # ChunkedPrefill + # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] return final_o.squeeze(0) @@ -205,24 +214,27 @@ class Attention(nn.Module): from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs o_acc, lse_acc = None, None + compute_stream = offload_engine.compute_stream for block_idx, cpu_block_id in enumerate(cpu_block_table): # Load to slot 0 (single slot) offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id) offload_engine.wait_slot_layer(0, self.layer_id) - prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id) + # IMPORTANT: Must use compute_stream to match wait_slot_layer + with torch.cuda.stream(compute_stream): + prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id) - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=self.scale, - causal=False, - ) + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=self.scale, + causal=False, + ) - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) return o_acc, lse_acc @@ -232,6 +244,7 @@ class Attention(nn.Module): cpu_block_table: list, load_slots: list, offload_engine, + current_chunk_idx: int = -1, ): """ Ring buffer async pipeline loading with double buffering. @@ -269,22 +282,26 @@ class Attention(nn.Module): if pipeline_depth == 1: # Only 1 slot available, cannot pipeline - use synchronous mode + # IMPORTANT: Must use compute_stream to match synchronization in + # load_to_slot_layer (waits for compute_done) and wait_slot_layer slot = load_slots[0] + compute_stream = offload_engine.compute_stream for block_idx in range(num_blocks): offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx]) offload_engine.wait_slot_layer(slot, self.layer_id) - prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id) - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=self.scale, - causal=False, - ) - # Record compute done so next load can safely reuse this slot - offload_engine.record_slot_compute_done(slot, self.layer_id) - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + with torch.cuda.stream(compute_stream): + prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id) + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=self.scale, + causal=False, + ) + # Record compute done so next load can safely reuse this slot + offload_engine.record_slot_compute_done(slot, self.layer_id) + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) return o_acc, lse_acc # N-way pipeline: use ALL available slots for maximum overlap @@ -378,12 +395,13 @@ class Attention(nn.Module): kvcache_manager = context.kvcache_manager seq = context.chunked_seq - # Get all CPU blocks for this sequence - cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq) + # Get only PREFILLED CPU blocks (exclude the current decode block) + # The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU + cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) if self.layer_id == 0: logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}") if not cpu_block_table: - raise RuntimeError("Chunked decode attention failed: no CPU blocks available") + raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") # Apply sparse policy if enabled if kvcache_manager.sparse_policy is not None: @@ -401,12 +419,17 @@ class Attention(nn.Module): ) offload_engine = kvcache_manager.offload_engine + compute_stream = offload_engine.compute_stream # Chunk size = capacity of each double buffer region (compute/prefetch) # Each region uses half of decode_load_slots chunk_size = max(1, len(offload_engine.decode_load_slots) // 2) num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size + # Check if double buffering is possible (need at least 2 separate regions) + # With only 1 load slot, compute and prefetch regions overlap -> can't double buffer + can_double_buffer = len(offload_engine.decode_load_slots) >= 2 + o_acc = None lse_acc = None @@ -422,49 +445,53 @@ class Attention(nn.Module): end = min(start + chunk_size, len(cpu_block_table)) num_blocks_in_chunk = end - start - # Wait for current buffer to be ready - if use_compute: - offload_engine.wait_compute_layer(self.layer_id) - else: - offload_engine.wait_prefetch_layer(self.layer_id) + # Wait for current buffer to be ready on compute_stream + # The load runs on transfer_stream_main, compute runs on compute_stream + compute_stream.wait_stream(offload_engine.transfer_stream_main) - # Trigger async prefetch of next chunk to the OTHER buffer - # This overlaps transfer with current chunk's computation + # All computation on explicit compute_stream + with torch.cuda.stream(compute_stream): + # Get KV from current buffer FIRST, before prefetching overwrites it + if use_compute: + k_chunk, v_chunk = offload_engine.get_kv_for_compute( + self.layer_id, num_blocks_in_chunk + ) + else: + k_chunk, v_chunk = offload_engine.get_kv_for_prefetch( + self.layer_id, num_blocks_in_chunk + ) + + # Compute attention for this chunk + o_chunk, lse_chunk = flash_attn_with_lse( + q_batched, k_chunk, v_chunk, + softmax_scale=self.scale, + causal=False, + ) + + # Merge with accumulated + if o_acc is None: + o_acc, lse_acc = o_chunk, lse_chunk + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) + + # Trigger async prefetch/load of next chunk to the OTHER buffer + # This happens AFTER attention completes, so the data is no longer needed if chunk_idx + 1 < num_chunks: next_start = end next_end = min(next_start + chunk_size, len(cpu_block_table)) next_chunk_ids = cpu_block_table[next_start:next_end] - if use_compute: - # Current in Compute, prefetch next to Prefetch region - offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids) + if can_double_buffer: + if use_compute: + # Current in Compute, prefetch next to Prefetch region + offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids) + else: + # Current in Prefetch, prefetch next to Compute region + offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids) else: - # Current in Prefetch, prefetch next to Compute region + # Sync fallback: load next chunk to same slot (always compute region) offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids) - # Get KV from current buffer - if use_compute: - k_chunk, v_chunk = offload_engine.get_kv_for_compute( - self.layer_id, num_blocks_in_chunk - ) - else: - k_chunk, v_chunk = offload_engine.get_kv_for_prefetch( - self.layer_id, num_blocks_in_chunk - ) - - # Compute attention for this chunk - o_chunk, lse_chunk = flash_attn_with_lse( - q_batched, k_chunk, v_chunk, - softmax_scale=self.scale, - causal=False, - ) - - # Merge with accumulated - if o_acc is None: - o_acc, lse_acc = o_chunk, lse_chunk - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) - - # Swap buffers for next iteration + # Swap buffers for next iteration (only matters if can_double_buffer) use_compute = not use_compute # Now attend to Decode region (contains accumulated decode tokens) @@ -472,24 +499,29 @@ class Attention(nn.Module): start_pos = context.decode_start_pos_in_block num_accumulated = pos_in_block - start_pos + 1 - if num_accumulated > 0: - decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] - decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] - decode_k = decode_k.unsqueeze(0) - decode_v = decode_v.unsqueeze(0) + with torch.cuda.stream(compute_stream): + if num_accumulated > 0: + decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] + decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] + decode_k = decode_k.unsqueeze(0) + decode_v = decode_v.unsqueeze(0) - decode_o, decode_lse = flash_attn_with_lse( - q_batched, decode_k, decode_v, - softmax_scale=self.scale, - causal=False, - ) + decode_o, decode_lse = flash_attn_with_lse( + q_batched, decode_k, decode_v, + softmax_scale=self.scale, + causal=False, + ) - if o_acc is None: - o_acc = decode_o - else: - o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse) + if o_acc is None: + o_acc = decode_o + else: + o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse) if o_acc is None: raise RuntimeError("Chunked decode attention failed: no KV available") + # Sync back to default stream before returning + # Caller expects result to be ready on default stream + torch.cuda.default_stream().wait_stream(compute_stream) + return o_acc diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py index 48d114c..08928e3 100644 --- a/tests/test_chunked_attention.py +++ b/tests/test_chunked_attention.py @@ -93,9 +93,9 @@ TEST_CASES = [ (1, 4, 256, 8, 128), (1, 4, 512, 8, 128), (1, 8, 512, 8, 128), - (1, 4, 1024, 8, 128), - (1, 4, 1024, 32, 128), # More heads - (1, 8, 256, 8, 64), # Smaller head dim + (1, 32, 1024, 8, 128), + (1, 32, 1024, 32, 128), # More heads + (1, 32, 256, 8, 64), # Smaller head dim ] DTYPES = [torch.float16, torch.bfloat16] diff --git a/tests/test_chunked_decode_hook.py b/tests/test_chunked_decode_hook.py new file mode 100644 index 0000000..90cce90 --- /dev/null +++ b/tests/test_chunked_decode_hook.py @@ -0,0 +1,374 @@ +""" +Hook-based correctness test for chunked decode attention. + +Uses PyTorch register_forward_hook() to capture real inference I/O, +then compares against reference computation to locate bugs. + +This test targets the decode phase with CPU offload - after prefill, +the model generates tokens one by one while attending to all previous context. +""" + +import os +os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" + +import torch +from random import randint, seed +from nanovllm import LLM, SamplingParams +from nanovllm.utils.context import get_context +from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + +# ============================================================ +# Configuration +# ============================================================ + +MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") +MAX_MODEL_LEN = 8 * 1024 +NUM_GPU_BLOCKS = 2 +INPUT_LEN = 2 * 1024 # 2K tokens for prefill +NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode +BLOCK_SIZE = 1024 + + +# ============================================================ +# Global capture storage +# ============================================================ + +captures = [] +prefill_kv = {} # Store prefill k,v for reference computation + + +# ============================================================ +# Hook Functions +# ============================================================ + +def make_hook(layer_id): + """Create a forward hook for a specific layer.""" + def hook(module, inputs, output): + q, k, v = inputs + ctx = get_context() + + is_prefill = ctx.is_prefill + + capture_entry = { + 'layer_id': layer_id, + 'is_prefill': is_prefill, + 'q': q.clone().cpu(), + 'k': k.clone().cpu(), + 'v': v.clone().cpu(), + 'output': output.clone().cpu(), + 'is_chunked_prefill': ctx.is_chunked_prefill, + } + + if is_prefill: + # Store prefill k,v for reference computation + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + capture_entry['chunk_idx'] = chunk_idx + if layer_id not in prefill_kv: + prefill_kv[layer_id] = [] + prefill_kv[layer_id].append({ + 'chunk_idx': chunk_idx, + 'k': k.clone().cpu(), + 'v': v.clone().cpu(), + }) + else: + # Decode phase - capture decode token info + capture_entry['decode_step'] = len([c for c in captures + if c['layer_id'] == layer_id and not c['is_prefill']]) + + captures.append(capture_entry) + return hook + + +def register_hooks(llm): + """Register forward hooks on all Attention modules.""" + hooks = [] + model = llm.model_runner.model + + for layer_idx, decoder_layer in enumerate(model.model.layers): + attn_module = decoder_layer.self_attn.attn + hook = attn_module.register_forward_hook(make_hook(layer_idx)) + hooks.append(hook) + + return hooks + + +# ============================================================ +# Reference Computation +# ============================================================ + +def compute_decode_reference(layer_id, decode_step, scale, debug=False): + """ + Compute reference decode attention output for a specific layer. + + For decode, the query is a single token that attends to: + 1. All prefill KV (from CPU cache) + 2. All previous decode tokens (stored in GPU decode slot) + """ + # Get the decode capture + decode_captures = [c for c in captures + if c['layer_id'] == layer_id and not c['is_prefill']] + if decode_step >= len(decode_captures): + return None + + decode_capture = decode_captures[decode_step] + q = decode_capture['q'].cuda() # [1, num_heads, head_dim] + q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim] + + if debug: + print(f" Reference for L{layer_id} D{decode_step}:") + print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}") + + o_acc, lse_acc = None, None + + # Attend to all prefill chunks + if layer_id in prefill_kv: + for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']): + k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim] + v = chunk_data['v'].cuda().unsqueeze(0) + + o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False) + + if debug: + print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}") + + if o_acc is None: + o_acc, lse_acc = o, lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse) + + # Attend to previous decode tokens (including current) + # In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens + # For step 0, we just have the current token's k,v + # For step 1, we have tokens 0 and 1's k,v + # etc. + + # Collect k,v from all decode steps up to and including current + decode_kv = [] + for i in range(decode_step + 1): + if i < len(decode_captures): + decode_kv.append({ + 'k': decode_captures[i]['k'].cuda(), + 'v': decode_captures[i]['v'].cuda(), + }) + + if decode_kv: + # Stack decode k,v into a single tensor + decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim] + decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0) + + if debug: + print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}") + + # For decode, we use causal=False since we're attending to all decode tokens + # (the causal masking was already handled by only including tokens up to current) + o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v, + softmax_scale=scale, causal=False) + + if debug: + print(f" Decode attention: o.mean={o_decode.mean().item():.6f}") + + if o_acc is None: + o_acc, lse_acc = o_decode, lse_decode + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode) + + if o_acc is None: + return None + + if debug: + print(f" Final: o.mean={o_acc.mean().item():.6f}") + + return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim] + + +# ============================================================ +# Test Runner +# ============================================================ + +def run_test(verbose=True): + """Run the hook-based chunked decode correctness test.""" + global captures, prefill_kv + captures = [] + prefill_kv = {} + + if verbose: + print("=" * 70) + print("Test: Hook-Based Chunked Decode Correctness") + print("=" * 70) + print(f"Model: {MODEL_PATH}") + print(f"Input length: {INPUT_LEN} tokens") + print(f"Decode tokens: {NUM_DECODE_TOKENS}") + print(f"Block size: {BLOCK_SIZE}") + print() + + # Initialize LLM with CPU offload + llm = LLM( + MODEL_PATH, + enforce_eager=True, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=MAX_MODEL_LEN, + enable_cpu_offload=True, + kvcache_block_size=BLOCK_SIZE, + num_gpu_blocks=NUM_GPU_BLOCKS, + ) + + # Get model info + num_layers = len(llm.model_runner.model.model.layers) + head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim + scale = head_dim ** -0.5 + + if verbose: + print(f"Num layers: {num_layers}") + print(f"Head dim: {head_dim}") + print() + + # Register hooks + hooks = register_hooks(llm) + if verbose: + print(f"Registered {len(hooks)} hooks") + + # Generate random prompt + seed(42) + prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] + + # Run prefill and decode + if verbose: + print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...") + sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS) + outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + + # Remove hooks + for hook in hooks: + hook.remove() + + # =========== VERIFICATION: Check CPU cache after prefill =========== + # Verify that CPU cache data matches captured prefill k,v + if verbose: + print("\n--- CPU Cache Verification (After Prefill) ---") + offload_engine = llm.model_runner.kvcache_manager.offload_engine + + # For each prefill capture, check if CPU cache matches + for layer_id in [0]: # Only check layer 0 for brevity + if layer_id not in prefill_kv: + continue + + for chunk_data in prefill_kv[layer_id]: + chunk_idx = chunk_data['chunk_idx'] + captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim] + + # CPU block ID should be chunk_idx (based on allocation order) + cpu_block_id = chunk_idx + cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu() + + diff = (captured_k - cpu_k).abs().max().item() + print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}") + if diff > 1e-3: + print(f" WARNING: CPU cache doesn't match captured k!") + print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}") + print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}") + print() + + # Analyze captures + prefill_count = sum(1 for c in captures if c['is_prefill']) + decode_count = sum(1 for c in captures if not c['is_prefill']) + if verbose: + print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls") + + # Count decode steps per layer + decode_per_layer = {} + for c in captures: + if not c['is_prefill']: + layer_id = c['layer_id'] + if layer_id not in decode_per_layer: + decode_per_layer[layer_id] = 0 + decode_per_layer[layer_id] += 1 + + if verbose: + print(f"Decode calls per layer: {decode_per_layer}") + print() + + # Verify decode correctness + all_passed = True + results = [] + first_fail_debug = True + + for c in captures: + if c['is_prefill']: + continue # Skip prefill (already tested in test_chunked_prefill_hook.py) + + layer_id = c['layer_id'] + decode_step = c['decode_step'] + + # Only test first decode step for now (simpler reference computation) + if decode_step > 0: + continue + + # Compute reference (debug first failure) + debug_this = (layer_id == 0 and first_fail_debug) + ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this) + if ref_output is None: + continue + + # Compare + actual_output = c['output'].squeeze(0) # Remove seq dim for decode + if actual_output.dim() == 3: + actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case + + diff = (actual_output - ref_output).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + tol = 1e-2 + passed = max_diff < tol + all_passed = all_passed and passed + + status = "PASS" if passed else "FAIL" + results.append((layer_id, decode_step, passed, max_diff, mean_diff)) + + if verbose: + print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: " + f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") + + # Debug first failure + if not passed and first_fail_debug: + first_fail_debug = False + print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}") + print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}") + # Find where max diff is + max_idx = diff.argmax() + flat_actual = actual_output.flatten() + flat_ref = ref_output.flatten() + print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}") + + print() + print("=" * 70) + + # Summary + total_tests = len(results) + passed_count = sum(1 for r in results if r[2]) + + print(f"Results: {passed_count}/{total_tests} tests passed") + + if not all_passed: + print("\nFailed tests:") + for layer_id, decode_step, passed, max_diff, mean_diff in results: + if not passed: + print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}") + + print() + return all_passed + + +# ============================================================ +# Main +# ============================================================ + +if __name__ == "__main__": + passed = run_test(verbose=True) + + if passed: + print("test_chunked_decode_hook: PASSED") + else: + print("test_chunked_decode_hook: FAILED") + exit(1) diff --git a/tests/test_chunked_prefill_hook.py b/tests/test_chunked_prefill_hook.py new file mode 100644 index 0000000..46b7278 --- /dev/null +++ b/tests/test_chunked_prefill_hook.py @@ -0,0 +1,473 @@ +""" +Hook-based correctness test for chunked prefill attention. + +Uses PyTorch register_forward_hook() to capture real inference I/O, +then compares against reference computation to locate bugs. + +This test targets the integration layer (context setup, cpu_block_table management) +which is where the needle test fails despite isolated attention tests passing. +""" + +import os +os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" + +import torch +from random import randint, seed +from nanovllm import LLM, SamplingParams +from nanovllm.utils.context import get_context +from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs +from flash_attn.flash_attn_interface import flash_attn_varlen_func + + +# ============================================================ +# Configuration +# ============================================================ + +MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") +MAX_MODEL_LEN = 32 * 1024 +NUM_GPU_BLOCKS = 2 +INPUT_LEN = 16 * 1024 # 4K tokens = 4 chunks with 1K block size +BLOCK_SIZE = 1024 + + +# ============================================================ +# Global capture storage +# ============================================================ + +captures = [] + + +# ============================================================ +# Hook Functions +# ============================================================ + +def make_hook(layer_id): + """Create a forward hook for a specific layer.""" + def hook(module, inputs, output): + q, k, v = inputs + ctx = get_context() + + # Only capture prefill phase + if not ctx.is_prefill: + return + + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + + capture_entry = { + 'layer_id': layer_id, + 'chunk_idx': chunk_idx, + 'q': q.clone().cpu(), + 'k': k.clone().cpu(), + 'v': v.clone().cpu(), + 'output': output.clone().cpu(), + 'is_chunked_prefill': ctx.is_chunked_prefill, + } + + # For debugging: also capture CPU cache state for layer 0 + if layer_id == 0 and chunk_idx >= 2: + kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None + if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'): + oe = kvcache_manager.offload_engine + # Get what should have been loaded from CPU + cpu_k0 = oe.k_cache_cpu[0, 0].clone().cpu() # Layer 0, CPU block 0 + cpu_k1 = oe.k_cache_cpu[0, 1].clone().cpu() # Layer 0, CPU block 1 + capture_entry['cpu_k0'] = cpu_k0 + capture_entry['cpu_k1'] = cpu_k1 + + captures.append(capture_entry) + return hook + + +def register_hooks(llm): + """Register forward hooks on all Attention modules.""" + hooks = [] + model = llm.model_runner.model + + for layer_idx, decoder_layer in enumerate(model.model.layers): + attn_module = decoder_layer.self_attn.attn + hook = attn_module.register_forward_hook(make_hook(layer_idx)) + hooks.append(hook) + + return hooks + + +# ============================================================ +# Reference Computation +# ============================================================ + +def compute_reference(layer_id, chunk_idx, scale, debug=False): + """ + Compute reference attention output for a specific layer and chunk. + + Uses the captured k, v from all chunks up to and including chunk_idx. + """ + # Filter captures for this layer + layer_captures = [c for c in captures + if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx] + + if not layer_captures: + return None + + # Get current chunk's q + current_capture = [c for c in layer_captures if c['chunk_idx'] == chunk_idx][0] + q = current_capture['q'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim] + + # Collect all k, v up to current chunk + kv_list = [] + for c in sorted(layer_captures, key=lambda x: x['chunk_idx']): + k = c['k'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim] + v = c['v'].cuda().unsqueeze(0) + kv_list.append((k, v, c['chunk_idx'])) + + if debug: + print(f" Reference for L{layer_id} C{chunk_idx}:") + print(f" q shape: {q.shape}, mean={q.mean().item():.4f}") + print(f" kv_list: {len(kv_list)} chunks") + for i, (k, v, cidx) in enumerate(kv_list): + print(f" chunk {cidx}: k.mean={k.mean().item():.4f}, v.mean={v.mean().item():.4f}") + + o_acc, lse_acc = None, None + + # Previous chunks: non-causal attention + for i in range(len(kv_list) - 1): + k, v, _ = kv_list[i] + o, lse = flash_attn_with_lse(q, k, v, softmax_scale=scale, causal=False) + if o_acc is None: + o_acc, lse_acc = o, lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse) + + # Current chunk: causal attention + k_cur, v_cur, _ = kv_list[-1] + o_cur, lse_cur = flash_attn_with_lse(q, k_cur, v_cur, softmax_scale=scale, causal=True) + + if o_acc is None: + return o_cur.squeeze(0).cpu() + + final_o, _ = merge_attention_outputs(o_acc, lse_acc, o_cur, lse_cur) + return final_o.squeeze(0).cpu() + + +def compute_standard_reference(layer_id, chunk_idx, scale, debug=False): + """ + Compute reference using standard flash attention (single pass with all K, V). + + This simulates what standard (non-chunked) prefill would produce. + Concatenates all Q, K, V from chunks 0 to chunk_idx and runs a single + causal attention pass, then extracts the output for the current chunk. + """ + # Filter captures for this layer + layer_captures = [c for c in captures + if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx] + + if not layer_captures: + return None + + # Sort by chunk index + layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx']) + + # Concatenate all Q, K, V + all_q = [] + all_k = [] + all_v = [] + chunk_lengths = [] + + for c in layer_captures: + q = c['q'].cuda() # [seqlen, nheads, headdim] + k = c['k'].cuda() + v = c['v'].cuda() + all_q.append(q) + all_k.append(k) + all_v.append(v) + chunk_lengths.append(q.shape[0]) + + # Concatenate along sequence dimension + full_q = torch.cat(all_q, dim=0) # [total_seqlen, nheads, headdim] + full_k = torch.cat(all_k, dim=0) + full_v = torch.cat(all_v, dim=0) + + total_len = full_q.shape[0] + + if debug: + print(f" Standard Reference for L{layer_id} C{chunk_idx}:") + print(f" full_q shape: {full_q.shape}, mean={full_q.mean().item():.4f}") + print(f" full_k shape: {full_k.shape}, mean={full_k.mean().item():.4f}") + print(f" chunk_lengths: {chunk_lengths}") + + # Run standard causal flash attention + # flash_attn_varlen_func expects: q, k, v with shape [total_seqlen, nheads, headdim] + cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda') + + full_o = flash_attn_varlen_func( + full_q, full_k, full_v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=total_len, + max_seqlen_k=total_len, + softmax_scale=scale, + causal=True, + ) + + # Extract output for current chunk only + start_pos = sum(chunk_lengths[:-1]) + end_pos = sum(chunk_lengths) + chunk_output = full_o[start_pos:end_pos] + + if debug: + print(f" full_o shape: {full_o.shape}") + print(f" extracting positions [{start_pos}:{end_pos}]") + print(f" chunk_output shape: {chunk_output.shape}, mean={chunk_output.mean().item():.4f}") + + return chunk_output.cpu() + + +# ============================================================ +# Test Runner +# ============================================================ + +def run_test(verbose=True): + """Run the hook-based chunked prefill correctness test.""" + global captures + captures = [] + + if verbose: + print("=" * 70) + print("Test: Hook-Based Chunked Prefill Correctness") + print("=" * 70) + print(f"Model: {MODEL_PATH}") + print(f"Input length: {INPUT_LEN} tokens") + print(f"Block size: {BLOCK_SIZE}") + print(f"Expected chunks: {INPUT_LEN // BLOCK_SIZE}") + print() + + # Initialize LLM with CPU offload + llm = LLM( + MODEL_PATH, + enforce_eager=True, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=MAX_MODEL_LEN, + enable_cpu_offload=True, + kvcache_block_size=BLOCK_SIZE, + num_gpu_blocks=NUM_GPU_BLOCKS, + ) + + # Get model info + num_layers = len(llm.model_runner.model.model.layers) + head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim + scale = head_dim ** -0.5 + + if verbose: + print(f"Num layers: {num_layers}") + print(f"Head dim: {head_dim}") + print() + + # Register hooks + hooks = register_hooks(llm) + if verbose: + print(f"Registered {len(hooks)} hooks") + + # Generate random prompt + seed(42) + prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] + + # Run prefill only (max_tokens=1) + if verbose: + print("Running inference...") + sampling_params = SamplingParams(temperature=0.6, max_tokens=1) + outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + + # Remove hooks + for hook in hooks: + hook.remove() + + # Analyze captures + if verbose: + print(f"\nCaptured {len(captures)} attention calls") + + # Group by layer and chunk + chunks_per_layer = {} + for c in captures: + layer_id = c['layer_id'] + chunk_idx = c['chunk_idx'] + if layer_id not in chunks_per_layer: + chunks_per_layer[layer_id] = set() + chunks_per_layer[layer_id].add(chunk_idx) + + if verbose: + print("Chunks per layer:", {k: sorted(v) for k, v in chunks_per_layer.items()}) + print() + + # First, verify CPU cache data integrity + if verbose: + print("\n--- CPU Cache Verification (Layer 0) ---") + # Get original k from chunk 0 and chunk 1 captures + chunk0_k = None + chunk1_k = None + chunk2_capture = None + for c in captures: + if c['layer_id'] == 0: + if c['chunk_idx'] == 0: + chunk0_k = c['k'] + elif c['chunk_idx'] == 1: + chunk1_k = c['k'] + elif c['chunk_idx'] == 2: + chunk2_capture = c + + if chunk0_k is not None and chunk2_capture is not None and 'cpu_k0' in chunk2_capture: + cpu_k0 = chunk2_capture['cpu_k0'] + diff_k0 = (chunk0_k - cpu_k0).abs().max().item() + print(f"Chunk 0 k vs CPU cache block 0: max_diff={diff_k0:.6f}") + if diff_k0 > 1e-3: + print(f" WARNING: CPU cache block 0 differs from original chunk 0 k!") + print(f" Original k[0,0,:5] = {chunk0_k[0,0,:5].tolist()}") + print(f" CPU k0[0,0,:5] = {cpu_k0[0,0,:5].tolist()}") + + if chunk1_k is not None and chunk2_capture is not None and 'cpu_k1' in chunk2_capture: + cpu_k1 = chunk2_capture['cpu_k1'] + diff_k1 = (chunk1_k - cpu_k1).abs().max().item() + print(f"Chunk 1 k vs CPU cache block 1: max_diff={diff_k1:.6f}") + if diff_k1 > 1e-3: + print(f" WARNING: CPU cache block 1 differs from original chunk 1 k!") + print(f" Original k[0,0,:5] = {chunk1_k[0,0,:5].tolist()}") + print(f" CPU k1[0,0,:5] = {cpu_k1[0,0,:5].tolist()}") + + print() + + # ================================================================ + # Test 1: Verify against merge-based reference (same algorithm) + # ================================================================ + if verbose: + print("--- Test 1: Merge-based Reference (verifies merge algorithm) ---") + + all_passed_merge = True + results_merge = [] + first_fail_debug = True + + for c in captures: + layer_id = c['layer_id'] + chunk_idx = c['chunk_idx'] + + if chunk_idx == 0: + continue + + debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug) + ref_output = compute_reference(layer_id, chunk_idx, scale, debug=debug_this) + if ref_output is None: + continue + + actual_output = c['output'] + diff = (actual_output - ref_output).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + tol = 1e-2 + passed = max_diff < tol + all_passed_merge = all_passed_merge and passed + + status = "PASS" if passed else "FAIL" + results_merge.append((layer_id, chunk_idx, passed, max_diff, mean_diff)) + + if verbose: + print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: " + f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") + + if not passed and first_fail_debug: + first_fail_debug = False + print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}") + print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}") + max_idx = diff.argmax() + flat_actual = actual_output.flatten() + flat_ref = ref_output.flatten() + print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}") + + print() + + # ================================================================ + # Test 2: Verify against standard flash attention (single pass) + # ================================================================ + if verbose: + print("--- Test 2: Standard FlashAttn Reference (verifies correctness vs non-chunked) ---") + + all_passed_standard = True + results_standard = [] + first_fail_debug = True + + for c in captures: + layer_id = c['layer_id'] + chunk_idx = c['chunk_idx'] + + if chunk_idx == 0: + continue + + debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug) + std_ref_output = compute_standard_reference(layer_id, chunk_idx, scale, debug=debug_this) + if std_ref_output is None: + continue + + actual_output = c['output'] + diff = (actual_output - std_ref_output).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + tol = 1e-2 + passed = max_diff < tol + all_passed_standard = all_passed_standard and passed + + status = "PASS" if passed else "FAIL" + results_standard.append((layer_id, chunk_idx, passed, max_diff, mean_diff)) + + if verbose: + print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: " + f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") + + if not passed and first_fail_debug: + first_fail_debug = False + print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}") + print(f" Debug: std_ref_output shape={std_ref_output.shape}, mean={std_ref_output.mean().item():.4f}") + max_idx = diff.argmax() + flat_actual = actual_output.flatten() + flat_ref = std_ref_output.flatten() + print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}") + + print() + print("=" * 70) + + # Summary + total_merge = len(results_merge) + passed_merge = sum(1 for r in results_merge if r[2]) + total_standard = len(results_standard) + passed_standard = sum(1 for r in results_standard if r[2]) + + print(f"Merge-based reference: {passed_merge}/{total_merge} tests passed") + print(f"Standard FlashAttn ref: {passed_standard}/{total_standard} tests passed") + + all_passed = all_passed_merge and all_passed_standard + + if not all_passed_merge: + print("\nFailed merge-based tests:") + for layer_id, chunk_idx, passed, max_diff, mean_diff in results_merge: + if not passed: + print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}") + + if not all_passed_standard: + print("\nFailed standard FlashAttn tests:") + for layer_id, chunk_idx, passed, max_diff, mean_diff in results_standard: + if not passed: + print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}") + + print() + return all_passed + + +# ============================================================ +# Main +# ============================================================ + +if __name__ == "__main__": + passed = run_test(verbose=True) + + if passed: + print("test_chunked_prefill_hook: PASSED") + else: + print("test_chunked_prefill_hook: FAILED") + exit(1) diff --git a/tests/test_flash_attn_kvcache.py b/tests/test_flash_attn_kvcache.py new file mode 100644 index 0000000..d488d6a --- /dev/null +++ b/tests/test_flash_attn_kvcache.py @@ -0,0 +1,276 @@ +""" +Test script for flash_attn_with_kvcache based chunked prefill. + +Verifies that chunked prefill produces identical results to full attention. +""" + +import torch +from flash_attn import flash_attn_func, flash_attn_with_kvcache + + +def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size): + """ + Chunked prefill using flash_attn_with_kvcache. + + Args: + q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim] + k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim] + cache_seqlens: [batch] - current cache lengths + chunk_size: size of each chunk + + Returns: + output: [batch, total_seq_len, heads, head_dim] + """ + total_len = q_full.shape[1] + outputs = [] + + for start in range(0, total_len, chunk_size): + end = min(start + chunk_size, total_len) + + q_chunk = q_full[:, start:end] + k_chunk = k_full[:, start:end] + v_chunk = v_full[:, start:end] + + out = flash_attn_with_kvcache( + q_chunk, + k_cache, + v_cache, + k=k_chunk, + v=v_chunk, + cache_seqlens=cache_seqlens, + causal=True, + ) + outputs.append(out) + + cache_seqlens += (end - start) + + return torch.cat(outputs, dim=1) + + +def reference_attention(q, k, v): + """Standard flash attention as reference.""" + return flash_attn_func(q, k, v, causal=True) + + +def test_chunked_prefill_correctness(): + """Test that chunked prefill matches full attention.""" + + batch_size = 1 + num_heads = 32 + num_kv_heads = 8 # GQA + head_dim = 128 + max_seq_len = 131072 # 128K + + test_configs = [ + (1024, 256), # 1K tokens, 256 chunk + (2048, 512), # 2K tokens, 512 chunk + (4096, 1024), # 4K tokens, 1K chunk + (4096, 2048), # 4K tokens, 2K chunk (2 chunks) + (8192, 2048), # 8K tokens, 2K chunk (4 chunks) + (16384, 4096), # 16K tokens, 4K chunk + (32768, 4096), # 32K tokens, 4K chunk + (65536, 8192), # 64K tokens, 8K chunk + (131072, 8192), # 128K tokens, 8K chunk (16 chunks) + ] + + for seq_len, chunk_size in test_configs: + print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...") + + # Generate random input + torch.manual_seed(42) + q = torch.randn(batch_size, seq_len, num_heads, head_dim, + dtype=torch.float16, device='cuda') + k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + + # Expand K/V for non-GQA reference + k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2) + v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2) + + # Reference: full attention + ref_out = reference_attention(q, k_expanded, v_expanded) + + # Chunked prefill with KV cache + k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda') + + chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size) + + # Compare + max_diff = (ref_out - chunked_out).abs().max().item() + mean_diff = (ref_out - chunked_out).abs().mean().item() + + # Verify cache was filled correctly + assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}" + + # Check K/V cache content + k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item() + v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item() + + print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}") + print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}") + + # Tolerance for fp16 + tolerance = 1e-2 + if max_diff < tolerance: + print(f" PASSED") + else: + print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})") + return False + + return True + + +def test_incremental_decode(): + """Test that decode after chunked prefill works correctly.""" + + batch_size = 1 + num_heads = 32 + num_kv_heads = 8 + head_dim = 128 + max_seq_len = 8192 + + prefill_len = 2048 + chunk_size = 512 + num_decode_steps = 10 + + print(f"\nTesting incremental decode after chunked prefill...") + print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}") + print(f" Decode: {num_decode_steps} steps") + + torch.manual_seed(42) + + # Prefill phase + q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim, + dtype=torch.float16, device='cuda') + k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + + k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda') + + # Run chunked prefill + prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill, + k_cache, v_cache, cache_seqlens, chunk_size) + + print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}") + + # Decode phase - one token at a time + for step in range(num_decode_steps): + q_decode = torch.randn(batch_size, 1, num_heads, head_dim, + dtype=torch.float16, device='cuda') + k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + + decode_out = flash_attn_with_kvcache( + q_decode, + k_cache, + v_cache, + k=k_decode, + v=v_decode, + cache_seqlens=cache_seqlens, + causal=True, + ) + + cache_seqlens += 1 + + assert decode_out.shape == (batch_size, 1, num_heads, head_dim) + + expected_len = prefill_len + num_decode_steps + actual_len = cache_seqlens[0].item() + + print(f" After decode: cache_seqlens={actual_len}") + + if actual_len == expected_len: + print(f" PASSED") + return True + else: + print(f" FAILED: expected {expected_len}, got {actual_len}") + return False + + +def test_batch_processing(): + """Test chunked prefill with batch > 1.""" + + batch_size = 4 + num_heads = 32 + num_kv_heads = 8 + head_dim = 128 + max_seq_len = 4096 + seq_len = 2048 + chunk_size = 512 + + print(f"\nTesting batch processing (batch_size={batch_size})...") + + torch.manual_seed(42) + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, + dtype=torch.float16, device='cuda') + k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + + k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, + dtype=torch.float16, device='cuda') + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda') + + out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size) + + # Verify all batches have correct cache length + assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}" + assert out.shape == (batch_size, seq_len, num_heads, head_dim) + + # Compare with reference for each batch item + k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2) + v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2) + ref_out = reference_attention(q, k_expanded, v_expanded) + + max_diff = (ref_out - out).abs().max().item() + + print(f" Output shape: {out.shape}") + print(f" Max diff vs reference: {max_diff:.6f}") + + if max_diff < 1e-2: + print(f" PASSED") + return True + else: + print(f" FAILED") + return False + + +# ============================================================ +# Main Test Script +# ============================================================ + +if __name__ == "__main__": + print("=" * 60) + print("Testing flash_attn_with_kvcache chunked prefill") + print("=" * 60) + + all_passed = True + + all_passed &= test_chunked_prefill_correctness() + all_passed &= test_incremental_decode() + all_passed &= test_batch_processing() + + print("\n" + "=" * 60) + if all_passed: + print("test_flash_attn_kvcache: ALL TESTS PASSED") + else: + print("test_flash_attn_kvcache: SOME TESTS FAILED") + print("=" * 60) diff --git a/tests/test_needle.py b/tests/test_needle.py new file mode 100644 index 0000000..5288c88 --- /dev/null +++ b/tests/test_needle.py @@ -0,0 +1,322 @@ +""" +Needle-in-a-haystack test for LLM. + +Tests: Long context retrieval capability with configurable sequence length. + +NOTE: CPU offload mode has a known bug that causes incorrect outputs for +sequences longer than ~200 tokens. Use --no-offload for correctness testing. +""" + +import os +os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" + +import argparse +from nanovllm import LLM, SamplingParams + + +# ============================================================ +# Needle Test Generator +# ============================================================ + +def generate_needle_prompt( + tokenizer, + target_length: int, + needle_position: float = 0.5, + needle_value: str = "7492", + use_chat_template: bool = True, +) -> tuple[str, str]: + """ + Generate a needle-in-haystack prompt of approximately target_length tokens. + + Args: + tokenizer: HuggingFace tokenizer for length estimation + target_length: Target total sequence length in tokens + needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end) + needle_value: The secret value to hide in the haystack + use_chat_template: Whether to use chat template for instruct models + + Returns: + (prompt, expected_answer): The full prompt and the expected needle value + """ + # Haystack filler paragraphs (various topics to create realistic context) + haystack_paragraphs = [ + "The weather today is quite pleasant with clear skies and moderate temperatures. " + "Many people are enjoying outdoor activities in the park. " + "Birds are singing in the trees and children are playing on the swings. ", + + "In the world of technology, new innovations continue to emerge every day. " + "Researchers are working on advanced algorithms and computing systems. " + "The future of artificial intelligence looks promising with many breakthroughs. ", + + "The history of human civilization spans thousands of years. " + "Ancient cultures developed writing, mathematics, and astronomy. " + "Trade routes connected distant lands and facilitated cultural exchange. ", + + "Modern cooking combines traditional techniques with new ingredients. " + "Chefs around the world experiment with flavors and presentations. " + "Food brings people together and creates memorable experiences. ", + + "The ocean covers more than seventy percent of Earth's surface. " + "Marine ecosystems support an incredible diversity of life forms. " + "Scientists continue to discover new species in the deep sea. ", + + "Music has been a part of human culture since prehistoric times. " + "Different genres evolved across various regions and time periods. " + "Today, people can access millions of songs through digital platforms. ", + + "Space exploration has revealed many secrets about our universe. " + "Telescopes can observe galaxies billions of light years away. " + "Future missions aim to establish human presence on other planets. ", + + "The study of languages reveals patterns in human cognition. " + "Linguists analyze grammar, semantics, and phonetics across cultures. " + "Language continues to evolve with new words and expressions. ", + ] + + # The needle sentence + needle = f"The secret number you need to remember is {needle_value}. This is very important. " + + # Question at the end + question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is" + + # Estimate tokens for fixed parts + needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False)) + question_text = "What is the secret number mentioned in the text above? Answer with just the number." + question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False)) + # Buffer for chat template, special tokens, etc. + overhead_tokens = 100 if use_chat_template else 50 + + # Available tokens for haystack + haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens + if haystack_target_tokens < 100: + raise ValueError(f"target_length {target_length} is too short for needle test") + + # Build haystack by repeating paragraphs + haystack_parts = [] + current_tokens = 0 + para_idx = 0 + + while current_tokens < haystack_target_tokens: + para = haystack_paragraphs[para_idx % len(haystack_paragraphs)] + para_tokens = len(tokenizer.encode(para, add_special_tokens=False)) + if current_tokens + para_tokens > haystack_target_tokens: + break + haystack_parts.append(para) + current_tokens += para_tokens + para_idx += 1 + + # Calculate needle insertion point + needle_idx = int(len(haystack_parts) * needle_position) + needle_idx = max(0, min(needle_idx, len(haystack_parts))) + + # Insert needle + haystack_parts.insert(needle_idx, needle) + + # Assemble prompt + full_text = "".join(haystack_parts) + + if use_chat_template and hasattr(tokenizer, 'apply_chat_template'): + # Use chat template for instruct models + # For Qwen3, add /no_think to disable thinking mode + question_text = "/no_think Answer only with the secret number mentioned above, nothing else:" + messages = [ + {"role": "user", "content": f"{full_text}\n\n{question_text}"} + ] + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + else: + # Raw text format for base models + question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is" + prompt = full_text + question + + # Verify length + actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False)) + print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens") + print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)") + print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}") + + return prompt, needle_value + + +def check_needle_answer(output_text: str, expected: str) -> bool: + """Check if the model output contains the expected needle value.""" + import re + # Clean output - remove special tokens and whitespace + output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ') + output_clean = ' '.join(output_clean.split()).lower() + expected_clean = expected.strip().lower() + + # Check if expected value appears in output + # Also try to find it as a standalone number + if expected_clean in output_clean: + return True + + # Try to extract numbers and check if expected is among them + numbers = re.findall(r'\d+', output_clean) + return expected_clean in numbers + + +# ============================================================ +# Main Test +# ============================================================ + +def run_needle_test( + model_path: str, + max_model_len: int, + input_len: int, + num_gpu_blocks: int = 4, + needle_position: float = 0.5, + needle_value: str = "7492", + max_new_tokens: int = 32, + enable_cpu_offload: bool = False, + verbose: bool = True, +) -> bool: + """ + Run a needle-in-haystack test. + + Args: + model_path: Path to model + max_model_len: Maximum model context length + input_len: Target input sequence length + num_gpu_blocks: Number of GPU blocks for offload + needle_position: Where to place needle (0.0-1.0) + needle_value: The secret value to find + max_new_tokens: Maximum tokens to generate + enable_cpu_offload: Enable CPU offload mode + verbose: Print detailed output + + Returns: + True if test passed, False otherwise + """ + if verbose: + print(f"\n{'='*60}") + print(f"Needle-in-Haystack Test") + print(f"{'='*60}") + print(f"Model: {model_path}") + print(f"Max model len: {max_model_len}") + print(f"Input length: {input_len}") + print(f"Needle position: {needle_position:.0%}") + print(f"Needle value: {needle_value}") + print(f"CPU offload: {enable_cpu_offload}") + print(f"{'='*60}\n") + + # 1. Initialize LLM + llm_kwargs = { + "enforce_eager": True, + "max_model_len": max_model_len, + "max_num_batched_tokens": max_model_len, + "enable_cpu_offload": enable_cpu_offload, + } + if enable_cpu_offload: + llm_kwargs["num_gpu_blocks"] = num_gpu_blocks + + llm = LLM(model_path, **llm_kwargs) + + # 2. Generate needle prompt + prompt, expected = generate_needle_prompt( + tokenizer=llm.tokenizer, + target_length=input_len, + needle_position=needle_position, + needle_value=needle_value, + ) + + # 3. Generate output + sampling_params = SamplingParams( + temperature=0.6, # Moderate temperature + max_tokens=max_new_tokens, + ) + outputs = llm.generate([prompt], sampling_params, use_tqdm=True) + + # 4. Check result + output_text = outputs[0]["text"] + output_token_ids = outputs[0]["token_ids"] + passed = check_needle_answer(output_text, expected) + + if verbose: + print(f"\n{'='*60}") + print(f"Result") + print(f"{'='*60}") + print(f"Expected: {expected}") + print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}") + print(f"Output: {output_text[:200]}...") + print(f"Status: {'PASSED' if passed else 'FAILED'}") + print(f"{'='*60}\n") + + return passed + + +# ============================================================ +# CLI Entry Point +# ============================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM") + parser.add_argument( + "--model", "-m", + type=str, + default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"), + help="Path to model" + ) + parser.add_argument( + "--max-model-len", + type=int, + default=32 * 1024, + help="Maximum model context length" + ) + parser.add_argument( + "--input-len", + type=int, + default=8 * 1024, + help="Target input sequence length" + ) + parser.add_argument( + "--num-gpu-blocks", + type=int, + default=2, + help="Number of GPU blocks for CPU offload" + ) + parser.add_argument( + "--needle-position", + type=float, + default=0.5, + help="Needle position (0.0=start, 0.5=middle, 1.0=end)" + ) + parser.add_argument( + "--needle-value", + type=str, + default="7492", + help="The secret value to hide" + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=32, + help="Maximum tokens to generate" + ) + parser.add_argument( + "--enable-offload", + action="store_true", + help="Enable CPU offload (has known bug for long sequences)" + ) + args = parser.parse_args() + + passed = run_needle_test( + model_path=args.model, + max_model_len=args.max_model_len, + input_len=args.input_len, + num_gpu_blocks=args.num_gpu_blocks, + needle_position=args.needle_position, + needle_value=args.needle_value, + max_new_tokens=args.max_new_tokens, + enable_cpu_offload=args.enable_offload, + verbose=True, + ) + + if passed: + print("test_needle: PASSED") + else: + print("test_needle: FAILED") + exit(1) diff --git a/tests/test_offload_correctness.py b/tests/test_offload_correctness.py new file mode 100644 index 0000000..4808015 --- /dev/null +++ b/tests/test_offload_correctness.py @@ -0,0 +1,573 @@ +""" +Correctness test for chunked attention with CPU offload. + +Validates that the offload pipeline (CPU -> GPU transfer + chunked attention) +produces the same result as direct GPU computation. + +Test scenario: +1. Generate Q, K, V data +2. Reference: Compute full causal attention on GPU +3. Offload: Store K, V in CPU cache, load via pipeline, compute chunked attention +4. Compare results + +This test is designed to identify bugs in: +- CPU <-> GPU data transfer (sgDMA) +- Ring buffer slot management +- N-way pipeline ordering +- Triton merge kernel correctness +""" + +import torch +from flash_attn.flash_attn_interface import flash_attn_func +from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager +from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + +# ============================================================ +# Configuration +# ============================================================ + +NUM_LAYERS = 4 +NUM_HEADS = 8 +NUM_KV_HEADS = 8 +HEAD_DIM = 64 +BLOCK_SIZE = 256 # Smaller for faster testing +DTYPE = torch.bfloat16 +DEVICE = "cuda" + + +# ============================================================ +# Reference Implementation (GPU only, no offload) +# ============================================================ + +def compute_reference_causal(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Compute reference causal attention using flash_attn_func. + + Args: + q, k, v: [batch, seqlen, nheads, headdim] + + Returns: + out: [batch, seqlen, nheads, headdim] + """ + return flash_attn_func(q, k, v, causal=True) + + +def compute_reference_chunked( + q_chunks: list, + kv_chunks: list, + scale: float, +) -> torch.Tensor: + """ + Compute chunked prefill attention directly on GPU (no offload). + + This is the "gold standard" for chunked attention correctness. + + Args: + q_chunks: List of [batch, chunk_size, nheads, headdim] + kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim] + scale: Softmax scale + + Returns: + out: [batch, total_seqlen, nheads, headdim] + """ + out_chunks = [] + + for chunk_idx, q_chunk in enumerate(q_chunks): + o_acc, lse_acc = None, None + + # Attend to all previous chunks (no causal mask) + for i in range(chunk_idx): + k_chunk, v_chunk = kv_chunks[i] + chunk_o, chunk_lse = flash_attn_with_lse( + q_chunk, k_chunk, v_chunk, + softmax_scale=scale, + causal=False, + ) + if o_acc is None: + o_acc, lse_acc = chunk_o, chunk_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse) + + # Attend to current chunk (with causal mask) + k_chunk, v_chunk = kv_chunks[chunk_idx] + current_o, current_lse = flash_attn_with_lse( + q_chunk, k_chunk, v_chunk, + softmax_scale=scale, + causal=True, + ) + + if o_acc is None: + final_o = current_o + else: + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + + out_chunks.append(final_o) + + return torch.cat(out_chunks, dim=1) + + +# ============================================================ +# Offload Implementation +# ============================================================ + +def create_manager(num_gpu_slots: int, num_cpu_blocks: int): + """Create HybridKVCacheManager with specified configuration.""" + manager = HybridKVCacheManager( + num_gpu_slots=num_gpu_slots, + num_cpu_blocks=num_cpu_blocks, + block_size=BLOCK_SIZE, + ) + manager.allocate_cache( + num_layers=NUM_LAYERS, + num_kv_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + dtype=DTYPE, + ) + return manager + + +def store_kv_to_cpu_cache(manager, kv_chunks: list, layer_id: int): + """ + Store K, V chunks to CPU cache. + + Args: + manager: HybridKVCacheManager + kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim] + layer_id: Layer index + + Returns: + cpu_block_ids: List of CPU block IDs + """ + offload_engine = manager.offload_engine + cpu_block_ids = [] + + for block_idx, (k_chunk, v_chunk) in enumerate(kv_chunks): + # k_chunk, v_chunk: [batch, chunk_size, nheads, headdim] + # CPU cache layout: [num_layers, num_blocks, block_size, nheads, headdim] + k_data = k_chunk.squeeze(0) # [chunk_size, nheads, headdim] + v_data = v_chunk.squeeze(0) + + offload_engine.k_cache_cpu[layer_id, block_idx, :k_data.shape[0]].copy_(k_data) + offload_engine.v_cache_cpu[layer_id, block_idx, :v_data.shape[0]].copy_(v_data) + + cpu_block_ids.append(block_idx) + + return cpu_block_ids + + +def compute_offload_chunked_single_layer( + manager, + q_chunks: list, + cpu_block_ids: list, + layer_id: int, + scale: float, +) -> torch.Tensor: + """ + Compute chunked attention for a single layer using offload pipeline. + + This mimics the behavior of Attention._ring_buffer_pipeline_load(). + + Args: + manager: HybridKVCacheManager + q_chunks: List of [batch, chunk_size, nheads, headdim] + cpu_block_ids: List of CPU block IDs containing K, V data + layer_id: Layer index + scale: Softmax scale + + Returns: + out: [batch, total_seqlen, nheads, headdim] + """ + offload_engine = manager.offload_engine + out_chunks = [] + + for chunk_idx, q_chunk in enumerate(q_chunks): + # CPU blocks to load: all blocks before current chunk + blocks_to_load = cpu_block_ids[:chunk_idx] + + # Get slots for this chunk + write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx) + load_slots = offload_engine.get_load_slots_for_prefill(write_slot) + + # Load and compute attention for previous chunks + o_acc, lse_acc = None, None + + if len(blocks_to_load) > 0 and len(load_slots) > 0: + o_acc, lse_acc = _pipeline_load_and_compute( + offload_engine, + q_chunk, + blocks_to_load, + load_slots, + layer_id, + scale, + ) + + # Current chunk's K, V (load from CPU to GPU slot) + current_cpu_block = cpu_block_ids[chunk_idx] + offload_engine.load_to_slot_layer(write_slot, layer_id, current_cpu_block) + offload_engine.wait_slot_layer(write_slot, layer_id) + + current_k, current_v = offload_engine.get_kv_for_slot(write_slot, layer_id) + + # Compute attention with causal mask + current_o, current_lse = flash_attn_with_lse( + q_chunk, current_k, current_v, + softmax_scale=scale, + causal=True, + ) + + # Merge + if o_acc is None: + final_o = current_o + else: + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + + out_chunks.append(final_o) + + return torch.cat(out_chunks, dim=1) + + +def _pipeline_load_and_compute( + offload_engine, + q_chunk: torch.Tensor, + cpu_block_table: list, + load_slots: list, + layer_id: int, + scale: float, +): + """ + Pipeline loading from CPU and computing attention. + + Mirrors Attention._ring_buffer_pipeline_load() logic. + """ + num_blocks = len(cpu_block_table) + num_slots = len(load_slots) + + o_acc, lse_acc = None, None + + # Phase 1: Pre-load up to num_slots blocks + num_preload = min(num_slots, num_blocks) + for i in range(num_preload): + offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) + + # Phase 2: Main loop + compute_stream = offload_engine.compute_stream + + for block_idx in range(num_blocks): + current_slot = load_slots[block_idx % num_slots] + + # Wait for transfer + offload_engine.wait_slot_layer(current_slot, layer_id) + + # Compute on dedicated stream + with torch.cuda.stream(compute_stream): + prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id) + prev_o, prev_lse = flash_attn_with_lse( + q_chunk, prev_k, prev_v, + softmax_scale=scale, + causal=False, + ) + offload_engine.record_slot_compute_done(current_slot, layer_id) + + # Start next transfer + next_block_idx = block_idx + num_slots + if next_block_idx < num_blocks: + offload_engine.load_to_slot_layer( + current_slot, layer_id, cpu_block_table[next_block_idx] + ) + + # Merge + with torch.cuda.stream(compute_stream): + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + + # Sync compute stream + compute_stream.synchronize() + + return o_acc, lse_acc + + +# ============================================================ +# Test Runner +# ============================================================ + +def run_correctness_test( + num_chunks: int, + num_gpu_slots: int, + verbose: bool = True, +) -> tuple[bool, float, float]: + """ + Run a single correctness test. + + Args: + num_chunks: Number of chunks (= number of CPU blocks) + num_gpu_slots: Number of GPU ring buffer slots + verbose: Print detailed info + + Returns: + (passed, max_diff, mean_diff) + """ + torch.manual_seed(42) + + seqlen = num_chunks * BLOCK_SIZE + scale = HEAD_DIM ** -0.5 + + # Generate Q, K, V + q_full = torch.randn(1, seqlen, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) + k_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) + v_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) + + # Split into chunks + q_chunks = [q_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE] for i in range(num_chunks)] + kv_chunks = [ + (k_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE], + v_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE]) + for i in range(num_chunks) + ] + + # Reference: chunked attention on GPU (no offload) + out_ref = compute_reference_chunked(q_chunks, kv_chunks, scale) + + # Create manager with enough CPU blocks + manager = create_manager(num_gpu_slots, num_chunks) + + # Test each layer + all_passed = True + max_diff_all = 0.0 + mean_diff_all = 0.0 + + for layer_id in range(NUM_LAYERS): + # Store K, V to CPU cache + cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id) + + # Compute with offload + out_offload = compute_offload_chunked_single_layer( + manager, q_chunks, cpu_block_ids, layer_id, scale + ) + + # Compare + diff = (out_ref - out_offload).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + max_diff_all = max(max_diff_all, max_diff) + mean_diff_all = max(mean_diff_all, mean_diff) + + tol = 1e-2 + passed = max_diff < tol + all_passed = all_passed and passed + + if verbose and not passed: + print(f" Layer {layer_id}: FAIL max_diff={max_diff:.6f}") + + return all_passed, max_diff_all, mean_diff_all + + +# ============================================================ +# Decode Phase Test +# ============================================================ + +def run_decode_correctness_test( + num_prefill_chunks: int, + num_gpu_slots: int, + num_decode_steps: int = 4, + verbose: bool = True, +) -> tuple[bool, float, float]: + """ + Test decode phase correctness with CPU offload. + + Simulates: + 1. Prefill: Store K, V for multiple chunks in CPU cache + 2. Decode: Single token queries against all prefilled K, V + + This tests the scenario in needle test where decode reads all previous KV. + """ + torch.manual_seed(42) + + scale = HEAD_DIM ** -0.5 + prefill_len = num_prefill_chunks * BLOCK_SIZE + + # Generate prefill K, V (store in CPU) + k_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) + v_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) + + # Split into chunks for CPU storage + kv_chunks = [ + (k_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE], + v_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE]) + for i in range(num_prefill_chunks) + ] + + # Create manager + manager = create_manager(num_gpu_slots, num_prefill_chunks) + offload_engine = manager.offload_engine + + all_passed = True + max_diff_all = 0.0 + mean_diff_all = 0.0 + + for layer_id in range(NUM_LAYERS): + # Store prefilled K, V to CPU cache + cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id) + + for decode_step in range(num_decode_steps): + # Decode query: single token + q_decode = torch.randn(1, 1, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) + + # Reference: direct attention on GPU + # Concat all prefilled K, V and compute attention + out_ref = flash_attn_func( + q_decode, + k_prefill, + v_prefill, + causal=False, # Decode query can attend to all prefilled tokens + ) + + # Offload: load from CPU and compute + load_slots = offload_engine.get_load_slots_for_prefill(0) # Use all slots except decode slot + + if len(load_slots) == 0 or len(cpu_block_ids) == 0: + # No previous chunks to load + out_offload = out_ref # Trivially equal + else: + o_acc, lse_acc = _pipeline_load_and_compute( + offload_engine, + q_decode, + cpu_block_ids, + load_slots, + layer_id, + scale, + ) + out_offload = o_acc + + # Compare + diff = (out_ref - out_offload).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + max_diff_all = max(max_diff_all, max_diff) + mean_diff_all = max(mean_diff_all, mean_diff) + + tol = 1e-2 + passed = max_diff < tol + all_passed = all_passed and passed + + if verbose and not passed: + print(f" Layer {layer_id} Step {decode_step}: FAIL max_diff={max_diff:.6f}") + + return all_passed, max_diff_all, mean_diff_all + + +# ============================================================ +# Main Test Script +# ============================================================ + +if __name__ == "__main__": + print("=" * 70) + print("Test: Offload Chunked Attention Correctness") + print("=" * 70) + print(f"Config: layers={NUM_LAYERS}, heads={NUM_HEADS}, kv_heads={NUM_KV_HEADS}, " + f"head_dim={HEAD_DIM}, block_size={BLOCK_SIZE}, dtype={DTYPE}") + print() + print("Comparing: Reference (GPU chunked) vs Offload (CPU->GPU pipeline)") + print() + + # Test configurations: (num_chunks, num_gpu_slots) + TEST_CASES = [ + # Basic tests + (2, 2), # Minimal: 2 chunks, 2 slots (no pipeline) + (2, 3), # 2 chunks, 3 slots (1-slot pipeline) + (4, 2), # 4 chunks, 2 slots (heavy slot reuse) + (4, 3), # 4 chunks, 3 slots + (4, 4), # 4 chunks, 4 slots + # Stress tests + (8, 3), # Many chunks, few slots + (8, 4), # Many chunks, moderate slots + (8, 6), # Many chunks, many slots (like bench_offload) + # Edge cases + (1, 2), # Single chunk + (3, 5), # Fewer chunks than slots + ] + + all_passed = True + results = [] + + for num_chunks, num_gpu_slots in TEST_CASES: + seqlen = num_chunks * BLOCK_SIZE + passed, max_diff, mean_diff = run_correctness_test( + num_chunks, num_gpu_slots, verbose=False + ) + + all_passed = all_passed and passed + status = "PASS" if passed else "FAIL" + + results.append((num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff)) + + print(f"[{status}] chunks={num_chunks:2d} slots={num_gpu_slots:2d} " + f"seqlen={seqlen:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") + + print() + + # ================================================================ + # Part 2: Decode Phase Tests + # ================================================================ + print("=" * 70) + print("Part 2: Decode Phase Correctness") + print("=" * 70) + print("Testing: Decode query (single token) against all prefilled K, V") + print() + + DECODE_TEST_CASES = [ + # (num_prefill_chunks, num_gpu_slots) + (2, 2), + (4, 3), + (4, 4), + (8, 4), + (8, 6), + ] + + decode_results = [] + + for num_prefill_chunks, num_gpu_slots in DECODE_TEST_CASES: + prefill_len = num_prefill_chunks * BLOCK_SIZE + passed, max_diff, mean_diff = run_decode_correctness_test( + num_prefill_chunks, num_gpu_slots, num_decode_steps=4, verbose=False + ) + + all_passed = all_passed and passed + status = "PASS" if passed else "FAIL" + + decode_results.append((num_prefill_chunks, num_gpu_slots, prefill_len, passed, max_diff, mean_diff)) + + print(f"[{status}] prefill_chunks={num_prefill_chunks:2d} slots={num_gpu_slots:2d} " + f"prefill_len={prefill_len:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") + + print() + print("=" * 70) + + # Summary + prefill_passed = sum(1 for r in results if r[3]) + decode_passed = sum(1 for r in decode_results if r[3]) + total_tests = len(results) + len(decode_results) + total_passed = prefill_passed + decode_passed + + print(f"Results: {total_passed}/{total_tests} tests passed") + print(f" - Prefill: {prefill_passed}/{len(results)}") + print(f" - Decode: {decode_passed}/{len(decode_results)}") + + if not all_passed: + print("\nFailed tests:") + for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in results: + if not passed: + print(f" - [Prefill] chunks={num_chunks}, slots={num_gpu_slots}, " + f"seqlen={seqlen}, max_diff={max_diff:.6f}") + for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in decode_results: + if not passed: + print(f" - [Decode] prefill_chunks={num_chunks}, slots={num_gpu_slots}, " + f"prefill_len={seqlen}, max_diff={max_diff:.6f}") + + print() + assert all_passed, "Some correctness tests failed!" + print("test_offload_correctness: PASSED")