[feat] Fixed warmup memory overhead.
This commit is contained in:
@@ -95,9 +95,14 @@ class ModelRunner:
|
||||
def warmup_model(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
||||
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
||||
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
||||
# Use a reasonable warmup length instead of max_model_len
|
||||
# Warmup only needs to trigger CUDA kernel JIT compilation
|
||||
# Using 2 blocks is sufficient and avoids huge memory allocation
|
||||
warmup_len = min(self.block_size * 2, self.config.max_model_len)
|
||||
warmup_len = max(warmup_len, 128) # At least 128 tokens
|
||||
num_seqs = min(self.config.max_num_batched_tokens // warmup_len, self.config.max_num_seqs, 4)
|
||||
num_seqs = max(num_seqs, 1)
|
||||
seqs = [Sequence([0] * warmup_len) for _ in range(num_seqs)]
|
||||
self.run(seqs, True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user