97 lines
4.3 KiB
Python
97 lines
4.3 KiB
Python
import os
|
|
from dataclasses import dataclass
|
|
from enum import Enum, auto
|
|
from transformers import AutoConfig
|
|
import torch
|
|
|
|
|
|
class SparsePolicyType(Enum):
|
|
"""Sparse attention policy types."""
|
|
FULL = auto() # No sparse attention (load all blocks)
|
|
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
|
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
|
|
XATTN = auto() # XAttention chunked estimation + block-sparse attention
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
model: str
|
|
max_num_batched_tokens: int = 16384
|
|
max_num_seqs: int = 512
|
|
max_model_len: int = 4096
|
|
gpu_memory_utilization: float = 0.9
|
|
tensor_parallel_size: int = 1
|
|
enforce_eager: bool = False
|
|
hf_config: AutoConfig | None = None
|
|
eos: int = -1
|
|
kvcache_block_size: int = 1024
|
|
num_kvcache_blocks: int = -1
|
|
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
|
|
|
# CPU Offload configuration
|
|
enable_cpu_offload: bool = False
|
|
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
|
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
|
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
|
num_kv_buffers: int = 4 # Ring buffer size for layer-wise offload (decode H2D pipeline)
|
|
|
|
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
|
num_gpu_kvcache_blocks: int = -1
|
|
num_cpu_kvcache_blocks: int = -1
|
|
|
|
# Sparse attention configuration
|
|
# Quest: decode-only sparse attention with Top-K block selection
|
|
# FULL: no sparse attention (load all blocks)
|
|
# MINFERENCE: MInference vertical + slash sparse prefill (GPU-only)
|
|
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
|
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
|
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
|
|
|
# MInference configuration (used when sparse_policy == MINFERENCE)
|
|
minference_adaptive_budget: float = 0.3 # Budget as fraction of seq_len (None to use fixed sizes)
|
|
minference_vertical_size: int = 1000 # Fixed vertical size (if adaptive_budget is None)
|
|
minference_slash_size: int = 6096 # Fixed slash size (if adaptive_budget is None)
|
|
minference_num_sink_tokens: int = 30 # Sink tokens to always keep
|
|
minference_num_recent_diags: int = 100 # Recent diagonals to always keep
|
|
|
|
# XAttention configuration (used when sparse_policy == XATTN)
|
|
xattn_stride: int = 8 # Stride for reorganizing Q/K
|
|
xattn_threshold: float = 0.9 # Block selection threshold (0-1)
|
|
xattn_chunk_size: int = 16384 # Chunk size for estimation (auto if None)
|
|
xattn_use_triton: bool = True # Use Triton kernels (requires SM 80+)
|
|
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
|
|
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
|
|
xattn_norm: float = 1.0 # Normalization factor for attention scores
|
|
xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation)
|
|
|
|
def __post_init__(self):
|
|
assert os.path.isdir(self.model)
|
|
assert self.kvcache_block_size % 256 == 0
|
|
assert 1 <= self.tensor_parallel_size <= 8
|
|
self.hf_config = AutoConfig.from_pretrained(self.model)
|
|
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
|
assert self.max_num_batched_tokens >= self.max_model_len
|
|
|
|
# CPU offload mode only supports single sequence (layer-wise processing)
|
|
if self.enable_cpu_offload and self.max_num_seqs != 1:
|
|
import logging
|
|
logging.warning(
|
|
f"CPU offload mode only supports single sequence. "
|
|
f"Overriding max_num_seqs from {self.max_num_seqs} to 1."
|
|
)
|
|
self.max_num_seqs = 1
|
|
|
|
# Override torch_dtype if user specified
|
|
if self.dtype is not None:
|
|
dtype_map = {
|
|
"float16": torch.float16,
|
|
"fp16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
"bf16": torch.bfloat16,
|
|
"float32": torch.float32,
|
|
"fp32": torch.float32,
|
|
}
|
|
if self.dtype not in dtype_map:
|
|
raise ValueError(f"Invalid dtype: {self.dtype}. Choose from: {list(dtype_map.keys())}")
|
|
self.hf_config.torch_dtype = dtype_map[self.dtype]
|