fix
This commit is contained in:
2
bench.py
2
bench.py
@@ -15,6 +15,6 @@ prompt_token_ids = torch.randint(0, 10240, (batch_size, seq_len)).tolist()
|
|||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_tokens)
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_tokens)
|
||||||
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
completions = llm.generate(prompt_token_ids, sampling_params)
|
llm.generate(prompt_token_ids, sampling_params)
|
||||||
throughput = batch_size * max_tokens / (time.time() - t)
|
throughput = batch_size * max_tokens / (time.time() - t)
|
||||||
print(f"Throughput: {throughput: .2f}")
|
print(f"Throughput: {throughput: .2f}")
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class LLMEngine:
|
|||||||
desc="Generating",
|
desc="Generating",
|
||||||
dynamic_ncols=True,
|
dynamic_ncols=True,
|
||||||
)
|
)
|
||||||
if not isinstance(SamplingParams, list):
|
if not isinstance(sampling_params, list):
|
||||||
sampling_params = [sampling_params] * len(prompts)
|
sampling_params = [sampling_params] * len(prompts)
|
||||||
for prompt, sp in zip(prompts, sampling_params):
|
for prompt, sp in zip(prompts, sampling_params):
|
||||||
self.add_request(prompt, sp)
|
self.add_request(prompt, sp)
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class ModelRunner:
|
|||||||
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||||
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||||
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(16, max_bs + 1, 16))
|
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
|
||||||
self.graphs = {}
|
self.graphs = {}
|
||||||
self.graph_pool = None
|
self.graph_pool = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from collections import defaultdict
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class CUDATimer:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.events = defaultdict(list)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def record(self, name, enabled=True):
|
|
||||||
if not enabled:
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
||||||
self.events[name].append((start, end))
|
|
||||||
start.record()
|
|
||||||
yield
|
|
||||||
end.record()
|
|
||||||
|
|
||||||
def log(self):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
ret = []
|
|
||||||
for name, events in self.events.items():
|
|
||||||
total = 0
|
|
||||||
count = len(self.events)
|
|
||||||
for start, end in events:
|
|
||||||
total += start.elapsed_time(end)
|
|
||||||
ret.append(f"{name} {total:.2f}ms/{count}times")
|
|
||||||
return ", ".join(ret)
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
torch
|
torch
|
||||||
triton
|
triton
|
||||||
transformers
|
transformers
|
||||||
cmake
|
flash-attn
|
||||||
ninja
|
|
||||||
Reference in New Issue
Block a user