[WIP] Before integrate the xattn operator.

This commit is contained in:
Zijie Tian
2026-01-19 21:19:21 +08:00
parent 9e6fdc0650
commit b5da802dff
11 changed files with 949 additions and 32 deletions

View File

@@ -9,6 +9,7 @@ class SparsePolicyType(Enum):
"""Sparse attention policy types."""
FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only)
XATTN_BSA = auto() # XAttention Block Sparse Attention (prefill only, chunked)
@dataclass
@@ -37,12 +38,20 @@ class Config:
num_cpu_kvcache_blocks: int = -1
# Sparse attention configuration
# Quest: decode-only sparse attention with Top-K block selection
# FULL: no sparse attention (load all blocks)
# QUEST: decode-only sparse attention with Top-K block selection
# XATTN_BSA: prefill-only block sparse attention with chunk-level selection
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
# XAttention BSA specific parameters
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1)
sparse_use_triton: bool = True # Use Triton kernels for estimation
sparse_stride: int = 8 # Stride for Q/K downsampling
def __post_init__(self):
assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0