refactor
This commit is contained in:
@@ -23,7 +23,6 @@ class Block:
|
||||
|
||||
def update(self, hash: int, token_ids: list[int]):
|
||||
assert hash != -1
|
||||
assert len(token_ids) == 256
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
@@ -38,8 +37,7 @@ class Block:
|
||||
|
||||
class BlockManager:
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int = 256):
|
||||
assert block_size == 256
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
self.block_size = block_size
|
||||
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||
self.hash_to_block_id: dict[int, int] = dict()
|
||||
@@ -59,15 +57,18 @@ class BlockManager:
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_prefill(self):
|
||||
return len(self.free_block_ids) > 0.1 * len(self.blocks)
|
||||
|
||||
def can_allocate(self, seq: Sequence):
|
||||
return seq.num_blocks <= len(self.free_block_ids)
|
||||
return len(self.free_block_ids) >= seq.num_blocks
|
||||
|
||||
def allocate(self, seq: Sequence):
|
||||
assert not seq.block_table
|
||||
h = -1
|
||||
cache_miss = False
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i, self.block_size)
|
||||
token_ids = seq.block(i)
|
||||
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||
block_id = self.hash_to_block_id.get(h, -1)
|
||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||
@@ -96,8 +97,8 @@ class BlockManager:
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self):
|
||||
return len(self.free_block_ids) >= 1
|
||||
def can_append(self, seq: Sequence):
|
||||
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
||||
|
||||
def may_append(self, seq: Sequence):
|
||||
block_table = seq.block_table
|
||||
@@ -109,7 +110,7 @@ class BlockManager:
|
||||
block_table.append(block_id)
|
||||
elif len(seq) % self.block_size == 0:
|
||||
assert last_block.hash == -1
|
||||
token_ids = seq.last_block(self.block_size)
|
||||
token_ids = seq.last_block()
|
||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||
h = compute_hash(token_ids, prefix)
|
||||
last_block.update(h, token_ids)
|
||||
|
||||
@@ -17,6 +17,7 @@ class LLMEngine:
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(config, k):
|
||||
setattr(config, k, v)
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
config.hf_config = AutoConfig.from_pretrained(config.model)
|
||||
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
@@ -74,7 +75,7 @@ class LLMEngine:
|
||||
if finish and use_tqdm:
|
||||
pbar.update(1)
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
||||
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
|
||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
return outputs
|
||||
return outputs
|
||||
@@ -169,7 +169,7 @@ class ModelRunner:
|
||||
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32))
|
||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(16, max_bs + 1, 16))
|
||||
self.graphs = {}
|
||||
self.graph_pool = None
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class Scheduler:
|
||||
# self.running = deque(sorted(self.running))
|
||||
while self.running and num_seqs < self.max_num_seqs:
|
||||
seq = self.running.popleft()
|
||||
while not self.block_manager.can_append():
|
||||
while not self.block_manager.can_append(seq):
|
||||
if self.running:
|
||||
self.preempt(self.running.pop())
|
||||
else:
|
||||
@@ -66,7 +66,6 @@ class Scheduler:
|
||||
seq.status = SequenceStatus.WAITING
|
||||
self.block_manager.deallocate(seq)
|
||||
self.waiting.appendleft(seq)
|
||||
return True
|
||||
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||
self.num_tokens += len(token_ids)
|
||||
|
||||
@@ -12,6 +12,7 @@ class SequenceStatus(Enum):
|
||||
|
||||
|
||||
class Sequence:
|
||||
block_size = 256
|
||||
counter = count()
|
||||
|
||||
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
|
||||
@@ -44,27 +45,27 @@ class Sequence:
|
||||
|
||||
@num_cached_tokens.setter
|
||||
def num_cached_tokens(self, num_cached_tokens):
|
||||
assert num_cached_tokens % 256 == 0
|
||||
assert num_cached_tokens % self.block_size == 0
|
||||
self._num_cached_tokens = num_cached_tokens
|
||||
|
||||
@property
|
||||
def num_cached_blocks(self):
|
||||
return self.num_cached_tokens // 256
|
||||
return self.num_cached_tokens // self.block_size
|
||||
|
||||
@property
|
||||
def num_blocks(self):
|
||||
return (len(self.token_ids) + 255) // 256
|
||||
return (len(self.token_ids) + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
def last_token(self):
|
||||
return self.token_ids[-1]
|
||||
|
||||
def block(self, i, block_size=256):
|
||||
return self.token_ids[i*block_size: (i+1)*block_size]
|
||||
def block(self, i):
|
||||
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
||||
|
||||
def last_block(self, block_size=256):
|
||||
def last_block(self):
|
||||
n = self.num_blocks
|
||||
return self.token_ids[(n-1)*block_size:]
|
||||
return self.token_ids[(n-1)*self.block_size:]
|
||||
|
||||
def append_token(self, token_id: int):
|
||||
self.token_ids.append(token_id)
|
||||
|
||||
Reference in New Issue
Block a user