fix
This commit is contained in:
4
bench.py
4
bench.py
@@ -16,5 +16,5 @@ sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=ma
|
|||||||
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
completions = llm.generate(prompt_token_ids, sampling_params)
|
completions = llm.generate(prompt_token_ids, sampling_params)
|
||||||
troughput = batch_size * max_tokens / (time.time() - t)
|
throughput = batch_size * max_tokens / (time.time() - t)
|
||||||
print(f"Throughput: {troughput: .2f}")
|
print(f"Throughput: {throughput: .2f}")
|
||||||
|
|||||||
@@ -24,6 +24,6 @@ prompts = [
|
|||||||
completions = llm.generate(prompts, sampling_params)
|
completions = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
for p, c in zip(prompts, completions):
|
for p, c in zip(prompts, completions):
|
||||||
print("\n\n")
|
print("\n")
|
||||||
print(f"Prompt: {p}")
|
print(f"Prompt: {p}")
|
||||||
print(f"Completion: {c}")
|
print(f"Completion: {c}")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from transformers import AutoConfig
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
model: str = ''
|
model: str = ''
|
||||||
max_num_batched_tokens: int = 16384
|
max_num_batched_tokens: int = 32768
|
||||||
max_num_seqs: int = 512
|
max_num_seqs: int = 512
|
||||||
max_model_len: int = 4096
|
max_model_len: int = 4096
|
||||||
gpu_memory_utilization: float = 0.95
|
gpu_memory_utilization: float = 0.95
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from time import perf_counter
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ class LLMEngine:
|
|||||||
seqs, is_prefill = self.scheduler.schedule()
|
seqs, is_prefill = self.scheduler.schedule()
|
||||||
token_ids = self.model_runner.run(seqs, is_prefill)
|
token_ids = self.model_runner.run(seqs, is_prefill)
|
||||||
finished = self.scheduler.postprocess(seqs, token_ids)
|
finished = self.scheduler.postprocess(seqs, token_ids)
|
||||||
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)]
|
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)], sum(len(seq) for seq in seqs) if is_prefill else len(seqs)
|
||||||
|
|
||||||
def is_finished(self):
|
def is_finished(self):
|
||||||
return self.scheduler.is_finished()
|
return self.scheduler.is_finished()
|
||||||
@@ -45,19 +46,32 @@ class LLMEngine:
|
|||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar = tqdm(total=len(prompts),
|
pbar = tqdm(
|
||||||
desc="Processed prompts",
|
total=len(prompts),
|
||||||
|
desc="Generating",
|
||||||
|
dynamic_ncols=True,
|
||||||
)
|
)
|
||||||
if not isinstance(SamplingParams, list):
|
if not isinstance(SamplingParams, list):
|
||||||
sampling_params = [sampling_params] * len(prompts)
|
sampling_params = [sampling_params] * len(prompts)
|
||||||
for prompt, sp in zip(prompts, sampling_params):
|
for prompt, sp in zip(prompts, sampling_params):
|
||||||
self.add_request(prompt, sp)
|
self.add_request(prompt, sp)
|
||||||
outputs = defaultdict(list)
|
outputs = defaultdict(list)
|
||||||
|
prefill_throughput = decode_throughput = 0.
|
||||||
while not self.is_finished():
|
while not self.is_finished():
|
||||||
output = self.step()
|
t = perf_counter()
|
||||||
|
output, num_tokens = self.step()
|
||||||
|
if use_tqdm:
|
||||||
|
if num_tokens > len(output):
|
||||||
|
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||||
|
else:
|
||||||
|
decode_throughput = num_tokens / (perf_counter() - t)
|
||||||
|
pbar.set_postfix({
|
||||||
|
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||||
|
"Decode": f"{int(decode_throughput)}tok/s",
|
||||||
|
})
|
||||||
for seq_id, token_id, finish in output:
|
for seq_id, token_id, finish in output:
|
||||||
outputs[seq_id].append(token_id)
|
outputs[seq_id].append(token_id)
|
||||||
if use_tqdm and finish:
|
if finish and use_tqdm:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
||||||
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
|
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class ModelRunner:
|
|||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
self.allocate_kv_cache(config.gpu_memory_utilization)
|
self.allocate_kv_cache(config.gpu_memory_utilization)
|
||||||
if not self.enforce_eager:
|
if not self.enforce_eager:
|
||||||
self.capture_model()
|
self.capture_cudagraph()
|
||||||
torch.set_default_device("cpu")
|
torch.set_default_device("cpu")
|
||||||
torch.set_default_dtype(default_dtype)
|
torch.set_default_dtype(default_dtype)
|
||||||
|
|
||||||
@@ -101,7 +101,7 @@ class ModelRunner:
|
|||||||
input_ids.append(seq.last_token)
|
input_ids.append(seq.last_token)
|
||||||
positions.append(len(seq))
|
positions.append(len(seq))
|
||||||
context_lens.append(len(seq))
|
context_lens.append(len(seq))
|
||||||
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()))
|
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
@@ -152,7 +152,7 @@ class ModelRunner:
|
|||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_model(self):
|
def capture_cudagraph(self):
|
||||||
get_rng_state = torch.cuda.get_rng_state
|
get_rng_state = torch.cuda.get_rng_state
|
||||||
set_rng_state = torch.cuda.set_rng_state
|
set_rng_state = torch.cuda.set_rng_state
|
||||||
rng_state = torch.cuda.get_rng_state()
|
rng_state = torch.cuda.get_rng_state()
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class Scheduler:
|
|||||||
finished = []
|
finished = []
|
||||||
for seq, token_id in zip(seqs, token_ids):
|
for seq, token_id in zip(seqs, token_ids):
|
||||||
seq.append_token(token_id)
|
seq.append_token(token_id)
|
||||||
if token_id == self.eos or seq.num_completion_tokens == seq.max_tokens:
|
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||||
seq.status = SequenceStatus.FINISHED
|
seq.status = SequenceStatus.FINISHED
|
||||||
self.block_manager.deallocate(seq)
|
self.block_manager.deallocate(seq)
|
||||||
self.running.remove(seq)
|
self.running.remove(seq)
|
||||||
|
|||||||
@@ -64,10 +64,7 @@ class Sequence:
|
|||||||
|
|
||||||
def last_block(self, block_size=256):
|
def last_block(self, block_size=256):
|
||||||
n = self.num_blocks
|
n = self.num_blocks
|
||||||
t = len(self) + block_size - self.num_blocks * block_size
|
return self.token_ids[(n-1)*block_size:]
|
||||||
x = self.token_ids[(n-1)*block_size:]
|
|
||||||
assert len(x) == t
|
|
||||||
return x
|
|
||||||
|
|
||||||
def append_token(self, token_id: int):
|
def append_token(self, token_id: int):
|
||||||
self.token_ids.append(token_id)
|
self.token_ids.append(token_id)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
|
||||||
|
|
||||||
@@ -64,8 +63,8 @@ class Attention(nn.Module):
|
|||||||
context = get_context()
|
context = get_context()
|
||||||
k_cache = self.k_cache
|
k_cache = self.k_cache
|
||||||
v_cache = self.v_cache
|
v_cache = self.v_cache
|
||||||
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
if context.is_prefill:
|
if context.is_prefill:
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
|
||||||
if context.block_tables is None: # normal prefill
|
if context.block_tables is None: # normal prefill
|
||||||
cu_seqlens_k = context.cu_seqlens_k
|
cu_seqlens_k = context.cu_seqlens_k
|
||||||
seqused_k = None
|
seqused_k = None
|
||||||
@@ -79,7 +78,7 @@ class Attention(nn.Module):
|
|||||||
seqused_k=seqused_k, softmax_scale=self.scale,
|
seqused_k=seqused_k, softmax_scale=self.scale,
|
||||||
causal=True, block_table=context.block_tables)
|
causal=True, block_table=context.block_tables)
|
||||||
else: # decode
|
else: # decode
|
||||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1),
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||||
softmax_scale=self.scale, causal=True)
|
softmax_scale=self.scale, causal=True)
|
||||||
o = o.view(-1, self.num_heads * self.head_dim)
|
o = o.view(-1, self.num_heads * self.head_dim)
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ class Sampler(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||||
logits = logits.to(torch.float)
|
logits = logits.to(torch.float)
|
||||||
if temperatures is not None:
|
greedy_tokens = logits.argmax(dim=-1)
|
||||||
logits.div_(temperatures.unsqueeze(dim=1))
|
logits.div_(temperatures.unsqueeze(dim=1))
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||||
return sampled_tokens
|
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
|
||||||
|
|||||||
@@ -212,7 +212,8 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = Qwen3Model(config)
|
self.model = Qwen3Model(config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
if config.tie_word_embeddings:
|
self.tie_word_embeddings = config.tie_word_embeddings
|
||||||
|
if self.tie_word_embeddings:
|
||||||
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -236,6 +237,8 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
||||||
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
||||||
for n, p in self.named_parameters():
|
for n, p in self.named_parameters():
|
||||||
|
if self.tie_word_embeddings and "lm_head" in n:
|
||||||
|
continue
|
||||||
for x in self.packed_modules_mapping:
|
for x in self.packed_modules_mapping:
|
||||||
if x in n:
|
if x in n:
|
||||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||||
|
|||||||
Reference in New Issue
Block a user