[fix] Fixed decode misalign.
This commit is contained in:
@@ -146,6 +146,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Key: sequence id, Value: starting position where decode began in current block
|
||||
self._decode_start_pos: Dict[int, int] = {}
|
||||
|
||||
# Track original prefill length (for correct last_block_valid_tokens calculation)
|
||||
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
||||
self._prefill_len: Dict[int, int] = {}
|
||||
|
||||
# Sparse attention policy (optional)
|
||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
||||
|
||||
@@ -542,6 +546,26 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
seq_id = id(seq)
|
||||
self._decode_start_pos[seq_id] = 0
|
||||
|
||||
def get_prefill_len(self, seq: Sequence) -> int:
|
||||
"""
|
||||
Get the original prefill length for a sequence.
|
||||
|
||||
This is cached on first call to ensure correct last_block_valid_tokens
|
||||
calculation during decode (the CPU blocks don't change after prefill).
|
||||
|
||||
Args:
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
Number of tokens from prefill (before decode started)
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
if seq_id not in self._prefill_len:
|
||||
# First decode step - store the prefill length
|
||||
# len(seq) - 1 because current len includes the first decode token
|
||||
self._prefill_len[seq_id] = len(seq) - 1
|
||||
return self._prefill_len[seq_id]
|
||||
|
||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Clear decode position tracking for sequence.
|
||||
@@ -553,6 +577,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
self._decode_start_pos.pop(seq_id, None)
|
||||
self._prefill_len.pop(seq_id, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -487,11 +487,12 @@ class Attention(nn.Module):
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last block
|
||||
# The last prefill chunk might be partial (less than block_size tokens)
|
||||
# Calculate valid tokens in the last CPU block
|
||||
# CRITICAL: Use original prefill length, not current seq length!
|
||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||
block_size = kvcache_manager.block_size
|
||||
num_prefill_blocks = len(cpu_block_table)
|
||||
total_prefill_tokens = len(seq) - 1 # Exclude the current decode token
|
||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
Reference in New Issue
Block a user