[WIP] Before fix needle.
This commit is contained in:
@@ -489,24 +489,15 @@ class ModelRunner:
|
|||||||
logical_id = seq.block_table[block_idx]
|
logical_id = seq.block_table[block_idx]
|
||||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||||
|
|
||||||
# Offload this chunk's ring buffer slot to CPU (async)
|
# NOTE: Per-layer offloading is now done in attention.forward
|
||||||
|
# Each layer offloads its KV to CPU immediately after computing attention.
|
||||||
|
# We just need to wait for the last offload to complete before reusing the slot.
|
||||||
if block_idx < len(cpu_block_ids):
|
if block_idx < len(cpu_block_ids):
|
||||||
cpu_block_id = cpu_block_ids[block_idx]
|
# TODO: Sparse policy hook needs update for new GPU cache architecture
|
||||||
|
# The GPU cache no longer has layer dimension, so we can't access
|
||||||
# Call sparse policy hook before offload (to capture metadata)
|
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
|
||||||
sparse_policy = self.kvcache_manager.sparse_policy
|
# in attention.forward after per-layer offload.
|
||||||
if sparse_policy is not None:
|
pass
|
||||||
num_tokens = chunk_end - chunk_start
|
|
||||||
for layer_id in range(offload_engine.num_layers):
|
|
||||||
k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens]
|
|
||||||
sparse_policy.on_block_offloaded(
|
|
||||||
cpu_block_id=cpu_block_id,
|
|
||||||
layer_id=layer_id,
|
|
||||||
k_cache=k_cache,
|
|
||||||
num_valid_tokens=num_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
|
|
||||||
|
|
||||||
# Wait for offload to complete before next chunk
|
# Wait for offload to complete before next chunk
|
||||||
# (slot will be reused after N chunks)
|
# (slot will be reused after N chunks)
|
||||||
@@ -628,7 +619,11 @@ class ModelRunner:
|
|||||||
if pos_in_block == self.block_size - 1:
|
if pos_in_block == self.block_size - 1:
|
||||||
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
||||||
if last_cpu_block >= 0:
|
if last_cpu_block >= 0:
|
||||||
offload_engine.offload_decode_slot(last_cpu_block)
|
# TODO: In new GPU cache architecture (no layer dimension),
|
||||||
|
# decode offload should be done per-layer in attention.forward.
|
||||||
|
# For now, offload all layers sequentially.
|
||||||
|
for layer_id in range(offload_engine.num_layers):
|
||||||
|
offload_engine.offload_decode_slot_layer(layer_id, last_cpu_block)
|
||||||
offload_engine.wait_all_offload_done()
|
offload_engine.wait_all_offload_done()
|
||||||
# Reset decode start position for next block
|
# Reset decode start position for next block
|
||||||
self.kvcache_manager.reset_decode_start_pos(seq)
|
self.kvcache_manager.reset_decode_start_pos(seq)
|
||||||
|
|||||||
@@ -67,14 +67,19 @@ class OffloadEngine:
|
|||||||
self.block_numel = block_size * self.kv_dim
|
self.block_numel = block_size * self.kv_dim
|
||||||
|
|
||||||
# ========== sgDMA pitch parameters for strided transfers ==========
|
# ========== sgDMA pitch parameters for strided transfers ==========
|
||||||
|
# CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dim)
|
||||||
|
# For CPU-to-GPU transfer (H2D): copy single layer, single block at a time
|
||||||
|
# For all-layer CPU operations (D2H offload to all layers): use sgDMA
|
||||||
self.dtype_size = dtype.itemsize
|
self.dtype_size = dtype.itemsize
|
||||||
|
# CPU pitch: stride between layers in CPU cache (for all-layer operations)
|
||||||
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
|
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
|
||||||
self.gpu_pitch = num_gpu_blocks * self.block_numel * self.dtype_size
|
# GPU has no layer dimension, so single block transfer is contiguous
|
||||||
self.width = self.block_numel * self.dtype_size
|
self.gpu_block_bytes = self.block_numel * self.dtype_size
|
||||||
self.height = num_layers
|
self.height = num_layers # For CPU all-layer operations
|
||||||
|
|
||||||
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, gpu_pitch={self.gpu_pitch}, "
|
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, "
|
||||||
f"width={self.width}, height={self.height}")
|
f"gpu_block_bytes={self.gpu_block_bytes}, height={self.height}")
|
||||||
|
|
||||||
# ========== Unified Ring Buffer configuration ==========
|
# ========== Unified Ring Buffer configuration ==========
|
||||||
# Constraint checks
|
# Constraint checks
|
||||||
@@ -100,14 +105,16 @@ class OffloadEngine:
|
|||||||
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
|
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
|
||||||
|
|
||||||
# ========== Fixed-address GPU KV cache ==========
|
# ========== Fixed-address GPU KV cache ==========
|
||||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
# Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||||
# Use zeros initialization to avoid uninitialized memory issues
|
# NOTE: No num_layers dimension! GPU slots are shared across layers.
|
||||||
|
# Each layer reuses the same slots (layers execute sequentially).
|
||||||
|
# This saves 28x GPU memory compared to per-layer allocation.
|
||||||
self.k_cache_gpu = torch.zeros(
|
self.k_cache_gpu = torch.zeros(
|
||||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||||
dtype=dtype, device="cuda"
|
dtype=dtype, device="cuda"
|
||||||
)
|
)
|
||||||
self.v_cache_gpu = torch.zeros(
|
self.v_cache_gpu = torch.zeros(
|
||||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||||
dtype=dtype, device="cuda"
|
dtype=dtype, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -159,35 +166,23 @@ class OffloadEngine:
|
|||||||
# Decode offload event
|
# Decode offload event
|
||||||
self.decode_offload_done = torch.cuda.Event()
|
self.decode_offload_done = torch.cuda.Event()
|
||||||
|
|
||||||
# ========== Per-slot Per-layer events for ring buffer ==========
|
# ========== Per-slot events for ring buffer ==========
|
||||||
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
|
# Since GPU cache has no layer dimension and layers execute sequentially,
|
||||||
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
|
# we only need per-slot events (not per-slot per-layer).
|
||||||
self.ring_slot_ready = [
|
# ring_slot_ready[slot_idx] = CUDA Event for H2D completion
|
||||||
[torch.cuda.Event() for _ in range(num_layers)]
|
# ring_slot_offload_done[slot_idx] = CUDA Event for D2H completion
|
||||||
for _ in range(self.num_ring_slots)
|
self.ring_slot_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||||
]
|
self.ring_slot_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||||
self.ring_slot_offload_done = [
|
|
||||||
[torch.cuda.Event() for _ in range(num_layers)]
|
|
||||||
for _ in range(self.num_ring_slots)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Per-slot events for all-layer operations (used in some legacy paths)
|
# ========== Per-slot compute_done events for async pipeline ==========
|
||||||
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
# ring_slot_compute_done[slot_idx] = CUDA Event for compute completion
|
||||||
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
# This ensures we don't overwrite data before it's been read by attention
|
||||||
|
self.ring_slot_compute_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||||
# ========== Per-slot Per-layer compute_done events for async pipeline ==========
|
|
||||||
# ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion
|
|
||||||
# This is used to ensure we don't overwrite data before it's been read by attention
|
|
||||||
self.ring_slot_compute_done = [
|
|
||||||
[torch.cuda.Event() for _ in range(num_layers)]
|
|
||||||
for _ in range(self.num_ring_slots)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Initialize all compute_done events (record them once)
|
# Initialize all compute_done events (record them once)
|
||||||
# This prevents undefined behavior on first load_to_slot_layer call
|
# This prevents undefined behavior on first load_to_slot_layer call
|
||||||
for slot_idx in range(self.num_ring_slots):
|
for slot_idx in range(self.num_ring_slots):
|
||||||
for layer_id in range(num_layers):
|
self.ring_slot_compute_done[slot_idx].record()
|
||||||
self.ring_slot_compute_done[slot_idx][layer_id].record()
|
|
||||||
torch.cuda.synchronize() # Ensure all events are recorded
|
torch.cuda.synchronize() # Ensure all events are recorded
|
||||||
|
|
||||||
# ========== Event tracking for async transfers ==========
|
# ========== Event tracking for async transfers ==========
|
||||||
@@ -204,23 +199,24 @@ class OffloadEngine:
|
|||||||
return stream
|
return stream
|
||||||
|
|
||||||
# ========== CUDA Graph compatible methods ==========
|
# ========== CUDA Graph compatible methods ==========
|
||||||
|
# NOTE: These methods need to be updated for the new GPU cache architecture.
|
||||||
|
# GPU cache no longer has layer dimension, so gathered copy semantics change.
|
||||||
|
# For now, these are kept for reference but should not be used without updating.
|
||||||
|
|
||||||
def gathered_h2d_layer(self, layer_id: int) -> None:
|
def gathered_h2d_layer(self, layer_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
Execute gathered H2D copy for a single layer.
|
Execute gathered H2D copy for a single layer.
|
||||||
|
|
||||||
This method is CUDA Graph compatible - can be captured into a graph.
|
WARNING: This method needs updating for new GPU cache architecture.
|
||||||
Before calling, update_gather_indices() must be called to set up
|
GPU cache no longer has layer dimension.
|
||||||
which CPU blocks to copy to which GPU slots.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index to transfer
|
|
||||||
"""
|
"""
|
||||||
|
# GPU cache has no layer dimension - use flat indexing
|
||||||
|
# Source is CPU[layer_id], dest is GPU (shared across layers)
|
||||||
gathered_copy_kv(
|
gathered_copy_kv(
|
||||||
k_src=self.k_cache_cpu[layer_id],
|
k_src=self.k_cache_cpu[layer_id],
|
||||||
v_src=self.v_cache_cpu[layer_id],
|
v_src=self.v_cache_cpu[layer_id],
|
||||||
k_dst=self.k_cache_gpu[layer_id],
|
k_dst=self.k_cache_gpu, # No layer indexing
|
||||||
v_dst=self.v_cache_gpu[layer_id],
|
v_dst=self.v_cache_gpu, # No layer indexing
|
||||||
indices=self.gather_indices_gpu[layer_id],
|
indices=self.gather_indices_gpu[layer_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -228,7 +224,8 @@ class OffloadEngine:
|
|||||||
"""
|
"""
|
||||||
Execute gathered H2D copy for all layers.
|
Execute gathered H2D copy for all layers.
|
||||||
|
|
||||||
CUDA Graph compatible - can be captured into a single graph.
|
WARNING: In new architecture, GPU slots are shared across layers.
|
||||||
|
This method would overwrite slots multiple times. Not recommended.
|
||||||
"""
|
"""
|
||||||
for layer_id in range(self.num_layers):
|
for layer_id in range(self.num_layers):
|
||||||
self.gathered_h2d_layer(layer_id)
|
self.gathered_h2d_layer(layer_id)
|
||||||
@@ -297,10 +294,10 @@ class OffloadEngine:
|
|||||||
"""
|
"""
|
||||||
Async prefetch a single block from CPU to GPU.
|
Async prefetch a single block from CPU to GPU.
|
||||||
|
|
||||||
For use in prefill phase where CUDA graphs are not used.
|
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: Layer index
|
layer_id: Layer index (for CPU cache)
|
||||||
cpu_block_id: Source block in CPU cache
|
cpu_block_id: Source block in CPU cache
|
||||||
gpu_block_id: Destination slot in GPU cache
|
gpu_block_id: Destination slot in GPU cache
|
||||||
|
|
||||||
@@ -313,13 +310,12 @@ class OffloadEngine:
|
|||||||
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
|
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(stream):
|
||||||
# K cache
|
# GPU: no layer dimension, CPU: has layer dimension
|
||||||
self.k_cache_gpu[layer_id, gpu_block_id].copy_(
|
self.k_cache_gpu[gpu_block_id].copy_(
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
)
|
)
|
||||||
# V cache
|
self.v_cache_gpu[gpu_block_id].copy_(
|
||||||
self.v_cache_gpu[layer_id, gpu_block_id].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
)
|
)
|
||||||
@@ -356,8 +352,10 @@ class OffloadEngine:
|
|||||||
"""
|
"""
|
||||||
Async offload a block from GPU to CPU.
|
Async offload a block from GPU to CPU.
|
||||||
|
|
||||||
|
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: Layer index
|
layer_id: Layer index (for CPU cache)
|
||||||
gpu_block_id: Source slot in GPU cache
|
gpu_block_id: Source slot in GPU cache
|
||||||
cpu_block_id: Destination block in CPU cache
|
cpu_block_id: Destination block in CPU cache
|
||||||
|
|
||||||
@@ -373,14 +371,13 @@ class OffloadEngine:
|
|||||||
# Wait for any compute using this block
|
# Wait for any compute using this block
|
||||||
stream.wait_stream(self.compute_stream)
|
stream.wait_stream(self.compute_stream)
|
||||||
|
|
||||||
# K cache
|
# GPU: no layer dimension, CPU: has layer dimension
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
self.k_cache_gpu[layer_id, gpu_block_id],
|
self.k_cache_gpu[gpu_block_id],
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
)
|
)
|
||||||
# V cache
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
self.v_cache_gpu[layer_id, gpu_block_id],
|
self.v_cache_gpu[gpu_block_id],
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
)
|
)
|
||||||
event.record()
|
event.record()
|
||||||
@@ -417,11 +414,10 @@ class OffloadEngine:
|
|||||||
"""
|
"""
|
||||||
Load CPU blocks to specific GPU slots for chunked decode.
|
Load CPU blocks to specific GPU slots for chunked decode.
|
||||||
|
|
||||||
Uses the main GPU KV cache slots, not a separate temp buffer.
|
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||||
This is the same mechanism as chunked prefill uses.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: Layer index
|
layer_id: Layer index (for CPU cache)
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
||||||
"""
|
"""
|
||||||
@@ -434,12 +430,12 @@ class OffloadEngine:
|
|||||||
|
|
||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(stream):
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||||
# Copy from pinned CPU memory to GPU KV cache slot
|
# GPU: no layer dimension, CPU: has layer dimension
|
||||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
self.k_cache_gpu[gpu_slot].copy_(
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
)
|
)
|
||||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
self.v_cache_gpu[gpu_slot].copy_(
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
)
|
)
|
||||||
@@ -456,8 +452,10 @@ class OffloadEngine:
|
|||||||
"""
|
"""
|
||||||
Async version: Load CPU blocks to GPU slots.
|
Async version: Load CPU blocks to GPU slots.
|
||||||
|
|
||||||
|
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: Layer index
|
layer_id: Layer index (for CPU cache)
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into
|
gpu_slot_ids: List of GPU slot IDs to load into
|
||||||
|
|
||||||
@@ -474,11 +472,12 @@ class OffloadEngine:
|
|||||||
|
|
||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(stream):
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
# GPU: no layer dimension, CPU: has layer dimension
|
||||||
|
self.k_cache_gpu[gpu_slot].copy_(
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
)
|
)
|
||||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
self.v_cache_gpu[gpu_slot].copy_(
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||||
non_blocking=True
|
non_blocking=True
|
||||||
)
|
)
|
||||||
@@ -486,44 +485,8 @@ class OffloadEngine:
|
|||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def load_cpu_blocks_to_gpu_slots_all_layers(
|
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
|
||||||
self,
|
# layer dimension. Each GPU slot holds data for ONE layer at a time.
|
||||||
cpu_block_ids: List[int],
|
|
||||||
gpu_slot_ids: List[int],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Load CPU blocks to GPU slots for ALL layers at once.
|
|
||||||
|
|
||||||
More efficient than per-layer loading when we know the mapping upfront.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into
|
|
||||||
"""
|
|
||||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
|
||||||
|
|
||||||
if cpu_block_ids:
|
|
||||||
logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
|
||||||
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
|
||||||
# Copy all layers at once using sgDMA
|
|
||||||
memcpy_2d_async(
|
|
||||||
self.k_cache_gpu[:, gpu_slot],
|
|
||||||
self.k_cache_cpu[:, cpu_block_id],
|
|
||||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
|
||||||
"h2d", stream=stream
|
|
||||||
)
|
|
||||||
memcpy_2d_async(
|
|
||||||
self.v_cache_gpu[:, gpu_slot],
|
|
||||||
self.v_cache_cpu[:, cpu_block_id],
|
|
||||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
|
||||||
"h2d", stream=stream
|
|
||||||
)
|
|
||||||
|
|
||||||
stream.synchronize()
|
|
||||||
|
|
||||||
# ========== Synchronization methods ==========
|
# ========== Synchronization methods ==========
|
||||||
|
|
||||||
@@ -548,21 +511,27 @@ class OffloadEngine:
|
|||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get GPU K/V cache tensors for a specific layer.
|
Get GPU K/V cache tensors for attention layer.
|
||||||
|
|
||||||
|
NOTE: GPU cache has no layer dimension - all layers share the same slots.
|
||||||
|
The layer_id parameter is kept for API compatibility but not used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache) tensors for the layer
|
(k_cache, v_cache) tensors
|
||||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id]
|
# GPU cache is shared across all layers (no layer dimension)
|
||||||
|
return self.k_cache_gpu, self.v_cache_gpu
|
||||||
|
|
||||||
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
|
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get full GPU K/V cache tensors.
|
Get full GPU K/V cache tensors.
|
||||||
|
|
||||||
|
NOTE: GPU cache has no layer dimension in the new architecture.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache) tensors
|
(k_cache, v_cache) tensors
|
||||||
Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
return self.k_cache_gpu, self.v_cache_gpu
|
return self.k_cache_gpu, self.v_cache_gpu
|
||||||
|
|
||||||
@@ -668,7 +637,7 @@ class OffloadEngine:
|
|||||||
|
|
||||||
# ----- Per-slot Per-layer loading methods -----
|
# ----- Per-slot Per-layer loading methods -----
|
||||||
|
|
||||||
def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None:
|
def record_slot_compute_done(self, slot_idx: int) -> None:
|
||||||
"""
|
"""
|
||||||
Record that computation using this slot's data is done.
|
Record that computation using this slot's data is done.
|
||||||
|
|
||||||
@@ -677,22 +646,23 @@ class OffloadEngine:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
slot_idx: GPU slot index that was just used for computation
|
slot_idx: GPU slot index that was just used for computation
|
||||||
layer_id: Layer index
|
|
||||||
"""
|
"""
|
||||||
self.ring_slot_compute_done[slot_idx][layer_id].record()
|
self.ring_slot_compute_done[slot_idx].record()
|
||||||
|
|
||||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
Async load a single CPU block to a ring buffer slot for one layer.
|
Async load a single CPU block to a ring buffer slot for one layer.
|
||||||
|
|
||||||
This is the core building block for ring buffer pipelining.
|
This is the core building block for ring buffer pipelining.
|
||||||
|
GPU cache has no layer dimension - slots are shared across all layers.
|
||||||
|
CPU cache still has layer dimension for persistent storage.
|
||||||
|
|
||||||
Before starting the transfer, waits for:
|
Before starting the transfer, waits for:
|
||||||
1. Any previous compute on this slot to complete
|
1. Any previous compute on this slot to complete
|
||||||
2. Any pending offload of this slot to complete
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
slot_idx: Target GPU slot index
|
slot_idx: Target GPU slot index
|
||||||
layer_id: Layer index to load
|
layer_id: Layer index to load (for CPU cache indexing)
|
||||||
cpu_block_id: Source CPU block ID
|
cpu_block_id: Source CPU block ID
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||||
@@ -704,150 +674,105 @@ class OffloadEngine:
|
|||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(stream):
|
||||||
# Wait for previous compute on this slot to complete before overwriting
|
# Wait for previous compute on this slot to complete before overwriting
|
||||||
# This prevents data race: transfer must not start until attention finishes reading
|
# This prevents data race: transfer must not start until attention finishes reading
|
||||||
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
|
stream.wait_event(self.ring_slot_compute_done[slot_idx])
|
||||||
|
|
||||||
# Also wait for any pending offload of this slot to complete
|
# Also wait for any pending offload of this slot to complete
|
||||||
# This prevents race: load must not write GPU slot while offload is reading from it
|
# This prevents race: load must not write GPU slot while offload is reading from it
|
||||||
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
|
stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||||
|
|
||||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
# GPU: no layer dimension, CPU: has layer dimension
|
||||||
|
self.k_cache_gpu[slot_idx].copy_(
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||||
)
|
)
|
||||||
self.v_cache_gpu[layer_id, slot_idx].copy_(
|
self.v_cache_gpu[slot_idx].copy_(
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||||
)
|
)
|
||||||
self.ring_slot_ready[slot_idx][layer_id].record(stream)
|
self.ring_slot_ready[slot_idx].record(stream)
|
||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
|
def wait_slot_layer(self, slot_idx: int) -> None:
|
||||||
"""
|
"""
|
||||||
Wait for a slot's loading to complete for a specific layer.
|
Wait for a slot's loading to complete.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
slot_idx: GPU slot index to wait for
|
slot_idx: GPU slot index to wait for
|
||||||
layer_id: Layer index to wait for
|
|
||||||
"""
|
"""
|
||||||
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
|
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
|
||||||
|
|
||||||
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
|
# NOTE: load_to_slot_all_layers removed - GPU cache no longer has layer dimension.
|
||||||
"""
|
# Each GPU slot holds data for ONE layer at a time. Layers execute sequentially,
|
||||||
Async load a CPU block to a ring buffer slot for ALL layers.
|
# reusing the same GPU slots.
|
||||||
|
|
||||||
Args:
|
|
||||||
slot_idx: Target GPU slot index
|
|
||||||
cpu_block_id: Source CPU block ID
|
|
||||||
"""
|
|
||||||
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
memcpy_2d_async(
|
|
||||||
self.k_cache_gpu[:, slot_idx],
|
|
||||||
self.k_cache_cpu[:, cpu_block_id],
|
|
||||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
|
||||||
"h2d", stream=self.transfer_stream_main
|
|
||||||
)
|
|
||||||
memcpy_2d_async(
|
|
||||||
self.v_cache_gpu[:, slot_idx],
|
|
||||||
self.v_cache_cpu[:, cpu_block_id],
|
|
||||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
|
||||||
"h2d", stream=self.transfer_stream_main
|
|
||||||
)
|
|
||||||
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
|
|
||||||
|
|
||||||
def wait_slot_all_layers(self, slot_idx: int) -> None:
|
|
||||||
"""Wait for a slot's loading to complete for ALL layers."""
|
|
||||||
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
|
|
||||||
|
|
||||||
# ----- Slot offload methods -----
|
# ----- Slot offload methods -----
|
||||||
|
|
||||||
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
|
# NOTE: offload_slot_to_cpu (all-layers) removed - GPU cache no longer has layer dimension.
|
||||||
|
# Use offload_slot_layer_to_cpu for per-layer offloading.
|
||||||
|
|
||||||
|
def wait_slot_offload(self, slot_idx: int) -> None:
|
||||||
|
"""Wait for slot offload to complete."""
|
||||||
|
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||||
|
|
||||||
|
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
Async offload a ring buffer slot to CPU (all layers).
|
Async offload a ring buffer slot to CPU for one layer.
|
||||||
|
|
||||||
|
GPU cache has no layer dimension, so we copy from GPU slot to the
|
||||||
|
specific layer in CPU cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
slot_idx: Source GPU slot index
|
slot_idx: Source GPU slot index
|
||||||
|
layer_id: Target layer in CPU cache
|
||||||
cpu_block_id: Target CPU block ID
|
cpu_block_id: Target CPU block ID
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
|
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
# Wait for both compute_stream and default stream
|
# Wait for both compute_stream and default stream
|
||||||
# - compute_stream: for flash attention operations
|
# - compute_stream: for flash attention operations
|
||||||
# - default_stream: for store_kvcache which runs on default stream
|
# - default_stream: for store_kvcache which runs on default stream
|
||||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||||
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
||||||
memcpy_2d_async(
|
|
||||||
self.k_cache_cpu[:, cpu_block_id],
|
|
||||||
self.k_cache_gpu[:, slot_idx],
|
|
||||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
|
||||||
"d2h", stream=self.transfer_stream_main
|
|
||||||
)
|
|
||||||
memcpy_2d_async(
|
|
||||||
self.v_cache_cpu[:, cpu_block_id],
|
|
||||||
self.v_cache_gpu[:, slot_idx],
|
|
||||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
|
||||||
"d2h", stream=self.transfer_stream_main
|
|
||||||
)
|
|
||||||
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
def wait_slot_offload(self, slot_idx: int) -> None:
|
# GPU: no layer dimension, CPU: has layer dimension
|
||||||
"""Wait for slot offload to complete."""
|
|
||||||
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
|
|
||||||
|
|
||||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Async offload a ring buffer slot to CPU for one layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
slot_idx: Source GPU slot index
|
|
||||||
layer_id: Layer index to offload
|
|
||||||
cpu_block_id: Target CPU block ID
|
|
||||||
"""
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
# Wait for both compute_stream and default stream
|
|
||||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
|
||||||
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
|
self.k_cache_gpu[slot_idx], non_blocking=True
|
||||||
)
|
)
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
|
self.v_cache_gpu[slot_idx], non_blocking=True
|
||||||
)
|
)
|
||||||
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
|
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
|
|
||||||
"""Wait for slot offload to complete for a specific layer."""
|
|
||||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
|
|
||||||
|
|
||||||
# ----- KV access methods for ring buffer -----
|
# ----- KV access methods for ring buffer -----
|
||||||
|
|
||||||
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get KV for a single ring buffer slot.
|
Get KV for a single ring buffer slot.
|
||||||
|
|
||||||
|
GPU cache has no layer dimension - slots contain data for whatever
|
||||||
|
layer was most recently loaded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
slot_idx: GPU slot index
|
slot_idx: GPU slot index
|
||||||
layer_id: Layer ID
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
|
k = self.k_cache_gpu[slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
|
||||||
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
|
v = self.v_cache_gpu[slot_idx].unsqueeze(0)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
def get_kv_for_slots(
|
def get_kv_for_slots(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
|
||||||
slot_indices: List[int],
|
slot_indices: List[int],
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get KV for multiple ring buffer slots.
|
Get KV for multiple ring buffer slots.
|
||||||
|
|
||||||
|
GPU cache has no layer dimension - returns data from specified slots.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: Layer ID
|
|
||||||
slot_indices: List of GPU slot indices
|
slot_indices: List of GPU slot indices
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -855,92 +780,86 @@ class OffloadEngine:
|
|||||||
"""
|
"""
|
||||||
if not slot_indices:
|
if not slot_indices:
|
||||||
return None, None
|
return None, None
|
||||||
k = self.k_cache_gpu[layer_id, slot_indices]
|
k = self.k_cache_gpu[slot_indices]
|
||||||
v = self.v_cache_gpu[layer_id, slot_indices]
|
v = self.v_cache_gpu[slot_indices]
|
||||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
# ----- Decode slot methods (kept for decode phase) -----
|
# ----- Decode slot methods (kept for decode phase) -----
|
||||||
|
# NOTE: For decode with CPU offload, the flow is per-layer:
|
||||||
|
# 1. Each layer stores to decode_slot (same GPU memory, reused)
|
||||||
|
# 2. Each layer offloads its data to CPU[layer_id, block_id]
|
||||||
|
# 3. Each layer loads prev blocks from CPU[layer_id] when needed
|
||||||
|
|
||||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
def offload_decode_slot_layer(self, layer_id: int, cpu_block_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
Offload KV from decode slot (slot[0]) to CPU.
|
Offload KV from decode slot (slot[0]) to CPU for one layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
layer_id: Layer ID
|
||||||
cpu_block_id: Target CPU block ID
|
cpu_block_id: Target CPU block ID
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
# Reuse the existing per-layer offload method
|
||||||
|
self.offload_slot_layer_to_cpu(self.decode_slot, layer_id, cpu_block_id)
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
|
||||||
memcpy_2d_async(
|
|
||||||
self.k_cache_cpu[:, cpu_block_id],
|
|
||||||
self.k_cache_gpu[:, self.decode_slot],
|
|
||||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
|
||||||
"d2h", stream=self.transfer_stream_main
|
|
||||||
)
|
|
||||||
memcpy_2d_async(
|
|
||||||
self.v_cache_cpu[:, cpu_block_id],
|
|
||||||
self.v_cache_gpu[:, self.decode_slot],
|
|
||||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
|
||||||
"d2h", stream=self.transfer_stream_main
|
|
||||||
)
|
|
||||||
self.decode_offload_done.record(self.transfer_stream_main)
|
|
||||||
|
|
||||||
def wait_decode_offload(self) -> None:
|
def wait_decode_offload(self) -> None:
|
||||||
"""Wait for decode slot offload to complete."""
|
"""Wait for decode slot offload to complete."""
|
||||||
self.compute_stream.wait_event(self.decode_offload_done)
|
self.wait_slot_offload(self.decode_slot)
|
||||||
|
|
||||||
def get_kv_for_decode_slot(
|
def get_kv_for_decode_slot(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
|
||||||
pos_in_block: int,
|
pos_in_block: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get KV at specified position in decode slot.
|
Get KV at specified position in decode slot.
|
||||||
|
|
||||||
|
GPU cache has no layer dimension - decode slot contains data for
|
||||||
|
whatever layer was most recently stored.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: Layer ID
|
|
||||||
pos_in_block: Token position within block (0 to block_size-1)
|
pos_in_block: Token position within block (0 to block_size-1)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
k = self.k_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
|
||||||
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
v = self.v_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
|
||||||
k = k.unsqueeze(0)
|
k = k.unsqueeze(0)
|
||||||
v = v.unsqueeze(0)
|
v = v.unsqueeze(0)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
def get_kv_for_decode_slot_accumulated(
|
def get_kv_for_decode_slot_accumulated(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
|
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
|
||||||
|
|
||||||
|
GPU cache has no layer dimension - decode slot contains data for
|
||||||
|
whatever layer was most recently stored.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: Layer ID
|
|
||||||
num_tokens: Number of accumulated tokens (1 to block_size)
|
num_tokens: Number of accumulated tokens (1 to block_size)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
k = self.k_cache_gpu[self.decode_slot, :num_tokens]
|
||||||
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
v = self.v_cache_gpu[self.decode_slot, :num_tokens]
|
||||||
k = k.unsqueeze(0)
|
k = k.unsqueeze(0)
|
||||||
v = v.unsqueeze(0)
|
v = v.unsqueeze(0)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
# ----- Legacy compatibility methods (for decode double-buffering) -----
|
# ----- Legacy compatibility methods (for decode double-buffering) -----
|
||||||
|
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
|
||||||
|
|
||||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||||
"""
|
"""
|
||||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
||||||
|
|
||||||
Uses first half of decode_load_slots as 'compute' region.
|
Uses first half of decode_load_slots as 'compute' region.
|
||||||
|
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||||
"""
|
"""
|
||||||
if not cpu_block_ids:
|
if not cpu_block_ids:
|
||||||
return
|
return
|
||||||
@@ -953,26 +872,27 @@ class OffloadEngine:
|
|||||||
for i in range(num_to_load):
|
for i in range(num_to_load):
|
||||||
cpu_id = cpu_block_ids[i]
|
cpu_id = cpu_block_ids[i]
|
||||||
gpu_slot = slots[i]
|
gpu_slot = slots[i]
|
||||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
# GPU: no layer dimension, CPU: has layer dimension
|
||||||
|
self.k_cache_gpu[gpu_slot].copy_(
|
||||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||||
)
|
)
|
||||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
self.v_cache_gpu[gpu_slot].copy_(
|
||||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||||
)
|
)
|
||||||
if num_to_load > 0:
|
if num_to_load > 0:
|
||||||
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
|
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
||||||
|
|
||||||
def wait_compute_layer(self, layer_id: int) -> None:
|
def wait_compute_layer(self) -> None:
|
||||||
"""Legacy: Wait for 'compute' region loading."""
|
"""Legacy: Wait for 'compute' region loading."""
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
if self.decode_load_slots:
|
if self.decode_load_slots:
|
||||||
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
|
self.wait_slot_layer(self.decode_load_slots[0])
|
||||||
|
|
||||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||||
"""
|
"""
|
||||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
||||||
|
|
||||||
Uses second half of decode_load_slots as 'prefetch' region.
|
Uses second half of decode_load_slots as 'prefetch' region.
|
||||||
|
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||||
"""
|
"""
|
||||||
if not cpu_block_ids:
|
if not cpu_block_ids:
|
||||||
return
|
return
|
||||||
@@ -987,37 +907,36 @@ class OffloadEngine:
|
|||||||
for i in range(num_to_load):
|
for i in range(num_to_load):
|
||||||
cpu_id = cpu_block_ids[i]
|
cpu_id = cpu_block_ids[i]
|
||||||
gpu_slot = slots[i]
|
gpu_slot = slots[i]
|
||||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
# GPU: no layer dimension, CPU: has layer dimension
|
||||||
|
self.k_cache_gpu[gpu_slot].copy_(
|
||||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||||
)
|
)
|
||||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
self.v_cache_gpu[gpu_slot].copy_(
|
||||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||||
)
|
)
|
||||||
if num_to_load > 0:
|
if num_to_load > 0:
|
||||||
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
|
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
||||||
|
|
||||||
def wait_prefetch_layer(self, layer_id: int) -> None:
|
def wait_prefetch_layer(self) -> None:
|
||||||
"""Legacy: Wait for 'prefetch' region loading."""
|
"""Legacy: Wait for 'prefetch' region loading."""
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
half = max(1, len(self.decode_load_slots) // 2)
|
||||||
slots = self.decode_load_slots[half:]
|
slots = self.decode_load_slots[half:]
|
||||||
if slots:
|
if slots:
|
||||||
self.wait_slot_layer(slots[0], layer_id)
|
self.wait_slot_layer(slots[0])
|
||||||
elif self.decode_load_slots:
|
elif self.decode_load_slots:
|
||||||
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
|
self.wait_slot_layer(self.decode_load_slots[0])
|
||||||
|
|
||||||
def get_kv_for_compute(
|
def get_kv_for_compute(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
|
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
half = max(1, len(self.decode_load_slots) // 2)
|
||||||
slots = self.decode_load_slots[:half][:num_blocks]
|
slots = self.decode_load_slots[:half][:num_blocks]
|
||||||
return self.get_kv_for_slots(layer_id, slots)
|
return self.get_kv_for_slots(slots)
|
||||||
|
|
||||||
def get_kv_for_prefetch(
|
def get_kv_for_prefetch(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
|
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
|
||||||
@@ -1026,7 +945,7 @@ class OffloadEngine:
|
|||||||
if not slots:
|
if not slots:
|
||||||
slots = self.decode_load_slots
|
slots = self.decode_load_slots
|
||||||
slots = slots[:num_blocks]
|
slots = slots[:num_blocks]
|
||||||
return self.get_kv_for_slots(layer_id, slots)
|
return self.get_kv_for_slots(slots)
|
||||||
|
|
||||||
# ========== Debug Hook Interface ==========
|
# ========== Debug Hook Interface ==========
|
||||||
#
|
#
|
||||||
@@ -1082,12 +1001,15 @@ class OffloadEngine:
|
|||||||
Call all registered debug hooks with loaded tensor (internal use).
|
Call all registered debug hooks with loaded tensor (internal use).
|
||||||
|
|
||||||
Called by attention.py after wait_slot_layer completes.
|
Called by attention.py after wait_slot_layer completes.
|
||||||
|
GPU cache has no layer dimension - slot contains data for the layer
|
||||||
|
that was just loaded.
|
||||||
"""
|
"""
|
||||||
if not self._debug_mode or not self._debug_hooks:
|
if not self._debug_mode or not self._debug_hooks:
|
||||||
return
|
return
|
||||||
|
|
||||||
k = self.k_cache_gpu[layer_id, slot_idx]
|
# GPU cache has no layer dimension
|
||||||
v = self.v_cache_gpu[layer_id, slot_idx]
|
k = self.k_cache_gpu[slot_idx]
|
||||||
|
v = self.v_cache_gpu[slot_idx]
|
||||||
|
|
||||||
for hook in self._debug_hooks:
|
for hook in self._debug_hooks:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -201,6 +201,18 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
|
|
||||||
|
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
||||||
|
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
||||||
|
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||||
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||||
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||||
|
if seq is not None:
|
||||||
|
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||||
|
if current_chunk_idx < len(cpu_block_ids):
|
||||||
|
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||||
|
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
||||||
|
|
||||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||||
return final_o.squeeze(0)
|
return final_o.squeeze(0)
|
||||||
|
|
||||||
@@ -219,11 +231,11 @@ class Attention(nn.Module):
|
|||||||
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
||||||
# Load to slot 0 (single slot)
|
# Load to slot 0 (single slot)
|
||||||
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
||||||
offload_engine.wait_slot_layer(0, self.layer_id)
|
offload_engine.wait_slot_layer(0)
|
||||||
|
|
||||||
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
|
||||||
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
q_batched, prev_k, prev_v,
|
q_batched, prev_k, prev_v,
|
||||||
@@ -289,21 +301,21 @@ class Attention(nn.Module):
|
|||||||
for block_idx in range(num_blocks):
|
for block_idx in range(num_blocks):
|
||||||
cpu_block_id = cpu_block_table[block_idx]
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
||||||
offload_engine.wait_slot_layer(slot, self.layer_id)
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
# Debug: call hooks on compute_stream (synchronized with transfer)
|
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||||
if offload_engine.debug_mode:
|
if offload_engine.debug_mode:
|
||||||
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
||||||
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
q_batched, prev_k, prev_v,
|
q_batched, prev_k, prev_v,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
)
|
)
|
||||||
# Record compute done so next load can safely reuse this slot
|
# Record compute done so next load can safely reuse this slot
|
||||||
offload_engine.record_slot_compute_done(slot, self.layer_id)
|
offload_engine.record_slot_compute_done(slot)
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
else:
|
else:
|
||||||
@@ -332,7 +344,7 @@ class Attention(nn.Module):
|
|||||||
cpu_block_id = cpu_block_table[block_idx]
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
|
|
||||||
# Wait for current slot's transfer to complete (on compute_stream)
|
# Wait for current slot's transfer to complete (on compute_stream)
|
||||||
offload_engine.wait_slot_layer(current_slot, self.layer_id)
|
offload_engine.wait_slot_layer(current_slot)
|
||||||
|
|
||||||
# Compute attention on current slot's data
|
# Compute attention on current slot's data
|
||||||
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
||||||
@@ -342,7 +354,7 @@ class Attention(nn.Module):
|
|||||||
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
q_batched, prev_k, prev_v,
|
q_batched, prev_k, prev_v,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
@@ -351,7 +363,7 @@ class Attention(nn.Module):
|
|||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
# Record compute done - this allows the next transfer to safely overwrite this slot
|
# Record compute done - this allows the next transfer to safely overwrite this slot
|
||||||
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
|
offload_engine.record_slot_compute_done(current_slot)
|
||||||
|
|
||||||
# Immediately start loading the NEXT block into this slot (if more blocks remain)
|
# Immediately start loading the NEXT block into this slot (if more blocks remain)
|
||||||
# Key insight: reuse current_slot immediately after compute is done!
|
# Key insight: reuse current_slot immediately after compute is done!
|
||||||
@@ -464,13 +476,9 @@ class Attention(nn.Module):
|
|||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
# Get KV from current buffer FIRST, before prefetching overwrites it
|
# Get KV from current buffer FIRST, before prefetching overwrites it
|
||||||
if use_compute:
|
if use_compute:
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
k_chunk, v_chunk = offload_engine.get_kv_for_compute(num_blocks_in_chunk)
|
||||||
self.layer_id, num_blocks_in_chunk
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
|
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
|
||||||
self.layer_id, num_blocks_in_chunk
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute attention for this chunk
|
# Compute attention for this chunk
|
||||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||||
@@ -512,8 +520,9 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
if num_accumulated > 0:
|
if num_accumulated > 0:
|
||||||
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
# GPU cache has no layer dimension
|
||||||
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
decode_k = offload_engine.k_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||||
|
decode_v = offload_engine.v_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||||
decode_k = decode_k.unsqueeze(0)
|
decode_k = decode_k.unsqueeze(0)
|
||||||
decode_v = decode_v.unsqueeze(0)
|
decode_v = decode_v.unsqueeze(0)
|
||||||
|
|
||||||
|
|||||||
@@ -30,9 +30,6 @@ def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor,
|
|||||||
if layer_id != 0:
|
if layer_id != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if layer_id == 0:
|
|
||||||
__import__('pdb').set_trace()
|
|
||||||
|
|
||||||
load_log.append({
|
load_log.append({
|
||||||
"chunk_idx": current_chunk[0],
|
"chunk_idx": current_chunk[0],
|
||||||
"cpu_block_id": cpu_block_id,
|
"cpu_block_id": cpu_block_id,
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import torch
|
|||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
from nanovllm.kvcache.debug_utils import dump_block_state
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -97,9 +96,9 @@ def make_verified_load_to_slot_layer(original_func, offload_engine):
|
|||||||
# cpu_block_id == chunk_idx in our sequential test
|
# cpu_block_id == chunk_idx in our sequential test
|
||||||
expected_k, expected_v = get_expected_pattern(cpu_block_id)
|
expected_k, expected_v = get_expected_pattern(cpu_block_id)
|
||||||
|
|
||||||
# Read GPU slot data
|
# Read GPU slot data (GPU cache has no layer dimension)
|
||||||
gpu_k = offload_engine.k_cache_gpu[layer_id, slot_idx]
|
gpu_k = offload_engine.k_cache_gpu[slot_idx]
|
||||||
gpu_v = offload_engine.v_cache_gpu[layer_id, slot_idx]
|
gpu_v = offload_engine.v_cache_gpu[slot_idx]
|
||||||
|
|
||||||
actual_k = gpu_k.float().mean().item()
|
actual_k = gpu_k.float().mean().item()
|
||||||
actual_v = gpu_v.float().mean().item()
|
actual_v = gpu_v.float().mean().item()
|
||||||
@@ -306,9 +305,9 @@ def make_gpu_write_verification_post_hook(layer_id: int):
|
|||||||
# Get expected pattern for current chunk
|
# Get expected pattern for current chunk
|
||||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
||||||
|
|
||||||
# Verify write_slot contains current chunk's data
|
# Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
|
||||||
gpu_k = oe.k_cache_gpu[layer_id, write_slot]
|
gpu_k = oe.k_cache_gpu[write_slot]
|
||||||
gpu_v = oe.v_cache_gpu[layer_id, write_slot]
|
gpu_v = oe.v_cache_gpu[write_slot]
|
||||||
|
|
||||||
actual_k_mean = gpu_k.float().mean().item()
|
actual_k_mean = gpu_k.float().mean().item()
|
||||||
actual_v_mean = gpu_v.float().mean().item()
|
actual_v_mean = gpu_v.float().mean().item()
|
||||||
@@ -419,9 +418,9 @@ def make_post_chunk_verification_hook(layer_id: int):
|
|||||||
|
|
||||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
||||||
|
|
||||||
# Check GPU ring buffer
|
# Check GPU ring buffer (GPU cache has no layer dimension)
|
||||||
gpu_k = oe.k_cache_gpu[layer_id, ring_slot]
|
gpu_k = oe.k_cache_gpu[ring_slot]
|
||||||
gpu_v = oe.v_cache_gpu[layer_id, ring_slot]
|
gpu_v = oe.v_cache_gpu[ring_slot]
|
||||||
|
|
||||||
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
|
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
|
||||||
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
|
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
|
||||||
|
|||||||
Reference in New Issue
Block a user