[WIP] Added basic test for quest.
This commit is contained in:
32
CLAUDE.md
32
CLAUDE.md
@@ -10,6 +10,38 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
|
||||
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
||||
|
||||
### Quest Sparse Policy
|
||||
|
||||
**Files**: `nanovllm/kvcache/sparse/quest.py`, `nanovllm/kvcache/sparse/policy.py`
|
||||
|
||||
Quest policy selects Top-K blocks based on query-key similarity bounds using min/max key metadata.
|
||||
|
||||
**Scoring Mechanism**:
|
||||
```python
|
||||
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
|
||||
score_max = torch.einsum('hd,bhd->bh', q, key_max) # [num_blocks, kv_heads]
|
||||
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] ← averaged!
|
||||
```
|
||||
|
||||
**Critical Limitation - No Per-Head Scheduling**:
|
||||
|
||||
The `.mean(dim=-1)` averages scores across all heads, making a **unified** block selection for all heads:
|
||||
|
||||
```
|
||||
Block A: head0 needs (+4), head1 doesn't (-4) → avg = 0 → NOT selected
|
||||
Block B: head0 doesn't (-4), head1 needs (+4) → avg = 0 → NOT selected
|
||||
Block C: both heads moderately need (+2, +2) → avg = +2 → selected
|
||||
```
|
||||
|
||||
**Why Per-Head Scheduling is Infeasible**:
|
||||
1. **Memory Layout**: GPU cache stores all heads together `[block_size, kv_heads, head_dim]`
|
||||
2. **FlashAttention**: Requires complete heads - partial heads cause dimension mismatch
|
||||
3. **Block Granularity**: If any head needs a block, the entire block (all heads) must be loaded
|
||||
|
||||
**Policy Types**:
|
||||
- `FullAttentionPolicy`: `supports_prefill=True, supports_decode=True` - loads all blocks
|
||||
- `QuestPolicy`: `supports_prefill=False, supports_decode=True` - decode-only Top-K selection
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
123
tests/test_quest_policy.py
Normal file
123
tests/test_quest_policy.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
|
||||
|
||||
Demonstrates the key limitation: scores are AVERAGED across heads,
|
||||
so blocks strongly needed by one head but not others may be dropped.
|
||||
|
||||
This is the expected Quest behavior - not a bug.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from nanovllm.kvcache.sparse import (
|
||||
create_sparse_policy,
|
||||
SparsePolicyType,
|
||||
PolicyContext,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Test: Per-Head Score Averaging in GQA
|
||||
# ============================================================
|
||||
|
||||
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
|
||||
# topk=2 to make selection competitive
|
||||
|
||||
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
|
||||
quest.initialize(
|
||||
num_layers=1,
|
||||
num_kv_heads=2,
|
||||
head_dim=4,
|
||||
num_cpu_blocks=6,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
metadata = quest.metadata
|
||||
|
||||
def set_key(block_id, head_id, values):
|
||||
"""Set both key_min and key_max to same values for deterministic scoring."""
|
||||
metadata.key_min[block_id, 0, head_id, :] = torch.tensor(values)
|
||||
metadata.key_max[block_id, 0, head_id, :] = torch.tensor(values)
|
||||
|
||||
# ============================================================
|
||||
# Design: Different heads want different blocks
|
||||
# ============================================================
|
||||
#
|
||||
# Query = [1,1,1,1] for all heads, so score = sum(key values)
|
||||
#
|
||||
# Block | Head 0 | Head 1 | Average | Result
|
||||
# ------|--------|--------|---------|--------
|
||||
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
|
||||
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
|
||||
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
|
||||
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
|
||||
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
|
||||
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
|
||||
|
||||
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
|
||||
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
|
||||
|
||||
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
|
||||
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
|
||||
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||
|
||||
# Block 2: Both heads want equally (highest average)
|
||||
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||
|
||||
# Block 3: Both heads want moderately
|
||||
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
|
||||
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
|
||||
|
||||
# Block 4: Head 0 strongly wants, Head 1 neutral
|
||||
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
|
||||
|
||||
# Block 5: Head 1 strongly wants, Head 0 neutral
|
||||
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
|
||||
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||
|
||||
# ============================================================
|
||||
# Run selection
|
||||
# ============================================================
|
||||
|
||||
query = torch.ones(1, 4, 4) # GQA: 4 query heads → 2 KV heads
|
||||
|
||||
ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=0,
|
||||
query=query,
|
||||
is_prefill=False,
|
||||
block_size=1024,
|
||||
total_kv_len=6144,
|
||||
)
|
||||
|
||||
available = list(range(6))
|
||||
selected = quest.select_blocks(available, ctx)
|
||||
|
||||
# ============================================================
|
||||
# Verify: Averaging behavior
|
||||
# ============================================================
|
||||
|
||||
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
|
||||
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
|
||||
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
|
||||
|
||||
# Key insight: blocks 0 and 1 have score +4 for ONE head,
|
||||
# but they cancel out due to averaging with the other head's -4
|
||||
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
|
||||
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
|
||||
|
||||
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
|
||||
# But +2 < +3 (block 3), so they don't make the top-2
|
||||
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
|
||||
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
|
||||
|
||||
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
|
||||
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
|
||||
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
|
||||
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
|
||||
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
|
||||
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
|
||||
|
||||
print("\ntest_quest_policy: PASSED")
|
||||
Reference in New Issue
Block a user