[fix] Fixed decode misalign.

This commit is contained in:
Zijie Tian
2026-01-05 19:00:44 +08:00
parent 054aaff403
commit 247c5312d9
2 changed files with 29 additions and 3 deletions

View File

@@ -146,6 +146,10 @@ class HybridKVCacheManager(KVCacheManager):
# Key: sequence id, Value: starting position where decode began in current block
self._decode_start_pos: Dict[int, int] = {}
# Track original prefill length (for correct last_block_valid_tokens calculation)
# Key: sequence id, Value: number of tokens from prefill (before decode started)
self._prefill_len: Dict[int, int] = {}
# Sparse attention policy (optional)
self.sparse_policy: Optional["SparsePolicy"] = None
@@ -542,6 +546,26 @@ class HybridKVCacheManager(KVCacheManager):
seq_id = id(seq)
self._decode_start_pos[seq_id] = 0
def get_prefill_len(self, seq: Sequence) -> int:
"""
Get the original prefill length for a sequence.
This is cached on first call to ensure correct last_block_valid_tokens
calculation during decode (the CPU blocks don't change after prefill).
Args:
seq: Sequence
Returns:
Number of tokens from prefill (before decode started)
"""
seq_id = id(seq)
if seq_id not in self._prefill_len:
# First decode step - store the prefill length
# len(seq) - 1 because current len includes the first decode token
self._prefill_len[seq_id] = len(seq) - 1
return self._prefill_len[seq_id]
def clear_decode_tracking(self, seq: Sequence) -> None:
"""
Clear decode position tracking for sequence.
@@ -553,6 +577,7 @@ class HybridKVCacheManager(KVCacheManager):
"""
seq_id = id(seq)
self._decode_start_pos.pop(seq_id, None)
self._prefill_len.pop(seq_id, None)
def __repr__(self) -> str:
return (

View File

@@ -487,11 +487,12 @@ class Attention(nn.Module):
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last block
# The last prefill chunk might be partial (less than block_size tokens)
# Calculate valid tokens in the last CPU block
# CRITICAL: Use original prefill length, not current seq length!
# CPU blocks are fixed after prefill, their content doesn't change during decode.
block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table)
total_prefill_tokens = len(seq) - 1 # Exclude the current decode token
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full