fix
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from time import perf_counter
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
@@ -33,7 +34,7 @@ class LLMEngine:
|
||||
seqs, is_prefill = self.scheduler.schedule()
|
||||
token_ids = self.model_runner.run(seqs, is_prefill)
|
||||
finished = self.scheduler.postprocess(seqs, token_ids)
|
||||
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)]
|
||||
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)], sum(len(seq) for seq in seqs) if is_prefill else len(seqs)
|
||||
|
||||
def is_finished(self):
|
||||
return self.scheduler.is_finished()
|
||||
@@ -45,19 +46,32 @@ class LLMEngine:
|
||||
use_tqdm: bool = True,
|
||||
) -> list[str]:
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts),
|
||||
desc="Processed prompts",
|
||||
pbar = tqdm(
|
||||
total=len(prompts),
|
||||
desc="Generating",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
if not isinstance(SamplingParams, list):
|
||||
sampling_params = [sampling_params] * len(prompts)
|
||||
for prompt, sp in zip(prompts, sampling_params):
|
||||
self.add_request(prompt, sp)
|
||||
outputs = defaultdict(list)
|
||||
prefill_throughput = decode_throughput = 0.
|
||||
while not self.is_finished():
|
||||
output = self.step()
|
||||
t = perf_counter()
|
||||
output, num_tokens = self.step()
|
||||
if use_tqdm:
|
||||
if num_tokens > len(output):
|
||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||
else:
|
||||
decode_throughput = num_tokens / (perf_counter() - t)
|
||||
pbar.set_postfix({
|
||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||
"Decode": f"{int(decode_throughput)}tok/s",
|
||||
})
|
||||
for seq_id, token_id, finish in output:
|
||||
outputs[seq_id].append(token_id)
|
||||
if use_tqdm and finish:
|
||||
if finish and use_tqdm:
|
||||
pbar.update(1)
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
||||
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
|
||||
|
||||
@@ -24,7 +24,7 @@ class ModelRunner:
|
||||
self.sampler = Sampler()
|
||||
self.allocate_kv_cache(config.gpu_memory_utilization)
|
||||
if not self.enforce_eager:
|
||||
self.capture_model()
|
||||
self.capture_cudagraph()
|
||||
torch.set_default_device("cpu")
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
@@ -101,7 +101,7 @@ class ModelRunner:
|
||||
input_ids.append(seq.last_token)
|
||||
positions.append(len(seq))
|
||||
context_lens.append(len(seq))
|
||||
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()))
|
||||
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
@@ -152,7 +152,7 @@ class ModelRunner:
|
||||
return token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self):
|
||||
def capture_cudagraph(self):
|
||||
get_rng_state = torch.cuda.get_rng_state
|
||||
set_rng_state = torch.cuda.set_rng_state
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
|
||||
@@ -73,7 +73,7 @@ class Scheduler:
|
||||
finished = []
|
||||
for seq, token_id in zip(seqs, token_ids):
|
||||
seq.append_token(token_id)
|
||||
if token_id == self.eos or seq.num_completion_tokens == seq.max_tokens:
|
||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.deallocate(seq)
|
||||
self.running.remove(seq)
|
||||
|
||||
@@ -64,10 +64,7 @@ class Sequence:
|
||||
|
||||
def last_block(self, block_size=256):
|
||||
n = self.num_blocks
|
||||
t = len(self) + block_size - self.num_blocks * block_size
|
||||
x = self.token_ids[(n-1)*block_size:]
|
||||
assert len(x) == t
|
||||
return x
|
||||
return self.token_ids[(n-1)*block_size:]
|
||||
|
||||
def append_token(self, token_id: int):
|
||||
self.token_ids.append(token_id)
|
||||
|
||||
Reference in New Issue
Block a user