diff --git a/DEBUG_SUMMARY.md b/DEBUG_SUMMARY.md new file mode 100644 index 0000000..ac7dfc4 --- /dev/null +++ b/DEBUG_SUMMARY.md @@ -0,0 +1,103 @@ +# Chunked Prefill Bug Debug Summary + +## Problem +`test_needle.py --enable-offload --input-len 8192` fails with garbage output. + +The model generates completely wrong tokens instead of the expected "7492". + +## Investigation Progress + +### 1. Stream Synchronization Fix (Completed) +- Replaced Triton `store_kvcache` kernel with pure PyTorch operations +- Moved `store_kvcache` to `compute_stream` in chunked prefill mode +- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload +- Added sync: `default_stream.wait_stream(compute_stream)` before return + +### 2. KV Cache Alignment Verification (Completed) +Created alignment tests to compare K/V tensors between torch reference and nanovllm: + +**RoPE Alignment:** +- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0) +- Confirmed RoPE is NOT the cause of the bug + +**K/V Cache Alignment (Chunk 0):** +- Cosine similarity: ~1.0 for all layers +- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision) +- Mean diff: < 0.001 +- **Conclusion: K/V cache offload is working correctly** + +### 3. Layer Output Divergence Analysis (Completed) +Created per-chunk layer output comparison: + +**Chunk 0 (tokens 0-4096):** +- All layers pass with excellent cosine similarity (0.999+) +- Max diff grows in later layers but within acceptable range + +**Chunk 1 (tokens 4096-8192):** +- Layers 0-19: OK (cosine ~1.0) +- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114) +- Divergence correlates with later transformer layers + +### 4. Critical Discovery: Single-Chunk Offload Also Fails +**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled. + +``` +# Without offload: PASSES +python tests/test_needle.py --input-len 2048 +# Output: "7492" (correct) + +# With offload: FAILS +python tests/test_needle.py --enable-offload --input-len 2048 +# Output: "The Ble White Th G Lopsiswin..." (garbage) +``` + +**This proves the bug is NOT in:** +- Chunked attention logic (merge_attention_outputs) +- Multi-chunk KV loading +- Ring buffer pipeline + +**The bug IS in:** +- The decode path when CPU offload is enabled +- How prefilled KV is loaded/used during decode + +### 5. Decode Path Analysis (In Progress) +The decode path in CPU offload mode: +1. Prefill writes KV to GPU, offloads to CPU +2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline` +3. Attend to prefilled KV + accumulated decode tokens +4. Merge results + +**Observations:** +- `prefilled_blocks` set is empty after decode (should contain block IDs) +- CPU cache has valid data (reasonable mean/std values) +- Decode buffer has zeros (decode tokens not being stored correctly?) + +## Current Status + +### Working +- Stream synchronization fixes +- K/V cache offload to CPU (verified alignment) +- RoPE implementation +- Chunked prefill attention for first chunk + +### Not Working +- Decode with CPU offload (even for single-chunk inputs) +- Multi-chunk attention (divergence in later layers for chunk 1) + +## Next Steps +1. Debug why `prefilled_blocks` is empty after decode +2. Check if decode path correctly loads KV from CPU +3. Verify decode buffer is being written correctly +4. Compare decode attention outputs between offload and non-offload modes + +## Key Files +- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths +- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine +- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks` +- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration + +## Hypothesis +The decode path fails because: +1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty +2. OR the decode attention is not correctly loading/using the prefilled KV from CPU +3. OR there's a stream synchronization issue specific to decode path diff --git a/nanovllm/debug/adapters/nanovllm_adapter.py b/nanovllm/debug/adapters/nanovllm_adapter.py index 8ab38be..d18eacb 100644 --- a/nanovllm/debug/adapters/nanovllm_adapter.py +++ b/nanovllm/debug/adapters/nanovllm_adapter.py @@ -61,8 +61,14 @@ class NanovllmSteppable(SteppableModel): def make_layer_hook(idx): def hook(module, input, output): # Decoder layer returns (hidden_states, residual) - hidden_states = output[0] if isinstance(output, tuple) else output - self._captured[f"layer_{idx}"] = hidden_states.detach().clone() + # hidden_states is MLP output, residual is accumulated residual + # To match torch reference, we need hidden_states + residual + if isinstance(output, tuple) and len(output) >= 2: + hidden_states, residual = output[0], output[1] + full_output = hidden_states + residual + else: + full_output = output + self._captured[f"layer_{idx}"] = full_output.detach().clone() return hook self._hooks.append( diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index de7dbff..b75a109 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -2,8 +2,6 @@ import logging import torch import torch.cuda.nvtx from torch import nn -import triton -import triton.language as tl from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context @@ -12,37 +10,49 @@ from nanovllm.kvcache.sparse.policy import PolicyContext logger = logging.getLogger(__name__) -@triton.jit -def store_kvcache_kernel( - key_ptr, - key_stride, - value_ptr, - value_stride, - k_cache_ptr, - v_cache_ptr, - slot_mapping_ptr, - D: tl.constexpr, +def store_kvcache( + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + slot_mapping: torch.Tensor, ): - idx = tl.program_id(0) - slot = tl.load(slot_mapping_ptr + idx) - if slot == -1: return - key_offsets = idx * key_stride + tl.arange(0, D) - value_offsets = idx * value_stride + tl.arange(0, D) - key = tl.load(key_ptr + key_offsets) - value = tl.load(value_ptr + value_offsets) - cache_offsets = slot * D + tl.arange(0, D) - tl.store(k_cache_ptr + cache_offsets, key) - tl.store(v_cache_ptr + cache_offsets, value) + """ + Store key/value tensors into KV cache using slot mapping. + This is a pure PyTorch implementation replacing the previous Triton kernel. + Uses index_copy_ for efficient in-place scatter operation. -def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): - N, num_heads, head_dim = key.shape - D = num_heads * head_dim - assert key.stride(-1) == 1 and value.stride(-1) == 1 - assert key.stride(1) == head_dim and value.stride(1) == head_dim - assert k_cache.stride(1) == D and v_cache.stride(1) == D - assert slot_mapping.numel() == N - store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D) + Args: + key: [N, num_kv_heads, head_dim] + value: [N, num_kv_heads, head_dim] + k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar + v_cache: same shape as k_cache + slot_mapping: [N] with values as flat indices, -1 means skip + """ + # Filter out invalid slots (slot == -1) + valid_mask = slot_mapping >= 0 + if not valid_mask.any(): + return + + valid_slots = slot_mapping[valid_mask] + valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim] + valid_values = value[valid_mask] + + # Flatten cache and KV for scatter operation + # Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim + N, num_kv_heads, head_dim = key.shape + D = num_kv_heads * head_dim + total_slots = k_cache.numel() // D + + k_cache_flat = k_cache.view(total_slots, D) + v_cache_flat = v_cache.view(total_slots, D) + valid_keys_flat = valid_keys.reshape(-1, D) + valid_values_flat = valid_values.reshape(-1, D) + + # In-place scatter using index_copy_ + k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat) + v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat) class Attention(nn.Module): @@ -66,8 +76,26 @@ class Attention(nn.Module): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): context = get_context() k_cache, v_cache = self.k_cache, self.v_cache - if k_cache.numel() and v_cache.numel(): - store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) + + # Determine if we're in chunked offload mode + is_chunked_offload = ( + context.is_chunked_prefill and + hasattr(context, 'kvcache_manager') and + context.kvcache_manager is not None and + hasattr(context.kvcache_manager, 'offload_engine') + ) + + if is_chunked_offload: + # Chunked offload mode: use compute_stream for store_kvcache + # This ensures proper synchronization with per-layer offload + compute_stream = context.kvcache_manager.offload_engine.compute_stream + if k_cache.numel() and v_cache.numel(): + with torch.cuda.stream(compute_stream): + store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) + else: + # Normal mode: store on default stream + if k_cache.numel() and v_cache.numel(): + store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: if context.is_chunked_prefill: @@ -182,31 +210,48 @@ class Attention(nn.Module): current_chunk_idx ) + # Get compute stream for all attention operations + compute_stream = None + if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'): + compute_stream = kvcache_manager.offload_engine.compute_stream # Compute attention against current chunk's KV (with causal mask) - torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") - current_o, current_lse = flash_attn_with_lse( - q_batched, - k_batched, - v_batched, - softmax_scale=self.scale, - causal=True, - ) - torch.cuda.nvtx.range_pop() + # Use compute_stream to ensure proper sync with store_kvcache and offload + if compute_stream is not None: + with torch.cuda.stream(compute_stream): + torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") + current_o, current_lse = flash_attn_with_lse( + q_batched, + k_batched, + v_batched, + softmax_scale=self.scale, + causal=True, + ) + torch.cuda.nvtx.range_pop() + else: + torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") + current_o, current_lse = flash_attn_with_lse( + q_batched, + k_batched, + v_batched, + softmax_scale=self.scale, + causal=True, + ) + torch.cuda.nvtx.range_pop() - # Merge with accumulated + # Merge with accumulated (all on compute_stream for consistency) if o_acc is None: final_o = current_o else: - # IMPORTANT: o_acc was computed on compute_stream. We need to sync before - # reading it on the default stream for the merge operation. - if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'): - offload_engine = kvcache_manager.offload_engine - torch.cuda.default_stream().wait_stream(offload_engine.compute_stream) - - torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") - final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) - torch.cuda.nvtx.range_pop() + if compute_stream is not None: + with torch.cuda.stream(compute_stream): + torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + torch.cuda.nvtx.range_pop() + else: + torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop() # ChunkedPrefill @@ -222,6 +267,16 @@ class Attention(nn.Module): cpu_block_id = cpu_block_ids[current_chunk_idx] offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id) + # CRITICAL: compute_stream must wait for offload to complete + # before the next layer's store_kvcache can overwrite the GPU slot. + # Without this, Layer N+1's store races with Layer N's offload copy. + compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot]) + + # Sync default stream with compute_stream before returning + # This ensures the result is ready for the rest of the model (layernorm, MLP) + if compute_stream is not None: + torch.cuda.default_stream().wait_stream(compute_stream) + # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] return final_o.squeeze(0) @@ -318,6 +373,7 @@ class Attention(nn.Module): offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id) prev_k, prev_v = offload_engine.get_kv_for_slot(slot) + prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, @@ -364,6 +420,7 @@ class Attention(nn.Module): torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}") prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) + prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, @@ -427,12 +484,13 @@ class Attention(nn.Module): raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") # Calculate valid tokens in the last block - # Note: For chunked prefill, each block is exactly block_size tokens - # The cpu_block_table only contains full prefill blocks + # The last prefill chunk might be partial (less than block_size tokens) block_size = kvcache_manager.block_size num_prefill_blocks = len(cpu_block_table) - # All prefill blocks are full (block_size tokens each) - last_block_valid_tokens = block_size + total_prefill_tokens = len(seq) - 1 # Exclude the current decode token + last_block_valid_tokens = total_prefill_tokens % block_size + if last_block_valid_tokens == 0 and total_prefill_tokens > 0: + last_block_valid_tokens = block_size # Last block was exactly full # Apply sparse policy if enabled if kvcache_manager.sparse_policy is not None: