This commit is contained in:
GeeeekExplorer
2025-06-11 21:12:57 +08:00
parent b98e1ca305
commit 386290d69e
8 changed files with 31 additions and 35 deletions

View File

@@ -21,9 +21,9 @@ prompts = [
) )
for prompt in prompts for prompt in prompts
] ]
completions = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
for p, c in zip(prompts, completions): for prompt, output in zip(prompts, outputs):
print("\n") print("\n")
print(f"Prompt: {p}") print(f"Prompt: {prompt}")
print(f"Completion: {c}") print(f"Completion: {output["text"]}")

View File

@@ -17,4 +17,4 @@ class Config:
def __post_init__(self): def __post_init__(self):
assert self.model assert self.model
assert self.kvcache_block_size == 256 assert self.kvcache_block_size % 256 == 0

View File

@@ -23,7 +23,6 @@ class Block:
def update(self, hash: int, token_ids: list[int]): def update(self, hash: int, token_ids: list[int]):
assert hash != -1 assert hash != -1
assert len(token_ids) == 256
self.hash = hash self.hash = hash
self.token_ids = token_ids self.token_ids = token_ids
@@ -38,8 +37,7 @@ class Block:
class BlockManager: class BlockManager:
def __init__(self, num_blocks: int, block_size: int = 256): def __init__(self, num_blocks: int, block_size: int):
assert block_size == 256
self.block_size = block_size self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict() self.hash_to_block_id: dict[int, int] = dict()
@@ -59,15 +57,18 @@ class BlockManager:
self.used_block_ids.remove(block_id) self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id) self.free_block_ids.append(block_id)
def can_prefill(self):
return len(self.free_block_ids) > 0.1 * len(self.blocks)
def can_allocate(self, seq: Sequence): def can_allocate(self, seq: Sequence):
return seq.num_blocks <= len(self.free_block_ids) return len(self.free_block_ids) >= seq.num_blocks
def allocate(self, seq: Sequence): def allocate(self, seq: Sequence):
assert not seq.block_table assert not seq.block_table
h = -1 h = -1
cache_miss = False cache_miss = False
for i in range(seq.num_blocks): for i in range(seq.num_blocks):
token_ids = seq.block(i, self.block_size) token_ids = seq.block(i)
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
block_id = self.hash_to_block_id.get(h, -1) block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids: if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
@@ -96,8 +97,8 @@ class BlockManager:
seq.num_cached_tokens = 0 seq.num_cached_tokens = 0
seq.block_table.clear() seq.block_table.clear()
def can_append(self): def can_append(self, seq: Sequence):
return len(self.free_block_ids) >= 1 return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence): def may_append(self, seq: Sequence):
block_table = seq.block_table block_table = seq.block_table
@@ -109,7 +110,7 @@ class BlockManager:
block_table.append(block_id) block_table.append(block_id)
elif len(seq) % self.block_size == 0: elif len(seq) % self.block_size == 0:
assert last_block.hash == -1 assert last_block.hash == -1
token_ids = seq.last_block(self.block_size) token_ids = seq.last_block()
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
h = compute_hash(token_ids, prefix) h = compute_hash(token_ids, prefix)
last_block.update(h, token_ids) last_block.update(h, token_ids)

View File

@@ -17,6 +17,7 @@ class LLMEngine:
for k, v in kwargs.items(): for k, v in kwargs.items():
if hasattr(config, k): if hasattr(config, k):
setattr(config, k, v) setattr(config, k, v)
Sequence.block_size = config.kvcache_block_size
config.hf_config = AutoConfig.from_pretrained(config.model) config.hf_config = AutoConfig.from_pretrained(config.model)
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings) config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
@@ -74,7 +75,7 @@ class LLMEngine:
if finish and use_tqdm: if finish and use_tqdm:
pbar.update(1) pbar.update(1)
outputs = [outputs[seq_id] for seq_id in sorted(outputs)] outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs] outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
return outputs return outputs

