This commit is contained in:
GeeeekExplorer
2025-06-11 21:12:57 +08:00
parent b98e1ca305
commit 386290d69e
8 changed files with 31 additions and 35 deletions

View File

@@ -23,7 +23,6 @@ class Block:
def update(self, hash: int, token_ids: list[int]):
assert hash != -1
assert len(token_ids) == 256
self.hash = hash
self.token_ids = token_ids
@@ -38,8 +37,7 @@ class Block:
class BlockManager:
def __init__(self, num_blocks: int, block_size: int = 256):
assert block_size == 256
def __init__(self, num_blocks: int, block_size: int):
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
@@ -59,15 +57,18 @@ class BlockManager:
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
def can_prefill(self):
return len(self.free_block_ids) > 0.1 * len(self.blocks)
def can_allocate(self, seq: Sequence):
return seq.num_blocks <= len(self.free_block_ids)
return len(self.free_block_ids) >= seq.num_blocks
def allocate(self, seq: Sequence):
assert not seq.block_table
h = -1
cache_miss = False
for i in range(seq.num_blocks):
token_ids = seq.block(i, self.block_size)
token_ids = seq.block(i)
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
@@ -96,8 +97,8 @@ class BlockManager:
seq.num_cached_tokens = 0
seq.block_table.clear()
def can_append(self):
return len(self.free_block_ids) >= 1
def can_append(self, seq: Sequence):
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence):
block_table = seq.block_table
@@ -109,7 +110,7 @@ class BlockManager:
block_table.append(block_id)
elif len(seq) % self.block_size == 0:
assert last_block.hash == -1
token_ids = seq.last_block(self.block_size)
token_ids = seq.last_block()
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
h = compute_hash(token_ids, prefix)
last_block.update(h, token_ids)