diff --git a/example.py b/example.py index 3dc65f8..5ae260e 100644 --- a/example.py +++ b/example.py @@ -21,9 +21,9 @@ prompts = [ ) for prompt in prompts ] -completions = llm.generate(prompts, sampling_params) +outputs = llm.generate(prompts, sampling_params) -for p, c in zip(prompts, completions): +for prompt, output in zip(prompts, outputs): print("\n") - print(f"Prompt: {p}") - print(f"Completion: {c}") + print(f"Prompt: {prompt}") + print(f"Completion: {output["text"]}") diff --git a/nanovllm/config.py b/nanovllm/config.py index 4c33837..ae76ade 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -17,4 +17,4 @@ class Config: def __post_init__(self): assert self.model - assert self.kvcache_block_size == 256 \ No newline at end of file + assert self.kvcache_block_size % 256 == 0 \ No newline at end of file diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index 4739c5b..72a0eef 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -23,7 +23,6 @@ class Block: def update(self, hash: int, token_ids: list[int]): assert hash != -1 - assert len(token_ids) == 256 self.hash = hash self.token_ids = token_ids @@ -38,8 +37,7 @@ class Block: class BlockManager: - def __init__(self, num_blocks: int, block_size: int = 256): - assert block_size == 256 + def __init__(self, num_blocks: int, block_size: int): self.block_size = block_size self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] self.hash_to_block_id: dict[int, int] = dict() @@ -59,15 +57,18 @@ class BlockManager: self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) + def can_prefill(self): + return len(self.free_block_ids) > 0.1 * len(self.blocks) + def can_allocate(self, seq: Sequence): - return seq.num_blocks <= len(self.free_block_ids) + return len(self.free_block_ids) >= seq.num_blocks def allocate(self, seq: Sequence): assert not seq.block_table h = -1 cache_miss = False for i in range(seq.num_blocks): - token_ids = seq.block(i, self.block_size) + token_ids = seq.block(i) h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 block_id = self.hash_to_block_id.get(h, -1) if block_id == -1 or self.blocks[block_id].token_ids != token_ids: @@ -96,8 +97,8 @@ class BlockManager: seq.num_cached_tokens = 0 seq.block_table.clear() - def can_append(self): - return len(self.free_block_ids) >= 1 + def can_append(self, seq: Sequence): + return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) def may_append(self, seq: Sequence): block_table = seq.block_table @@ -109,7 +110,7 @@ class BlockManager: block_table.append(block_id) elif len(seq) % self.block_size == 0: assert last_block.hash == -1 - token_ids = seq.last_block(self.block_size) + token_ids = seq.last_block() prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 h = compute_hash(token_ids, prefix) last_block.update(h, token_ids) diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 6c48af9..0194e6a 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -17,6 +17,7 @@ class LLMEngine: for k, v in kwargs.items(): if hasattr(config, k): setattr(config, k, v) + Sequence.block_size = config.kvcache_block_size config.hf_config = AutoConfig.from_pretrained(config.model) config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) @@ -74,7 +75,7 @@ class LLMEngine: if finish and use_tqdm: pbar.update(1) outputs = [outputs[seq_id] for seq_id in sorted(outputs)] - outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs] + outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] if use_tqdm: pbar.close() - return outputs + return outputs \ No newline at end of file diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 3d26e48..c8f8588 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -169,7 +169,7 @@ class ModelRunner: context_lens = torch.zeros(max_bs, dtype=torch.int32) block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) outputs = torch.zeros(max_bs, hf_config.hidden_size) - self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32)) + self.graph_bs = [1, 2, 4, 8, 16] + list(range(16, max_bs + 1, 16)) self.graphs = {} self.graph_pool = None diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index 28cc298..cb8bfd1 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -46,7 +46,7 @@ class Scheduler: # self.running = deque(sorted(self.running)) while self.running and num_seqs < self.max_num_seqs: seq = self.running.popleft() - while not self.block_manager.can_append(): + while not self.block_manager.can_append(seq): if self.running: self.preempt(self.running.pop()) else: @@ -66,7 +66,6 @@ class Scheduler: seq.status = SequenceStatus.WAITING self.block_manager.deallocate(seq) self.waiting.appendleft(seq) - return True def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: self.num_tokens += len(token_ids) diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 717434a..d519fb7 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -12,6 +12,7 @@ class SequenceStatus(Enum): class Sequence: + block_size = 256 counter = count() def __init__(self, token_ids: list[int], sampling_params: SamplingParams): @@ -44,27 +45,27 @@ class Sequence: @num_cached_tokens.setter def num_cached_tokens(self, num_cached_tokens): - assert num_cached_tokens % 256 == 0 + assert num_cached_tokens % self.block_size == 0 self._num_cached_tokens = num_cached_tokens @property def num_cached_blocks(self): - return self.num_cached_tokens // 256 + return self.num_cached_tokens // self.block_size @property def num_blocks(self): - return (len(self.token_ids) + 255) // 256 + return (len(self.token_ids) + self.block_size - 1) // self.block_size @property def last_token(self): return self.token_ids[-1] - def block(self, i, block_size=256): - return self.token_ids[i*block_size: (i+1)*block_size] + def block(self, i): + return self.token_ids[i*self.block_size: (i+1)*self.block_size] - def last_block(self, block_size=256): + def last_block(self): n = self.num_blocks - return self.token_ids[(n-1)*block_size:] + return self.token_ids[(n-1)*self.block_size:] def append_token(self, token_id: int): self.token_ids.append(token_id) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 3f52a3b..bb5344e 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -3,7 +3,7 @@ from torch import nn import triton import triton.language as tl -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context @@ -65,18 +65,12 @@ class Attention(nn.Module): v_cache = self.v_cache store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: - if context.block_tables is None: # normal prefill - cu_seqlens_k = context.cu_seqlens_k - seqused_k = None - else: # prefix cache - cu_seqlens_k = None - seqused_k = context.context_lens + if context.block_tables is not None: # prefix cache k, v = k_cache, v_cache o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, - max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k, - seqused_k=seqused_k, softmax_scale=self.scale, - causal=True, block_table=context.block_tables) + max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, + softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: # decode o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, cache_seqlens=context.context_lens, block_table=context.block_tables,