[WIP] Before add Quest policy.

This commit is contained in:
Zijie Tian
2026-01-07 02:32:30 +08:00
parent f240903013
commit c99a6f3d3f
11 changed files with 166 additions and 191 deletions

View File

@@ -3,11 +3,6 @@ import time
from random import randint, seed
from nanovllm import LLM, SamplingParams
# Import sparse policy classes
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
def bench_decode(llm, num_seqs, input_len, output_len):
"""Benchmark decode performance (original test)"""
@@ -38,58 +33,6 @@ def bench_prefill(llm, num_seqs, input_len):
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
def setup_quest_policy(llm, topk_blocks=8, threshold_blocks=4):
"""
Setup Quest sparse policy for decode phase.
Uses HybridPolicy: Full attention for prefill, Quest Top-K for decode.
"""
import torch
kvcache_manager = llm.model_runner.kvcache_manager
offload_engine = kvcache_manager.offload_engine
# Get model parameters from offload engine
num_layers = offload_engine.num_layers
num_kv_heads = offload_engine.num_kv_heads
head_dim = offload_engine.head_dim
num_cpu_blocks = kvcache_manager.num_cpu_blocks
dtype = offload_engine.k_cache_cpu.dtype
print(f"Setting up Quest policy:")
print(f" num_layers={num_layers}, num_kv_heads={num_kv_heads}, head_dim={head_dim}")
print(f" num_cpu_blocks={num_cpu_blocks}, dtype={dtype}")
print(f" topk_blocks={topk_blocks}, threshold_blocks={threshold_blocks}")
# Create BlockMetadataManager for storing min/max keys
metadata = BlockMetadataManager(
num_blocks=num_cpu_blocks,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
# Create Quest policy for decode
quest_config = QuestConfig(
topk_blocks=topk_blocks,
threshold_blocks=threshold_blocks,
)
quest_policy = QuestPolicy(quest_config, metadata)
# Create Hybrid policy: Full for prefill, Quest for decode
hybrid_policy = HybridPolicy(
prefill_policy=FullAttentionPolicy(),
decode_policy=quest_policy,
)
# Set the policy
kvcache_manager.set_sparse_policy(hybrid_policy)
print(f" Policy set: HybridPolicy(prefill=Full, decode=Quest)")
return hybrid_policy
def main():
import argparse
parser = argparse.ArgumentParser()
@@ -101,7 +44,18 @@ def main():
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
max_len = 131072 # 128K tokens
max_len = 32 * 1024 # 128K tokens
# Setup policy configuration
if not args.no_sparse:
prefill_policy = "full" # Full attention for prefill
decode_policy = "quest" # Quest Top-K for decode
print(f"\n[Quest Sparse Attention] prefill={prefill_policy}, decode={decode_policy}, topk={args.topk}")
else:
prefill_policy = "full" # Full attention for both phases
decode_policy = "full"
print("\n[Full Attention] No sparse policy (baseline)")
llm = LLM(
path,
enforce_eager=False,
@@ -109,15 +63,12 @@ def main():
max_num_batched_tokens=max_len,
enable_cpu_offload=True,
num_gpu_blocks=6, # Small GPU buffer for offload testing
prefill_policy=prefill_policy,
decode_policy=decode_policy,
sparse_topk_blocks=args.topk,
sparse_threshold_blocks=4,
)
if not args.no_sparse:
# Setup Quest policy for decode (Top-K blocks, apply when > 4 blocks)
setup_quest_policy(llm, topk_blocks=args.topk, threshold_blocks=4)
print(f"\n[Quest Sparse Attention] topk={args.topk}")
else:
print("\n[Full Attention] No sparse policy (baseline)")
# Warmup
llm.generate(["Benchmark: "], SamplingParams())

View File

@@ -29,8 +29,9 @@ class Config:
num_gpu_kvcache_blocks: int = -1
num_cpu_kvcache_blocks: int = -1
# Sparse attention configuration
sparse_policy: str | None = None # "vertical_slash", "quest", "streaming_llm", or None
# Sparse attention configuration (dual policy architecture)
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
sparse_topk_blocks: int = 8 # Top-K blocks for Quest

View File

@@ -156,6 +156,25 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Initialize sparse policies if manager has them (CPU offload mode)
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'):
# Initialize both policies with model config
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
if policy is not None:
policy.initialize(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
logger.info(
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} "
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
)
# Log KV cache allocation info with detailed per-token breakdown
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)

