[WIP] move metadata to GPU.
This commit is contained in:
@@ -18,6 +18,10 @@ from nanovllm.kvcache.sparse import (
|
||||
# Test: Per-Head Score Averaging in GQA
|
||||
# ============================================================
|
||||
|
||||
# Determine device (GPU if available, else CPU)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Running test on device: {device}")
|
||||
|
||||
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
|
||||
# topk=2 to make selection competitive
|
||||
|
||||
@@ -28,14 +32,17 @@ quest.initialize(
|
||||
head_dim=4,
|
||||
num_cpu_blocks=6,
|
||||
dtype=torch.float32,
|
||||
device=device, # Metadata stored on GPU
|
||||
)
|
||||
|
||||
metadata = quest.metadata
|
||||
|
||||
def set_key(block_id, head_id, values):
|
||||
"""Set both key_min and key_max to same values for deterministic scoring."""
|
||||
metadata.key_min[block_id, 0, head_id, :] = torch.tensor(values)
|
||||
metadata.key_max[block_id, 0, head_id, :] = torch.tensor(values)
|
||||
# Values need to be on the same device as metadata
|
||||
tensor = torch.tensor(values, device=device)
|
||||
metadata.key_min[block_id, 0, head_id, :] = tensor
|
||||
metadata.key_max[block_id, 0, head_id, :] = tensor
|
||||
|
||||
# ============================================================
|
||||
# Design: Different heads want different blocks
|
||||
@@ -80,7 +87,8 @@ set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||
# Run selection
|
||||
# ============================================================
|
||||
|
||||
query = torch.ones(1, 4, 4) # GQA: 4 query heads → 2 KV heads
|
||||
# Query on same device as metadata
|
||||
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
|
||||
|
||||
ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
@@ -120,4 +128,9 @@ print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
|
||||
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
|
||||
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
|
||||
|
||||
# Verify metadata is on correct device
|
||||
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
|
||||
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
|
||||
print(f"✓ Metadata stored on {device.type.upper()}")
|
||||
|
||||
print("\ntest_quest_policy: PASSED")
|
||||
|
||||
Reference in New Issue
Block a user