Files
nano-vllm/tests/test_sparse_policy.py
2025-12-22 08:51:02 +08:00

253 lines
8.3 KiB
Python

"""
Test sparse attention policies.
Usage:
CUDA_VISIBLE_DEVICES=4,5 python tests/test_sparse_policy.py [policy_name]
Policy names: full, vertical_slash, streaming_llm, quest
"""
import sys
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from typing import List
# Test the sparse policy implementations
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
def test_full_attention_policy():
"""Test FullAttentionPolicy returns all blocks."""
print("\n=== Testing FullAttentionPolicy ===")
policy = FullAttentionPolicy()
available_blocks = list(range(10))
ctx = PolicyContext(
query_chunk_idx=5,
num_query_chunks=10,
layer_id=0,
query=None,
is_prefill=True,
)
selected = policy.select_blocks(available_blocks, ctx)
assert selected == available_blocks, f"Expected all blocks, got {selected}"
print(f" Prefill: input={available_blocks}, selected={selected} [PASS]")
# Test decode
ctx.is_prefill = False
selected = policy.select_blocks(available_blocks, ctx)
assert selected == available_blocks, f"Expected all blocks, got {selected}"
print(f" Decode: input={available_blocks}, selected={selected} [PASS]")
def test_vertical_slash_policy():
"""Test VerticalSlashPolicy selects sink + local window."""
print("\n=== Testing VerticalSlashPolicy ===")
config = VerticalSlashConfig(
num_sink_blocks=2,
local_window_blocks=3,
threshold_blocks=4,
)
policy = VerticalSlashPolicy(config)
# Test with 10 blocks, chunk 7 (should select sink[0,1] + local[4,5,6])
available_blocks = list(range(10))
ctx = PolicyContext(
query_chunk_idx=7,
num_query_chunks=10,
layer_id=0,
query=None,
is_prefill=True,
)
selected = policy.select_blocks(available_blocks, ctx)
expected = [0, 1, 4, 5, 6] # sink + local window before chunk 7
assert selected == expected, f"Expected {expected}, got {selected}"
print(f" Prefill chunk 7: input={available_blocks}, selected={selected} [PASS]")
# Test with small number of blocks (below threshold)
available_blocks = [0, 1, 2]
selected = policy.select_blocks(available_blocks, ctx)
assert selected == [0, 1, 2], f"Expected all blocks for small input, got {selected}"
print(f" Below threshold: input={[0,1,2]}, selected={selected} [PASS]")
# Test decode (local window is last M blocks)
available_blocks = list(range(10))
ctx.is_prefill = False
selected = policy.select_blocks(available_blocks, ctx)
expected = [0, 1, 7, 8, 9] # sink + last 3 blocks
assert selected == expected, f"Expected {expected}, got {selected}"
print(f" Decode: input={available_blocks}, selected={selected} [PASS]")
def test_streaming_llm_policy():
"""Test StreamingLLMPolicy selects sink + recent only."""
print("\n=== Testing StreamingLLMPolicy ===")
config = StreamingLLMConfig(
num_sink_blocks=1,
num_recent_blocks=2,
)
policy = StreamingLLMPolicy(config)
available_blocks = list(range(10))
ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=0,
query=None,
is_prefill=False,
)
selected = policy.select_blocks(available_blocks, ctx)
expected = [0, 8, 9] # sink[0] + recent[8,9]
assert selected == expected, f"Expected {expected}, got {selected}"
print(f" 10 blocks: selected={selected} [PASS]")
# Test with 3 blocks (all fit in sink+recent)
available_blocks = [0, 1, 2]
selected = policy.select_blocks(available_blocks, ctx)
assert selected == [0, 1, 2], f"Expected all blocks, got {selected}"
print(f" 3 blocks: selected={selected} [PASS]")
def test_quest_policy():
"""Test QuestPolicy with mock metadata."""
print("\n=== Testing QuestPolicy ===")
# Create metadata manager
num_blocks = 10
num_layers = 2
num_kv_heads = 4
head_dim = 64
dtype = torch.float32
metadata = BlockMetadataManager(
num_blocks=num_blocks,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
# Simulate offloading blocks with different key patterns
# Blocks 0, 5, 9 will have high scores (keys aligned with query)
for block_id in range(num_blocks):
for layer_id in range(num_layers):
k_cache = torch.randn(100, num_kv_heads, head_dim) # 100 tokens per block
if block_id in [0, 5, 9]:
# Make these blocks have keys that score high
k_cache = k_cache.abs() # All positive
else:
k_cache = -k_cache.abs() # All negative
metadata.update_metadata(block_id, layer_id, k_cache, 100)
config = QuestConfig(
topk_blocks=4,
threshold_blocks=3,
)
policy = QuestPolicy(config, metadata)
available_blocks = list(range(10))
# Create query that scores high with positive keys
query = torch.ones(1, num_kv_heads, head_dim, device='cuda')
ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=0,
query=query,
is_prefill=False,
)
selected = policy.select_blocks(available_blocks, ctx)
print(f" Top-4 selection: input={available_blocks}, selected={selected}")
# High-scoring blocks [0, 5, 9] should be in selection
for expected_block in [0, 5, 9]:
assert expected_block in selected, f"Expected block {expected_block} in selection"
print(f" High-score blocks [0, 5, 9] in selection [PASS]")
# Test below threshold (should return all)
available_blocks = [0, 1, 2]
selected = policy.select_blocks(available_blocks, ctx)
assert selected == [0, 1, 2], f"Expected all blocks below threshold, got {selected}"
print(f" Below threshold: selected={selected} [PASS]")
# Test without query (should return all)
ctx.query = None
available_blocks = list(range(10))
selected = policy.select_blocks(available_blocks, ctx)
assert selected == available_blocks, f"Expected all blocks without query, got {selected}"
print(f" No query: selected all [PASS]")
def test_custom_policy():
"""Test creating a custom policy."""
print("\n=== Testing Custom Policy ===")
class EveryOtherPolicy(SparsePolicy):
"""Select every other block."""
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
return [available_blocks[i] for i in range(0, len(available_blocks), 2)]
policy = EveryOtherPolicy()
available_blocks = list(range(10))
ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=0,
query=None,
is_prefill=True,
)
selected = policy.select_blocks(available_blocks, ctx)
expected = [0, 2, 4, 6, 8]
assert selected == expected, f"Expected {expected}, got {selected}"
print(f" Every other: input={available_blocks}, selected={selected} [PASS]")
def run_all_tests():
"""Run all policy tests."""
print("Running Sparse Policy Tests...")
test_full_attention_policy()
test_vertical_slash_policy()
test_streaming_llm_policy()
test_quest_policy()
test_custom_policy()
print("\n" + "=" * 50)
print("All tests passed!")
print("=" * 50)
if __name__ == "__main__":
if len(sys.argv) > 1:
policy_name = sys.argv[1].lower()
if policy_name == "full":
test_full_attention_policy()
elif policy_name == "vertical_slash":
test_vertical_slash_policy()
elif policy_name == "streaming_llm":
test_streaming_llm_policy()
elif policy_name == "quest":
test_quest_policy()
elif policy_name == "custom":
test_custom_policy()
else:
print(f"Unknown policy: {policy_name}")
print("Available: full, vertical_slash, streaming_llm, quest, custom")
sys.exit(1)
else:
run_all_tests()