refactor
This commit is contained in:
@@ -21,9 +21,9 @@ prompts = [
|
|||||||
)
|
)
|
||||||
for prompt in prompts
|
for prompt in prompts
|
||||||
]
|
]
|
||||||
completions = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
for p, c in zip(prompts, completions):
|
for prompt, output in zip(prompts, outputs):
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"Prompt: {p}")
|
print(f"Prompt: {prompt}")
|
||||||
print(f"Completion: {c}")
|
print(f"Completion: {output["text"]}")
|
||||||
|
|||||||
@@ -17,4 +17,4 @@ class Config:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.model
|
assert self.model
|
||||||
assert self.kvcache_block_size == 256
|
assert self.kvcache_block_size % 256 == 0
|
||||||
@@ -23,7 +23,6 @@ class Block:
|
|||||||
|
|
||||||
def update(self, hash: int, token_ids: list[int]):
|
def update(self, hash: int, token_ids: list[int]):
|
||||||
assert hash != -1
|
assert hash != -1
|
||||||
assert len(token_ids) == 256
|
|
||||||
self.hash = hash
|
self.hash = hash
|
||||||
self.token_ids = token_ids
|
self.token_ids = token_ids
|
||||||
|
|
||||||
@@ -38,8 +37,7 @@ class Block:
|
|||||||
|
|
||||||
class BlockManager:
|
class BlockManager:
|
||||||
|
|
||||||
def __init__(self, num_blocks: int, block_size: int = 256):
|
def __init__(self, num_blocks: int, block_size: int):
|
||||||
assert block_size == 256
|
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||||
self.hash_to_block_id: dict[int, int] = dict()
|
self.hash_to_block_id: dict[int, int] = dict()
|
||||||
@@ -59,15 +57,18 @@ class BlockManager:
|
|||||||
self.used_block_ids.remove(block_id)
|
self.used_block_ids.remove(block_id)
|
||||||
self.free_block_ids.append(block_id)
|
self.free_block_ids.append(block_id)
|
||||||
|
|
||||||
|
def can_prefill(self):
|
||||||
|
return len(self.free_block_ids) > 0.1 * len(self.blocks)
|
||||||
|
|
||||||
def can_allocate(self, seq: Sequence):
|
def can_allocate(self, seq: Sequence):
|
||||||
return seq.num_blocks <= len(self.free_block_ids)
|
return len(self.free_block_ids) >= seq.num_blocks
|
||||||
|
|
||||||
def allocate(self, seq: Sequence):
|
def allocate(self, seq: Sequence):
|
||||||
assert not seq.block_table
|
assert not seq.block_table
|
||||||
h = -1
|
h = -1
|
||||||
cache_miss = False
|
cache_miss = False
|
||||||
for i in range(seq.num_blocks):
|
for i in range(seq.num_blocks):
|
||||||
token_ids = seq.block(i, self.block_size)
|
token_ids = seq.block(i)
|
||||||
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||||
block_id = self.hash_to_block_id.get(h, -1)
|
block_id = self.hash_to_block_id.get(h, -1)
|
||||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||||
@@ -96,8 +97,8 @@ class BlockManager:
|
|||||||
seq.num_cached_tokens = 0
|
seq.num_cached_tokens = 0
|
||||||
seq.block_table.clear()
|
seq.block_table.clear()
|
||||||
|
|
||||||
def can_append(self):
|
def can_append(self, seq: Sequence):
|
||||||
return len(self.free_block_ids) >= 1
|
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
||||||
|
|
||||||
def may_append(self, seq: Sequence):
|
def may_append(self, seq: Sequence):
|
||||||
block_table = seq.block_table
|
block_table = seq.block_table
|
||||||
@@ -109,7 +110,7 @@ class BlockManager:
|
|||||||
block_table.append(block_id)
|
block_table.append(block_id)
|
||||||
elif len(seq) % self.block_size == 0:
|
elif len(seq) % self.block_size == 0:
|
||||||
assert last_block.hash == -1
|
assert last_block.hash == -1
|
||||||
token_ids = seq.last_block(self.block_size)
|
token_ids = seq.last_block()
|
||||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||||
h = compute_hash(token_ids, prefix)
|
h = compute_hash(token_ids, prefix)
|
||||||
last_block.update(h, token_ids)
|
last_block.update(h, token_ids)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class LLMEngine:
|
|||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if hasattr(config, k):
|
if hasattr(config, k):
|
||||||
setattr(config, k, v)
|
setattr(config, k, v)
|
||||||
|
Sequence.block_size = config.kvcache_block_size
|
||||||
config.hf_config = AutoConfig.from_pretrained(config.model)
|
config.hf_config = AutoConfig.from_pretrained(config.model)
|
||||||
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||||
@@ -74,7 +75,7 @@ class LLMEngine:
|
|||||||
if finish and use_tqdm:
|
if finish and use_tqdm:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
||||||
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
|
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
return outputs
|
return outputs
|
||||||
@@ -169,7 +169,7 @@ class ModelRunner:
|
|||||||
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||||
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||||
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32))
|
self.graph_bs = [1, 2, 4, 8, 16] + list(range(16, max_bs + 1, 16))
|
||||||
self.graphs = {}
|
self.graphs = {}
|
||||||
self.graph_pool = None
|
self.graph_pool = None
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class Scheduler:
|
|||||||
# self.running = deque(sorted(self.running))
|
# self.running = deque(sorted(self.running))
|
||||||
while self.running and num_seqs < self.max_num_seqs:
|
while self.running and num_seqs < self.max_num_seqs:
|
||||||
seq = self.running.popleft()
|
seq = self.running.popleft()
|
||||||
while not self.block_manager.can_append():
|
while not self.block_manager.can_append(seq):
|
||||||
if self.running:
|
if self.running:
|
||||||
self.preempt(self.running.pop())
|
self.preempt(self.running.pop())
|
||||||
else:
|
else:
|
||||||
@@ -66,7 +66,6 @@ class Scheduler:
|
|||||||
seq.status = SequenceStatus.WAITING
|
seq.status = SequenceStatus.WAITING
|
||||||
self.block_manager.deallocate(seq)
|
self.block_manager.deallocate(seq)
|
||||||
self.waiting.appendleft(seq)
|
self.waiting.appendleft(seq)
|
||||||
return True
|
|
||||||
|
|
||||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||||
self.num_tokens += len(token_ids)
|
self.num_tokens += len(token_ids)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class SequenceStatus(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Sequence:
|
class Sequence:
|
||||||
|
block_size = 256
|
||||||
counter = count()
|
counter = count()
|
||||||
|
|
||||||
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
|
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
|
||||||
@@ -44,27 +45,27 @@ class Sequence:
|
|||||||
|
|
||||||
@num_cached_tokens.setter
|
@num_cached_tokens.setter
|
||||||
def num_cached_tokens(self, num_cached_tokens):
|
def num_cached_tokens(self, num_cached_tokens):
|
||||||
assert num_cached_tokens % 256 == 0
|
assert num_cached_tokens % self.block_size == 0
|
||||||
self._num_cached_tokens = num_cached_tokens
|
self._num_cached_tokens = num_cached_tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_cached_blocks(self):
|
def num_cached_blocks(self):
|
||||||
return self.num_cached_tokens // 256
|
return self.num_cached_tokens // self.block_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_blocks(self):
|
def num_blocks(self):
|
||||||
return (len(self.token_ids) + 255) // 256
|
return (len(self.token_ids) + self.block_size - 1) // self.block_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_token(self):
|
def last_token(self):
|
||||||
return self.token_ids[-1]
|
return self.token_ids[-1]
|
||||||
|
|
||||||
def block(self, i, block_size=256):
|
def block(self, i):
|
||||||
return self.token_ids[i*block_size: (i+1)*block_size]
|
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
||||||
|
|
||||||
def last_block(self, block_size=256):
|
def last_block(self):
|
||||||
n = self.num_blocks
|
n = self.num_blocks
|
||||||
return self.token_ids[(n-1)*block_size:]
|
return self.token_ids[(n-1)*self.block_size:]
|
||||||
|
|
||||||
def append_token(self, token_id: int):
|
def append_token(self, token_id: int):
|
||||||
self.token_ids.append(token_id)
|
self.token_ids.append(token_id)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from torch import nn
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
|
||||||
|
|
||||||
@@ -65,18 +65,12 @@ class Attention(nn.Module):
|
|||||||
v_cache = self.v_cache
|
v_cache = self.v_cache
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
if context.is_prefill:
|
if context.is_prefill:
|
||||||
if context.block_tables is None: # normal prefill
|
if context.block_tables is not None: # prefix cache
|
||||||
cu_seqlens_k = context.cu_seqlens_k
|
|
||||||
seqused_k = None
|
|
||||||
else: # prefix cache
|
|
||||||
cu_seqlens_k = None
|
|
||||||
seqused_k = context.context_lens
|
|
||||||
k, v = k_cache, v_cache
|
k, v = k_cache, v_cache
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
o = flash_attn_varlen_func(q, k, v,
|
||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k,
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||||
seqused_k=seqused_k, softmax_scale=self.scale,
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||||
causal=True, block_table=context.block_tables)
|
|
||||||
else: # decode
|
else: # decode
|
||||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||||
|
|||||||
Reference in New Issue
Block a user