[feat] Added debug hook to offload_engine.py.

This commit is contained in:
Zijie Tian
2025-12-31 19:44:39 +08:00
parent 7af721c12c
commit 484d0de9f9
5 changed files with 383 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
import os
from dataclasses import dataclass
from transformers import AutoConfig
import torch
@dataclass
@@ -16,6 +17,7 @@ class Config:
eos: int = -1
kvcache_block_size: int = 4096
num_kvcache_blocks: int = -1
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
# CPU Offload configuration
enable_cpu_offload: bool = False
@@ -41,3 +43,17 @@ class Config:
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
assert self.max_num_batched_tokens >= self.max_model_len
# Override torch_dtype if user specified
if self.dtype is not None:
dtype_map = {
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
}
if self.dtype not in dtype_map:
raise ValueError(f"Invalid dtype: {self.dtype}. Choose from: {list(dtype_map.keys())}")
self.hf_config.torch_dtype = dtype_map[self.dtype]