View File

@@ -56,14 +56,36 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
# Need CPU offload: use hybrid manager
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.policies import get_policy
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
# Create sparse policies from config
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
decode_policy_type = getattr(config, 'decode_policy', 'full')
def create_policy(policy_type_str):
"""Create a sparse policy from config string."""
if policy_type_str.lower() == 'full':
return create_sparse_policy(SparsePolicyType.FULL)
policy_type = SparsePolicyType[policy_type_str.upper()]
return create_sparse_policy(
policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
)
prefill_policy = create_policy(prefill_policy_type)
decode_policy = create_policy(decode_policy_type)
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=config.kvcache_block_size,
policy=policy,
policy=eviction_policy,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
)

View File

@@ -90,6 +90,8 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
prefill_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
):
"""
Initialize hybrid manager with CPU-primary ring buffer design.
@@ -102,6 +104,8 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU, used for prefix cache management)
prefill_policy: Sparse attention policy for prefill phase
decode_policy: Sparse attention policy for decode phase
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
@@ -113,6 +117,10 @@ class HybridKVCacheManager(KVCacheManager):
# Eviction policy
self.policy = policy or LRUPolicy()
# Sparse attention policies (set at construction time, immutable)
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
# Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [
LogicalBlock(i) for i in range(self.total_blocks)
@@ -153,9 +161,6 @@ class HybridKVCacheManager(KVCacheManager):
# Key: sequence id, Value: number of tokens from prefill (before decode started)
self._prefill_len: Dict[int, int] = {}
# Sparse attention policy (optional)
self.sparse_policy: Optional["SparsePolicy"] = None
@property
def block_size(self) -> int:
return self._block_size
@@ -180,6 +185,8 @@ class HybridKVCacheManager(KVCacheManager):
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
prefill_policy=self.prefill_policy,
decode_policy=self.decode_policy,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
@@ -187,23 +194,17 @@ class HybridKVCacheManager(KVCacheManager):
assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id)
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
"""
Set sparse attention policy for block selection.
The sparse policy determines which KV blocks to load from CPU
for each query chunk during chunked attention computation.
Get sparse policy for the specified phase.
Args:
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
is_prefill: True for prefill phase, False for decode phase
Example:
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
manager.set_sparse_policy(policy)
Returns:
SparsePolicy for the phase, or None if not set
"""
self.sparse_policy = policy
logger.info(f"Sparse attention policy set: {policy}")
return self.prefill_policy if is_prefill else self.decode_policy
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence."""

View File

@@ -17,6 +17,11 @@ from nanovllm.kvcache.kernels import gathered_copy_kv
from nanovllm.comm import memcpy_2d_async
from nanovllm.utils.logger import get_logger
# Import for type hints only (avoid circular import)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from nanovllm.kvcache.sparse import SparsePolicy
logger = get_logger("offload_engine")
@@ -55,6 +60,8 @@ class OffloadEngine:
head_dim: int,
dtype: torch.dtype = torch.float16,
num_streams: int = 4,
prefill_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
@@ -210,6 +217,10 @@ class OffloadEngine:
self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events
# ========== Sparse attention policies (set at construction time) ==========
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
stream = self.transfer_streams[self._stream_idx]
@@ -730,7 +741,14 @@ class OffloadEngine:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
def offload_slot_layer_to_cpu(
self,
slot_idx: int,
layer_id: int,
cpu_block_id: int,
num_valid_tokens: int = -1,
is_prefill: bool = True,
) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
@@ -741,9 +759,27 @@ class OffloadEngine:
slot_idx: Source GPU slot index
layer_id: Target layer in CPU cache
cpu_block_id: Target CPU block ID
num_valid_tokens: Number of valid tokens in this block (-1 = use block_size)
is_prefill: True if in prefill phase, False if in decode phase
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
# Collect metadata BEFORE offload (while k_cache is still on GPU)
# Both policies' callbacks are called - each decides whether to respond
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
k_cache = self.k_cache_gpu[slot_idx]
if is_prefill:
if self.prefill_policy is not None:
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else:
if self.prefill_policy is not None:
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream

View File

