[feat] Fixed warmup memory overhead.

This commit is contained in:
Zijie Tian
2025-12-15 21:39:14 +08:00
parent 91a0f09a24
commit dc7807a211
2 changed files with 20 additions and 72 deletions

View File

@@ -95,9 +95,14 @@ class ModelRunner:
def warmup_model(self):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
# Use a reasonable warmup length instead of max_model_len
# Warmup only needs to trigger CUDA kernel JIT compilation
# Using 2 blocks is sufficient and avoids huge memory allocation
warmup_len = min(self.block_size * 2, self.config.max_model_len)
warmup_len = max(warmup_len, 128) # At least 128 tokens
num_seqs = min(self.config.max_num_batched_tokens // warmup_len, self.config.max_num_seqs, 4)
num_seqs = max(num_seqs, 1)
seqs = [Sequence([0] * warmup_len) for _ in range(num_seqs)]
self.run(seqs, True)
torch.cuda.empty_cache()