refactor
This commit is contained in:
@@ -21,9 +21,9 @@ prompts = [
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
completions = llm.generate(prompts, sampling_params)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for p, c in zip(prompts, completions):
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("\n")
|
||||
print(f"Prompt: {p}")
|
||||
print(f"Completion: {c}")
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Completion: {output["text"]}")
|
||||
|
||||
@@ -17,4 +17,4 @@ class Config:
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.model
|
||||
assert self.kvcache_block_size == 256
|
||||
assert self.kvcache_block_size % 256 == 0
|
||||
@@ -23,7 +23,6 @@ class Block:
|
||||
|
||||
def update(self, hash: int, token_ids: list[int]):
|
||||
assert hash != -1
|
||||
assert len(token_ids) == 256
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
@@ -38,8 +37,7 @@ class Block:
|
||||
|
||||
class BlockManager:
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int = 256):
|
||||
assert block_size == 256
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
self.block_size = block_size
|
||||
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||
self.hash_to_block_id: dict[int, int] = dict()
|
||||
@@ -59,15 +57,18 @@ class BlockManager:
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_prefill(self):
|
||||
return len(self.free_block_ids) > 0.1 * len(self.blocks)
|
||||
|
||||
def can_allocate(self, seq: Sequence):
|
||||
return seq.num_blocks <= len(self.free_block_ids)
|
||||
return len(self.free_block_ids) >= seq.num_blocks
|
||||
|
||||
def allocate(self, seq: Sequence):
|
||||
assert not seq.block_table
|
||||
h = -1
|
||||
cache_miss = False
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i, self.block_size)
|
||||
token_ids = seq.block(i)
|
||||
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||
block_id = self.hash_to_block_id.get(h, -1)
|
||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||
@@ -96,8 +97,8 @@ class BlockManager:
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self):
|
||||
return len(self.free_block_ids) >= 1
|
||||
def can_append(self, seq: Sequence):
|
||||
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
||||
|
||||
def may_append(self, seq: Sequence):
|
||||
block_table = seq.block_table
|
||||
@@ -109,7 +110,7 @@ class BlockManager:
|
||||
block_table.append(block_id)
|
||||
elif len(seq) % self.block_size == 0:
|
||||
assert last_block.hash == -1
|
||||
token_ids = seq.last_block(self.block_size)
|
||||
token_ids = seq.last_block()
|
||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||
h = compute_hash(token_ids, prefix)
|
||||
last_block.update(h, token_ids)
|
||||
|
||||
@@ -17,6 +17,7 @@ class LLMEngine:
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(config, k):
|
||||
setattr(config, k, v)
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
config.hf_config = AutoConfig.from_pretrained(config.model)
|
||||
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
@@ -74,7 +75,7 @@ class LLMEngine:
|
||||
if finish and use_tqdm:
|
||||
pbar.update(1)
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
||||
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
|
||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
return outputs
|
||||
return outputs
|
||||
@@ -169,7 +169,7 @@ class ModelRunner:
|
||||
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32))
|
||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(16, max_bs + 1, 16))
|
||||
self.graphs = {}
|
||||
self.graph_pool = None
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class Scheduler:
|
||||
# self.running = deque(sorted(self.running))
|
||||
while self.running and num_seqs < self.max_num_seqs:
|
||||
seq = self.running.popleft()
|
||||
while not self.block_manager.can_append():
|
||||
while not self.block_manager.can_append(seq):
|
||||
if self.running:
|
||||
self.preempt(self.running.pop())
|
||||
else:
|
||||
@@ -66,7 +66,6 @@ class Scheduler:
|
||||
seq.status = SequenceStatus.WAITING
|
||||
self.block_manager.deallocate(seq)
|
||||
self.waiting.appendleft(seq)
|
||||
return True
|
||||
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||
self.num_tokens += len(token_ids)
|
||||
|
||||
@@ -12,6 +12,7 @@ class SequenceStatus(Enum):
|
||||
|
||||
|
||||
class Sequence:
|
||||
block_size = 256
|
||||
counter = count()
|
||||
|
||||
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
|
||||
@@ -44,27 +45,27 @@ class Sequence:
|
||||
|
||||
@num_cached_tokens.setter
|
||||
def num_cached_tokens(self, num_cached_tokens):
|
||||
assert num_cached_tokens % 256 == 0
|
||||
assert num_cached_tokens % self.block_size == 0
|
||||
self._num_cached_tokens = num_cached_tokens
|
||||
|
||||
@property
|
||||
def num_cached_blocks(self):
|
||||
return self.num_cached_tokens // 256
|
||||
return self.num_cached_tokens // self.block_size
|
||||
|
||||
@property
|
||||
def num_blocks(self):
|
||||
return (len(self.token_ids) + 255) // 256
|
||||
return (len(self.token_ids) + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
def last_token(self):
|
||||
return self.token_ids[-1]
|
||||
|
||||
def block(self, i, block_size=256):
|
||||
return self.token_ids[i*block_size: (i+1)*block_size]
|
||||
def block(self, i):
|
||||
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
||||
|
||||
def last_block(self, block_size=256):
|
||||
def last_block(self):
|
||||
n = self.num_blocks
|
||||
return self.token_ids[(n-1)*block_size:]
|
||||
return self.token_ids[(n-1)*self.block_size:]
|
||||
|
||||
def append_token(self, token_id: int):
|
||||
self.token_ids.append(token_id)
|
||||
|
||||
@@ -3,7 +3,7 @@ from torch import nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
@@ -65,18 +65,12 @@ class Attention(nn.Module):
|
||||
v_cache = self.v_cache
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.is_prefill:
|
||||
if context.block_tables is None: # normal prefill
|
||||
cu_seqlens_k = context.cu_seqlens_k
|
||||
seqused_k = None
|
||||
else: # prefix cache
|
||||
cu_seqlens_k = None
|
||||
seqused_k = context.context_lens
|
||||
if context.block_tables is not None: # prefix cache
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_k=seqused_k, softmax_scale=self.scale,
|
||||
causal=True, block_table=context.block_tables)
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
|
||||
Reference in New Issue
Block a user