diff --git a/CLAUDE.md b/CLAUDE.md index 6584a15..16c2b37 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -10,6 +10,38 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md). +### Quest Sparse Policy + +**Files**: `nanovllm/kvcache/sparse/quest.py`, `nanovllm/kvcache/sparse/policy.py` + +Quest policy selects Top-K blocks based on query-key similarity bounds using min/max key metadata. + +**Scoring Mechanism**: +```python +score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads] +score_max = torch.einsum('hd,bhd->bh', q, key_max) # [num_blocks, kv_heads] +scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] ← averaged! +``` + +**Critical Limitation - No Per-Head Scheduling**: + +The `.mean(dim=-1)` averages scores across all heads, making a **unified** block selection for all heads: + +``` +Block A: head0 needs (+4), head1 doesn't (-4) → avg = 0 → NOT selected +Block B: head0 doesn't (-4), head1 needs (+4) → avg = 0 → NOT selected +Block C: both heads moderately need (+2, +2) → avg = +2 → selected +``` + +**Why Per-Head Scheduling is Infeasible**: +1. **Memory Layout**: GPU cache stores all heads together `[block_size, kv_heads, head_dim]` +2. **FlashAttention**: Requires complete heads - partial heads cause dimension mismatch +3. **Block Granularity**: If any head needs a block, the entire block (all heads) must be loaded + +**Policy Types**: +- `FullAttentionPolicy`: `supports_prefill=True, supports_decode=True` - loads all blocks +- `QuestPolicy`: `supports_prefill=False, supports_decode=True` - decode-only Top-K selection + ## Architecture ### Core Components diff --git a/tests/test_quest_policy.py b/tests/test_quest_policy.py new file mode 100644 index 0000000..c82be74 --- /dev/null +++ b/tests/test_quest_policy.py @@ -0,0 +1,123 @@ +""" +Test for QuestPolicy block selection with GQA (Grouped Query Attention). + +Demonstrates the key limitation: scores are AVERAGED across heads, +so blocks strongly needed by one head but not others may be dropped. + +This is the expected Quest behavior - not a bug. +""" + +import torch +from nanovllm.kvcache.sparse import ( + create_sparse_policy, + SparsePolicyType, + PolicyContext, +) + +# ============================================================ +# Test: Per-Head Score Averaging in GQA +# ============================================================ + +# Setup: 2 KV heads, 4 query heads (GQA group_size=2) +# topk=2 to make selection competitive + +quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0) +quest.initialize( + num_layers=1, + num_kv_heads=2, + head_dim=4, + num_cpu_blocks=6, + dtype=torch.float32, +) + +metadata = quest.metadata + +def set_key(block_id, head_id, values): + """Set both key_min and key_max to same values for deterministic scoring.""" + metadata.key_min[block_id, 0, head_id, :] = torch.tensor(values) + metadata.key_max[block_id, 0, head_id, :] = torch.tensor(values) + +# ============================================================ +# Design: Different heads want different blocks +# ============================================================ +# +# Query = [1,1,1,1] for all heads, so score = sum(key values) +# +# Block | Head 0 | Head 1 | Average | Result +# ------|--------|--------|---------|-------- +# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED +# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED +# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1) +# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2) +# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3 +# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3 + +# Block 0: Head 0 strongly wants, Head 1 strongly rejects +set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4 +set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4 + +# Block 1: Head 1 strongly wants, Head 0 strongly rejects +set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4 +set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4 + +# Block 2: Both heads want equally (highest average) +set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4 +set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4 + +# Block 3: Both heads want moderately +set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3 +set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3 + +# Block 4: Head 0 strongly wants, Head 1 neutral +set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4 +set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0 + +# Block 5: Head 1 strongly wants, Head 0 neutral +set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0 +set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4 + +# ============================================================ +# Run selection +# ============================================================ + +query = torch.ones(1, 4, 4) # GQA: 4 query heads → 2 KV heads + +ctx = PolicyContext( + query_chunk_idx=0, + num_query_chunks=1, + layer_id=0, + query=query, + is_prefill=False, + block_size=1024, + total_kv_len=6144, +) + +available = list(range(6)) +selected = quest.select_blocks(available, ctx) + +# ============================================================ +# Verify: Averaging behavior +# ============================================================ + +# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected +assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}" +assert selected == [2, 3], f"Expected [2, 3], got {selected}" + +# Key insight: blocks 0 and 1 have score +4 for ONE head, +# but they cancel out due to averaging with the other head's -4 +assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)" +assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)" + +# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2 +# But +2 < +3 (block 3), so they don't make the top-2 +assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3" +assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3" + +print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4") +print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3") +print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)") +print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)") +print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)") +print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)") + +print("\ntest_quest_policy: PASSED")