[WIP] Before add Quest policy.
This commit is contained in:
@@ -3,11 +3,6 @@ import time
|
|||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
|
||||||
# Import sparse policy classes
|
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
|
||||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
|
||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||||
"""Benchmark decode performance (original test)"""
|
"""Benchmark decode performance (original test)"""
|
||||||
@@ -38,58 +33,6 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
|
|
||||||
|
|
||||||
def setup_quest_policy(llm, topk_blocks=8, threshold_blocks=4):
|
|
||||||
"""
|
|
||||||
Setup Quest sparse policy for decode phase.
|
|
||||||
|
|
||||||
Uses HybridPolicy: Full attention for prefill, Quest Top-K for decode.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
kvcache_manager = llm.model_runner.kvcache_manager
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
|
||||||
|
|
||||||
# Get model parameters from offload engine
|
|
||||||
num_layers = offload_engine.num_layers
|
|
||||||
num_kv_heads = offload_engine.num_kv_heads
|
|
||||||
head_dim = offload_engine.head_dim
|
|
||||||
num_cpu_blocks = kvcache_manager.num_cpu_blocks
|
|
||||||
dtype = offload_engine.k_cache_cpu.dtype
|
|
||||||
|
|
||||||
print(f"Setting up Quest policy:")
|
|
||||||
print(f" num_layers={num_layers}, num_kv_heads={num_kv_heads}, head_dim={head_dim}")
|
|
||||||
print(f" num_cpu_blocks={num_cpu_blocks}, dtype={dtype}")
|
|
||||||
print(f" topk_blocks={topk_blocks}, threshold_blocks={threshold_blocks}")
|
|
||||||
|
|
||||||
# Create BlockMetadataManager for storing min/max keys
|
|
||||||
metadata = BlockMetadataManager(
|
|
||||||
num_blocks=num_cpu_blocks,
|
|
||||||
num_layers=num_layers,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create Quest policy for decode
|
|
||||||
quest_config = QuestConfig(
|
|
||||||
topk_blocks=topk_blocks,
|
|
||||||
threshold_blocks=threshold_blocks,
|
|
||||||
)
|
|
||||||
quest_policy = QuestPolicy(quest_config, metadata)
|
|
||||||
|
|
||||||
# Create Hybrid policy: Full for prefill, Quest for decode
|
|
||||||
hybrid_policy = HybridPolicy(
|
|
||||||
prefill_policy=FullAttentionPolicy(),
|
|
||||||
decode_policy=quest_policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the policy
|
|
||||||
kvcache_manager.set_sparse_policy(hybrid_policy)
|
|
||||||
print(f" Policy set: HybridPolicy(prefill=Full, decode=Quest)")
|
|
||||||
|
|
||||||
return hybrid_policy
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -101,7 +44,18 @@ def main():
|
|||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
|
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
|
||||||
max_len = 131072 # 128K tokens
|
max_len = 32 * 1024 # 128K tokens
|
||||||
|
|
||||||
|
# Setup policy configuration
|
||||||
|
if not args.no_sparse:
|
||||||
|
prefill_policy = "full" # Full attention for prefill
|
||||||
|
decode_policy = "quest" # Quest Top-K for decode
|
||||||
|
print(f"\n[Quest Sparse Attention] prefill={prefill_policy}, decode={decode_policy}, topk={args.topk}")
|
||||||
|
else:
|
||||||
|
prefill_policy = "full" # Full attention for both phases
|
||||||
|
decode_policy = "full"
|
||||||
|
print("\n[Full Attention] No sparse policy (baseline)")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
@@ -109,15 +63,12 @@ def main():
|
|||||||
max_num_batched_tokens=max_len,
|
max_num_batched_tokens=max_len,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
num_gpu_blocks=6, # Small GPU buffer for offload testing
|
num_gpu_blocks=6, # Small GPU buffer for offload testing
|
||||||
|
prefill_policy=prefill_policy,
|
||||||
|
decode_policy=decode_policy,
|
||||||
|
sparse_topk_blocks=args.topk,
|
||||||
|
sparse_threshold_blocks=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.no_sparse:
|
|
||||||
# Setup Quest policy for decode (Top-K blocks, apply when > 4 blocks)
|
|
||||||
setup_quest_policy(llm, topk_blocks=args.topk, threshold_blocks=4)
|
|
||||||
print(f"\n[Quest Sparse Attention] topk={args.topk}")
|
|
||||||
else:
|
|
||||||
print("\n[Full Attention] No sparse policy (baseline)")
|
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
llm.generate(["Benchmark: "], SamplingParams())
|
llm.generate(["Benchmark: "], SamplingParams())
|
||||||
|
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ class Config:
|
|||||||
num_gpu_kvcache_blocks: int = -1
|
num_gpu_kvcache_blocks: int = -1
|
||||||
num_cpu_kvcache_blocks: int = -1
|
num_cpu_kvcache_blocks: int = -1
|
||||||
|
|
||||||
# Sparse attention configuration
|
# Sparse attention configuration (dual policy architecture)
|
||||||
sparse_policy: str | None = None # "vertical_slash", "quest", "streaming_llm", or None
|
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
||||||
|
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
||||||
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
|
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
|
||||||
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
|
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
|
||||||
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||||
|
|||||||
@@ -156,6 +156,25 @@ class ModelRunner:
|
|||||||
dtype=hf_config.torch_dtype,
|
dtype=hf_config.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize sparse policies if manager has them (CPU offload mode)
|
||||||
|
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'):
|
||||||
|
# Initialize both policies with model config
|
||||||
|
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
|
||||||
|
if policy is not None:
|
||||||
|
policy.initialize(
|
||||||
|
num_layers=hf_config.num_hidden_layers,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
||||||
|
dtype=hf_config.torch_dtype,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} "
|
||||||
|
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||||
|
)
|
||||||
|
|
||||||
# Log KV cache allocation info with detailed per-token breakdown
|
# Log KV cache allocation info with detailed per-token breakdown
|
||||||
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||||
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||||
|
|||||||
@@ -56,14 +56,36 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
# Need CPU offload: use hybrid manager
|
# Need CPU offload: use hybrid manager
|
||||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||||
from nanovllm.kvcache.policies import get_policy
|
from nanovllm.kvcache.policies import get_policy
|
||||||
|
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||||
|
|
||||||
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||||
|
|
||||||
|
# Create sparse policies from config
|
||||||
|
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
|
||||||
|
decode_policy_type = getattr(config, 'decode_policy', 'full')
|
||||||
|
|
||||||
|
def create_policy(policy_type_str):
|
||||||
|
"""Create a sparse policy from config string."""
|
||||||
|
if policy_type_str.lower() == 'full':
|
||||||
|
return create_sparse_policy(SparsePolicyType.FULL)
|
||||||
|
policy_type = SparsePolicyType[policy_type_str.upper()]
|
||||||
|
return create_sparse_policy(
|
||||||
|
policy_type,
|
||||||
|
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
||||||
|
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||||
|
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
prefill_policy = create_policy(prefill_policy_type)
|
||||||
|
decode_policy = create_policy(decode_policy_type)
|
||||||
|
|
||||||
return HybridKVCacheManager(
|
return HybridKVCacheManager(
|
||||||
num_gpu_slots=num_gpu_blocks,
|
num_gpu_slots=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
block_size=config.kvcache_block_size,
|
block_size=config.kvcache_block_size,
|
||||||
policy=policy,
|
policy=eviction_policy,
|
||||||
|
prefill_policy=prefill_policy,
|
||||||
|
decode_policy=decode_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,8 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
policy: Optional[EvictionPolicy] = None,
|
policy: Optional[EvictionPolicy] = None,
|
||||||
|
prefill_policy: "SparsePolicy" = None,
|
||||||
|
decode_policy: "SparsePolicy" = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||||
@@ -102,6 +104,8 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||||
block_size: Tokens per block
|
block_size: Tokens per block
|
||||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||||
|
prefill_policy: Sparse attention policy for prefill phase
|
||||||
|
decode_policy: Sparse attention policy for decode phase
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self.num_gpu_slots = num_gpu_slots
|
self.num_gpu_slots = num_gpu_slots
|
||||||
@@ -113,6 +117,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Eviction policy
|
# Eviction policy
|
||||||
self.policy = policy or LRUPolicy()
|
self.policy = policy or LRUPolicy()
|
||||||
|
|
||||||
|
# Sparse attention policies (set at construction time, immutable)
|
||||||
|
self.prefill_policy = prefill_policy
|
||||||
|
self.decode_policy = decode_policy
|
||||||
|
|
||||||
# Logical blocks (what sequences reference) - one per CPU block
|
# Logical blocks (what sequences reference) - one per CPU block
|
||||||
self.logical_blocks: List[LogicalBlock] = [
|
self.logical_blocks: List[LogicalBlock] = [
|
||||||
LogicalBlock(i) for i in range(self.total_blocks)
|
LogicalBlock(i) for i in range(self.total_blocks)
|
||||||
@@ -153,9 +161,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
||||||
self._prefill_len: Dict[int, int] = {}
|
self._prefill_len: Dict[int, int] = {}
|
||||||
|
|
||||||
# Sparse attention policy (optional)
|
|
||||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block_size(self) -> int:
|
def block_size(self) -> int:
|
||||||
return self._block_size
|
return self._block_size
|
||||||
@@ -180,6 +185,8 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
prefill_policy=self.prefill_policy,
|
||||||
|
decode_policy=self.decode_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
@@ -187,23 +194,17 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
assert self.offload_engine is not None
|
assert self.offload_engine is not None
|
||||||
return self.offload_engine.get_layer_cache(layer_id)
|
return self.offload_engine.get_layer_cache(layer_id)
|
||||||
|
|
||||||
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
|
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
|
||||||
"""
|
"""
|
||||||
Set sparse attention policy for block selection.
|
Get sparse policy for the specified phase.
|
||||||
|
|
||||||
The sparse policy determines which KV blocks to load from CPU
|
|
||||||
for each query chunk during chunked attention computation.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
|
is_prefill: True for prefill phase, False for decode phase
|
||||||
|
|
||||||
Example:
|
Returns:
|
||||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
|
SparsePolicy for the phase, or None if not set
|
||||||
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
|
|
||||||
manager.set_sparse_policy(policy)
|
|
||||||
"""
|
"""
|
||||||
self.sparse_policy = policy
|
return self.prefill_policy if is_prefill else self.decode_policy
|
||||||
logger.info(f"Sparse attention policy set: {policy}")
|
|
||||||
|
|
||||||
def can_allocate(self, seq: Sequence) -> bool:
|
def can_allocate(self, seq: Sequence) -> bool:
|
||||||
"""Check if we can allocate blocks for a new sequence."""
|
"""Check if we can allocate blocks for a new sequence."""
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ from nanovllm.kvcache.kernels import gathered_copy_kv
|
|||||||
from nanovllm.comm import memcpy_2d_async
|
from nanovllm.comm import memcpy_2d_async
|
||||||
from nanovllm.utils.logger import get_logger
|
from nanovllm.utils.logger import get_logger
|
||||||
|
|
||||||
|
# Import for type hints only (avoid circular import)
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanovllm.kvcache.sparse import SparsePolicy
|
||||||
|
|
||||||
logger = get_logger("offload_engine")
|
logger = get_logger("offload_engine")
|
||||||
|
|
||||||
|
|
||||||
@@ -55,6 +60,8 @@ class OffloadEngine:
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
num_streams: int = 4,
|
num_streams: int = 4,
|
||||||
|
prefill_policy: "SparsePolicy" = None,
|
||||||
|
decode_policy: "SparsePolicy" = None,
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
@@ -210,6 +217,10 @@ class OffloadEngine:
|
|||||||
self._debug_mode = False
|
self._debug_mode = False
|
||||||
self._debug_hooks: List = [] # External hooks for debug events
|
self._debug_hooks: List = [] # External hooks for debug events
|
||||||
|
|
||||||
|
# ========== Sparse attention policies (set at construction time) ==========
|
||||||
|
self.prefill_policy = prefill_policy
|
||||||
|
self.decode_policy = decode_policy
|
||||||
|
|
||||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||||
"""Round-robin stream selection for parallel transfers."""
|
"""Round-robin stream selection for parallel transfers."""
|
||||||
stream = self.transfer_streams[self._stream_idx]
|
stream = self.transfer_streams[self._stream_idx]
|
||||||
@@ -730,7 +741,14 @@ class OffloadEngine:
|
|||||||
"""Wait for slot offload to complete."""
|
"""Wait for slot offload to complete."""
|
||||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||||
|
|
||||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
def offload_slot_layer_to_cpu(
|
||||||
|
self,
|
||||||
|
slot_idx: int,
|
||||||
|
layer_id: int,
|
||||||
|
cpu_block_id: int,
|
||||||
|
num_valid_tokens: int = -1,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Async offload a ring buffer slot to CPU for one layer.
|
Async offload a ring buffer slot to CPU for one layer.
|
||||||
|
|
||||||
@@ -741,9 +759,27 @@ class OffloadEngine:
|
|||||||
slot_idx: Source GPU slot index
|
slot_idx: Source GPU slot index
|
||||||
layer_id: Target layer in CPU cache
|
layer_id: Target layer in CPU cache
|
||||||
cpu_block_id: Target CPU block ID
|
cpu_block_id: Target CPU block ID
|
||||||
|
num_valid_tokens: Number of valid tokens in this block (-1 = use block_size)
|
||||||
|
is_prefill: True if in prefill phase, False if in decode phase
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||||
|
|
||||||
|
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
||||||
|
# Both policies' callbacks are called - each decides whether to respond
|
||||||
|
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||||
|
k_cache = self.k_cache_gpu[slot_idx]
|
||||||
|
|
||||||
|
if is_prefill:
|
||||||
|
if self.prefill_policy is not None:
|
||||||
|
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
if self.decode_policy is not None:
|
||||||
|
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
else:
|
||||||
|
if self.prefill_policy is not None:
|
||||||
|
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
if self.decode_policy is not None:
|
||||||
|
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
# Wait for both compute_stream and default stream
|
# Wait for both compute_stream and default stream
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ Usage:
|
|||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
|
||||||
|
|
||||||
|
|
||||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||||
@@ -67,6 +66,5 @@ __all__ = [
|
|||||||
"QuestPolicy",
|
"QuestPolicy",
|
||||||
"QuestConfig",
|
"QuestConfig",
|
||||||
"BlockMetadataManager",
|
"BlockMetadataManager",
|
||||||
"HybridPolicy",
|
|
||||||
"create_sparse_policy",
|
"create_sparse_policy",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
"""
|
|
||||||
Hybrid sparse attention policy.
|
|
||||||
|
|
||||||
Allows using different policies for prefill vs decode phases.
|
|
||||||
This is useful because optimal sparsity patterns often differ:
|
|
||||||
- Prefill: fixed patterns work well (e.g., VerticalSlash)
|
|
||||||
- Decode: query-aware selection helps (e.g., Quest)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
from .policy import SparsePolicy, PolicyContext
|
|
||||||
|
|
||||||
|
|
||||||
class HybridPolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
Hybrid policy that uses different policies for prefill and decode.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```python
|
|
||||||
from nanovllm.kvcache.sparse import (
|
|
||||||
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
|
|
||||||
VerticalSlashConfig, QuestConfig, BlockMetadataManager
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill: use fast fixed pattern
|
|
||||||
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
|
|
||||||
num_sink_blocks=1,
|
|
||||||
local_window_blocks=3,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Decode: use query-aware selection
|
|
||||||
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
|
|
||||||
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
|
|
||||||
|
|
||||||
# Combine
|
|
||||||
policy = HybridPolicy(prefill_policy, decode_policy)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prefill_policy: SparsePolicy,
|
|
||||||
decode_policy: SparsePolicy,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize hybrid policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefill_policy: Policy to use during prefill phase
|
|
||||||
decode_policy: Policy to use during decode phase
|
|
||||||
"""
|
|
||||||
self.prefill_policy = prefill_policy
|
|
||||||
self.decode_policy = decode_policy
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""Delegate to appropriate policy based on phase."""
|
|
||||||
if ctx.is_prefill:
|
|
||||||
return self.prefill_policy.select_blocks(available_blocks, ctx)
|
|
||||||
else:
|
|
||||||
return self.decode_policy.select_blocks(available_blocks, ctx)
|
|
||||||
|
|
||||||
def on_block_offloaded(
|
|
||||||
self,
|
|
||||||
cpu_block_id: int,
|
|
||||||
layer_id: int,
|
|
||||||
k_cache: torch.Tensor,
|
|
||||||
num_valid_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
"""Forward to both policies (both may need metadata updates)."""
|
|
||||||
self.prefill_policy.on_block_offloaded(
|
|
||||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
|
||||||
)
|
|
||||||
self.decode_policy.on_block_offloaded(
|
|
||||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset both policies."""
|
|
||||||
self.prefill_policy.reset()
|
|
||||||
self.decode_policy.reset()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"HybridPolicy(\n"
|
|
||||||
f" prefill={self.prefill_policy},\n"
|
|
||||||
f" decode={self.decode_policy}\n"
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
@@ -134,7 +134,7 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_block_offloaded(
|
def on_prefill_offload(
|
||||||
self,
|
self,
|
||||||
cpu_block_id: int,
|
cpu_block_id: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
@@ -142,15 +142,38 @@ class SparsePolicy(ABC):
|
|||||||
num_valid_tokens: int,
|
num_valid_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Hook called when a block is offloaded from GPU to CPU.
|
Hook called when a block is offloaded during prefill phase.
|
||||||
|
|
||||||
|
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||||
Override this to collect metadata about blocks (e.g., min/max keys
|
Override this to collect metadata about blocks (e.g., min/max keys
|
||||||
for Quest-style selection). Default implementation does nothing.
|
for Quest-style selection). Default implementation does nothing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_id: The CPU block ID that was written
|
cpu_block_id: The CPU block ID that will be written
|
||||||
layer_id: Transformer layer index
|
layer_id: Transformer layer index
|
||||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||||
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_decode_offload(
|
||||||
|
self,
|
||||||
|
cpu_block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Hook called when a block is offloaded during decode phase.
|
||||||
|
|
||||||
|
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||||
|
Override this to update metadata about blocks. Default implementation
|
||||||
|
does nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_block_id: The CPU block ID that will be written
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||||
num_valid_tokens: Number of valid tokens in this block
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -289,14 +289,25 @@ class QuestPolicy(SparsePolicy):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def on_block_offloaded(
|
def on_prefill_offload(
|
||||||
self,
|
self,
|
||||||
cpu_block_id: int,
|
cpu_block_id: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
k_cache: torch.Tensor,
|
k_cache: torch.Tensor,
|
||||||
num_valid_tokens: int,
|
num_valid_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update min/max key metadata when block is offloaded."""
|
"""Update min/max key metadata during prefill offload."""
|
||||||
|
if self.metadata is not None:
|
||||||
|
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
|
||||||
|
def on_decode_offload(
|
||||||
|
self,
|
||||||
|
cpu_block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""Update min/max key metadata during decode offload (for new blocks)."""
|
||||||
if self.metadata is not None:
|
if self.metadata is not None:
|
||||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
|
||||||
|
|||||||
@@ -189,7 +189,8 @@ class Attention(nn.Module):
|
|||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# Apply sparse policy if enabled
|
||||||
if cpu_block_table and kvcache_manager.sparse_policy is not None:
|
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
|
||||||
|
if cpu_block_table and prefill_policy is not None:
|
||||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||||
policy_ctx = PolicyContext(
|
policy_ctx = PolicyContext(
|
||||||
query_chunk_idx=current_chunk_idx,
|
query_chunk_idx=current_chunk_idx,
|
||||||
@@ -200,7 +201,7 @@ class Attention(nn.Module):
|
|||||||
block_size=kvcache_manager.block_size,
|
block_size=kvcache_manager.block_size,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
)
|
)
|
||||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
cpu_block_table = prefill_policy.select_blocks(
|
||||||
cpu_block_table, policy_ctx
|
cpu_block_table, policy_ctx
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -279,7 +280,11 @@ class Attention(nn.Module):
|
|||||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||||
if current_chunk_idx < len(cpu_block_ids):
|
if current_chunk_idx < len(cpu_block_ids):
|
||||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||||
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
# k.shape[0] = number of tokens in current chunk
|
||||||
|
num_valid_tokens = k.shape[0]
|
||||||
|
offload_engine.offload_slot_layer_to_cpu(
|
||||||
|
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# CRITICAL: compute_stream must wait for offload to complete
|
# CRITICAL: compute_stream must wait for offload to complete
|
||||||
# before the next layer's store_kvcache can overwrite the GPU slot.
|
# before the next layer's store_kvcache can overwrite the GPU slot.
|
||||||
@@ -508,7 +513,8 @@ class Attention(nn.Module):
|
|||||||
last_block_valid_tokens = block_size # Last block was exactly full
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# Apply sparse policy if enabled
|
||||||
if kvcache_manager.sparse_policy is not None:
|
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
|
||||||
|
if decode_policy is not None:
|
||||||
policy_ctx = PolicyContext(
|
policy_ctx = PolicyContext(
|
||||||
query_chunk_idx=0,
|
query_chunk_idx=0,
|
||||||
num_query_chunks=1,
|
num_query_chunks=1,
|
||||||
@@ -518,7 +524,7 @@ class Attention(nn.Module):
|
|||||||
block_size=kvcache_manager.block_size,
|
block_size=kvcache_manager.block_size,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
)
|
)
|
||||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
cpu_block_table = decode_policy.select_blocks(
|
||||||
cpu_block_table, policy_ctx
|
cpu_block_table, policy_ctx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user