View File

@@ -169,7 +169,7 @@ class ModelRunner:
context_lens = torch.zeros(max_bs, dtype=torch.int32) context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size) outputs = torch.zeros(max_bs, hf_config.hidden_size)
self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32)) self.graph_bs = [1, 2, 4, 8, 16] + list(range(16, max_bs + 1, 16))
self.graphs = {} self.graphs = {}
self.graph_pool = None self.graph_pool = None

View File

@@ -46,7 +46,7 @@ class Scheduler:
# self.running = deque(sorted(self.running)) # self.running = deque(sorted(self.running))
while self.running and num_seqs < self.max_num_seqs: while self.running and num_seqs < self.max_num_seqs:
seq = self.running.popleft() seq = self.running.popleft()
while not self.block_manager.can_append(): while not self.block_manager.can_append(seq):
if self.running: if self.running:
self.preempt(self.running.pop()) self.preempt(self.running.pop())
else: else:
@@ -66,7 +66,6 @@ class Scheduler:
seq.status = SequenceStatus.WAITING seq.status = SequenceStatus.WAITING
self.block_manager.deallocate(seq) self.block_manager.deallocate(seq)
self.waiting.appendleft(seq) self.waiting.appendleft(seq)
return True
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
self.num_tokens += len(token_ids) self.num_tokens += len(token_ids)

View File

@@ -12,6 +12,7 @@ class SequenceStatus(Enum):
class Sequence: class Sequence:
block_size = 256
counter = count() counter = count()
def __init__(self, token_ids: list[int], sampling_params: SamplingParams): def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
@@ -44,27 +45,27 @@ class Sequence:
@num_cached_tokens.setter @num_cached_tokens.setter
def num_cached_tokens(self, num_cached_tokens): def num_cached_tokens(self, num_cached_tokens):
assert num_cached_tokens % 256 == 0 assert num_cached_tokens % self.block_size == 0
self._num_cached_tokens = num_cached_tokens self._num_cached_tokens = num_cached_tokens
@property @property
def num_cached_blocks(self): def num_cached_blocks(self):
return self.num_cached_tokens // 256 return self.num_cached_tokens // self.block_size
@property @property
def num_blocks(self): def num_blocks(self):
return (len(self.token_ids) + 255) // 256 return (len(self.token_ids) + self.block_size - 1) // self.block_size
@property @property
def last_token(self): def last_token(self):
return self.token_ids[-1] return self.token_ids[-1]
def block(self, i, block_size=256): def block(self, i):
return self.token_ids[i*block_size: (i+1)*block_size] return self.token_ids[i*self.block_size: (i+1)*self.block_size]
def last_block(self, block_size=256): def last_block(self):
n = self.num_blocks n = self.num_blocks
return self.token_ids[(n-1)*block_size:] return self.token_ids[(n-1)*self.block_size:]
def append_token(self, token_id: int): def append_token(self, token_id: int):
self.token_ids.append(token_id) self.token_ids.append(token_id)

View File

@@ -3,7 +3,7 @@ from torch import nn
import triton import triton
import triton.language as tl import triton.language as tl
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
@@ -65,18 +65,12 @@ class Attention(nn.Module):
v_cache = self.v_cache v_cache = self.v_cache
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill: if context.is_prefill:
if context.block_tables is None: # normal prefill if context.block_tables is not None: # prefix cache
cu_seqlens_k = context.cu_seqlens_k
seqused_k = None
else: # prefix cache
cu_seqlens_k = None
seqused_k = context.context_lens
k, v = k_cache, v_cache k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v, o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
seqused_k=seqused_k, softmax_scale=self.scale, softmax_scale=self.scale, causal=True, block_table=context.block_tables)
causal=True, block_table=context.block_tables)
else: # decode else: # decode
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables, cache_seqlens=context.context_lens, block_table=context.block_tables,