@@ -22,7 +22,6 @@ Usage:
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
@@ -67,6 +66,5 @@ __all__ = [
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"HybridPolicy",
"create_sparse_policy",
]

View File

@@ -1,93 +0,0 @@
"""
Hybrid sparse attention policy.
Allows using different policies for prefill vs decode phases.
This is useful because optimal sparsity patterns often differ:
- Prefill: fixed patterns work well (e.g., VerticalSlash)
- Decode: query-aware selection helps (e.g., Quest)
"""
from typing import List
import torch
from .policy import SparsePolicy, PolicyContext
class HybridPolicy(SparsePolicy):
"""
Hybrid policy that uses different policies for prefill and decode.
Example usage:
```python
from nanovllm.kvcache.sparse import (
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
VerticalSlashConfig, QuestConfig, BlockMetadataManager
)
# Prefill: use fast fixed pattern
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
num_sink_blocks=1,
local_window_blocks=3,
))
# Decode: use query-aware selection
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
# Combine
policy = HybridPolicy(prefill_policy, decode_policy)
```
"""
def __init__(
self,
prefill_policy: SparsePolicy,
decode_policy: SparsePolicy,
):
"""
Initialize hybrid policy.
Args:
prefill_policy: Policy to use during prefill phase
decode_policy: Policy to use during decode phase
"""
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""Delegate to appropriate policy based on phase."""
if ctx.is_prefill:
return self.prefill_policy.select_blocks(available_blocks, ctx)
else:
return self.decode_policy.select_blocks(available_blocks, ctx)
def on_block_offloaded(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Forward to both policies (both may need metadata updates)."""
self.prefill_policy.on_block_offloaded(
cpu_block_id, layer_id, k_cache, num_valid_tokens
)
self.decode_policy.on_block_offloaded(
cpu_block_id, layer_id, k_cache, num_valid_tokens
)
def reset(self) -> None:
"""Reset both policies."""
self.prefill_policy.reset()
self.decode_policy.reset()
def __repr__(self) -> str:
return (
f"HybridPolicy(\n"
f" prefill={self.prefill_policy},\n"
f" decode={self.decode_policy}\n"
f")"
)

View File

@@ -134,7 +134,7 @@ class SparsePolicy(ABC):
"""
pass
def on_block_offloaded(
def on_prefill_offload(
self,
cpu_block_id: int,
layer_id: int,
@@ -142,15 +142,38 @@ class SparsePolicy(ABC):
num_valid_tokens: int,
) -> None:
"""
Hook called when a block is offloaded from GPU to CPU.
Hook called when a block is offloaded during prefill phase.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to collect metadata about blocks (e.g., min/max keys
for Quest-style selection). Default implementation does nothing.
Args:
cpu_block_id: The CPU block ID that was written
cpu_block_id: The CPU block ID that will be written
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
"""
pass
def on_decode_offload(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""
Hook called when a block is offloaded during decode phase.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to update metadata about blocks. Default implementation
does nothing.
Args:
cpu_block_id: The CPU block ID that will be written
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
"""
pass

View File

@@ -289,14 +289,25 @@ class QuestPolicy(SparsePolicy):
return result
def on_block_offloaded(
def on_prefill_offload(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Update min/max key metadata when block is offloaded."""
"""Update min/max key metadata during prefill offload."""
if self.metadata is not None:
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
def on_decode_offload(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Update min/max key metadata during decode offload (for new blocks)."""
if self.metadata is not None:
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)

View File

@@ -189,7 +189,8 @@ class Attention(nn.Module):
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled
if cpu_block_table and kvcache_manager.sparse_policy is not None:
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
if cpu_block_table and prefill_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
@@ -200,7 +201,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table = prefill_policy.select_blocks(
cpu_block_table, policy_ctx
)
@@ -279,7 +280,11 @@ class Attention(nn.Module):
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# k.shape[0] = number of tokens in current chunk
num_valid_tokens = k.shape[0]
offload_engine.offload_slot_layer_to_cpu(
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
)
# CRITICAL: compute_stream must wait for offload to complete
# before the next layer's store_kvcache can overwrite the GPU slot.
@@ -508,7 +513,8 @@ class Attention(nn.Module):
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
if decode_policy is not None:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
@@ -518,7 +524,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table = decode_policy.select_blocks(
cpu_block_table, policy_ctx
)