[WIP] move metadata to GPU.

This commit is contained in:
Zijie Tian
2026-01-06 23:32:32 +08:00
parent edb5273e34
commit 0e691f2d85
3 changed files with 35 additions and 16 deletions

View File

@@ -35,6 +35,7 @@ class BlockMetadataManager:
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = None,
):
"""
Initialize metadata storage.
@@ -45,20 +46,23 @@ class BlockMetadataManager:
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
dtype: Data type for metadata storage
device: Device for metadata storage (default: CUDA if available)
"""
self.num_blocks = num_blocks
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
# Stored on GPU for efficient score computation during decode
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True)
self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True)
self.key_min = torch.zeros(shape, dtype=dtype, device=self.device)
self.key_max = torch.zeros(shape, dtype=dtype, device=self.device)
# Track which blocks have valid metadata
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool)
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool, device=self.device)
def update_metadata(
self,
@@ -70,21 +74,21 @@ class BlockMetadataManager:
"""
Update min/max key bounds for a block.
Called when a block is offloaded to CPU.
Called BEFORE offload to CPU, while k_cache is still on GPU.
Args:
block_id: CPU block ID
layer_id: Layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
"""
if num_valid_tokens == 0:
return
# Get valid keys only
k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim]
# Get valid keys only (k_cache is on GPU, metadata is on GPU)
k_valid = k_cache[:num_valid_tokens] # [num_tokens, heads, dim]
# Compute min/max across token dimension
# Compute min/max across token dimension (all on GPU)
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
self.valid_blocks[block_id] = True
@@ -172,14 +176,16 @@ class QuestPolicy(SparsePolicy):
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
device: torch.device = None,
) -> None:
"""Create BlockMetadataManager for storing min/max keys."""
"""Create BlockMetadataManager for storing min/max keys on GPU."""
self.metadata = BlockMetadataManager(
num_blocks=num_cpu_blocks,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
device=device,
)
def select_blocks(
@@ -209,15 +215,13 @@ class QuestPolicy(SparsePolicy):
# No query available - cannot compute scores
return available_blocks
# Get metadata for available blocks
# Get metadata for available blocks (already on GPU)
key_min, key_max = self.metadata.get_block_metadata(
available_blocks, ctx.layer_id
)
# Move to query device for computation
# Metadata is already on GPU, same device as query
device = ctx.query.device
key_min = key_min.to(device, non_blocking=True)
key_max = key_max.to(device, non_blocking=True)
# Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]