[claudesquad] update from 'lw-offload-2' on 08 Jan 26 20:53 CST
This commit is contained in:
@@ -65,23 +65,22 @@ class LogicalBlock:
|
||||
|
||||
class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Hybrid CPU-GPU KV cache manager with ring buffer design.
|
||||
Hybrid CPU-GPU KV cache manager with layer-wise offload design.
|
||||
|
||||
Architecture (CPU-primary mode):
|
||||
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
|
||||
- GPU buffer: Ring buffer for computation only (num_gpu_slots)
|
||||
- Logical blocks: What sequences reference (num_cpu_blocks)
|
||||
- GPU ring buffer: For decode H2D pipeline (num_kv_buffers)
|
||||
- Decode buffer: Per-layer accumulation of decode tokens (block_size)
|
||||
|
||||
Design:
|
||||
- All KV cache is stored on CPU as primary storage
|
||||
- GPU is used as a ring buffer for computation only (no persistent data)
|
||||
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
|
||||
- During decode: Previous KV is loaded from CPU to GPU for attention
|
||||
- Ring buffer enables pipelined H2D transfers overlapped with computation
|
||||
- GPU ring buffer enables pipelined H2D transfers during decode
|
||||
- During prefill: KV is computed and offloaded layer-by-layer to CPU
|
||||
- During decode: Previous KV is loaded from CPU via ring buffer pipeline
|
||||
|
||||
Note:
|
||||
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
|
||||
- GPU slots are transient compute buffers, not tracked in logical blocks
|
||||
- GPU ring buffer is for decode pipeline, not persistent storage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -91,25 +90,31 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
sparse_policy: "SparsePolicy" = None,
|
||||
num_kv_buffers: int = 4,
|
||||
max_seq_len: int = 131072,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||
Initialize hybrid manager with layer-wise offload design.
|
||||
|
||||
All KV cache is stored on CPU as primary storage. GPU slots are used
|
||||
as a ring buffer for computation only.
|
||||
All KV cache is stored on CPU as primary storage. GPU ring buffer is used
|
||||
for decode H2D pipeline.
|
||||
|
||||
Args:
|
||||
num_gpu_slots: Number of GPU buffer slots (ring buffer for computation)
|
||||
num_gpu_slots: Number of GPU buffer slots (kept for backward compat, not used)
|
||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
|
||||
num_kv_buffers: Ring buffer size for decode H2D pipeline
|
||||
max_seq_len: Maximum sequence length for GPU buffer allocation
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.num_kv_buffers = num_kv_buffers
|
||||
self.max_seq_len = max_seq_len
|
||||
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
|
||||
# GPU slots are transient compute buffers, not tracked as logical blocks
|
||||
# GPU ring buffer is for decode pipeline, not persistent storage
|
||||
self.total_blocks = num_cpu_blocks
|
||||
|
||||
# Eviction policy
|
||||
@@ -147,7 +152,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Track blocks pending GPU load (for decode graph)
|
||||
self.pending_gpu_loads: Set[int] = set() # logical_ids
|
||||
|
||||
# Track blocks that have been prefilled (KV written) for chunked prefill
|
||||
# Track blocks that have been prefilled (KV offloaded to CPU)
|
||||
self.prefilled_blocks: Set[int] = set() # logical_ids
|
||||
|
||||
# Track decode starting position within block (for batched offload optimization)
|
||||
@@ -182,13 +187,21 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
num_kv_buffers=self.num_kv_buffers,
|
||||
max_seq_len=self.max_seq_len,
|
||||
sparse_policy=self.sparse_policy,
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""Get GPU K/V cache tensors for a layer."""
|
||||
"""
|
||||
Get GPU K/V cache tensors for a layer.
|
||||
|
||||
Note: In layer-wise offload mode, this returns empty tensors as KV
|
||||
is managed directly by the offload engine's ring buffer.
|
||||
"""
|
||||
assert self.offload_engine is not None
|
||||
return self.offload_engine.get_layer_cache(layer_id)
|
||||
# Return empty tensors - actual KV is in offload_engine's ring buffer
|
||||
return torch.empty(0), torch.empty(0)
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""Check if we can allocate blocks for a new sequence."""
|
||||
@@ -279,8 +292,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Prepare KV cache for attention computation.
|
||||
|
||||
In ring buffer mode, this is a no-op because chunked offload
|
||||
paths handle H2D transfers directly in the attention layer.
|
||||
In layer-wise offload mode, this is a no-op because KV transfers
|
||||
are handled directly in model_runner's layer-by-layer methods.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -291,12 +304,12 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Get GPU slot tables for sequences.
|
||||
|
||||
In ring buffer mode, all blocks are on CPU, so this raises an error
|
||||
if called. Use run_chunked_offload_* methods instead.
|
||||
In layer-wise offload mode, all blocks are on CPU, so this raises an error
|
||||
if called. Use run_layerwise_offload_* methods instead.
|
||||
"""
|
||||
raise RuntimeError(
|
||||
"get_gpu_block_tables should not be called in ring buffer mode. "
|
||||
"Use run_chunked_offload_prefill/decode instead."
|
||||
"get_gpu_block_tables should not be called in layer-wise offload mode. "
|
||||
"Use run_layerwise_offload_prefill/decode instead."
|
||||
)
|
||||
|
||||
def post_attention_cleanup(
|
||||
@@ -307,18 +320,18 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Cleanup after attention.
|
||||
|
||||
In ring buffer mode, this is a no-op because offload is handled
|
||||
directly in the chunked prefill/decode paths.
|
||||
In layer-wise offload mode, this is a no-op because offload is handled
|
||||
directly in model_runner's layer-by-layer methods.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ========== Ring Buffer CPU-primary Chunked Prefill Support ==========
|
||||
# ========== Layer-wise Offload Support ==========
|
||||
|
||||
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
|
||||
"""
|
||||
Get list of CPU block IDs for blocks that have been prefilled.
|
||||
|
||||
Used for loading previous KV during chunked prefill.
|
||||
Used for loading prefilled KV during decode.
|
||||
|
||||
Returns:
|
||||
List of CPU block IDs in sequence order
|
||||
@@ -335,11 +348,11 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# )
|
||||
return cpu_blocks
|
||||
|
||||
# ========== Ring Buffer CPU-primary support ==========
|
||||
# ========== CPU Block Allocation ==========
|
||||
|
||||
def allocate_cpu_only(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Allocate CPU blocks for sequence (for ring buffer mode).
|
||||
Allocate CPU blocks for sequence (for layer-wise offload mode).
|
||||
|
||||
Unlike allocate(), here all blocks are allocated to CPU,
|
||||
GPU is only used as ring buffer for computation.
|
||||
@@ -468,20 +481,6 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
return block.cpu_block_id
|
||||
return -1
|
||||
|
||||
def get_write_slot_for_chunked_offload(self, seq: Sequence) -> int:
|
||||
"""
|
||||
Get GPU slot for writing new KV during chunked offload decode.
|
||||
|
||||
In ring buffer design, always use decode_slot (slot[0]) to write new KV.
|
||||
This avoids conflicts with loading operations which use slots[1:].
|
||||
|
||||
Args:
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
GPU slot ID (always decode_slot = 0)
|
||||
"""
|
||||
return self.offload_engine.decode_slot
|
||||
|
||||
def get_decode_start_pos(self, seq: Sequence) -> int:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user