diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 688f8be..cd3ad01 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -1007,9 +1007,8 @@ class OffloadEngine: if not self._debug_mode or not self._debug_hooks: return - # GPU cache has no layer dimension - k = self.k_cache_gpu[slot_idx] - v = self.v_cache_gpu[slot_idx] + # Use get_kv_for_slot for consistency with attention.py + k, v = self.get_kv_for_slot(slot_idx) for hook in self._debug_hooks: try: diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 3cc170a..88dfce1 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -426,6 +426,14 @@ class Attention(nn.Module): if not cpu_block_table: raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") + # Calculate valid tokens in the last block + # prefill_len = total prefilled tokens (current decode token not yet in CPU) + block_size = kvcache_manager.block_size + prefill_len = len(seq) - 1 # Exclude current decode token + last_block_valid_tokens = prefill_len % block_size + if last_block_valid_tokens == 0 and prefill_len > 0: + last_block_valid_tokens = block_size # Last block is full + # Apply sparse policy if enabled if kvcache_manager.sparse_policy is not None: policy_ctx = PolicyContext( @@ -480,6 +488,18 @@ class Attention(nn.Module): else: k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk) + # Handle partial last block: slice to only include valid tokens + # This is critical because the rest of the block contains stale data + is_last_chunk = (end == len(cpu_block_table)) + if is_last_chunk and last_block_valid_tokens < block_size: + # Calculate total valid tokens in this chunk + # All blocks except the last are full, last block has last_block_valid_tokens + full_blocks = num_blocks_in_chunk - 1 + valid_tokens = full_blocks * block_size + last_block_valid_tokens + # Slice KV: [batch, seqlen, heads, dim] -> [batch, valid_tokens, heads, dim] + k_chunk = k_chunk[:, :valid_tokens, :, :] + v_chunk = v_chunk[:, :valid_tokens, :, :] + # Compute attention for this chunk o_chunk, lse_chunk = flash_attn_with_lse( q_batched, k_chunk, v_chunk, @@ -518,6 +538,11 @@ class Attention(nn.Module): start_pos = context.decode_start_pos_in_block num_accumulated = pos_in_block - start_pos + 1 + # IMPORTANT: Sync compute_stream with default stream before reading decode_slot + # store_kvcache writes to decode_slot on default stream (before entering this function) + # We need to ensure that write is complete before reading on compute_stream + compute_stream.wait_stream(torch.cuda.default_stream()) + with torch.cuda.stream(compute_stream): if num_accumulated > 0: # GPU cache has no layer dimension diff --git a/tests/test_debug_verification.py b/tests/test_debug_verification.py index 35d5179..532de2b 100644 --- a/tests/test_debug_verification.py +++ b/tests/test_debug_verification.py @@ -6,6 +6,7 @@ Injects distinctive K/V values, verifies loaded tensors match expected patterns. import os os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" +import inspect from random import randint, seed from typing import Dict, List import torch @@ -30,6 +31,27 @@ def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, if layer_id != 0: return + # Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots + frame = inspect.currentframe() + try: + caller_frame = frame.f_back + if caller_frame is not None: + local_vars = caller_frame.f_locals + if 'self' in local_vars: + self_obj = local_vars['self'] + if hasattr(self_obj, 'k_cache_gpu'): + num_slots = self_obj.k_cache_gpu.shape[0] + vals = [] + for i in range(num_slots): + v = self_obj.k_cache_gpu[i][0,0,0].item() + if i == slot_idx: + vals.append(f"[{v}]") + else: + vals.append(str(v)) + print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]") + finally: + del frame + load_log.append({ "chunk_idx": current_chunk[0], "cpu_block_id": cpu_block_id,