[refactor] Delete unnesscessory test, and refacrtor the offload prefix cache.
This commit is contained in:
@@ -62,6 +62,8 @@ class LLMEngine:
|
||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||
self.scheduler.postprocess(seqs, token_ids)
|
||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||||
|
||||
#> Calculate number of tokens processed
|
||||
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
||||
return outputs, num_tokens
|
||||
|
||||
|
||||
@@ -128,6 +128,9 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id
|
||||
|
||||
# Prefix cache (uses logical block IDs)
|
||||
# NOTE: Currently WRITE-ONLY in offload mode - hashes are stored but never
|
||||
#> used for cache hit detection. This is intentional: offload mode always
|
||||
#> allocates new blocks and doesn't reuse existing ones.
|
||||
self.hash_to_logical_id: Dict[int, int] = {}
|
||||
|
||||
# Step counter for policy
|
||||
@@ -258,14 +261,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
pos_in_block = seq_len % self._block_size
|
||||
|
||||
if pos_in_block == 1:
|
||||
# Need new block
|
||||
assert last_block.hash != -1
|
||||
|
||||
# Need new block (previous block is full)
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
block.hash = -1
|
||||
block.token_ids = []
|
||||
|
||||
# Allocate new block to CPU (ring buffer mode)
|
||||
if not self.free_cpu_blocks:
|
||||
@@ -279,17 +278,13 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block_table.append(logical_id)
|
||||
|
||||
elif pos_in_block == 0:
|
||||
# Block is full, update hash for prefix cache
|
||||
assert last_block.hash == -1
|
||||
token_ids = seq.block(seq.num_blocks - 1)
|
||||
prefix_hash = (
|
||||
self.logical_blocks[block_table[-2]].hash
|
||||
if len(block_table) > 1 else -1
|
||||
)
|
||||
h = self.compute_hash(token_ids, prefix_hash)
|
||||
last_block.hash = h
|
||||
last_block.token_ids = token_ids.copy()
|
||||
self.hash_to_logical_id[h] = last_logical_id
|
||||
# Block is full
|
||||
# NOTE: Prefix cache disabled in offload mode
|
||||
# If enabled, would compute hash and update:
|
||||
# h = self.compute_hash(seq.block(seq.num_blocks - 1), prefix_hash)
|
||||
# last_block.hash = h
|
||||
# self.hash_to_logical_id[h] = last_logical_id
|
||||
pass
|
||||
|
||||
def prepare_for_attention(
|
||||
self,
|
||||
@@ -369,8 +364,6 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
h = -1 # Running hash for prefix cache
|
||||
|
||||
for i in range(seq.num_blocks):
|
||||
# Allocate CPU block
|
||||
if not self.free_cpu_blocks:
|
||||
@@ -381,19 +374,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
|
||||
# Get token IDs for this block and compute hash
|
||||
token_ids = seq.block(i)
|
||||
if len(token_ids) == self._block_size:
|
||||
h = self.compute_hash(token_ids, h)
|
||||
else:
|
||||
h = -1 # Incomplete block
|
||||
|
||||
# Allocate logical block
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
block.hash = h
|
||||
block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else []
|
||||
block.location = BlockLocation.CPU
|
||||
block.cpu_block_id = cpu_block_id
|
||||
block.gpu_slot = -1
|
||||
@@ -401,9 +385,11 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
seq.block_table.append(logical_id)
|
||||
|
||||
# Update prefix cache
|
||||
if h != -1:
|
||||
self.hash_to_logical_id[h] = logical_id
|
||||
# NOTE: Prefix cache disabled in offload mode
|
||||
# If enabled, would compute hash and update:
|
||||
# h = self.compute_hash(seq.block(i), prefix_hash)
|
||||
# block.hash = h
|
||||
# self.hash_to_logical_id[h] = logical_id
|
||||
|
||||
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user