[fix] Fixed decode misalign.
This commit is contained in:
@@ -146,6 +146,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Key: sequence id, Value: starting position where decode began in current block
|
# Key: sequence id, Value: starting position where decode began in current block
|
||||||
self._decode_start_pos: Dict[int, int] = {}
|
self._decode_start_pos: Dict[int, int] = {}
|
||||||
|
|
||||||
|
# Track original prefill length (for correct last_block_valid_tokens calculation)
|
||||||
|
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
||||||
|
self._prefill_len: Dict[int, int] = {}
|
||||||
|
|
||||||
# Sparse attention policy (optional)
|
# Sparse attention policy (optional)
|
||||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
self.sparse_policy: Optional["SparsePolicy"] = None
|
||||||
|
|
||||||
@@ -542,6 +546,26 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
seq_id = id(seq)
|
seq_id = id(seq)
|
||||||
self._decode_start_pos[seq_id] = 0
|
self._decode_start_pos[seq_id] = 0
|
||||||
|
|
||||||
|
def get_prefill_len(self, seq: Sequence) -> int:
|
||||||
|
"""
|
||||||
|
Get the original prefill length for a sequence.
|
||||||
|
|
||||||
|
This is cached on first call to ensure correct last_block_valid_tokens
|
||||||
|
calculation during decode (the CPU blocks don't change after prefill).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: Sequence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens from prefill (before decode started)
|
||||||
|
"""
|
||||||
|
seq_id = id(seq)
|
||||||
|
if seq_id not in self._prefill_len:
|
||||||
|
# First decode step - store the prefill length
|
||||||
|
# len(seq) - 1 because current len includes the first decode token
|
||||||
|
self._prefill_len[seq_id] = len(seq) - 1
|
||||||
|
return self._prefill_len[seq_id]
|
||||||
|
|
||||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||||
"""
|
"""
|
||||||
Clear decode position tracking for sequence.
|
Clear decode position tracking for sequence.
|
||||||
@@ -553,6 +577,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
seq_id = id(seq)
|
seq_id = id(seq)
|
||||||
self._decode_start_pos.pop(seq_id, None)
|
self._decode_start_pos.pop(seq_id, None)
|
||||||
|
self._prefill_len.pop(seq_id, None)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -487,11 +487,12 @@ class Attention(nn.Module):
|
|||||||
if not cpu_block_table:
|
if not cpu_block_table:
|
||||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||||
|
|
||||||
# Calculate valid tokens in the last block
|
# Calculate valid tokens in the last CPU block
|
||||||
# The last prefill chunk might be partial (less than block_size tokens)
|
# CRITICAL: Use original prefill length, not current seq length!
|
||||||
|
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||||
block_size = kvcache_manager.block_size
|
block_size = kvcache_manager.block_size
|
||||||
num_prefill_blocks = len(cpu_block_table)
|
num_prefill_blocks = len(cpu_block_table)
|
||||||
total_prefill_tokens = len(seq) - 1 # Exclude the current decode token
|
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||||
last_block_valid_tokens = block_size # Last block was exactly full
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|||||||
Reference in New Issue
Block a user