simplify
This commit is contained in:
@@ -5,14 +5,6 @@ import numpy as np
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
|
||||
|
||||
def compute_hash(token_ids: list[int], prefix: int = -1):
|
||||
h = xxhash.xxh64()
|
||||
if prefix != -1:
|
||||
h.update(prefix.to_bytes(8, "little"))
|
||||
h.update(np.array(token_ids).tobytes())
|
||||
return h.intdigest()
|
||||
|
||||
|
||||
class Block:
|
||||
|
||||
def __init__(self, block_id):
|
||||
@@ -22,7 +14,6 @@ class Block:
|
||||
self.token_ids = []
|
||||
|
||||
def update(self, hash: int, token_ids: list[int]):
|
||||
assert hash != -1
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
@@ -42,7 +33,15 @@ class BlockManager:
|
||||
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
||||
self.used_block_ids: set[int] = set()
|
||||
|
||||
def _allocate_block(self, block_id: int):
|
||||
@classmethod
|
||||
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
||||
h = xxhash.xxh64()
|
||||
if prefix != -1:
|
||||
h.update(prefix.to_bytes(8, "little"))
|
||||
h.update(np.array(token_ids).tobytes())
|
||||
return h.intdigest()
|
||||
|
||||
def _allocate_block(self, block_id: int) -> Block:
|
||||
block = self.blocks[block_id]
|
||||
assert block.ref_count == 0
|
||||
block.reset()
|
||||
@@ -50,12 +49,12 @@ class BlockManager:
|
||||
self.used_block_ids.add(block_id)
|
||||
return self.blocks[block_id]
|
||||
|
||||
def _deallocate_block(self, block_id: int):
|
||||
def _deallocate_block(self, block_id: int) -> Block:
|
||||
assert self.blocks[block_id].ref_count == 0
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_allocate(self, seq: Sequence):
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
return len(self.free_block_ids) >= seq.num_blocks
|
||||
|
||||
def allocate(self, seq: Sequence):
|
||||
@@ -64,7 +63,7 @@ class BlockManager:
|
||||
cache_miss = False
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i)
|
||||
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||
block_id = self.hash_to_block_id.get(h, -1)
|
||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||
cache_miss = True
|
||||
@@ -92,7 +91,7 @@ class BlockManager:
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self, seq: Sequence):
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
||||
|
||||
def may_append(self, seq: Sequence):
|
||||
@@ -107,7 +106,7 @@ class BlockManager:
|
||||
assert last_block.hash == -1
|
||||
token_ids = seq.block(seq.num_blocks-1)
|
||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||
h = compute_hash(token_ids, prefix)
|
||||
h = self.compute_hash(token_ids, prefix)
|
||||
last_block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = last_block.block_id
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user