[feat] Added Quest Sparsity Policy.
This commit is contained in:
@@ -1,9 +1,16 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
|
||||
|
||||
class SparsePolicyType(Enum):
|
||||
"""Sparse attention policy types."""
|
||||
FULL = auto() # No sparse attention (load all blocks)
|
||||
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str
|
||||
@@ -29,11 +36,10 @@ class Config:
|
||||
num_gpu_kvcache_blocks: int = -1
|
||||
num_cpu_kvcache_blocks: int = -1
|
||||
|
||||
# Sparse attention configuration (dual policy architecture)
|
||||
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
||||
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
||||
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
|
||||
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
|
||||
# Sparse attention configuration
|
||||
# Quest: decode-only sparse attention with Top-K block selection
|
||||
# FULL: no sparse attention (load all blocks)
|
||||
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
||||
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
||||
|
||||
|
||||
@@ -156,12 +156,9 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Initialize sparse policies if manager has them (CPU offload mode)
|
||||
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'):
|
||||
# Initialize both policies with model config
|
||||
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
|
||||
if policy is not None:
|
||||
policy.initialize(
|
||||
# Initialize sparse policy if manager has one (CPU offload mode)
|
||||
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||
self.kvcache_manager.sparse_policy.initialize(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
@@ -171,7 +168,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} "
|
||||
f"Sparse policy initialized: {config.sparse_policy.name} "
|
||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||
)
|
||||
|
||||
|
||||
@@ -56,36 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
# Need CPU offload: use hybrid manager
|
||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||
from nanovllm.kvcache.policies import get_policy
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||
from nanovllm.config import SparsePolicyType
|
||||
|
||||
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
|
||||
# Create sparse policies from config
|
||||
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
|
||||
decode_policy_type = getattr(config, 'decode_policy', 'full')
|
||||
|
||||
def create_policy(policy_type_str):
|
||||
"""Create a sparse policy from config string."""
|
||||
if policy_type_str.lower() == 'full':
|
||||
return create_sparse_policy(SparsePolicyType.FULL)
|
||||
policy_type = SparsePolicyType[policy_type_str.upper()]
|
||||
return create_sparse_policy(
|
||||
policy_type,
|
||||
# Create sparse policy from config enum
|
||||
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
|
||||
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
|
||||
sparse_policy = create_sparse_policy(
|
||||
sparse_policy_type,
|
||||
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
||||
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
|
||||
)
|
||||
|
||||
prefill_policy = create_policy(prefill_policy_type)
|
||||
decode_policy = create_policy(decode_policy_type)
|
||||
|
||||
return HybridKVCacheManager(
|
||||
num_gpu_slots=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
policy=eviction_policy,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
sparse_policy=sparse_policy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -90,8 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
sparse_policy: "SparsePolicy" = None,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||
@@ -104,8 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||
prefill_policy: Sparse attention policy for prefill phase
|
||||
decode_policy: Sparse attention policy for decode phase
|
||||
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
@@ -117,9 +115,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
|
||||
# Sparse attention policies (set at construction time, immutable)
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
# Sparse attention policy (set at construction time, immutable)
|
||||
self.sparse_policy = sparse_policy
|
||||
|
||||
# Logical blocks (what sequences reference) - one per CPU block
|
||||
self.logical_blocks: List[LogicalBlock] = [
|
||||
@@ -185,8 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
prefill_policy=self.prefill_policy,
|
||||
decode_policy=self.decode_policy,
|
||||
sparse_policy=self.sparse_policy,
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
@@ -194,18 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
assert self.offload_engine is not None
|
||||
return self.offload_engine.get_layer_cache(layer_id)
|
||||
|
||||
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
|
||||
"""
|
||||
Get sparse policy for the specified phase.
|
||||
|
||||
Args:
|
||||
is_prefill: True for prefill phase, False for decode phase
|
||||
|
||||
Returns:
|
||||
SparsePolicy for the phase, or None if not set
|
||||
"""
|
||||
return self.prefill_policy if is_prefill else self.decode_policy
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""Check if we can allocate blocks for a new sequence."""
|
||||
return len(self.free_logical_ids) >= seq.num_blocks
|
||||
|
||||
@@ -60,8 +60,7 @@ class OffloadEngine:
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_streams: int = 4,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
sparse_policy: "SparsePolicy" = None,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@@ -217,9 +216,8 @@ class OffloadEngine:
|
||||
self._debug_mode = False
|
||||
self._debug_hooks: List = [] # External hooks for debug events
|
||||
|
||||
# ========== Sparse attention policies (set at construction time) ==========
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
# ========== Sparse attention policy (set at construction time) ==========
|
||||
self.sparse_policy = sparse_policy
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
@@ -765,20 +763,14 @@ class OffloadEngine:
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||
|
||||
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
||||
# Both policies' callbacks are called - each decides whether to respond
|
||||
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||
k_cache = self.k_cache_gpu[slot_idx]
|
||||
|
||||
if self.sparse_policy is not None:
|
||||
if is_prefill:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
else:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
|
||||
@@ -19,7 +19,8 @@ Usage:
|
||||
return available_blocks[:5] # Just first 5 blocks
|
||||
"""
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
||||
from nanovllm.config import SparsePolicyType
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
|
||||
|
||||
@@ -7,15 +7,11 @@ from CPU for each query chunk during chunked attention computation.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List, Optional, Any
|
||||
import torch
|
||||
|
||||
|
||||
class SparsePolicyType(Enum):
|
||||
"""Built-in sparse attention policy types."""
|
||||
FULL = auto() # prefill + decode
|
||||
QUEST = auto() # decode only
|
||||
# Import SparsePolicyType from config to avoid circular imports
|
||||
from nanovllm.config import SparsePolicyType
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -188,9 +188,9 @@ class Attention(nn.Module):
|
||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
|
||||
if cpu_block_table and prefill_policy is not None:
|
||||
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
|
||||
sparse_policy = kvcache_manager.sparse_policy
|
||||
if cpu_block_table and sparse_policy is not None:
|
||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
@@ -201,7 +201,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = prefill_policy.select_blocks(
|
||||
cpu_block_table = sparse_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
@@ -512,9 +512,9 @@ class Attention(nn.Module):
|
||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
|
||||
if decode_policy is not None:
|
||||
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
||||
sparse_policy = kvcache_manager.sparse_policy
|
||||
if sparse_policy is not None:
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
@@ -524,7 +524,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = decode_policy.select_blocks(
|
||||
cpu_block_table = sparse_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.config import SparsePolicyType
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
@@ -29,6 +30,9 @@ def run_needle_test(
|
||||
needle_value: str = "7492",
|
||||
max_new_tokens: int = 32,
|
||||
enable_cpu_offload: bool = False,
|
||||
enable_quest: bool = False,
|
||||
sparse_topk: int = 8,
|
||||
sparse_threshold: int = 4,
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -44,11 +48,16 @@ def run_needle_test(
|
||||
needle_value: The secret value to find
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
enable_cpu_offload: Enable CPU offload mode
|
||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||
sparse_topk: Top-K blocks for Quest
|
||||
sparse_threshold: Apply sparse only when blocks > threshold
|
||||
verbose: Print detailed output
|
||||
|
||||
Returns:
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Needle-in-Haystack Test")
|
||||
@@ -60,6 +69,8 @@ def run_needle_test(
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
if enable_cpu_offload:
|
||||
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 1. Initialize LLM
|
||||
@@ -72,6 +83,9 @@ def run_needle_test(
|
||||
}
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
llm_kwargs["sparse_policy"] = sparse_policy
|
||||
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
|
||||
@@ -167,6 +181,23 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Enable CPU offload (has known bug for long sequences)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-quest",
|
||||
action="store_true",
|
||||
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-topk",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Top-K blocks for Quest sparse attention"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-threshold",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Apply sparse only when blocks > threshold"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
passed = run_needle_test(
|
||||
@@ -179,6 +210,9 @@ if __name__ == "__main__":
|
||||
needle_value=args.needle_value,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
enable_cpu_offload=args.enable_offload,
|
||||
enable_quest=args.enable_quest,
|
||||
sparse_topk=args.sparse_topk,
|
||||
sparse_threshold=args.sparse_threshold,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user