[feat] Added Quest Sparsity Policy.
This commit is contained in:
@@ -1,9 +1,16 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
|
||||
|
||||
class SparsePolicyType(Enum):
|
||||
"""Sparse attention policy types."""
|
||||
FULL = auto() # No sparse attention (load all blocks)
|
||||
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str
|
||||
@@ -29,11 +36,10 @@ class Config:
|
||||
num_gpu_kvcache_blocks: int = -1
|
||||
num_cpu_kvcache_blocks: int = -1
|
||||
|
||||
# Sparse attention configuration (dual policy architecture)
|
||||
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
||||
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
||||
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
|
||||
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
|
||||
# Sparse attention configuration
|
||||
# Quest: decode-only sparse attention with Top-K block selection
|
||||
# FULL: no sparse attention (load all blocks)
|
||||
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
||||
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
||||
|
||||
|
||||
Reference in New Issue
Block a user