warmup and allocate
This commit is contained in:
@@ -6,7 +6,7 @@ from transformers import AutoConfig
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str
|
||||
max_num_batched_tokens: int = 32768
|
||||
max_num_batched_tokens: int = 16384
|
||||
max_num_seqs: int = 512
|
||||
max_model_len: int = 4096
|
||||
gpu_memory_utilization: float = 0.9
|
||||
@@ -23,3 +23,4 @@ class Config:
|
||||
assert 1 <= self.tensor_parallel_size <= 8
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
|
||||
@@ -31,7 +31,8 @@ class ModelRunner:
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
load_model(self.model, config.model)
|
||||
self.sampler = Sampler()
|
||||
self.allocate_kv_cache(config.gpu_memory_utilization)
|
||||
peak = self.warmup_model()
|
||||
self.allocate_kv_cache(config.gpu_memory_utilization, peak)
|
||||
if not self.enforce_eager:
|
||||
self.capture_cudagraph()
|
||||
torch.set_default_device("cpu")
|
||||
@@ -46,6 +47,18 @@ class ModelRunner:
|
||||
self.shm = SharedMemory(name="nanovllm")
|
||||
self.loop()
|
||||
|
||||
def warmup_model(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
before = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
|
||||
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
||||
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
||||
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
||||
self.run(seqs, True)
|
||||
torch.cuda.empty_cache()
|
||||
after = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
|
||||
return after - before
|
||||
|
||||
def exit(self):
|
||||
if self.world_size > 1:
|
||||
self.shm.close()
|
||||
@@ -89,14 +102,15 @@ class ModelRunner:
|
||||
assert callable(method)
|
||||
return method(*args)
|
||||
|
||||
def allocate_kv_cache(self, gpu_memory_utilization):
|
||||
def allocate_kv_cache(self, gpu_memory_utilization, peak):
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used = total - free
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes
|
||||
config.num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak) // block_bytes
|
||||
print(f"{config.num_kvcache_blocks=}")
|
||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
@@ -133,6 +147,8 @@ class ModelRunner:
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
if not seq.block_table:
|
||||
continue
|
||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||
start = seq.block_table[i] * self.block_size
|
||||
if i != seq.num_blocks - 1:
|
||||
@@ -140,7 +156,6 @@ class ModelRunner:
|
||||
else:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
assert len(input_ids) == len(slot_mapping)
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
@@ -177,7 +192,7 @@ class ModelRunner:
|
||||
return temperatures
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill):
|
||||
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
||||
return self.model.compute_logits(self.model(input_ids, positions))
|
||||
else:
|
||||
|
||||
@@ -15,7 +15,7 @@ class Sequence:
|
||||
block_size = 256
|
||||
counter = count()
|
||||
|
||||
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
|
||||
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
||||
self.seq_id = next(Sequence.counter)
|
||||
self.status = SequenceStatus.WAITING
|
||||
self.token_ids = copy(token_ids)
|
||||
|
||||
@@ -63,7 +63,8 @@ class Attention(nn.Module):
|
||||
context = get_context()
|
||||
k_cache = self.k_cache
|
||||
v_cache = self.v_cache
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.is_prefill:
|
||||
if context.block_tables is not None: # prefix cache
|
||||
k, v = k_cache, v_cache
|
||||
|
||||
Reference in New Issue
Block a user