warmup and allocate

This commit is contained in:
GeeeekExplorer
2025-06-27 01:51:57 +08:00
parent cfc4cb6710
commit 658520b788
4 changed files with 25 additions and 8 deletions

View File

@@ -6,7 +6,7 @@ from transformers import AutoConfig
@dataclass @dataclass
class Config: class Config:
model: str model: str
max_num_batched_tokens: int = 32768 max_num_batched_tokens: int = 16384
max_num_seqs: int = 512 max_num_seqs: int = 512
max_model_len: int = 4096 max_model_len: int = 4096
gpu_memory_utilization: float = 0.9 gpu_memory_utilization: float = 0.9
@@ -23,3 +23,4 @@ class Config:
assert 1 <= self.tensor_parallel_size <= 8 assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = AutoConfig.from_pretrained(self.model) self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
assert self.max_num_batched_tokens >= self.max_model_len

View File

@@ -31,7 +31,8 @@ class ModelRunner:
self.model = Qwen3ForCausalLM(hf_config) self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model) load_model(self.model, config.model)
self.sampler = Sampler() self.sampler = Sampler()
self.allocate_kv_cache(config.gpu_memory_utilization) peak = self.warmup_model()
self.allocate_kv_cache(config.gpu_memory_utilization, peak)
if not self.enforce_eager: if not self.enforce_eager:
self.capture_cudagraph() self.capture_cudagraph()
torch.set_default_device("cpu") torch.set_default_device("cpu")
@@ -46,6 +47,18 @@ class ModelRunner:
self.shm = SharedMemory(name="nanovllm") self.shm = SharedMemory(name="nanovllm")
self.loop() self.loop()
def warmup_model(self):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
before = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
self.run(seqs, True)
torch.cuda.empty_cache()
after = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
return after - before
def exit(self): def exit(self):
if self.world_size > 1: if self.world_size > 1:
self.shm.close() self.shm.close()
@@ -89,14 +102,15 @@ class ModelRunner:
assert callable(method) assert callable(method)
return method(*args) return method(*args)
def allocate_kv_cache(self, gpu_memory_utilization): def allocate_kv_cache(self, gpu_memory_utilization, peak):
config = self.config config = self.config
hf_config = config.hf_config hf_config = config.hf_config
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
used = total - free used = total - free
num_kv_heads = hf_config.num_key_value_heads // self.world_size num_kv_heads = hf_config.num_key_value_heads // self.world_size
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes config.num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak) // block_bytes
print(f"{config.num_kvcache_blocks=}")
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
layer_id = 0 layer_id = 0
for module in self.model.modules(): for module in self.model.modules():
@@ -133,6 +147,8 @@ class ModelRunner:
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k) max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table:
continue
for i in range(seq.num_cached_blocks, seq.num_blocks): for i in range(seq.num_cached_blocks, seq.num_blocks):
start = seq.block_table[i] * self.block_size start = seq.block_table[i] * self.block_size
if i != seq.num_blocks - 1: if i != seq.num_blocks - 1:
@@ -140,7 +156,6 @@ class ModelRunner:
else: else:
end = start + seq.last_block_num_tokens end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end))) slot_mapping.extend(list(range(start, end)))
assert len(input_ids) == len(slot_mapping)
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
block_tables = self.prepare_block_tables(seqs) block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
@@ -177,7 +192,7 @@ class ModelRunner:
return temperatures return temperatures
@torch.inference_mode() @torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill): def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
if is_prefill or self.enforce_eager or input_ids.size(0) > 512: if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.compute_logits(self.model(input_ids, positions)) return self.model.compute_logits(self.model(input_ids, positions))
else: else:

View File

@@ -15,7 +15,7 @@ class Sequence:
block_size = 256 block_size = 256
counter = count() counter = count()
def __init__(self, token_ids: list[int], sampling_params: SamplingParams): def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
self.seq_id = next(Sequence.counter) self.seq_id = next(Sequence.counter)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.token_ids = copy(token_ids) self.token_ids = copy(token_ids)

View File

@@ -63,7 +63,8 @@ class Attention(nn.Module):
context = get_context() context = get_context()
k_cache = self.k_cache k_cache = self.k_cache
v_cache = self.v_cache v_cache = self.v_cache
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill: if context.is_prefill:
if context.block_tables is not None: # prefix cache if context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache k, v = k_cache, v_cache