diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index eddd270..375b534 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -146,6 +146,10 @@ class HybridKVCacheManager(KVCacheManager): # Key: sequence id, Value: starting position where decode began in current block self._decode_start_pos: Dict[int, int] = {} + # Track original prefill length (for correct last_block_valid_tokens calculation) + # Key: sequence id, Value: number of tokens from prefill (before decode started) + self._prefill_len: Dict[int, int] = {} + # Sparse attention policy (optional) self.sparse_policy: Optional["SparsePolicy"] = None @@ -542,6 +546,26 @@ class HybridKVCacheManager(KVCacheManager): seq_id = id(seq) self._decode_start_pos[seq_id] = 0 + def get_prefill_len(self, seq: Sequence) -> int: + """ + Get the original prefill length for a sequence. + + This is cached on first call to ensure correct last_block_valid_tokens + calculation during decode (the CPU blocks don't change after prefill). + + Args: + seq: Sequence + + Returns: + Number of tokens from prefill (before decode started) + """ + seq_id = id(seq) + if seq_id not in self._prefill_len: + # First decode step - store the prefill length + # len(seq) - 1 because current len includes the first decode token + self._prefill_len[seq_id] = len(seq) - 1 + return self._prefill_len[seq_id] + def clear_decode_tracking(self, seq: Sequence) -> None: """ Clear decode position tracking for sequence. @@ -553,6 +577,7 @@ class HybridKVCacheManager(KVCacheManager): """ seq_id = id(seq) self._decode_start_pos.pop(seq_id, None) + self._prefill_len.pop(seq_id, None) def __repr__(self) -> str: return ( diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 914c488..9d4d579 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -487,11 +487,12 @@ class Attention(nn.Module): if not cpu_block_table: raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") - # Calculate valid tokens in the last block - # The last prefill chunk might be partial (less than block_size tokens) + # Calculate valid tokens in the last CPU block + # CRITICAL: Use original prefill length, not current seq length! + # CPU blocks are fixed after prefill, their content doesn't change during decode. block_size = kvcache_manager.block_size num_prefill_blocks = len(cpu_block_table) - total_prefill_tokens = len(seq) - 1 # Exclude the current decode token + total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length last_block_valid_tokens = total_prefill_tokens % block_size if last_block_valid_tokens == 0 and total_prefill_tokens > 0: last_block_valid_tokens = block_size # Last block was exactly full