[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -1,9 +1,16 @@
import os
from dataclasses import dataclass
from enum import Enum, auto
from transformers import AutoConfig
import torch
class SparsePolicyType(Enum):
"""Sparse attention policy types."""
FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only)
@dataclass
class Config:
model: str
@@ -29,11 +36,10 @@ class Config:
num_gpu_kvcache_blocks: int = -1
num_cpu_kvcache_blocks: int = -1
# Sparse attention configuration (dual policy architecture)
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
# Sparse attention configuration
# Quest: decode-only sparse attention with Top-K block selection
# FULL: no sparse attention (load all blocks)
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold

View File

@@ -156,22 +156,19 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Initialize sparse policies if manager has them (CPU offload mode)
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'):
# Initialize both policies with model config
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
if policy is not None:
policy.initialize(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
# Initialize sparse policy if manager has one (CPU offload mode)
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
self.kvcache_manager.sparse_policy.initialize(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
logger.info(
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} "
f"Sparse policy initialized: {config.sparse_policy.name} "
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
)

View File

@@ -56,36 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
# Need CPU offload: use hybrid manager
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.policies import get_policy
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
from nanovllm.kvcache.sparse import create_sparse_policy
from nanovllm.config import SparsePolicyType
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
# Create sparse policies from config
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
decode_policy_type = getattr(config, 'decode_policy', 'full')
def create_policy(policy_type_str):
"""Create a sparse policy from config string."""
if policy_type_str.lower() == 'full':
return create_sparse_policy(SparsePolicyType.FULL)
policy_type = SparsePolicyType[policy_type_str.upper()]
return create_sparse_policy(
policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
)
prefill_policy = create_policy(prefill_policy_type)
decode_policy = create_policy(decode_policy_type)
# Create sparse policy from config enum
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
sparse_policy = create_sparse_policy(
sparse_policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
)
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=config.kvcache_block_size,
policy=eviction_policy,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
sparse_policy=sparse_policy,
)

View File

@@ -90,8 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
prefill_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
sparse_policy: "SparsePolicy" = None,
):
"""
Initialize hybrid manager with CPU-primary ring buffer design.
@@ -104,8 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU, used for prefix cache management)
prefill_policy: Sparse attention policy for prefill phase
decode_policy: Sparse attention policy for decode phase
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
@@ -117,9 +115,8 @@ class HybridKVCacheManager(KVCacheManager):
# Eviction policy
self.policy = policy or LRUPolicy()
# Sparse attention policies (set at construction time, immutable)
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
# Sparse attention policy (set at construction time, immutable)
self.sparse_policy = sparse_policy
# Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [
@@ -185,8 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
prefill_policy=self.prefill_policy,
decode_policy=self.decode_policy,
sparse_policy=self.sparse_policy,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
@@ -194,18 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id)
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
"""
Get sparse policy for the specified phase.
Args:
is_prefill: True for prefill phase, False for decode phase
Returns:
SparsePolicy for the phase, or None if not set
"""
return self.prefill_policy if is_prefill else self.decode_policy
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence."""
return len(self.free_logical_ids) >= seq.num_blocks

View File

@@ -60,8 +60,7 @@ class OffloadEngine:
head_dim: int,
dtype: torch.dtype = torch.float16,
num_streams: int = 4,
prefill_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
sparse_policy: "SparsePolicy" = None,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
@@ -217,9 +216,8 @@ class OffloadEngine:
self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events
# ========== Sparse attention policies (set at construction time) ==========
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
# ========== Sparse attention policy (set at construction time) ==========
self.sparse_policy = sparse_policy
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
@@ -765,20 +763,14 @@ class OffloadEngine:
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
# Collect metadata BEFORE offload (while k_cache is still on GPU)
# Both policies' callbacks are called - each decides whether to respond
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
k_cache = self.k_cache_gpu[slot_idx]
if is_prefill:
if self.prefill_policy is not None:
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else:
if self.prefill_policy is not None:
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.sparse_policy is not None:
if is_prefill:
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else:
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):

View File

@@ -19,7 +19,8 @@ Usage:
return available_blocks[:5] # Just first 5 blocks
"""
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager

View File

@@ -7,15 +7,11 @@ from CPU for each query chunk during chunked attention computation.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional, Any
import torch
class SparsePolicyType(Enum):
"""Built-in sparse attention policy types."""
FULL = auto() # prefill + decode
QUEST = auto() # decode only
# Import SparsePolicyType from config to avoid circular imports
from nanovllm.config import SparsePolicyType
@dataclass

View File

@@ -188,9 +188,9 @@ class Attention(nn.Module):
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
if cpu_block_table and prefill_policy is not None:
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
sparse_policy = kvcache_manager.sparse_policy
if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
@@ -201,7 +201,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = prefill_policy.select_blocks(
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
@@ -512,9 +512,9 @@ class Attention(nn.Module):
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
if decode_policy is not None:
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is not None:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
@@ -524,7 +524,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = decode_policy.select_blocks(
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)

View File

@@ -12,6 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
from utils import generate_needle_prompt, check_needle_answer
@@ -29,6 +30,9 @@ def run_needle_test(
needle_value: str = "7492",
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
enable_quest: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
verbose: bool = True,
) -> bool:
"""
@@ -44,11 +48,16 @@ def run_needle_test(
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test")
@@ -60,6 +69,8 @@ def run_needle_test(
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload:
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
print(f"{'='*60}\n")
# 1. Initialize LLM
@@ -72,6 +83,9 @@ def run_needle_test(
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
llm = LLM(model_path, **llm_kwargs)
@@ -167,6 +181,23 @@ if __name__ == "__main__":
action="store_true",
help="Enable CPU offload (has known bug for long sequences)"
)
parser.add_argument(
"--enable-quest",
action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)"
)
parser.add_argument(
"--sparse-topk",
type=int,
default=8,
help="Top-K blocks for Quest sparse attention"
)
parser.add_argument(
"--sparse-threshold",
type=int,
default=4,
help="Apply sparse only when blocks > threshold"
)
args = parser.parse_args()
passed = run_needle_test(
@@ -179,6 +210,9 @@ if __name__ == "__main__":
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
verbose=True,
)