[feat] Added debug hook to offload_engine.py.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -16,6 +17,7 @@ class Config:
|
||||
eos: int = -1
|
||||
kvcache_block_size: int = 4096
|
||||
num_kvcache_blocks: int = -1
|
||||
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
||||
|
||||
# CPU Offload configuration
|
||||
enable_cpu_offload: bool = False
|
||||
@@ -41,3 +43,17 @@ class Config:
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
|
||||
# Override torch_dtype if user specified
|
||||
if self.dtype is not None:
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"fp16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"bf16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
if self.dtype not in dtype_map:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}. Choose from: {list(dtype_map.keys())}")
|
||||
self.hf_config.torch_dtype = dtype_map[self.dtype]
|
||||
|
||||
Reference in New Issue
Block a user