253 lines
8.3 KiB
Python
253 lines
8.3 KiB
Python
"""
|
|
Test sparse attention policies.
|
|
|
|
Usage:
|
|
CUDA_VISIBLE_DEVICES=4,5 python tests/test_sparse_policy.py [policy_name]
|
|
|
|
Policy names: full, vertical_slash, streaming_llm, quest
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
|
|
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
|
|
|
import torch
|
|
from typing import List
|
|
|
|
# Test the sparse policy implementations
|
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
|
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
|
|
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
|
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
|
|
|
|
|
def test_full_attention_policy():
|
|
"""Test FullAttentionPolicy returns all blocks."""
|
|
print("\n=== Testing FullAttentionPolicy ===")
|
|
policy = FullAttentionPolicy()
|
|
|
|
available_blocks = list(range(10))
|
|
ctx = PolicyContext(
|
|
query_chunk_idx=5,
|
|
num_query_chunks=10,
|
|
layer_id=0,
|
|
query=None,
|
|
is_prefill=True,
|
|
)
|
|
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
assert selected == available_blocks, f"Expected all blocks, got {selected}"
|
|
print(f" Prefill: input={available_blocks}, selected={selected} [PASS]")
|
|
|
|
# Test decode
|
|
ctx.is_prefill = False
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
assert selected == available_blocks, f"Expected all blocks, got {selected}"
|
|
print(f" Decode: input={available_blocks}, selected={selected} [PASS]")
|
|
|
|
|
|
def test_vertical_slash_policy():
|
|
"""Test VerticalSlashPolicy selects sink + local window."""
|
|
print("\n=== Testing VerticalSlashPolicy ===")
|
|
config = VerticalSlashConfig(
|
|
num_sink_blocks=2,
|
|
local_window_blocks=3,
|
|
threshold_blocks=4,
|
|
)
|
|
policy = VerticalSlashPolicy(config)
|
|
|
|
# Test with 10 blocks, chunk 7 (should select sink[0,1] + local[4,5,6])
|
|
available_blocks = list(range(10))
|
|
ctx = PolicyContext(
|
|
query_chunk_idx=7,
|
|
num_query_chunks=10,
|
|
layer_id=0,
|
|
query=None,
|
|
is_prefill=True,
|
|
)
|
|
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
expected = [0, 1, 4, 5, 6] # sink + local window before chunk 7
|
|
assert selected == expected, f"Expected {expected}, got {selected}"
|
|
print(f" Prefill chunk 7: input={available_blocks}, selected={selected} [PASS]")
|
|
|
|
# Test with small number of blocks (below threshold)
|
|
available_blocks = [0, 1, 2]
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
assert selected == [0, 1, 2], f"Expected all blocks for small input, got {selected}"
|
|
print(f" Below threshold: input={[0,1,2]}, selected={selected} [PASS]")
|
|
|
|
# Test decode (local window is last M blocks)
|
|
available_blocks = list(range(10))
|
|
ctx.is_prefill = False
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
expected = [0, 1, 7, 8, 9] # sink + last 3 blocks
|
|
assert selected == expected, f"Expected {expected}, got {selected}"
|
|
print(f" Decode: input={available_blocks}, selected={selected} [PASS]")
|
|
|
|
|
|
def test_streaming_llm_policy():
|
|
"""Test StreamingLLMPolicy selects sink + recent only."""
|
|
print("\n=== Testing StreamingLLMPolicy ===")
|
|
config = StreamingLLMConfig(
|
|
num_sink_blocks=1,
|
|
num_recent_blocks=2,
|
|
)
|
|
policy = StreamingLLMPolicy(config)
|
|
|
|
available_blocks = list(range(10))
|
|
ctx = PolicyContext(
|
|
query_chunk_idx=0,
|
|
num_query_chunks=1,
|
|
layer_id=0,
|
|
query=None,
|
|
is_prefill=False,
|
|
)
|
|
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
expected = [0, 8, 9] # sink[0] + recent[8,9]
|
|
assert selected == expected, f"Expected {expected}, got {selected}"
|
|
print(f" 10 blocks: selected={selected} [PASS]")
|
|
|
|
# Test with 3 blocks (all fit in sink+recent)
|
|
available_blocks = [0, 1, 2]
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
assert selected == [0, 1, 2], f"Expected all blocks, got {selected}"
|
|
print(f" 3 blocks: selected={selected} [PASS]")
|
|
|
|
|
|
def test_quest_policy():
|
|
"""Test QuestPolicy with mock metadata."""
|
|
print("\n=== Testing QuestPolicy ===")
|
|
|
|
# Create metadata manager
|
|
num_blocks = 10
|
|
num_layers = 2
|
|
num_kv_heads = 4
|
|
head_dim = 64
|
|
dtype = torch.float32
|
|
|
|
metadata = BlockMetadataManager(
|
|
num_blocks=num_blocks,
|
|
num_layers=num_layers,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim=head_dim,
|
|
dtype=dtype,
|
|
)
|
|
|
|
# Simulate offloading blocks with different key patterns
|
|
# Blocks 0, 5, 9 will have high scores (keys aligned with query)
|
|
for block_id in range(num_blocks):
|
|
for layer_id in range(num_layers):
|
|
k_cache = torch.randn(100, num_kv_heads, head_dim) # 100 tokens per block
|
|
if block_id in [0, 5, 9]:
|
|
# Make these blocks have keys that score high
|
|
k_cache = k_cache.abs() # All positive
|
|
else:
|
|
k_cache = -k_cache.abs() # All negative
|
|
metadata.update_metadata(block_id, layer_id, k_cache, 100)
|
|
|
|
config = QuestConfig(
|
|
topk_blocks=4,
|
|
threshold_blocks=3,
|
|
)
|
|
policy = QuestPolicy(config, metadata)
|
|
|
|
available_blocks = list(range(10))
|
|
|
|
# Create query that scores high with positive keys
|
|
query = torch.ones(1, num_kv_heads, head_dim, device='cuda')
|
|
|
|
ctx = PolicyContext(
|
|
query_chunk_idx=0,
|
|
num_query_chunks=1,
|
|
layer_id=0,
|
|
query=query,
|
|
is_prefill=False,
|
|
)
|
|
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
print(f" Top-4 selection: input={available_blocks}, selected={selected}")
|
|
|
|
# High-scoring blocks [0, 5, 9] should be in selection
|
|
for expected_block in [0, 5, 9]:
|
|
assert expected_block in selected, f"Expected block {expected_block} in selection"
|
|
print(f" High-score blocks [0, 5, 9] in selection [PASS]")
|
|
|
|
# Test below threshold (should return all)
|
|
available_blocks = [0, 1, 2]
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
assert selected == [0, 1, 2], f"Expected all blocks below threshold, got {selected}"
|
|
print(f" Below threshold: selected={selected} [PASS]")
|
|
|
|
# Test without query (should return all)
|
|
ctx.query = None
|
|
available_blocks = list(range(10))
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
assert selected == available_blocks, f"Expected all blocks without query, got {selected}"
|
|
print(f" No query: selected all [PASS]")
|
|
|
|
|
|
def test_custom_policy():
|
|
"""Test creating a custom policy."""
|
|
print("\n=== Testing Custom Policy ===")
|
|
|
|
class EveryOtherPolicy(SparsePolicy):
|
|
"""Select every other block."""
|
|
|
|
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
|
|
return [available_blocks[i] for i in range(0, len(available_blocks), 2)]
|
|
|
|
policy = EveryOtherPolicy()
|
|
available_blocks = list(range(10))
|
|
ctx = PolicyContext(
|
|
query_chunk_idx=0,
|
|
num_query_chunks=1,
|
|
layer_id=0,
|
|
query=None,
|
|
is_prefill=True,
|
|
)
|
|
|
|
selected = policy.select_blocks(available_blocks, ctx)
|
|
expected = [0, 2, 4, 6, 8]
|
|
assert selected == expected, f"Expected {expected}, got {selected}"
|
|
print(f" Every other: input={available_blocks}, selected={selected} [PASS]")
|
|
|
|
|
|
def run_all_tests():
|
|
"""Run all policy tests."""
|
|
print("Running Sparse Policy Tests...")
|
|
|
|
test_full_attention_policy()
|
|
test_vertical_slash_policy()
|
|
test_streaming_llm_policy()
|
|
test_quest_policy()
|
|
test_custom_policy()
|
|
|
|
print("\n" + "=" * 50)
|
|
print("All tests passed!")
|
|
print("=" * 50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) > 1:
|
|
policy_name = sys.argv[1].lower()
|
|
if policy_name == "full":
|
|
test_full_attention_policy()
|
|
elif policy_name == "vertical_slash":
|
|
test_vertical_slash_policy()
|
|
elif policy_name == "streaming_llm":
|
|
test_streaming_llm_policy()
|
|
elif policy_name == "quest":
|
|
test_quest_policy()
|
|
elif policy_name == "custom":
|
|
test_custom_policy()
|
|
else:
|
|
print(f"Unknown policy: {policy_name}")
|
|
print("Available: full, vertical_slash, streaming_llm, quest, custom")
|
|
sys.exit(1)
|
|
else:
|
|
run_all_tests()
|