[WIP] move metadata to GPU.
This commit is contained in:
@@ -88,6 +88,7 @@ class SparsePolicy(ABC):
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: torch.device = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize policy resources.
|
Initialize policy resources.
|
||||||
@@ -102,6 +103,7 @@ class SparsePolicy(ABC):
|
|||||||
head_dim: Dimension per head
|
head_dim: Dimension per head
|
||||||
num_cpu_blocks: Number of CPU blocks allocated
|
num_cpu_blocks: Number of CPU blocks allocated
|
||||||
dtype: Data type for tensors
|
dtype: Data type for tensors
|
||||||
|
device: Device for metadata storage (GPU recommended for performance)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class BlockMetadataManager:
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: torch.device = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize metadata storage.
|
Initialize metadata storage.
|
||||||
@@ -45,20 +46,23 @@ class BlockMetadataManager:
|
|||||||
num_kv_heads: Number of KV attention heads
|
num_kv_heads: Number of KV attention heads
|
||||||
head_dim: Dimension per head
|
head_dim: Dimension per head
|
||||||
dtype: Data type for metadata storage
|
dtype: Data type for metadata storage
|
||||||
|
device: Device for metadata storage (default: CUDA if available)
|
||||||
"""
|
"""
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
|
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
|
||||||
|
# Stored on GPU for efficient score computation during decode
|
||||||
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
|
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
|
||||||
self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True)
|
self.key_min = torch.zeros(shape, dtype=dtype, device=self.device)
|
||||||
self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True)
|
self.key_max = torch.zeros(shape, dtype=dtype, device=self.device)
|
||||||
|
|
||||||
# Track which blocks have valid metadata
|
# Track which blocks have valid metadata
|
||||||
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool)
|
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
def update_metadata(
|
def update_metadata(
|
||||||
self,
|
self,
|
||||||
@@ -70,21 +74,21 @@ class BlockMetadataManager:
|
|||||||
"""
|
"""
|
||||||
Update min/max key bounds for a block.
|
Update min/max key bounds for a block.
|
||||||
|
|
||||||
Called when a block is offloaded to CPU.
|
Called BEFORE offload to CPU, while k_cache is still on GPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_id: CPU block ID
|
block_id: CPU block ID
|
||||||
layer_id: Layer index
|
layer_id: Layer index
|
||||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||||
num_valid_tokens: Number of valid tokens in this block
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
"""
|
"""
|
||||||
if num_valid_tokens == 0:
|
if num_valid_tokens == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get valid keys only
|
# Get valid keys only (k_cache is on GPU, metadata is on GPU)
|
||||||
k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim]
|
k_valid = k_cache[:num_valid_tokens] # [num_tokens, heads, dim]
|
||||||
|
|
||||||
# Compute min/max across token dimension
|
# Compute min/max across token dimension (all on GPU)
|
||||||
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
|
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
|
||||||
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
|
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
|
||||||
self.valid_blocks[block_id] = True
|
self.valid_blocks[block_id] = True
|
||||||
@@ -172,14 +176,16 @@ class QuestPolicy(SparsePolicy):
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: torch.device = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create BlockMetadataManager for storing min/max keys."""
|
"""Create BlockMetadataManager for storing min/max keys on GPU."""
|
||||||
self.metadata = BlockMetadataManager(
|
self.metadata = BlockMetadataManager(
|
||||||
num_blocks=num_cpu_blocks,
|
num_blocks=num_cpu_blocks,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
@@ -209,15 +215,13 @@ class QuestPolicy(SparsePolicy):
|
|||||||
# No query available - cannot compute scores
|
# No query available - cannot compute scores
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
# Get metadata for available blocks
|
# Get metadata for available blocks (already on GPU)
|
||||||
key_min, key_max = self.metadata.get_block_metadata(
|
key_min, key_max = self.metadata.get_block_metadata(
|
||||||
available_blocks, ctx.layer_id
|
available_blocks, ctx.layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move to query device for computation
|
# Metadata is already on GPU, same device as query
|
||||||
device = ctx.query.device
|
device = ctx.query.device
|
||||||
key_min = key_min.to(device, non_blocking=True)
|
|
||||||
key_max = key_max.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
# Compute upper bound scores
|
# Compute upper bound scores
|
||||||
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ from nanovllm.kvcache.sparse import (
|
|||||||
# Test: Per-Head Score Averaging in GQA
|
# Test: Per-Head Score Averaging in GQA
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
|
# Determine device (GPU if available, else CPU)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Running test on device: {device}")
|
||||||
|
|
||||||
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
|
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
|
||||||
# topk=2 to make selection competitive
|
# topk=2 to make selection competitive
|
||||||
|
|
||||||
@@ -28,14 +32,17 @@ quest.initialize(
|
|||||||
head_dim=4,
|
head_dim=4,
|
||||||
num_cpu_blocks=6,
|
num_cpu_blocks=6,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
device=device, # Metadata stored on GPU
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = quest.metadata
|
metadata = quest.metadata
|
||||||
|
|
||||||
def set_key(block_id, head_id, values):
|
def set_key(block_id, head_id, values):
|
||||||
"""Set both key_min and key_max to same values for deterministic scoring."""
|
"""Set both key_min and key_max to same values for deterministic scoring."""
|
||||||
metadata.key_min[block_id, 0, head_id, :] = torch.tensor(values)
|
# Values need to be on the same device as metadata
|
||||||
metadata.key_max[block_id, 0, head_id, :] = torch.tensor(values)
|
tensor = torch.tensor(values, device=device)
|
||||||
|
metadata.key_min[block_id, 0, head_id, :] = tensor
|
||||||
|
metadata.key_max[block_id, 0, head_id, :] = tensor
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Design: Different heads want different blocks
|
# Design: Different heads want different blocks
|
||||||
@@ -80,7 +87,8 @@ set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
|||||||
# Run selection
|
# Run selection
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
query = torch.ones(1, 4, 4) # GQA: 4 query heads → 2 KV heads
|
# Query on same device as metadata
|
||||||
|
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
|
||||||
|
|
||||||
ctx = PolicyContext(
|
ctx = PolicyContext(
|
||||||
query_chunk_idx=0,
|
query_chunk_idx=0,
|
||||||
@@ -120,4 +128,9 @@ print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
|
|||||||
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
|
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
|
||||||
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
|
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
|
||||||
|
|
||||||
|
# Verify metadata is on correct device
|
||||||
|
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
|
||||||
|
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
|
||||||
|
print(f"✓ Metadata stored on {device.type.upper()}")
|
||||||
|
|
||||||
print("\ntest_quest_policy: PASSED")
|
print("\ntest_quest_policy: PASSED")
|
||||||
|
|||||||
Reference in New Issue
Block a user