[feat] Added Quest Sparsity Policy.
This commit is contained in:
@@ -1,9 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SparsePolicyType(Enum):
|
||||||
|
"""Sparse attention policy types."""
|
||||||
|
FULL = auto() # No sparse attention (load all blocks)
|
||||||
|
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
model: str
|
model: str
|
||||||
@@ -29,11 +36,10 @@ class Config:
|
|||||||
num_gpu_kvcache_blocks: int = -1
|
num_gpu_kvcache_blocks: int = -1
|
||||||
num_cpu_kvcache_blocks: int = -1
|
num_cpu_kvcache_blocks: int = -1
|
||||||
|
|
||||||
# Sparse attention configuration (dual policy architecture)
|
# Sparse attention configuration
|
||||||
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
# Quest: decode-only sparse attention with Top-K block selection
|
||||||
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
# FULL: no sparse attention (load all blocks)
|
||||||
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
|
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
||||||
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
|
|
||||||
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||||
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
||||||
|
|
||||||
|
|||||||
@@ -156,12 +156,9 @@ class ModelRunner:
|
|||||||
dtype=hf_config.torch_dtype,
|
dtype=hf_config.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize sparse policies if manager has them (CPU offload mode)
|
# Initialize sparse policy if manager has one (CPU offload mode)
|
||||||
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'):
|
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||||
# Initialize both policies with model config
|
self.kvcache_manager.sparse_policy.initialize(
|
||||||
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
|
|
||||||
if policy is not None:
|
|
||||||
policy.initialize(
|
|
||||||
num_layers=hf_config.num_hidden_layers,
|
num_layers=hf_config.num_hidden_layers,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
@@ -171,7 +168,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} "
|
f"Sparse policy initialized: {config.sparse_policy.name} "
|
||||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -56,36 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
# Need CPU offload: use hybrid manager
|
# Need CPU offload: use hybrid manager
|
||||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||||
from nanovllm.kvcache.policies import get_policy
|
from nanovllm.kvcache.policies import get_policy
|
||||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||||
|
|
||||||
# Create sparse policies from config
|
# Create sparse policy from config enum
|
||||||
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
|
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
|
||||||
decode_policy_type = getattr(config, 'decode_policy', 'full')
|
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
|
||||||
|
sparse_policy = create_sparse_policy(
|
||||||
def create_policy(policy_type_str):
|
sparse_policy_type,
|
||||||
"""Create a sparse policy from config string."""
|
|
||||||
if policy_type_str.lower() == 'full':
|
|
||||||
return create_sparse_policy(SparsePolicyType.FULL)
|
|
||||||
policy_type = SparsePolicyType[policy_type_str.upper()]
|
|
||||||
return create_sparse_policy(
|
|
||||||
policy_type,
|
|
||||||
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
||||||
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||||
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_policy = create_policy(prefill_policy_type)
|
|
||||||
decode_policy = create_policy(decode_policy_type)
|
|
||||||
|
|
||||||
return HybridKVCacheManager(
|
return HybridKVCacheManager(
|
||||||
num_gpu_slots=num_gpu_blocks,
|
num_gpu_slots=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
block_size=config.kvcache_block_size,
|
block_size=config.kvcache_block_size,
|
||||||
policy=eviction_policy,
|
policy=eviction_policy,
|
||||||
prefill_policy=prefill_policy,
|
sparse_policy=sparse_policy,
|
||||||
decode_policy=decode_policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,8 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
policy: Optional[EvictionPolicy] = None,
|
policy: Optional[EvictionPolicy] = None,
|
||||||
prefill_policy: "SparsePolicy" = None,
|
sparse_policy: "SparsePolicy" = None,
|
||||||
decode_policy: "SparsePolicy" = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||||
@@ -104,8 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||||
block_size: Tokens per block
|
block_size: Tokens per block
|
||||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||||
prefill_policy: Sparse attention policy for prefill phase
|
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
|
||||||
decode_policy: Sparse attention policy for decode phase
|
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self.num_gpu_slots = num_gpu_slots
|
self.num_gpu_slots = num_gpu_slots
|
||||||
@@ -117,9 +115,8 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Eviction policy
|
# Eviction policy
|
||||||
self.policy = policy or LRUPolicy()
|
self.policy = policy or LRUPolicy()
|
||||||
|
|
||||||
# Sparse attention policies (set at construction time, immutable)
|
# Sparse attention policy (set at construction time, immutable)
|
||||||
self.prefill_policy = prefill_policy
|
self.sparse_policy = sparse_policy
|
||||||
self.decode_policy = decode_policy
|
|
||||||
|
|
||||||
# Logical blocks (what sequences reference) - one per CPU block
|
# Logical blocks (what sequences reference) - one per CPU block
|
||||||
self.logical_blocks: List[LogicalBlock] = [
|
self.logical_blocks: List[LogicalBlock] = [
|
||||||
@@ -185,8 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
prefill_policy=self.prefill_policy,
|
sparse_policy=self.sparse_policy,
|
||||||
decode_policy=self.decode_policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
@@ -194,18 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
assert self.offload_engine is not None
|
assert self.offload_engine is not None
|
||||||
return self.offload_engine.get_layer_cache(layer_id)
|
return self.offload_engine.get_layer_cache(layer_id)
|
||||||
|
|
||||||
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
|
|
||||||
"""
|
|
||||||
Get sparse policy for the specified phase.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
is_prefill: True for prefill phase, False for decode phase
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SparsePolicy for the phase, or None if not set
|
|
||||||
"""
|
|
||||||
return self.prefill_policy if is_prefill else self.decode_policy
|
|
||||||
|
|
||||||
def can_allocate(self, seq: Sequence) -> bool:
|
def can_allocate(self, seq: Sequence) -> bool:
|
||||||
"""Check if we can allocate blocks for a new sequence."""
|
"""Check if we can allocate blocks for a new sequence."""
|
||||||
return len(self.free_logical_ids) >= seq.num_blocks
|
return len(self.free_logical_ids) >= seq.num_blocks
|
||||||
|
|||||||
@@ -60,8 +60,7 @@ class OffloadEngine:
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
num_streams: int = 4,
|
num_streams: int = 4,
|
||||||
prefill_policy: "SparsePolicy" = None,
|
sparse_policy: "SparsePolicy" = None,
|
||||||
decode_policy: "SparsePolicy" = None,
|
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
@@ -217,9 +216,8 @@ class OffloadEngine:
|
|||||||
self._debug_mode = False
|
self._debug_mode = False
|
||||||
self._debug_hooks: List = [] # External hooks for debug events
|
self._debug_hooks: List = [] # External hooks for debug events
|
||||||
|
|
||||||
# ========== Sparse attention policies (set at construction time) ==========
|
# ========== Sparse attention policy (set at construction time) ==========
|
||||||
self.prefill_policy = prefill_policy
|
self.sparse_policy = sparse_policy
|
||||||
self.decode_policy = decode_policy
|
|
||||||
|
|
||||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||||
"""Round-robin stream selection for parallel transfers."""
|
"""Round-robin stream selection for parallel transfers."""
|
||||||
@@ -765,20 +763,14 @@ class OffloadEngine:
|
|||||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||||
|
|
||||||
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
||||||
# Both policies' callbacks are called - each decides whether to respond
|
|
||||||
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||||
k_cache = self.k_cache_gpu[slot_idx]
|
k_cache = self.k_cache_gpu[slot_idx]
|
||||||
|
|
||||||
|
if self.sparse_policy is not None:
|
||||||
if is_prefill:
|
if is_prefill:
|
||||||
if self.prefill_policy is not None:
|
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
|
||||||
if self.decode_policy is not None:
|
|
||||||
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
|
||||||
else:
|
else:
|
||||||
if self.prefill_policy is not None:
|
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
|
||||||
if self.decode_policy is not None:
|
|
||||||
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ Usage:
|
|||||||
return available_blocks[:5] # Just first 5 blocks
|
return available_blocks[:5] # Just first 5 blocks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
from nanovllm.config import SparsePolicyType
|
||||||
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
|
|
||||||
|
|||||||
@@ -7,15 +7,11 @@ from CPU for each query chunk during chunked attention computation.
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# Import SparsePolicyType from config to avoid circular imports
|
||||||
class SparsePolicyType(Enum):
|
from nanovllm.config import SparsePolicyType
|
||||||
"""Built-in sparse attention policy types."""
|
|
||||||
FULL = auto() # prefill + decode
|
|
||||||
QUEST = auto() # decode only
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -188,9 +188,9 @@ class Attention(nn.Module):
|
|||||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
|
||||||
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
if cpu_block_table and prefill_policy is not None:
|
if cpu_block_table and sparse_policy is not None:
|
||||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||||
policy_ctx = PolicyContext(
|
policy_ctx = PolicyContext(
|
||||||
query_chunk_idx=current_chunk_idx,
|
query_chunk_idx=current_chunk_idx,
|
||||||
@@ -201,7 +201,7 @@ class Attention(nn.Module):
|
|||||||
block_size=kvcache_manager.block_size,
|
block_size=kvcache_manager.block_size,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
)
|
)
|
||||||
cpu_block_table = prefill_policy.select_blocks(
|
cpu_block_table = sparse_policy.select_blocks(
|
||||||
cpu_block_table, policy_ctx
|
cpu_block_table, policy_ctx
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -512,9 +512,9 @@ class Attention(nn.Module):
|
|||||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||||
last_block_valid_tokens = block_size # Last block was exactly full
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
||||||
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
if decode_policy is not None:
|
if sparse_policy is not None:
|
||||||
policy_ctx = PolicyContext(
|
policy_ctx = PolicyContext(
|
||||||
query_chunk_idx=0,
|
query_chunk_idx=0,
|
||||||
num_query_chunks=1,
|
num_query_chunks=1,
|
||||||
@@ -524,7 +524,7 @@ class Attention(nn.Module):
|
|||||||
block_size=kvcache_manager.block_size,
|
block_size=kvcache_manager.block_size,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
)
|
)
|
||||||
cpu_block_table = decode_policy.select_blocks(
|
cpu_block_table = sparse_policy.select_blocks(
|
||||||
cpu_block_table, policy_ctx
|
cpu_block_table, policy_ctx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
from utils import generate_needle_prompt, check_needle_answer
|
from utils import generate_needle_prompt, check_needle_answer
|
||||||
|
|
||||||
|
|
||||||
@@ -29,6 +30,9 @@ def run_needle_test(
|
|||||||
needle_value: str = "7492",
|
needle_value: str = "7492",
|
||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
enable_cpu_offload: bool = False,
|
enable_cpu_offload: bool = False,
|
||||||
|
enable_quest: bool = False,
|
||||||
|
sparse_topk: int = 8,
|
||||||
|
sparse_threshold: int = 4,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -44,11 +48,16 @@ def run_needle_test(
|
|||||||
needle_value: The secret value to find
|
needle_value: The secret value to find
|
||||||
max_new_tokens: Maximum tokens to generate
|
max_new_tokens: Maximum tokens to generate
|
||||||
enable_cpu_offload: Enable CPU offload mode
|
enable_cpu_offload: Enable CPU offload mode
|
||||||
|
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||||
|
sparse_topk: Top-K blocks for Quest
|
||||||
|
sparse_threshold: Apply sparse only when blocks > threshold
|
||||||
verbose: Print detailed output
|
verbose: Print detailed output
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if test passed, False otherwise
|
True if test passed, False otherwise
|
||||||
"""
|
"""
|
||||||
|
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"Needle-in-Haystack Test")
|
print(f"Needle-in-Haystack Test")
|
||||||
@@ -60,6 +69,8 @@ def run_needle_test(
|
|||||||
print(f"Needle position: {needle_position:.0%}")
|
print(f"Needle position: {needle_position:.0%}")
|
||||||
print(f"Needle value: {needle_value}")
|
print(f"Needle value: {needle_value}")
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
print(f"CPU offload: {enable_cpu_offload}")
|
||||||
|
if enable_cpu_offload:
|
||||||
|
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
# 1. Initialize LLM
|
# 1. Initialize LLM
|
||||||
@@ -72,6 +83,9 @@ def run_needle_test(
|
|||||||
}
|
}
|
||||||
if enable_cpu_offload:
|
if enable_cpu_offload:
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||||
|
llm_kwargs["sparse_policy"] = sparse_policy
|
||||||
|
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
||||||
|
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
llm = LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
@@ -167,6 +181,23 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable CPU offload (has known bug for long sequences)"
|
help="Enable CPU offload (has known bug for long sequences)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-quest",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sparse-topk",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Top-K blocks for Quest sparse attention"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sparse-threshold",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Apply sparse only when blocks > threshold"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
passed = run_needle_test(
|
passed = run_needle_test(
|
||||||
@@ -179,6 +210,9 @@ if __name__ == "__main__":
|
|||||||
needle_value=args.needle_value,
|
needle_value=args.needle_value,
|
||||||
max_new_tokens=args.max_new_tokens,
|
max_new_tokens=args.max_new_tokens,
|
||||||
enable_cpu_offload=args.enable_offload,
|
enable_cpu_offload=args.enable_offload,
|
||||||
|
enable_quest=args.enable_quest,
|
||||||
|
sparse_topk=args.sparse_topk,
|
||||||
|
sparse_threshold=args.sparse_threshold,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user