[fix] Fixed decode misalign.
This commit is contained in:
@@ -146,6 +146,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Key: sequence id, Value: starting position where decode began in current block
|
||||
self._decode_start_pos: Dict[int, int] = {}
|
||||
|
||||
# Track original prefill length (for correct last_block_valid_tokens calculation)
|
||||
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
||||
self._prefill_len: Dict[int, int] = {}
|
||||
|
||||
# Sparse attention policy (optional)
|
||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
||||
|
||||
@@ -542,6 +546,26 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
seq_id = id(seq)
|
||||
self._decode_start_pos[seq_id] = 0
|
||||
|
||||
def get_prefill_len(self, seq: Sequence) -> int:
|
||||
"""
|
||||
Get the original prefill length for a sequence.
|
||||
|
||||
This is cached on first call to ensure correct last_block_valid_tokens
|
||||
calculation during decode (the CPU blocks don't change after prefill).
|
||||
|
||||
Args:
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
Number of tokens from prefill (before decode started)
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
if seq_id not in self._prefill_len:
|
||||
# First decode step - store the prefill length
|
||||
# len(seq) - 1 because current len includes the first decode token
|
||||
self._prefill_len[seq_id] = len(seq) - 1
|
||||
return self._prefill_len[seq_id]
|
||||
|
||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Clear decode position tracking for sequence.
|
||||
@@ -553,6 +577,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
self._decode_start_pos.pop(seq_id, None)
|
||||
self._prefill_len.pop(seq_id, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user