[feat] Added chunked prefill and kvcache offload mechenism.
This commit is contained in:
262
nanovllm/kvcache/gpu_manager.py
Normal file
262
nanovllm/kvcache/gpu_manager.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
GPU-only KV cache manager.
|
||||
|
||||
This is the default manager when CPU offload is disabled.
|
||||
Refactored from the original block_manager.py to implement
|
||||
the KVCacheManager interface.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||
|
||||
|
||||
class Block:
|
||||
"""Physical block in GPU memory."""
|
||||
|
||||
def __init__(self, block_id: int):
|
||||
self.block_id = block_id
|
||||
self.ref_count = 0
|
||||
self.hash = -1
|
||||
self.token_ids: List[int] = []
|
||||
|
||||
def update(self, hash: int, token_ids: List[int]):
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
def reset(self):
|
||||
self.ref_count = 1
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
|
||||
class GPUOnlyManager(KVCacheManager):
|
||||
"""
|
||||
Pure GPU KV cache manager.
|
||||
|
||||
This is the default implementation when enable_cpu_offload=False.
|
||||
All KV cache resides in GPU memory.
|
||||
|
||||
Features:
|
||||
- Paged attention with configurable block size
|
||||
- Prefix caching via xxhash
|
||||
- Reference counting for block sharing
|
||||
|
||||
This manager is fully compatible with CUDA graphs since
|
||||
all data stays on GPU at fixed addresses.
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
"""
|
||||
Initialize GPU-only manager.
|
||||
|
||||
Args:
|
||||
num_blocks: Total number of blocks to manage
|
||||
block_size: Tokens per block (default 256)
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self._num_blocks = num_blocks
|
||||
|
||||
# Block metadata
|
||||
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
||||
|
||||
# Prefix cache: hash -> block_id
|
||||
self.hash_to_block_id: Dict[int, int] = {}
|
||||
|
||||
# Free/used tracking
|
||||
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
||||
self.used_block_ids: set[int] = set()
|
||||
|
||||
# KV cache tensors (set by allocate_cache)
|
||||
self.kv_cache: Optional[Tensor] = None
|
||||
self.num_layers: int = 0
|
||||
self.num_kv_heads: int = 0
|
||||
self.head_dim: int = 0
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
|
||||
@property
|
||||
def num_free_blocks(self) -> int:
|
||||
return len(self.free_block_ids)
|
||||
|
||||
def allocate_cache(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Allocate GPU KV cache tensor."""
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
# Shape: [2, num_layers, num_blocks, block_size, kv_heads, head_dim]
|
||||
# 2 for K and V
|
||||
self.kv_cache = torch.empty(
|
||||
2, num_layers, self._num_blocks, self._block_size,
|
||||
num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""Get K/V cache for a layer."""
|
||||
assert self.kv_cache is not None, "Cache not allocated"
|
||||
return self.kv_cache[0, layer_id], self.kv_cache[1, layer_id]
|
||||
|
||||
def _allocate_block(self, block_id: int) -> Block:
|
||||
"""Internal: allocate a specific block."""
|
||||
block = self.blocks[block_id]
|
||||
assert block.ref_count == 0, f"Block {block_id} is not free"
|
||||
block.reset()
|
||||
self.free_block_ids.remove(block_id)
|
||||
self.used_block_ids.add(block_id)
|
||||
return block
|
||||
|
||||
def _deallocate_block(self, block_id: int) -> None:
|
||||
"""Internal: deallocate a block."""
|
||||
assert self.blocks[block_id].ref_count == 0
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""Check if we have enough blocks for the sequence."""
|
||||
return len(self.free_block_ids) >= seq.num_blocks
|
||||
|
||||
def allocate(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Allocate blocks for a sequence during prefill.
|
||||
|
||||
Implements prefix caching: if a block's content matches
|
||||
a previously cached block, reuse it instead of allocating new.
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks allocated"
|
||||
|
||||
h = -1 # Hash chain
|
||||
cache_miss = False
|
||||
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i)
|
||||
|
||||
# Only compute hash for full blocks
|
||||
if len(token_ids) == self._block_size:
|
||||
h = self.compute_hash(token_ids, h)
|
||||
else:
|
||||
h = -1
|
||||
|
||||
# Try prefix cache lookup
|
||||
block_id = self.hash_to_block_id.get(h, -1)
|
||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||
cache_miss = True
|
||||
|
||||
if cache_miss:
|
||||
# Cache miss: allocate new block
|
||||
block_id = self.free_block_ids[0]
|
||||
block = self._allocate_block(block_id)
|
||||
else:
|
||||
# Cache hit: reuse existing block
|
||||
seq.num_cached_tokens += self._block_size
|
||||
if block_id in self.used_block_ids:
|
||||
# Block is in use, increment ref count
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count += 1
|
||||
else:
|
||||
# Block was freed but hash still valid
|
||||
block = self._allocate_block(block_id)
|
||||
|
||||
# Update hash mapping for full blocks
|
||||
if h != -1:
|
||||
block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = block_id
|
||||
|
||||
seq.block_table.append(block_id)
|
||||
|
||||
def deallocate(self, seq: Sequence) -> None:
|
||||
"""Release all blocks for a sequence."""
|
||||
for block_id in reversed(seq.block_table):
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
self._deallocate_block(block_id)
|
||||
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
"""Check if we can append a token (may need new block)."""
|
||||
# Need new block only if current position is at block boundary
|
||||
need_new_block = (len(seq) % self._block_size == 1)
|
||||
return len(self.free_block_ids) >= int(need_new_block)
|
||||
|
||||
def may_append(self, seq: Sequence) -> None:
|
||||
"""Handle potential new block allocation during decode."""
|
||||
block_table = seq.block_table
|
||||
last_block = self.blocks[block_table[-1]]
|
||||
|
||||
seq_len = len(seq)
|
||||
pos_in_block = seq_len % self._block_size
|
||||
|
||||
if pos_in_block == 1:
|
||||
# Just crossed into new block, need to allocate
|
||||
assert last_block.hash != -1, "Previous block should be complete"
|
||||
block_id = self.free_block_ids[0]
|
||||
self._allocate_block(block_id)
|
||||
block_table.append(block_id)
|
||||
|
||||
elif pos_in_block == 0:
|
||||
# Just filled a block, compute hash for prefix cache
|
||||
assert last_block.hash == -1, "Block should not have hash yet"
|
||||
token_ids = seq.block(seq.num_blocks - 1)
|
||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||
h = self.compute_hash(token_ids, prefix)
|
||||
last_block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = last_block.block_id
|
||||
|
||||
else:
|
||||
# Middle of block, nothing to do
|
||||
assert last_block.hash == -1
|
||||
|
||||
def prepare_for_attention(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
is_prefill: bool,
|
||||
) -> None:
|
||||
"""
|
||||
No-op for GPU-only manager.
|
||||
|
||||
All blocks are already on GPU, no preparation needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_gpu_block_tables(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Return block tables directly (logical = physical for GPU-only).
|
||||
"""
|
||||
return [list(seq.block_table) for seq in seqs]
|
||||
|
||||
def post_attention_cleanup(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
is_prefill: bool,
|
||||
) -> None:
|
||||
"""No-op for GPU-only manager."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPUOnlyManager("
|
||||
f"num_blocks={self._num_blocks}, "
|
||||
f"block_size={self._block_size}, "
|
||||
f"free={len(self.free_block_ids)}, "
|
||||
f"used={len(self.used_block_ids)}"
|
||||
f")"
|
||||
)
|
||||
Reference in New Issue
Block a user