[refactor] Refactor the kvcache offload.

This commit is contained in:
Zijie Tian
2026-01-04 19:37:03 +08:00
parent 00ed17c640
commit 772313db8f
3 changed files with 224 additions and 57 deletions

103
DEBUG_SUMMARY.md Normal file
View File

@@ -0,0 +1,103 @@
# Chunked Prefill Bug Debug Summary
## Problem
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
The model generates completely wrong tokens instead of the expected "7492".
## Investigation Progress
### 1. Stream Synchronization Fix (Completed)
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
- Added sync: `default_stream.wait_stream(compute_stream)` before return
### 2. KV Cache Alignment Verification (Completed)
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
**RoPE Alignment:**
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
- Confirmed RoPE is NOT the cause of the bug
**K/V Cache Alignment (Chunk 0):**
- Cosine similarity: ~1.0 for all layers
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
- Mean diff: < 0.001
- **Conclusion: K/V cache offload is working correctly**
### 3. Layer Output Divergence Analysis (Completed)
Created per-chunk layer output comparison:
**Chunk 0 (tokens 0-4096):**
- All layers pass with excellent cosine similarity (0.999+)
- Max diff grows in later layers but within acceptable range
**Chunk 1 (tokens 4096-8192):**
- Layers 0-19: OK (cosine ~1.0)
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
- Divergence correlates with later transformer layers
### 4. Critical Discovery: Single-Chunk Offload Also Fails
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
```
# Without offload: PASSES
python tests/test_needle.py --input-len 2048
# Output: "7492" (correct)
# With offload: FAILS
python tests/test_needle.py --enable-offload --input-len 2048
# Output: "The Ble White Th G Lopsiswin..." (garbage)
```
**This proves the bug is NOT in:**
- Chunked attention logic (merge_attention_outputs)
- Multi-chunk KV loading
- Ring buffer pipeline
**The bug IS in:**
- The decode path when CPU offload is enabled
- How prefilled KV is loaded/used during decode
### 5. Decode Path Analysis (In Progress)
The decode path in CPU offload mode:
1. Prefill writes KV to GPU, offloads to CPU
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
3. Attend to prefilled KV + accumulated decode tokens
4. Merge results
**Observations:**
- `prefilled_blocks` set is empty after decode (should contain block IDs)
- CPU cache has valid data (reasonable mean/std values)
- Decode buffer has zeros (decode tokens not being stored correctly?)
## Current Status
### Working
- Stream synchronization fixes
- K/V cache offload to CPU (verified alignment)
- RoPE implementation
- Chunked prefill attention for first chunk
### Not Working
- Decode with CPU offload (even for single-chunk inputs)
- Multi-chunk attention (divergence in later layers for chunk 1)
## Next Steps
1. Debug why `prefilled_blocks` is empty after decode
2. Check if decode path correctly loads KV from CPU
3. Verify decode buffer is being written correctly
4. Compare decode attention outputs between offload and non-offload modes
## Key Files
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
## Hypothesis
The decode path fails because:
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
3. OR there's a stream synchronization issue specific to decode path

View File

@@ -61,8 +61,14 @@ class NanovllmSteppable(SteppableModel):
def make_layer_hook(idx):
def hook(module, input, output):
# Decoder layer returns (hidden_states, residual)
hidden_states = output[0] if isinstance(output, tuple) else output
self._captured[f"layer_{idx}"] = hidden_states.detach().clone()
# hidden_states is MLP output, residual is accumulated residual
# To match torch reference, we need hidden_states + residual
if isinstance(output, tuple) and len(output) >= 2:
hidden_states, residual = output[0], output[1]
full_output = hidden_states + residual
else:
full_output = output
self._captured[f"layer_{idx}"] = full_output.detach().clone()
return hook
self._hooks.append(

View File

@@ -2,8 +2,6 @@ import logging
import torch
import torch.cuda.nvtx
from torch import nn
import triton
import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
@@ -12,37 +10,49 @@ from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__)
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
def store_kvcache(
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
"""
Store key/value tensors into KV cache using slot mapping.
This is a pure PyTorch implementation replacing the previous Triton kernel.
Uses index_copy_ for efficient in-place scatter operation.
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
Args:
key: [N, num_kv_heads, head_dim]
value: [N, num_kv_heads, head_dim]
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar
v_cache: same shape as k_cache
slot_mapping: [N] with values as flat indices, -1 means skip
"""
# Filter out invalid slots (slot == -1)
valid_mask = slot_mapping >= 0
if not valid_mask.any():
return
valid_slots = slot_mapping[valid_mask]
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
valid_values = value[valid_mask]
# Flatten cache and KV for scatter operation
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim
N, num_kv_heads, head_dim = key.shape
D = num_kv_heads * head_dim
total_slots = k_cache.numel() // D
k_cache_flat = k_cache.view(total_slots, D)
v_cache_flat = v_cache.view(total_slots, D)
valid_keys_flat = valid_keys.reshape(-1, D)
valid_values_flat = valid_values.reshape(-1, D)
# In-place scatter using index_copy_
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
class Attention(nn.Module):
@@ -66,8 +76,26 @@ class Attention(nn.Module):
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Determine if we're in chunked offload mode
is_chunked_offload = (
context.is_chunked_prefill and
hasattr(context, 'kvcache_manager') and
context.kvcache_manager is not None and
hasattr(context.kvcache_manager, 'offload_engine')
)
if is_chunked_offload:
# Chunked offload mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
else:
# Normal mode: store on default stream
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.is_chunked_prefill:
@@ -182,31 +210,48 @@ class Attention(nn.Module):
current_chunk_idx
)
# Get compute stream for all attention operations
compute_stream = None
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
compute_stream = kvcache_manager.offload_engine.compute_stream
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Use compute_stream to ensure proper sync with store_kvcache and offload
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated
# Merge with accumulated (all on compute_stream for consistency)
if o_acc is None:
final_o = current_o
else:
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
# reading it on the default stream for the merge operation.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
@@ -222,6 +267,16 @@ class Attention(nn.Module):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# CRITICAL: compute_stream must wait for offload to complete
# before the next layer's store_kvcache can overwrite the GPU slot.
# Without this, Layer N+1's store races with Layer N's offload copy.
compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot])
# Sync default stream with compute_stream before returning
# This ensures the result is ready for the rest of the model (layernorm, MLP)
if compute_stream is not None:
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -318,6 +373,7 @@ class Attention(nn.Module):
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
@@ -364,6 +420,7 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
@@ -427,12 +484,13 @@ class Attention(nn.Module):
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last block
# Note: For chunked prefill, each block is exactly block_size tokens
# The cpu_block_table only contains full prefill blocks
# The last prefill chunk might be partial (less than block_size tokens)
block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table)
# All prefill blocks are full (block_size tokens each)
last_block_valid_tokens = block_size
total_prefill_tokens = len(seq) - 1 # Exclude the current decode token
last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None: