From 0e691f2d85861166fb3e64b528bc3e9f83e4c293 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Tue, 6 Jan 2026 23:32:32 +0800 Subject: [PATCH] [WIP] move metadata to GPU. --- nanovllm/kvcache/sparse/policy.py | 2 ++ nanovllm/kvcache/sparse/quest.py | 30 +++++++++++++++++------------- tests/test_quest_policy.py | 19 ++++++++++++++++--- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index fab87ca..c935d06 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -88,6 +88,7 @@ class SparsePolicy(ABC): head_dim: int, num_cpu_blocks: int, dtype: torch.dtype, + device: torch.device = None, ) -> None: """ Initialize policy resources. @@ -102,6 +103,7 @@ class SparsePolicy(ABC): head_dim: Dimension per head num_cpu_blocks: Number of CPU blocks allocated dtype: Data type for tensors + device: Device for metadata storage (GPU recommended for performance) """ pass diff --git a/nanovllm/kvcache/sparse/quest.py b/nanovllm/kvcache/sparse/quest.py index 3583905..d038832 100644 --- a/nanovllm/kvcache/sparse/quest.py +++ b/nanovllm/kvcache/sparse/quest.py @@ -35,6 +35,7 @@ class BlockMetadataManager: num_kv_heads: int, head_dim: int, dtype: torch.dtype = torch.bfloat16, + device: torch.device = None, ): """ Initialize metadata storage. @@ -45,20 +46,23 @@ class BlockMetadataManager: num_kv_heads: Number of KV attention heads head_dim: Dimension per head dtype: Data type for metadata storage + device: Device for metadata storage (default: CUDA if available) """ self.num_blocks = num_blocks self.num_layers = num_layers self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.dtype = dtype + self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") # Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim] + # Stored on GPU for efficient score computation during decode shape = (num_blocks, num_layers, num_kv_heads, head_dim) - self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True) - self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True) + self.key_min = torch.zeros(shape, dtype=dtype, device=self.device) + self.key_max = torch.zeros(shape, dtype=dtype, device=self.device) # Track which blocks have valid metadata - self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool) + self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool, device=self.device) def update_metadata( self, @@ -70,21 +74,21 @@ class BlockMetadataManager: """ Update min/max key bounds for a block. - Called when a block is offloaded to CPU. + Called BEFORE offload to CPU, while k_cache is still on GPU. Args: block_id: CPU block ID layer_id: Layer index - k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] + k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) num_valid_tokens: Number of valid tokens in this block """ if num_valid_tokens == 0: return - # Get valid keys only - k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim] + # Get valid keys only (k_cache is on GPU, metadata is on GPU) + k_valid = k_cache[:num_valid_tokens] # [num_tokens, heads, dim] - # Compute min/max across token dimension + # Compute min/max across token dimension (all on GPU) self.key_min[block_id, layer_id] = k_valid.min(dim=0).values self.key_max[block_id, layer_id] = k_valid.max(dim=0).values self.valid_blocks[block_id] = True @@ -172,14 +176,16 @@ class QuestPolicy(SparsePolicy): head_dim: int, num_cpu_blocks: int, dtype: torch.dtype, + device: torch.device = None, ) -> None: - """Create BlockMetadataManager for storing min/max keys.""" + """Create BlockMetadataManager for storing min/max keys on GPU.""" self.metadata = BlockMetadataManager( num_blocks=num_cpu_blocks, num_layers=num_layers, num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=dtype, + device=device, ) def select_blocks( @@ -209,15 +215,13 @@ class QuestPolicy(SparsePolicy): # No query available - cannot compute scores return available_blocks - # Get metadata for available blocks + # Get metadata for available blocks (already on GPU) key_min, key_max = self.metadata.get_block_metadata( available_blocks, ctx.layer_id ) - # Move to query device for computation + # Metadata is already on GPU, same device as query device = ctx.query.device - key_min = key_min.to(device, non_blocking=True) - key_max = key_max.to(device, non_blocking=True) # Compute upper bound scores # query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim] diff --git a/tests/test_quest_policy.py b/tests/test_quest_policy.py index c82be74..14a893f 100644 --- a/tests/test_quest_policy.py +++ b/tests/test_quest_policy.py @@ -18,6 +18,10 @@ from nanovllm.kvcache.sparse import ( # Test: Per-Head Score Averaging in GQA # ============================================================ +# Determine device (GPU if available, else CPU) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Running test on device: {device}") + # Setup: 2 KV heads, 4 query heads (GQA group_size=2) # topk=2 to make selection competitive @@ -28,14 +32,17 @@ quest.initialize( head_dim=4, num_cpu_blocks=6, dtype=torch.float32, + device=device, # Metadata stored on GPU ) metadata = quest.metadata def set_key(block_id, head_id, values): """Set both key_min and key_max to same values for deterministic scoring.""" - metadata.key_min[block_id, 0, head_id, :] = torch.tensor(values) - metadata.key_max[block_id, 0, head_id, :] = torch.tensor(values) + # Values need to be on the same device as metadata + tensor = torch.tensor(values, device=device) + metadata.key_min[block_id, 0, head_id, :] = tensor + metadata.key_max[block_id, 0, head_id, :] = tensor # ============================================================ # Design: Different heads want different blocks @@ -80,7 +87,8 @@ set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4 # Run selection # ============================================================ -query = torch.ones(1, 4, 4) # GQA: 4 query heads → 2 KV heads +# Query on same device as metadata +query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads ctx = PolicyContext( query_chunk_idx=0, @@ -120,4 +128,9 @@ print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)") print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)") print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)") +# Verify metadata is on correct device +assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}" +assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}" +print(f"✓ Metadata stored on {device.type.upper()}") + print("\ntest_quest_policy: PASSED")