[feat] Added sparse KVcache feature, NEED VERIFY.
This commit is contained in:
8
bench.py
8
bench.py
@@ -43,17 +43,17 @@ def main():
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Prefill Benchmark")
|
print("Prefill Benchmark")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
bench_prefill(llm, num_seqs=1, input_len=1024)
|
# bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=4095)
|
bench_prefill(llm, num_seqs=1, input_len=4095)
|
||||||
# bench_prefill(llm, num_seqs=16, input_len=1024)
|
# bench_prefill(llm, num_seqs=16, input_len=1024)
|
||||||
# bench_prefill(llm, num_seqs=64, input_len=1024)
|
# bench_prefill(llm, num_seqs=64, input_len=1024)
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Decode Benchmark")
|
print("Decode Benchmark")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024)
|
# bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024)
|
||||||
# bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024)
|
bench_decode(llm, num_seqs=1, max_input_len=4072, max_output_len=16)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -3,12 +3,17 @@ import time
|
|||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Import sparse policy classes
|
||||||
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
|
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
||||||
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, input_len, max_output_len):
|
def bench_decode(llm, num_seqs, input_len, max_output_len):
|
||||||
"""Benchmark decode performance (original test)"""
|
"""Benchmark decode performance (original test)"""
|
||||||
seed(0)
|
seed(0)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, input_len))] for _ in range(num_seqs)]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_output_len) for _ in range(num_seqs)]
|
||||||
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
@@ -33,7 +38,67 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_quest_policy(llm, topk_blocks=8, threshold_blocks=4):
|
||||||
|
"""
|
||||||
|
Setup Quest sparse policy for decode phase.
|
||||||
|
|
||||||
|
Uses HybridPolicy: Full attention for prefill, Quest Top-K for decode.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
kvcache_manager = llm.model_runner.kvcache_manager
|
||||||
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
|
# Get model parameters from offload engine
|
||||||
|
num_layers = offload_engine.num_layers
|
||||||
|
num_kv_heads = offload_engine.num_kv_heads
|
||||||
|
head_dim = offload_engine.head_dim
|
||||||
|
num_cpu_blocks = kvcache_manager.num_cpu_blocks
|
||||||
|
dtype = offload_engine.k_cache_cpu.dtype
|
||||||
|
|
||||||
|
print(f"Setting up Quest policy:")
|
||||||
|
print(f" num_layers={num_layers}, num_kv_heads={num_kv_heads}, head_dim={head_dim}")
|
||||||
|
print(f" num_cpu_blocks={num_cpu_blocks}, dtype={dtype}")
|
||||||
|
print(f" topk_blocks={topk_blocks}, threshold_blocks={threshold_blocks}")
|
||||||
|
|
||||||
|
# Create BlockMetadataManager for storing min/max keys
|
||||||
|
metadata = BlockMetadataManager(
|
||||||
|
num_blocks=num_cpu_blocks,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create Quest policy for decode
|
||||||
|
quest_config = QuestConfig(
|
||||||
|
topk_blocks=topk_blocks,
|
||||||
|
threshold_blocks=threshold_blocks,
|
||||||
|
)
|
||||||
|
quest_policy = QuestPolicy(quest_config, metadata)
|
||||||
|
|
||||||
|
# Create Hybrid policy: Full for prefill, Quest for decode
|
||||||
|
hybrid_policy = HybridPolicy(
|
||||||
|
prefill_policy=FullAttentionPolicy(),
|
||||||
|
decode_policy=quest_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the policy
|
||||||
|
kvcache_manager.set_sparse_policy(hybrid_policy)
|
||||||
|
print(f" Policy set: HybridPolicy(prefill=Full, decode=Quest)")
|
||||||
|
|
||||||
|
return hybrid_policy
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--no-sparse", action="store_true", help="Disable sparse attention (baseline)")
|
||||||
|
parser.add_argument("--topk", type=int, default=8, help="Top-K blocks for Quest")
|
||||||
|
parser.add_argument("--input-len", type=int, default=128 * 1024, help="Input length in tokens")
|
||||||
|
parser.add_argument("--output-len", type=int, default=128, help="Output length in tokens")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
@@ -45,22 +110,25 @@ def main():
|
|||||||
num_prefetch_blocks=4,
|
num_prefetch_blocks=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not args.no_sparse:
|
||||||
|
# Setup Quest policy for decode (Top-K blocks, apply when > 4 blocks)
|
||||||
|
setup_quest_policy(llm, topk_blocks=args.topk, threshold_blocks=4)
|
||||||
|
print(f"\n[Quest Sparse Attention] topk={args.topk}")
|
||||||
|
else:
|
||||||
|
print("\n[Full Attention] No sparse policy (baseline)")
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
llm.generate(["Benchmark: "], SamplingParams())
|
llm.generate(["Benchmark: "], SamplingParams())
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Prefill Benchmark (CPU Offload)")
|
print("Prefill Benchmark (CPU Offload)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=1024)
|
bench_prefill(llm, num_seqs=1, input_len=args.input_len)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=4096)
|
|
||||||
bench_prefill(llm, num_seqs=1, input_len=128 * 1024)
|
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Decode Benchmark (CPU Offload)")
|
print("Decode Benchmark (CPU Offload)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
bench_decode(llm, num_seqs=1, input_len=128 * 1024, max_output_len=128)
|
bench_decode(llm, num_seqs=1, input_len=args.input_len, max_output_len=args.output_len)
|
||||||
# bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -28,6 +28,13 @@ class Config:
|
|||||||
num_gpu_kvcache_blocks: int = -1
|
num_gpu_kvcache_blocks: int = -1
|
||||||
num_cpu_kvcache_blocks: int = -1
|
num_cpu_kvcache_blocks: int = -1
|
||||||
|
|
||||||
|
# Sparse attention configuration
|
||||||
|
sparse_policy: str | None = None # "vertical_slash", "quest", "streaming_llm", or None
|
||||||
|
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
|
||||||
|
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
|
||||||
|
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||||
|
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert os.path.isdir(self.model)
|
assert os.path.isdir(self.model)
|
||||||
assert self.kvcache_block_size % 256 == 0
|
assert self.kvcache_block_size % 256 == 0
|
||||||
|
|||||||
@@ -700,6 +700,20 @@ class ModelRunner:
|
|||||||
# Offload this chunk's ring buffer slot to CPU (async)
|
# Offload this chunk's ring buffer slot to CPU (async)
|
||||||
if block_idx < len(cpu_block_ids):
|
if block_idx < len(cpu_block_ids):
|
||||||
cpu_block_id = cpu_block_ids[block_idx]
|
cpu_block_id = cpu_block_ids[block_idx]
|
||||||
|
|
||||||
|
# Call sparse policy hook before offload (to capture metadata)
|
||||||
|
sparse_policy = self.kvcache_manager.sparse_policy
|
||||||
|
if sparse_policy is not None:
|
||||||
|
num_tokens = chunk_end - chunk_start
|
||||||
|
for layer_id in range(offload_engine.num_layers):
|
||||||
|
k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens]
|
||||||
|
sparse_policy.on_block_offloaded(
|
||||||
|
cpu_block_id=cpu_block_id,
|
||||||
|
layer_id=layer_id,
|
||||||
|
k_cache=k_cache,
|
||||||
|
num_valid_tokens=num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
|
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
|
||||||
|
|
||||||
# Wait for offload to complete before next chunk
|
# Wait for offload to complete before next chunk
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ from nanovllm.kvcache.offload_engine import OffloadEngine
|
|||||||
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
|
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
|
||||||
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
|
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
|
||||||
|
|
||||||
|
# Type checking import for sparse policy
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanovllm.kvcache.sparse.policy import SparsePolicy
|
||||||
|
|
||||||
|
|
||||||
class BlockLocation(Enum):
|
class BlockLocation(Enum):
|
||||||
"""Where a logical block's data currently resides."""
|
"""Where a logical block's data currently resides."""
|
||||||
@@ -142,6 +147,9 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Key: sequence id, Value: starting position where decode began in current block
|
# Key: sequence id, Value: starting position where decode began in current block
|
||||||
self._decode_start_pos: Dict[int, int] = {}
|
self._decode_start_pos: Dict[int, int] = {}
|
||||||
|
|
||||||
|
# Sparse attention policy (optional)
|
||||||
|
self.sparse_policy: Optional["SparsePolicy"] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block_size(self) -> int:
|
def block_size(self) -> int:
|
||||||
return self._block_size
|
return self._block_size
|
||||||
@@ -174,6 +182,24 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
assert self.offload_engine is not None
|
assert self.offload_engine is not None
|
||||||
return self.offload_engine.get_layer_cache(layer_id)
|
return self.offload_engine.get_layer_cache(layer_id)
|
||||||
|
|
||||||
|
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
|
||||||
|
"""
|
||||||
|
Set sparse attention policy for block selection.
|
||||||
|
|
||||||
|
The sparse policy determines which KV blocks to load from CPU
|
||||||
|
for each query chunk during chunked attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
|
||||||
|
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
|
||||||
|
manager.set_sparse_policy(policy)
|
||||||
|
"""
|
||||||
|
self.sparse_policy = policy
|
||||||
|
logger.info(f"Sparse attention policy set: {policy}")
|
||||||
|
|
||||||
def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
|
def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
|
||||||
"""
|
"""
|
||||||
Get a free GPU slot, evicting if necessary.
|
Get a free GPU slot, evicting if necessary.
|
||||||
|
|||||||
90
nanovllm/kvcache/sparse/__init__.py
Normal file
90
nanovllm/kvcache/sparse/__init__.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
Sparse Attention Policy module.
|
||||||
|
|
||||||
|
Provides pluggable policies for selecting which KV blocks to load
|
||||||
|
during chunked attention with CPU offload.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
|
||||||
|
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
|
||||||
|
|
||||||
|
# Use built-in policy
|
||||||
|
policy = VerticalSlashPolicy(VerticalSlashConfig())
|
||||||
|
|
||||||
|
# Or create custom policy
|
||||||
|
class MyPolicy(SparsePolicy):
|
||||||
|
def select_blocks(self, available_blocks, ctx):
|
||||||
|
return available_blocks[:5] # Just first 5 blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
|
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
|
||||||
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
|
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
|
||||||
|
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
||||||
|
|
||||||
|
# Built-in policy registry
|
||||||
|
BUILTIN_SPARSE_POLICIES = {
|
||||||
|
"full": FullAttentionPolicy,
|
||||||
|
"vertical_slash": VerticalSlashPolicy,
|
||||||
|
"streaming_llm": StreamingLLMPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
|
||||||
|
"""
|
||||||
|
Get a sparse attention policy instance by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
|
||||||
|
**kwargs: Policy-specific configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SparsePolicy instance
|
||||||
|
"""
|
||||||
|
policy_name = policy_name.lower()
|
||||||
|
|
||||||
|
if policy_name == "full":
|
||||||
|
return FullAttentionPolicy()
|
||||||
|
elif policy_name == "vertical_slash":
|
||||||
|
config = VerticalSlashConfig(
|
||||||
|
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
||||||
|
local_window_blocks=kwargs.get("local_window_blocks", 2),
|
||||||
|
threshold_blocks=kwargs.get("threshold_blocks", 4),
|
||||||
|
)
|
||||||
|
return VerticalSlashPolicy(config)
|
||||||
|
elif policy_name == "streaming_llm":
|
||||||
|
config = StreamingLLMConfig(
|
||||||
|
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
||||||
|
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
|
||||||
|
)
|
||||||
|
return StreamingLLMPolicy(config)
|
||||||
|
elif policy_name == "quest":
|
||||||
|
# Quest requires metadata_manager to be passed separately
|
||||||
|
raise ValueError(
|
||||||
|
"Quest policy requires BlockMetadataManager. "
|
||||||
|
"Use QuestPolicy(config, metadata_manager) directly."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown sparse policy '{policy_name}'. "
|
||||||
|
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SparsePolicy",
|
||||||
|
"PolicyContext",
|
||||||
|
"FullAttentionPolicy",
|
||||||
|
"VerticalSlashPolicy",
|
||||||
|
"VerticalSlashConfig",
|
||||||
|
"QuestPolicy",
|
||||||
|
"QuestConfig",
|
||||||
|
"BlockMetadataManager",
|
||||||
|
"StreamingLLMPolicy",
|
||||||
|
"StreamingLLMConfig",
|
||||||
|
"HybridPolicy",
|
||||||
|
"get_sparse_policy",
|
||||||
|
"BUILTIN_SPARSE_POLICIES",
|
||||||
|
]
|
||||||
34
nanovllm/kvcache/sparse/full_policy.py
Normal file
34
nanovllm/kvcache/sparse/full_policy.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""
|
||||||
|
Full attention policy - loads all blocks (no sparsity).
|
||||||
|
|
||||||
|
This serves as a baseline and default policy when sparse
|
||||||
|
attention is not needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from .policy import SparsePolicy, PolicyContext
|
||||||
|
|
||||||
|
|
||||||
|
class FullAttentionPolicy(SparsePolicy):
|
||||||
|
"""
|
||||||
|
Full attention policy that loads all available blocks.
|
||||||
|
|
||||||
|
This is the default behavior with no sparsity - all previous
|
||||||
|
KV cache blocks are loaded for each query chunk.
|
||||||
|
|
||||||
|
Use this as:
|
||||||
|
- A baseline for comparing sparse policies
|
||||||
|
- When you need full attention accuracy
|
||||||
|
- For short sequences where sparsity isn't beneficial
|
||||||
|
"""
|
||||||
|
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int],
|
||||||
|
ctx: PolicyContext,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Return all blocks - no sparsity."""
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "FullAttentionPolicy()"
|
||||||
93
nanovllm/kvcache/sparse/hybrid.py
Normal file
93
nanovllm/kvcache/sparse/hybrid.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""
|
||||||
|
Hybrid sparse attention policy.
|
||||||
|
|
||||||
|
Allows using different policies for prefill vs decode phases.
|
||||||
|
This is useful because optimal sparsity patterns often differ:
|
||||||
|
- Prefill: fixed patterns work well (e.g., VerticalSlash)
|
||||||
|
- Decode: query-aware selection helps (e.g., Quest)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
import torch
|
||||||
|
from .policy import SparsePolicy, PolicyContext
|
||||||
|
|
||||||
|
|
||||||
|
class HybridPolicy(SparsePolicy):
|
||||||
|
"""
|
||||||
|
Hybrid policy that uses different policies for prefill and decode.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
from nanovllm.kvcache.sparse import (
|
||||||
|
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
|
||||||
|
VerticalSlashConfig, QuestConfig, BlockMetadataManager
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill: use fast fixed pattern
|
||||||
|
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
|
||||||
|
num_sink_blocks=1,
|
||||||
|
local_window_blocks=3,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Decode: use query-aware selection
|
||||||
|
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
|
||||||
|
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
|
||||||
|
|
||||||
|
# Combine
|
||||||
|
policy = HybridPolicy(prefill_policy, decode_policy)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefill_policy: SparsePolicy,
|
||||||
|
decode_policy: SparsePolicy,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize hybrid policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefill_policy: Policy to use during prefill phase
|
||||||
|
decode_policy: Policy to use during decode phase
|
||||||
|
"""
|
||||||
|
self.prefill_policy = prefill_policy
|
||||||
|
self.decode_policy = decode_policy
|
||||||
|
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int],
|
||||||
|
ctx: PolicyContext,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Delegate to appropriate policy based on phase."""
|
||||||
|
if ctx.is_prefill:
|
||||||
|
return self.prefill_policy.select_blocks(available_blocks, ctx)
|
||||||
|
else:
|
||||||
|
return self.decode_policy.select_blocks(available_blocks, ctx)
|
||||||
|
|
||||||
|
def on_block_offloaded(
|
||||||
|
self,
|
||||||
|
cpu_block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""Forward to both policies (both may need metadata updates)."""
|
||||||
|
self.prefill_policy.on_block_offloaded(
|
||||||
|
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
||||||
|
)
|
||||||
|
self.decode_policy.on_block_offloaded(
|
||||||
|
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset both policies."""
|
||||||
|
self.prefill_policy.reset()
|
||||||
|
self.decode_policy.reset()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"HybridPolicy(\n"
|
||||||
|
f" prefill={self.prefill_policy},\n"
|
||||||
|
f" decode={self.decode_policy}\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
124
nanovllm/kvcache/sparse/policy.py
Normal file
124
nanovllm/kvcache/sparse/policy.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Base class for sparse attention policies.
|
||||||
|
|
||||||
|
Sparse attention policies determine which KV cache blocks to load
|
||||||
|
from CPU for each query chunk during chunked attention computation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Any
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PolicyContext:
|
||||||
|
"""
|
||||||
|
Context passed to sparse policy for block selection.
|
||||||
|
|
||||||
|
This dataclass contains all information needed by a sparse policy
|
||||||
|
to decide which blocks to load for the current query chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query_chunk_idx: int
|
||||||
|
"""Index of the current query chunk (0-indexed)."""
|
||||||
|
|
||||||
|
num_query_chunks: int
|
||||||
|
"""Total number of query chunks in this prefill."""
|
||||||
|
|
||||||
|
layer_id: int
|
||||||
|
"""Current transformer layer index."""
|
||||||
|
|
||||||
|
query: Optional[torch.Tensor]
|
||||||
|
"""
|
||||||
|
Query tensor for current chunk.
|
||||||
|
Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill.
|
||||||
|
May be None if not available (e.g., some prefill scenarios).
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_prefill: bool
|
||||||
|
"""True if in prefill phase, False if in decode phase."""
|
||||||
|
|
||||||
|
block_size: int = 4096
|
||||||
|
"""Number of tokens per block."""
|
||||||
|
|
||||||
|
total_kv_len: int = 0
|
||||||
|
"""Total KV sequence length so far (for reference)."""
|
||||||
|
|
||||||
|
|
||||||
|
class SparsePolicy(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for sparse attention policies.
|
||||||
|
|
||||||
|
Subclass this and implement select_blocks() to create custom
|
||||||
|
sparse attention patterns. The policy receives context about
|
||||||
|
the current query chunk and returns which KV blocks to load.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class MySparsePolicy(SparsePolicy):
|
||||||
|
def select_blocks(self, available_blocks, ctx):
|
||||||
|
# Load first block and last 2 blocks
|
||||||
|
if len(available_blocks) <= 3:
|
||||||
|
return available_blocks
|
||||||
|
return [available_blocks[0]] + available_blocks[-2:]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int],
|
||||||
|
ctx: PolicyContext,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Select which KV blocks to load for the current query chunk.
|
||||||
|
|
||||||
|
This is the core method that defines the sparse attention pattern.
|
||||||
|
The returned blocks will be loaded from CPU to GPU for attention
|
||||||
|
computation against the current query chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_blocks: List of CPU block IDs that contain KV cache
|
||||||
|
from previous chunks. These are ordered by
|
||||||
|
their position in the sequence.
|
||||||
|
ctx: PolicyContext with information about the current query
|
||||||
|
chunk, layer, phase (prefill/decode), etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of block IDs to load (must be a subset of available_blocks).
|
||||||
|
The order may affect performance (sequential access is faster).
|
||||||
|
Returning [] means no previous blocks will be loaded.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_block_offloaded(
|
||||||
|
self,
|
||||||
|
cpu_block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Hook called when a block is offloaded from GPU to CPU.
|
||||||
|
|
||||||
|
Override this to collect metadata about blocks (e.g., min/max keys
|
||||||
|
for Quest-style selection). Default implementation does nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_block_id: The CPU block ID that was written
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
||||||
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset policy state.
|
||||||
|
|
||||||
|
Called when starting a new sequence or clearing state.
|
||||||
|
Default implementation does nothing.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
284
nanovllm/kvcache/sparse/quest.py
Normal file
284
nanovllm/kvcache/sparse/quest.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""
|
||||||
|
Quest-style sparse attention policy.
|
||||||
|
|
||||||
|
Uses min/max key bounds per block to estimate attention scores
|
||||||
|
and select Top-K blocks most relevant to the current query.
|
||||||
|
|
||||||
|
Reference: Quest paper on query-aware KV cache selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
from .policy import SparsePolicy, PolicyContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockMetadataManager:
|
||||||
|
"""
|
||||||
|
Manages per-block metadata for Quest-style sparse selection.
|
||||||
|
|
||||||
|
Stores min/max key values for each block, which are used to
|
||||||
|
compute upper bounds on attention scores without loading the
|
||||||
|
full KV cache.
|
||||||
|
|
||||||
|
Memory usage: 2 * num_blocks * num_layers * num_kv_heads * head_dim * dtype_size
|
||||||
|
Example: 1000 blocks, 28 layers, 4 heads, 128 dim, bf16 = ~57 MB
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_blocks: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize metadata storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_blocks: Maximum number of CPU blocks
|
||||||
|
num_layers: Number of transformer layers
|
||||||
|
num_kv_heads: Number of KV attention heads
|
||||||
|
head_dim: Dimension per head
|
||||||
|
dtype: Data type for metadata storage
|
||||||
|
"""
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
|
||||||
|
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
|
||||||
|
self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True)
|
||||||
|
self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True)
|
||||||
|
|
||||||
|
# Track which blocks have valid metadata
|
||||||
|
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool)
|
||||||
|
|
||||||
|
def update_metadata(
|
||||||
|
self,
|
||||||
|
block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update min/max key bounds for a block.
|
||||||
|
|
||||||
|
Called when a block is offloaded to CPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_id: CPU block ID
|
||||||
|
layer_id: Layer index
|
||||||
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
||||||
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
|
"""
|
||||||
|
if num_valid_tokens == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get valid keys only
|
||||||
|
k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim]
|
||||||
|
|
||||||
|
# Compute min/max across token dimension
|
||||||
|
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
|
||||||
|
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
|
||||||
|
self.valid_blocks[block_id] = True
|
||||||
|
|
||||||
|
def get_block_metadata(
|
||||||
|
self,
|
||||||
|
block_ids: List[int],
|
||||||
|
layer_id: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Get min/max keys for specified blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_ids: List of CPU block IDs
|
||||||
|
layer_id: Layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (key_min, key_max) tensors
|
||||||
|
Shape: [num_blocks, num_kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
key_min = self.key_min[block_ids, layer_id]
|
||||||
|
key_max = self.key_max[block_ids, layer_id]
|
||||||
|
return key_min, key_max
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset all metadata."""
|
||||||
|
self.key_min.zero_()
|
||||||
|
self.key_max.zero_()
|
||||||
|
self.valid_blocks.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuestConfig:
|
||||||
|
"""Configuration for QuestPolicy."""
|
||||||
|
|
||||||
|
topk_blocks: int = 8
|
||||||
|
"""Number of top blocks to select based on estimated attention scores."""
|
||||||
|
|
||||||
|
threshold_blocks: int = 4
|
||||||
|
"""If total blocks <= threshold, load all (no scoring needed)."""
|
||||||
|
|
||||||
|
include_sink_blocks: int = 0
|
||||||
|
"""Always include this many sink blocks (first N blocks), in addition to Top-K."""
|
||||||
|
|
||||||
|
include_recent_blocks: int = 0
|
||||||
|
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
|
||||||
|
|
||||||
|
|
||||||
|
class QuestPolicy(SparsePolicy):
|
||||||
|
"""
|
||||||
|
Quest-style Top-K block selection using min/max key bounds.
|
||||||
|
|
||||||
|
For each query, computes an upper bound on attention scores for
|
||||||
|
each block using the stored min/max keys, then selects the Top-K
|
||||||
|
blocks with highest estimated scores.
|
||||||
|
|
||||||
|
Score computation:
|
||||||
|
score(q, block) = max(q · key_min, q · key_max)
|
||||||
|
|
||||||
|
This upper bound is derived from the fact that for any key k in
|
||||||
|
the block: min_k <= k <= max_k (element-wise), so the actual
|
||||||
|
attention score is bounded by the maximum of the two extremes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: QuestConfig,
|
||||||
|
metadata_manager: BlockMetadataManager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Quest policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: QuestConfig with selection parameters
|
||||||
|
metadata_manager: BlockMetadataManager for min/max key storage
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.metadata = metadata_manager
|
||||||
|
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int],
|
||||||
|
ctx: PolicyContext,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Select Top-K blocks based on query-key similarity bounds.
|
||||||
|
|
||||||
|
If query is not available (some prefill scenarios), falls back
|
||||||
|
to loading all blocks.
|
||||||
|
"""
|
||||||
|
n = len(available_blocks)
|
||||||
|
|
||||||
|
# If below threshold or no query, load all
|
||||||
|
if n <= self.config.threshold_blocks:
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
if ctx.query is None:
|
||||||
|
# No query available - cannot compute scores
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
# Get metadata for available blocks
|
||||||
|
key_min, key_max = self.metadata.get_block_metadata(
|
||||||
|
available_blocks, ctx.layer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to query device for computation
|
||||||
|
device = ctx.query.device
|
||||||
|
key_min = key_min.to(device, non_blocking=True)
|
||||||
|
key_max = key_max.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# Compute upper bound scores
|
||||||
|
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
||||||
|
q = ctx.query
|
||||||
|
if q.dim() == 4:
|
||||||
|
# Prefill: use mean over sequence length
|
||||||
|
q = q.mean(dim=1) # [1, num_heads, head_dim]
|
||||||
|
q = q.squeeze(0) # [num_q_heads, head_dim]
|
||||||
|
|
||||||
|
# Handle GQA: query may have more heads than KV
|
||||||
|
# key_min/key_max: [num_blocks, num_kv_heads, head_dim]
|
||||||
|
num_q_heads = q.shape[0]
|
||||||
|
num_kv_heads = key_min.shape[1]
|
||||||
|
|
||||||
|
if num_q_heads != num_kv_heads:
|
||||||
|
# GQA: group query heads and average per KV group
|
||||||
|
# Reshape q: [num_q_heads, head_dim] -> [num_kv_heads, group_size, head_dim]
|
||||||
|
group_size = num_q_heads // num_kv_heads
|
||||||
|
q = q.view(num_kv_heads, group_size, -1).mean(dim=1) # [num_kv_heads, head_dim]
|
||||||
|
|
||||||
|
# Score: max(q·k_min, q·k_max) averaged over heads
|
||||||
|
# key_min/key_max: [num_blocks, num_kv_heads, head_dim]
|
||||||
|
# q: [num_kv_heads, head_dim]
|
||||||
|
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
|
||||||
|
score_max = torch.einsum('hd,bhd->bh', q, key_max)
|
||||||
|
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks]
|
||||||
|
|
||||||
|
# Build selection set
|
||||||
|
selected_indices = set()
|
||||||
|
|
||||||
|
# Always include sink blocks
|
||||||
|
for i in range(min(self.config.include_sink_blocks, n)):
|
||||||
|
selected_indices.add(i)
|
||||||
|
|
||||||
|
# Always include recent blocks
|
||||||
|
for i in range(max(0, n - self.config.include_recent_blocks), n):
|
||||||
|
selected_indices.add(i)
|
||||||
|
|
||||||
|
# Top-K selection from remaining
|
||||||
|
remaining_k = max(0, self.config.topk_blocks - len(selected_indices))
|
||||||
|
if remaining_k > 0:
|
||||||
|
# Mask out already selected
|
||||||
|
mask = torch.ones(n, dtype=torch.bool, device=device)
|
||||||
|
for idx in selected_indices:
|
||||||
|
mask[idx] = False
|
||||||
|
|
||||||
|
if mask.any():
|
||||||
|
masked_scores = scores.clone()
|
||||||
|
masked_scores[~mask] = float('-inf')
|
||||||
|
topk_count = min(remaining_k, mask.sum().item())
|
||||||
|
if topk_count > 0:
|
||||||
|
topk_indices = masked_scores.topk(topk_count).indices.cpu().tolist()
|
||||||
|
selected_indices.update(topk_indices)
|
||||||
|
|
||||||
|
# Return in sequential order for better memory access
|
||||||
|
result = [available_blocks[i] for i in sorted(selected_indices)]
|
||||||
|
|
||||||
|
# Log selection info (only for layer 0 to avoid spam)
|
||||||
|
if ctx.layer_id == 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Quest select: {len(result)}/{n} blocks "
|
||||||
|
f"(topk={self.config.topk_blocks}, sink={self.config.include_sink_blocks}, "
|
||||||
|
f"recent={self.config.include_recent_blocks})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def on_block_offloaded(
|
||||||
|
self,
|
||||||
|
cpu_block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""Update min/max key metadata when block is offloaded."""
|
||||||
|
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset metadata."""
|
||||||
|
self.metadata.reset()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"QuestPolicy(topk={self.config.topk_blocks}, "
|
||||||
|
f"threshold={self.config.threshold_blocks}, "
|
||||||
|
f"sink={self.config.include_sink_blocks}, "
|
||||||
|
f"recent={self.config.include_recent_blocks})"
|
||||||
|
)
|
||||||
84
nanovllm/kvcache/sparse/streaming_llm.py
Normal file
84
nanovllm/kvcache/sparse/streaming_llm.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
StreamingLLM sparse attention policy.
|
||||||
|
|
||||||
|
Only keeps sink tokens (beginning) + recent tokens (end).
|
||||||
|
Intermediate context is discarded. This enables infinite-length
|
||||||
|
generation but loses intermediate context.
|
||||||
|
|
||||||
|
Reference: StreamingLLM paper on attention sinks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
from .policy import SparsePolicy, PolicyContext
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamingLLMConfig:
|
||||||
|
"""Configuration for StreamingLLMPolicy."""
|
||||||
|
|
||||||
|
num_sink_blocks: int = 1
|
||||||
|
"""Number of blocks at the beginning to always include (attention sinks)."""
|
||||||
|
|
||||||
|
num_recent_blocks: int = 3
|
||||||
|
"""Number of most recent blocks to include (sliding window)."""
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingLLMPolicy(SparsePolicy):
|
||||||
|
"""
|
||||||
|
StreamingLLM pattern: sink tokens + recent tokens only.
|
||||||
|
|
||||||
|
This is the most aggressive sparsity pattern - only keeps a small
|
||||||
|
fixed window of context. Suitable for:
|
||||||
|
- Very long streaming generation
|
||||||
|
- When intermediate context can be safely discarded
|
||||||
|
- Maximizing throughput over accuracy
|
||||||
|
|
||||||
|
Pattern visualization:
|
||||||
|
```
|
||||||
|
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
|
||||||
|
↑ × × × ↑ ↑ ↑
|
||||||
|
sink (discarded) recent window
|
||||||
|
```
|
||||||
|
|
||||||
|
Warning: This loses information from intermediate blocks!
|
||||||
|
Use only when this trade-off is acceptable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: StreamingLLMConfig = None):
|
||||||
|
self.config = config or StreamingLLMConfig()
|
||||||
|
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int],
|
||||||
|
ctx: PolicyContext,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Select sink blocks + recent blocks only.
|
||||||
|
|
||||||
|
Intermediate blocks are not loaded (effectively discarded).
|
||||||
|
"""
|
||||||
|
n = len(available_blocks)
|
||||||
|
|
||||||
|
# If total blocks fit in sink + recent, load all
|
||||||
|
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
|
||||||
|
if n <= total_keep:
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
selected_indices = set()
|
||||||
|
|
||||||
|
# Sink blocks (first N)
|
||||||
|
for i in range(min(self.config.num_sink_blocks, n)):
|
||||||
|
selected_indices.add(i)
|
||||||
|
|
||||||
|
# Recent blocks (last M)
|
||||||
|
for i in range(max(0, n - self.config.num_recent_blocks), n):
|
||||||
|
selected_indices.add(i)
|
||||||
|
|
||||||
|
return [available_blocks[i] for i in sorted(selected_indices)]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
|
||||||
|
f"recent={self.config.num_recent_blocks})"
|
||||||
|
)
|
||||||
95
nanovllm/kvcache/sparse/vertical_slash.py
Normal file
95
nanovllm/kvcache/sparse/vertical_slash.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""
|
||||||
|
Vertical-Slash sparse attention policy (MInference-style).
|
||||||
|
|
||||||
|
Selects sink blocks (beginning of sequence) + local window blocks
|
||||||
|
(near the current query position). This pattern captures:
|
||||||
|
- Important initial context (system prompt, instructions)
|
||||||
|
- Recent context (relevant for local dependencies)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
from .policy import SparsePolicy, PolicyContext
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VerticalSlashConfig:
|
||||||
|
"""Configuration for VerticalSlashPolicy."""
|
||||||
|
|
||||||
|
num_sink_blocks: int = 1
|
||||||
|
"""Number of blocks at the beginning to always include (sink tokens)."""
|
||||||
|
|
||||||
|
local_window_blocks: int = 2
|
||||||
|
"""Number of blocks in the local window near current query position."""
|
||||||
|
|
||||||
|
threshold_blocks: int = 4
|
||||||
|
"""If total blocks <= threshold, load all (no sparsity applied)."""
|
||||||
|
|
||||||
|
|
||||||
|
class VerticalSlashPolicy(SparsePolicy):
|
||||||
|
"""
|
||||||
|
Vertical-Slash pattern: sink tokens + local window.
|
||||||
|
|
||||||
|
This pattern is inspired by MInference and observations that:
|
||||||
|
1. Initial tokens (sink) often receive high attention
|
||||||
|
2. Local context (recent tokens) is important for dependencies
|
||||||
|
|
||||||
|
Pattern visualization:
|
||||||
|
```
|
||||||
|
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
|
||||||
|
↑ ↑ ↑ ↑
|
||||||
|
sink local window (for query at block 9)
|
||||||
|
```
|
||||||
|
|
||||||
|
For prefill chunk K, the local window is blocks [K-window, K-1].
|
||||||
|
For decode, the local window is the last N blocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: VerticalSlashConfig = None):
|
||||||
|
self.config = config or VerticalSlashConfig()
|
||||||
|
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int],
|
||||||
|
ctx: PolicyContext,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Select sink blocks + local window blocks.
|
||||||
|
|
||||||
|
For prefill: local window is relative to current chunk position.
|
||||||
|
For decode: local window is the most recent blocks.
|
||||||
|
"""
|
||||||
|
n = len(available_blocks)
|
||||||
|
|
||||||
|
# If below threshold, load all
|
||||||
|
if n <= self.config.threshold_blocks:
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
selected_indices = set()
|
||||||
|
|
||||||
|
# Sink blocks (first N blocks)
|
||||||
|
for i in range(min(self.config.num_sink_blocks, n)):
|
||||||
|
selected_indices.add(i)
|
||||||
|
|
||||||
|
# Local window
|
||||||
|
if ctx.is_prefill:
|
||||||
|
# For prefill chunk K, local window is blocks [K-window, K-1]
|
||||||
|
# (blocks before current chunk, not including current)
|
||||||
|
window_end = min(ctx.query_chunk_idx, n)
|
||||||
|
window_start = max(0, window_end - self.config.local_window_blocks)
|
||||||
|
for i in range(window_start, window_end):
|
||||||
|
selected_indices.add(i)
|
||||||
|
else:
|
||||||
|
# For decode, local window is the last M blocks
|
||||||
|
for i in range(max(0, n - self.config.local_window_blocks), n):
|
||||||
|
selected_indices.add(i)
|
||||||
|
|
||||||
|
# Return blocks in order (maintains sequential access pattern)
|
||||||
|
return [available_blocks[i] for i in sorted(selected_indices)]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, "
|
||||||
|
f"window={self.config.local_window_blocks}, "
|
||||||
|
f"threshold={self.config.threshold_blocks})"
|
||||||
|
)
|
||||||
@@ -6,6 +6,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -133,6 +134,22 @@ class Attention(nn.Module):
|
|||||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Apply sparse policy if enabled
|
||||||
|
if cpu_block_table and kvcache_manager.sparse_policy is not None:
|
||||||
|
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=current_chunk_idx,
|
||||||
|
num_query_chunks=num_chunks,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
query=None, # Prefill typically doesn't use query for selection
|
||||||
|
is_prefill=True,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
||||||
|
cpu_block_table, policy_ctx
|
||||||
|
)
|
||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
@@ -344,6 +361,21 @@ class Attention(nn.Module):
|
|||||||
if not cpu_block_table:
|
if not cpu_block_table:
|
||||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||||
|
|
||||||
|
# Apply sparse policy if enabled
|
||||||
|
if kvcache_manager.sparse_policy is not None:
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
query=q_batched, # Decode provides query for query-aware selection
|
||||||
|
is_prefill=False,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
||||||
|
cpu_block_table, policy_ctx
|
||||||
|
)
|
||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
# Use prefetch_size as chunk size for double buffering
|
# Use prefetch_size as chunk size for double buffering
|
||||||
|
|||||||
252
tests/test_sparse_policy.py
Normal file
252
tests/test_sparse_policy.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
Test sparse attention policies.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
CUDA_VISIBLE_DEVICES=4,5 python tests/test_sparse_policy.py [policy_name]
|
||||||
|
|
||||||
|
Policy names: full, vertical_slash, streaming_llm, quest
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
# Test the sparse policy implementations
|
||||||
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
|
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
|
||||||
|
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
|
||||||
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_attention_policy():
|
||||||
|
"""Test FullAttentionPolicy returns all blocks."""
|
||||||
|
print("\n=== Testing FullAttentionPolicy ===")
|
||||||
|
policy = FullAttentionPolicy()
|
||||||
|
|
||||||
|
available_blocks = list(range(10))
|
||||||
|
ctx = PolicyContext(
|
||||||
|
query_chunk_idx=5,
|
||||||
|
num_query_chunks=10,
|
||||||
|
layer_id=0,
|
||||||
|
query=None,
|
||||||
|
is_prefill=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
assert selected == available_blocks, f"Expected all blocks, got {selected}"
|
||||||
|
print(f" Prefill: input={available_blocks}, selected={selected} [PASS]")
|
||||||
|
|
||||||
|
# Test decode
|
||||||
|
ctx.is_prefill = False
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
assert selected == available_blocks, f"Expected all blocks, got {selected}"
|
||||||
|
print(f" Decode: input={available_blocks}, selected={selected} [PASS]")
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertical_slash_policy():
|
||||||
|
"""Test VerticalSlashPolicy selects sink + local window."""
|
||||||
|
print("\n=== Testing VerticalSlashPolicy ===")
|
||||||
|
config = VerticalSlashConfig(
|
||||||
|
num_sink_blocks=2,
|
||||||
|
local_window_blocks=3,
|
||||||
|
threshold_blocks=4,
|
||||||
|
)
|
||||||
|
policy = VerticalSlashPolicy(config)
|
||||||
|
|
||||||
|
# Test with 10 blocks, chunk 7 (should select sink[0,1] + local[4,5,6])
|
||||||
|
available_blocks = list(range(10))
|
||||||
|
ctx = PolicyContext(
|
||||||
|
query_chunk_idx=7,
|
||||||
|
num_query_chunks=10,
|
||||||
|
layer_id=0,
|
||||||
|
query=None,
|
||||||
|
is_prefill=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
expected = [0, 1, 4, 5, 6] # sink + local window before chunk 7
|
||||||
|
assert selected == expected, f"Expected {expected}, got {selected}"
|
||||||
|
print(f" Prefill chunk 7: input={available_blocks}, selected={selected} [PASS]")
|
||||||
|
|
||||||
|
# Test with small number of blocks (below threshold)
|
||||||
|
available_blocks = [0, 1, 2]
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
assert selected == [0, 1, 2], f"Expected all blocks for small input, got {selected}"
|
||||||
|
print(f" Below threshold: input={[0,1,2]}, selected={selected} [PASS]")
|
||||||
|
|
||||||
|
# Test decode (local window is last M blocks)
|
||||||
|
available_blocks = list(range(10))
|
||||||
|
ctx.is_prefill = False
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
expected = [0, 1, 7, 8, 9] # sink + last 3 blocks
|
||||||
|
assert selected == expected, f"Expected {expected}, got {selected}"
|
||||||
|
print(f" Decode: input={available_blocks}, selected={selected} [PASS]")
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_llm_policy():
|
||||||
|
"""Test StreamingLLMPolicy selects sink + recent only."""
|
||||||
|
print("\n=== Testing StreamingLLMPolicy ===")
|
||||||
|
config = StreamingLLMConfig(
|
||||||
|
num_sink_blocks=1,
|
||||||
|
num_recent_blocks=2,
|
||||||
|
)
|
||||||
|
policy = StreamingLLMPolicy(config)
|
||||||
|
|
||||||
|
available_blocks = list(range(10))
|
||||||
|
ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=0,
|
||||||
|
query=None,
|
||||||
|
is_prefill=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
expected = [0, 8, 9] # sink[0] + recent[8,9]
|
||||||
|
assert selected == expected, f"Expected {expected}, got {selected}"
|
||||||
|
print(f" 10 blocks: selected={selected} [PASS]")
|
||||||
|
|
||||||
|
# Test with 3 blocks (all fit in sink+recent)
|
||||||
|
available_blocks = [0, 1, 2]
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
assert selected == [0, 1, 2], f"Expected all blocks, got {selected}"
|
||||||
|
print(f" 3 blocks: selected={selected} [PASS]")
|
||||||
|
|
||||||
|
|
||||||
|
def test_quest_policy():
|
||||||
|
"""Test QuestPolicy with mock metadata."""
|
||||||
|
print("\n=== Testing QuestPolicy ===")
|
||||||
|
|
||||||
|
# Create metadata manager
|
||||||
|
num_blocks = 10
|
||||||
|
num_layers = 2
|
||||||
|
num_kv_heads = 4
|
||||||
|
head_dim = 64
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
metadata = BlockMetadataManager(
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate offloading blocks with different key patterns
|
||||||
|
# Blocks 0, 5, 9 will have high scores (keys aligned with query)
|
||||||
|
for block_id in range(num_blocks):
|
||||||
|
for layer_id in range(num_layers):
|
||||||
|
k_cache = torch.randn(100, num_kv_heads, head_dim) # 100 tokens per block
|
||||||
|
if block_id in [0, 5, 9]:
|
||||||
|
# Make these blocks have keys that score high
|
||||||
|
k_cache = k_cache.abs() # All positive
|
||||||
|
else:
|
||||||
|
k_cache = -k_cache.abs() # All negative
|
||||||
|
metadata.update_metadata(block_id, layer_id, k_cache, 100)
|
||||||
|
|
||||||
|
config = QuestConfig(
|
||||||
|
topk_blocks=4,
|
||||||
|
threshold_blocks=3,
|
||||||
|
)
|
||||||
|
policy = QuestPolicy(config, metadata)
|
||||||
|
|
||||||
|
available_blocks = list(range(10))
|
||||||
|
|
||||||
|
# Create query that scores high with positive keys
|
||||||
|
query = torch.ones(1, num_kv_heads, head_dim, device='cuda')
|
||||||
|
|
||||||
|
ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=0,
|
||||||
|
query=query,
|
||||||
|
is_prefill=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
print(f" Top-4 selection: input={available_blocks}, selected={selected}")
|
||||||
|
|
||||||
|
# High-scoring blocks [0, 5, 9] should be in selection
|
||||||
|
for expected_block in [0, 5, 9]:
|
||||||
|
assert expected_block in selected, f"Expected block {expected_block} in selection"
|
||||||
|
print(f" High-score blocks [0, 5, 9] in selection [PASS]")
|
||||||
|
|
||||||
|
# Test below threshold (should return all)
|
||||||
|
available_blocks = [0, 1, 2]
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
assert selected == [0, 1, 2], f"Expected all blocks below threshold, got {selected}"
|
||||||
|
print(f" Below threshold: selected={selected} [PASS]")
|
||||||
|
|
||||||
|
# Test without query (should return all)
|
||||||
|
ctx.query = None
|
||||||
|
available_blocks = list(range(10))
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
assert selected == available_blocks, f"Expected all blocks without query, got {selected}"
|
||||||
|
print(f" No query: selected all [PASS]")
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_policy():
|
||||||
|
"""Test creating a custom policy."""
|
||||||
|
print("\n=== Testing Custom Policy ===")
|
||||||
|
|
||||||
|
class EveryOtherPolicy(SparsePolicy):
|
||||||
|
"""Select every other block."""
|
||||||
|
|
||||||
|
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
|
||||||
|
return [available_blocks[i] for i in range(0, len(available_blocks), 2)]
|
||||||
|
|
||||||
|
policy = EveryOtherPolicy()
|
||||||
|
available_blocks = list(range(10))
|
||||||
|
ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=0,
|
||||||
|
query=None,
|
||||||
|
is_prefill=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected = policy.select_blocks(available_blocks, ctx)
|
||||||
|
expected = [0, 2, 4, 6, 8]
|
||||||
|
assert selected == expected, f"Expected {expected}, got {selected}"
|
||||||
|
print(f" Every other: input={available_blocks}, selected={selected} [PASS]")
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests():
|
||||||
|
"""Run all policy tests."""
|
||||||
|
print("Running Sparse Policy Tests...")
|
||||||
|
|
||||||
|
test_full_attention_policy()
|
||||||
|
test_vertical_slash_policy()
|
||||||
|
test_streaming_llm_policy()
|
||||||
|
test_quest_policy()
|
||||||
|
test_custom_policy()
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("All tests passed!")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
policy_name = sys.argv[1].lower()
|
||||||
|
if policy_name == "full":
|
||||||
|
test_full_attention_policy()
|
||||||
|
elif policy_name == "vertical_slash":
|
||||||
|
test_vertical_slash_policy()
|
||||||
|
elif policy_name == "streaming_llm":
|
||||||
|
test_streaming_llm_policy()
|
||||||
|
elif policy_name == "quest":
|
||||||
|
test_quest_policy()
|
||||||
|
elif policy_name == "custom":
|
||||||
|
test_custom_policy()
|
||||||
|
else:
|
||||||
|
print(f"Unknown policy: {policy_name}")
|
||||||
|
print("Available: full, vertical_slash, streaming_llm, quest, custom")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
run_all_tests()
|
||||||
Reference in New Issue
Block a user