[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -1,9 +1,16 @@
import os
from dataclasses import dataclass
from enum import Enum, auto
from transformers import AutoConfig
import torch
class SparsePolicyType(Enum):
"""Sparse attention policy types."""
FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only)
@dataclass
class Config:
model: str
@@ -29,11 +36,10 @@ class Config:
num_gpu_kvcache_blocks: int = -1
num_cpu_kvcache_blocks: int = -1
# Sparse attention configuration (dual policy architecture)
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
# Sparse attention configuration
# Quest: decode-only sparse attention with Top-K block selection
# FULL: no sparse attention (load all blocks)
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold