[WIP] Before fix needle.

This commit is contained in:
Zijie Tian
2025-12-31 23:35:25 +08:00
parent ccd1b3d4ab
commit 30462fe89a
5 changed files with 212 additions and 290 deletions

View File

@@ -489,24 +489,15 @@ class ModelRunner:
logical_id = seq.block_table[block_idx] logical_id = seq.block_table[block_idx]
self.kvcache_manager.prefilled_blocks.add(logical_id) self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk's ring buffer slot to CPU (async) # NOTE: Per-layer offloading is now done in attention.forward
# Each layer offloads its KV to CPU immediately after computing attention.
# We just need to wait for the last offload to complete before reusing the slot.
if block_idx < len(cpu_block_ids): if block_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[block_idx] # TODO: Sparse policy hook needs update for new GPU cache architecture
# The GPU cache no longer has layer dimension, so we can't access
# Call sparse policy hook before offload (to capture metadata) # k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
sparse_policy = self.kvcache_manager.sparse_policy # in attention.forward after per-layer offload.
if sparse_policy is not None: pass
num_tokens = chunk_end - chunk_start
for layer_id in range(offload_engine.num_layers):
k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens]
sparse_policy.on_block_offloaded(
cpu_block_id=cpu_block_id,
layer_id=layer_id,
k_cache=k_cache,
num_valid_tokens=num_tokens,
)
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
# Wait for offload to complete before next chunk # Wait for offload to complete before next chunk
# (slot will be reused after N chunks) # (slot will be reused after N chunks)
@@ -628,7 +619,11 @@ class ModelRunner:
if pos_in_block == self.block_size - 1: if pos_in_block == self.block_size - 1:
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0: if last_cpu_block >= 0:
offload_engine.offload_decode_slot(last_cpu_block) # TODO: In new GPU cache architecture (no layer dimension),
# decode offload should be done per-layer in attention.forward.
# For now, offload all layers sequentially.
for layer_id in range(offload_engine.num_layers):
offload_engine.offload_decode_slot_layer(layer_id, last_cpu_block)
offload_engine.wait_all_offload_done() offload_engine.wait_all_offload_done()
# Reset decode start position for next block # Reset decode start position for next block
self.kvcache_manager.reset_decode_start_pos(seq) self.kvcache_manager.reset_decode_start_pos(seq)

View File

@@ -67,14 +67,19 @@ class OffloadEngine:
self.block_numel = block_size * self.kv_dim self.block_numel = block_size * self.kv_dim
# ========== sgDMA pitch parameters for strided transfers ========== # ========== sgDMA pitch parameters for strided transfers ==========
# CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
# GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dim)
# For CPU-to-GPU transfer (H2D): copy single layer, single block at a time
# For all-layer CPU operations (D2H offload to all layers): use sgDMA
self.dtype_size = dtype.itemsize self.dtype_size = dtype.itemsize
# CPU pitch: stride between layers in CPU cache (for all-layer operations)
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
self.gpu_pitch = num_gpu_blocks * self.block_numel * self.dtype_size # GPU has no layer dimension, so single block transfer is contiguous
self.width = self.block_numel * self.dtype_size self.gpu_block_bytes = self.block_numel * self.dtype_size
self.height = num_layers self.height = num_layers # For CPU all-layer operations
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, gpu_pitch={self.gpu_pitch}, " logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, "
f"width={self.width}, height={self.height}") f"gpu_block_bytes={self.gpu_block_bytes}, height={self.height}")
# ========== Unified Ring Buffer configuration ========== # ========== Unified Ring Buffer configuration ==========
# Constraint checks # Constraint checks
@@ -100,14 +105,16 @@ class OffloadEngine:
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading") logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
# ========== Fixed-address GPU KV cache ========== # ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] # Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
# Use zeros initialization to avoid uninitialized memory issues # NOTE: No num_layers dimension! GPU slots are shared across layers.
# Each layer reuses the same slots (layers execute sequentially).
# This saves 28x GPU memory compared to per-layer allocation.
self.k_cache_gpu = torch.zeros( self.k_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda" dtype=dtype, device="cuda"
) )
self.v_cache_gpu = torch.zeros( self.v_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda" dtype=dtype, device="cuda"
) )
@@ -159,35 +166,23 @@ class OffloadEngine:
# Decode offload event # Decode offload event
self.decode_offload_done = torch.cuda.Event() self.decode_offload_done = torch.cuda.Event()
# ========== Per-slot Per-layer events for ring buffer ========== # ========== Per-slot events for ring buffer ==========
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion # Since GPU cache has no layer dimension and layers execute sequentially,
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion # we only need per-slot events (not per-slot per-layer).
self.ring_slot_ready = [ # ring_slot_ready[slot_idx] = CUDA Event for H2D completion
[torch.cuda.Event() for _ in range(num_layers)] # ring_slot_offload_done[slot_idx] = CUDA Event for D2H completion
for _ in range(self.num_ring_slots) self.ring_slot_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
] self.ring_slot_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_offload_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# Per-slot events for all-layer operations (used in some legacy paths) # ========== Per-slot compute_done events for async pipeline ==========
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)] # ring_slot_compute_done[slot_idx] = CUDA Event for compute completion
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)] # This ensures we don't overwrite data before it's been read by attention
self.ring_slot_compute_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# ========== Per-slot Per-layer compute_done events for async pipeline ==========
# ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion
# This is used to ensure we don't overwrite data before it's been read by attention
self.ring_slot_compute_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# Initialize all compute_done events (record them once) # Initialize all compute_done events (record them once)
# This prevents undefined behavior on first load_to_slot_layer call # This prevents undefined behavior on first load_to_slot_layer call
for slot_idx in range(self.num_ring_slots): for slot_idx in range(self.num_ring_slots):
for layer_id in range(num_layers): self.ring_slot_compute_done[slot_idx].record()
self.ring_slot_compute_done[slot_idx][layer_id].record()
torch.cuda.synchronize() # Ensure all events are recorded torch.cuda.synchronize() # Ensure all events are recorded
# ========== Event tracking for async transfers ========== # ========== Event tracking for async transfers ==========
@@ -204,23 +199,24 @@ class OffloadEngine:
return stream return stream
# ========== CUDA Graph compatible methods ========== # ========== CUDA Graph compatible methods ==========
# NOTE: These methods need to be updated for the new GPU cache architecture.
# GPU cache no longer has layer dimension, so gathered copy semantics change.
# For now, these are kept for reference but should not be used without updating.
def gathered_h2d_layer(self, layer_id: int) -> None: def gathered_h2d_layer(self, layer_id: int) -> None:
""" """
Execute gathered H2D copy for a single layer. Execute gathered H2D copy for a single layer.
This method is CUDA Graph compatible - can be captured into a graph. WARNING: This method needs updating for new GPU cache architecture.
Before calling, update_gather_indices() must be called to set up GPU cache no longer has layer dimension.
which CPU blocks to copy to which GPU slots.
Args:
layer_id: Layer index to transfer
""" """
# GPU cache has no layer dimension - use flat indexing
# Source is CPU[layer_id], dest is GPU (shared across layers)
gathered_copy_kv( gathered_copy_kv(
k_src=self.k_cache_cpu[layer_id], k_src=self.k_cache_cpu[layer_id],
v_src=self.v_cache_cpu[layer_id], v_src=self.v_cache_cpu[layer_id],
k_dst=self.k_cache_gpu[layer_id], k_dst=self.k_cache_gpu, # No layer indexing
v_dst=self.v_cache_gpu[layer_id], v_dst=self.v_cache_gpu, # No layer indexing
indices=self.gather_indices_gpu[layer_id], indices=self.gather_indices_gpu[layer_id],
) )
@@ -228,7 +224,8 @@ class OffloadEngine:
""" """
Execute gathered H2D copy for all layers. Execute gathered H2D copy for all layers.
CUDA Graph compatible - can be captured into a single graph. WARNING: In new architecture, GPU slots are shared across layers.
This method would overwrite slots multiple times. Not recommended.
""" """
for layer_id in range(self.num_layers): for layer_id in range(self.num_layers):
self.gathered_h2d_layer(layer_id) self.gathered_h2d_layer(layer_id)
@@ -297,10 +294,10 @@ class OffloadEngine:
""" """
Async prefetch a single block from CPU to GPU. Async prefetch a single block from CPU to GPU.
For use in prefill phase where CUDA graphs are not used. GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args: Args:
layer_id: Layer index layer_id: Layer index (for CPU cache)
cpu_block_id: Source block in CPU cache cpu_block_id: Source block in CPU cache
gpu_block_id: Destination slot in GPU cache gpu_block_id: Destination slot in GPU cache
@@ -313,13 +310,12 @@ class OffloadEngine:
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]") logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
# K cache # GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[layer_id, gpu_block_id].copy_( self.k_cache_gpu[gpu_block_id].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True non_blocking=True
) )
# V cache self.v_cache_gpu[gpu_block_id].copy_(
self.v_cache_gpu[layer_id, gpu_block_id].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True non_blocking=True
) )
@@ -356,8 +352,10 @@ class OffloadEngine:
""" """
Async offload a block from GPU to CPU. Async offload a block from GPU to CPU.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args: Args:
layer_id: Layer index layer_id: Layer index (for CPU cache)
gpu_block_id: Source slot in GPU cache gpu_block_id: Source slot in GPU cache
cpu_block_id: Destination block in CPU cache cpu_block_id: Destination block in CPU cache
@@ -373,14 +371,13 @@ class OffloadEngine:
# Wait for any compute using this block # Wait for any compute using this block
stream.wait_stream(self.compute_stream) stream.wait_stream(self.compute_stream)
# K cache # GPU: no layer dimension, CPU: has layer dimension
self.k_cache_cpu[layer_id, cpu_block_id].copy_( self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, gpu_block_id], self.k_cache_gpu[gpu_block_id],
non_blocking=True non_blocking=True
) )
# V cache
self.v_cache_cpu[layer_id, cpu_block_id].copy_( self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, gpu_block_id], self.v_cache_gpu[gpu_block_id],
non_blocking=True non_blocking=True
) )
event.record() event.record()
@@ -417,11 +414,10 @@ class OffloadEngine:
""" """
Load CPU blocks to specific GPU slots for chunked decode. Load CPU blocks to specific GPU slots for chunked decode.
Uses the main GPU KV cache slots, not a separate temp buffer. GPU cache has no layer dimension - layer_id is for CPU cache indexing.
This is the same mechanism as chunked prefill uses.
Args: Args:
layer_id: Layer index layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into (must be same length) gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
""" """
@@ -434,12 +430,12 @@ class OffloadEngine:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy from pinned CPU memory to GPU KV cache slot # GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[layer_id, gpu_slot].copy_( self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True non_blocking=True
) )
self.v_cache_gpu[layer_id, gpu_slot].copy_( self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True non_blocking=True
) )
@@ -456,8 +452,10 @@ class OffloadEngine:
""" """
Async version: Load CPU blocks to GPU slots. Async version: Load CPU blocks to GPU slots.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args: Args:
layer_id: Layer index layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into gpu_slot_ids: List of GPU slot IDs to load into
@@ -474,11 +472,12 @@ class OffloadEngine:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
self.k_cache_gpu[layer_id, gpu_slot].copy_( # GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True non_blocking=True
) )
self.v_cache_gpu[layer_id, gpu_slot].copy_( self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True non_blocking=True
) )
@@ -486,44 +485,8 @@ class OffloadEngine:
return event return event
def load_cpu_blocks_to_gpu_slots_all_layers( # NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
self, # layer dimension. Each GPU slot holds data for ONE layer at a time.
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to GPU slots for ALL layers at once.
More efficient than per-layer loading when we know the mapping upfront.
Args:
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy all layers at once using sgDMA
memcpy_2d_async(
self.k_cache_gpu[:, gpu_slot],
self.k_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=stream
)
memcpy_2d_async(
self.v_cache_gpu[:, gpu_slot],
self.v_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=stream
)
stream.synchronize()
# ========== Synchronization methods ========== # ========== Synchronization methods ==========
@@ -548,21 +511,27 @@ class OffloadEngine:
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
""" """
Get GPU K/V cache tensors for a specific layer. Get GPU K/V cache tensors for attention layer.
NOTE: GPU cache has no layer dimension - all layers share the same slots.
The layer_id parameter is kept for API compatibility but not used.
Returns: Returns:
(k_cache, v_cache) tensors for the layer (k_cache, v_cache) tensors
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim] Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
""" """
return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id] # GPU cache is shared across all layers (no layer dimension)
return self.k_cache_gpu, self.v_cache_gpu
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]: def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
""" """
Get full GPU K/V cache tensors. Get full GPU K/V cache tensors.
NOTE: GPU cache has no layer dimension in the new architecture.
Returns: Returns:
(k_cache, v_cache) tensors (k_cache, v_cache) tensors
Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
""" """
return self.k_cache_gpu, self.v_cache_gpu return self.k_cache_gpu, self.v_cache_gpu
@@ -668,7 +637,7 @@ class OffloadEngine:
# ----- Per-slot Per-layer loading methods ----- # ----- Per-slot Per-layer loading methods -----
def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None: def record_slot_compute_done(self, slot_idx: int) -> None:
""" """
Record that computation using this slot's data is done. Record that computation using this slot's data is done.
@@ -677,22 +646,23 @@ class OffloadEngine:
Args: Args:
slot_idx: GPU slot index that was just used for computation slot_idx: GPU slot index that was just used for computation
layer_id: Layer index
""" """
self.ring_slot_compute_done[slot_idx][layer_id].record() self.ring_slot_compute_done[slot_idx].record()
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
""" """
Async load a single CPU block to a ring buffer slot for one layer. Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining. This is the core building block for ring buffer pipelining.
GPU cache has no layer dimension - slots are shared across all layers.
CPU cache still has layer dimension for persistent storage.
Before starting the transfer, waits for: Before starting the transfer, waits for:
1. Any previous compute on this slot to complete 1. Any previous compute on this slot to complete
2. Any pending offload of this slot to complete
Args: Args:
slot_idx: Target GPU slot index slot_idx: Target GPU slot index
layer_id: Layer index to load layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID cpu_block_id: Source CPU block ID
""" """
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
@@ -704,150 +674,105 @@ class OffloadEngine:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
# Wait for previous compute on this slot to complete before overwriting # Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading # This prevents data race: transfer must not start until attention finishes reading
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id]) stream.wait_event(self.ring_slot_compute_done[slot_idx])
# Also wait for any pending offload of this slot to complete # Also wait for any pending offload of this slot to complete
# This prevents race: load must not write GPU slot while offload is reading from it # This prevents race: load must not write GPU slot while offload is reading from it
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx]) stream.wait_event(self.ring_slot_offload_done[slot_idx])
self.k_cache_gpu[layer_id, slot_idx].copy_( # GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
) )
self.v_cache_gpu[layer_id, slot_idx].copy_( self.v_cache_gpu[slot_idx].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
) )
self.ring_slot_ready[slot_idx][layer_id].record(stream) self.ring_slot_ready[slot_idx].record(stream)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None: def wait_slot_layer(self, slot_idx: int) -> None:
""" """
Wait for a slot's loading to complete for a specific layer. Wait for a slot's loading to complete.
Args: Args:
slot_idx: GPU slot index to wait for slot_idx: GPU slot index to wait for
layer_id: Layer index to wait for
""" """
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id]) self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None: # NOTE: load_to_slot_all_layers removed - GPU cache no longer has layer dimension.
""" # Each GPU slot holds data for ONE layer at a time. Layers execute sequentially,
Async load a CPU block to a ring buffer slot for ALL layers. # reusing the same GPU slots.
Args:
slot_idx: Target GPU slot index
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
memcpy_2d_async(
self.k_cache_gpu[:, slot_idx],
self.k_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_gpu[:, slot_idx],
self.v_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
def wait_slot_all_layers(self, slot_idx: int) -> None:
"""Wait for a slot's loading to complete for ALL layers."""
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
# ----- Slot offload methods ----- # ----- Slot offload methods -----
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None: # NOTE: offload_slot_to_cpu (all-layers) removed - GPU cache no longer has layer dimension.
# Use offload_slot_layer_to_cpu for per-layer offloading.
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
""" """
Async offload a ring buffer slot to CPU (all layers). Async offload a ring buffer slot to CPU for one layer.
GPU cache has no layer dimension, so we copy from GPU slot to the
specific layer in CPU cache.
Args: Args:
slot_idx: Source GPU slot index slot_idx: Source GPU slot index
layer_id: Target layer in CPU cache
cpu_block_id: Target CPU block ID cpu_block_id: Target CPU block ID
""" """
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]") logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]") torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main): with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream # Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations # - compute_stream: for flash attention operations
# - default_stream: for store_kvcache which runs on default stream # - default_stream: for store_kvcache which runs on default stream
self.transfer_stream_main.wait_stream(self.compute_stream) self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream()) self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, slot_idx],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_cpu[:, cpu_block_id],
self.v_cache_gpu[:, slot_idx],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_offload(self, slot_idx: int) -> None: # GPU: no layer dimension, CPU: has layer dimension
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
Args:
slot_idx: Source GPU slot index
layer_id: Layer index to offload
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
self.k_cache_cpu[layer_id, cpu_block_id].copy_( self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True self.k_cache_gpu[slot_idx], non_blocking=True
) )
self.v_cache_cpu[layer_id, cpu_block_id].copy_( self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True self.v_cache_gpu[slot_idx], non_blocking=True
) )
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main) self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
"""Wait for slot offload to complete for a specific layer."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
# ----- KV access methods for ring buffer ----- # ----- KV access methods for ring buffer -----
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]: def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]:
""" """
Get KV for a single ring buffer slot. Get KV for a single ring buffer slot.
GPU cache has no layer dimension - slots contain data for whatever
layer was most recently loaded.
Args: Args:
slot_idx: GPU slot index slot_idx: GPU slot index
layer_id: Layer ID
Returns: Returns:
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim] (k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
""" """
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim] k = self.k_cache_gpu[slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0) v = self.v_cache_gpu[slot_idx].unsqueeze(0)
return k, v return k, v
def get_kv_for_slots( def get_kv_for_slots(
self, self,
layer_id: int,
slot_indices: List[int], slot_indices: List[int],
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Get KV for multiple ring buffer slots. Get KV for multiple ring buffer slots.
GPU cache has no layer dimension - returns data from specified slots.
Args: Args:
layer_id: Layer ID
slot_indices: List of GPU slot indices slot_indices: List of GPU slot indices
Returns: Returns:
@@ -855,92 +780,86 @@ class OffloadEngine:
""" """
if not slot_indices: if not slot_indices:
return None, None return None, None
k = self.k_cache_gpu[layer_id, slot_indices] k = self.k_cache_gpu[slot_indices]
v = self.v_cache_gpu[layer_id, slot_indices] v = self.v_cache_gpu[slot_indices]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v return k, v
# ----- Decode slot methods (kept for decode phase) ----- # ----- Decode slot methods (kept for decode phase) -----
# NOTE: For decode with CPU offload, the flow is per-layer:
# 1. Each layer stores to decode_slot (same GPU memory, reused)
# 2. Each layer offloads its data to CPU[layer_id, block_id]
# 3. Each layer loads prev blocks from CPU[layer_id] when needed
def offload_decode_slot(self, cpu_block_id: int) -> None: def offload_decode_slot_layer(self, layer_id: int, cpu_block_id: int) -> None:
""" """
Offload KV from decode slot (slot[0]) to CPU. Offload KV from decode slot (slot[0]) to CPU for one layer.
Args: Args:
layer_id: Layer ID
cpu_block_id: Target CPU block ID cpu_block_id: Target CPU block ID
""" """
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]") # Reuse the existing per-layer offload method
self.offload_slot_layer_to_cpu(self.decode_slot, layer_id, cpu_block_id)
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, self.decode_slot],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_cpu[:, cpu_block_id],
self.v_cache_gpu[:, self.decode_slot],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
self.decode_offload_done.record(self.transfer_stream_main)
def wait_decode_offload(self) -> None: def wait_decode_offload(self) -> None:
"""Wait for decode slot offload to complete.""" """Wait for decode slot offload to complete."""
self.compute_stream.wait_event(self.decode_offload_done) self.wait_slot_offload(self.decode_slot)
def get_kv_for_decode_slot( def get_kv_for_decode_slot(
self, self,
layer_id: int,
pos_in_block: int, pos_in_block: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Get KV at specified position in decode slot. Get KV at specified position in decode slot.
GPU cache has no layer dimension - decode slot contains data for
whatever layer was most recently stored.
Args: Args:
layer_id: Layer ID
pos_in_block: Token position within block (0 to block_size-1) pos_in_block: Token position within block (0 to block_size-1)
Returns: Returns:
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim] (k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
""" """
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] k = self.k_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] v = self.v_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
k = k.unsqueeze(0) k = k.unsqueeze(0)
v = v.unsqueeze(0) v = v.unsqueeze(0)
return k, v return k, v
def get_kv_for_decode_slot_accumulated( def get_kv_for_decode_slot_accumulated(
self, self,
layer_id: int,
num_tokens: int, num_tokens: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Get accumulated KV in decode slot (positions 0 to num_tokens-1). Get accumulated KV in decode slot (positions 0 to num_tokens-1).
GPU cache has no layer dimension - decode slot contains data for
whatever layer was most recently stored.
Args: Args:
layer_id: Layer ID
num_tokens: Number of accumulated tokens (1 to block_size) num_tokens: Number of accumulated tokens (1 to block_size)
Returns: Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim] (k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
""" """
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] k = self.k_cache_gpu[self.decode_slot, :num_tokens]
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens] v = self.v_cache_gpu[self.decode_slot, :num_tokens]
k = k.unsqueeze(0) k = k.unsqueeze(0)
v = v.unsqueeze(0) v = v.unsqueeze(0)
return k, v return k, v
# ----- Legacy compatibility methods (for decode double-buffering) ----- # ----- Legacy compatibility methods (for decode double-buffering) -----
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
""" """
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering. Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses first half of decode_load_slots as 'compute' region. Uses first half of decode_load_slots as 'compute' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
""" """
if not cpu_block_ids: if not cpu_block_ids:
return return
@@ -953,26 +872,27 @@ class OffloadEngine:
for i in range(num_to_load): for i in range(num_to_load):
cpu_id = cpu_block_ids[i] cpu_id = cpu_block_ids[i]
gpu_slot = slots[i] gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_( # GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
) )
self.v_cache_gpu[layer_id, gpu_slot].copy_( self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
) )
if num_to_load > 0: if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main) self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None: def wait_compute_layer(self) -> None:
"""Legacy: Wait for 'compute' region loading.""" """Legacy: Wait for 'compute' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
if self.decode_load_slots: if self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id) self.wait_slot_layer(self.decode_load_slots[0])
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
""" """
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering. Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses second half of decode_load_slots as 'prefetch' region. Uses second half of decode_load_slots as 'prefetch' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
""" """
if not cpu_block_ids: if not cpu_block_ids:
return return
@@ -987,37 +907,36 @@ class OffloadEngine:
for i in range(num_to_load): for i in range(num_to_load):
cpu_id = cpu_block_ids[i] cpu_id = cpu_block_ids[i]
gpu_slot = slots[i] gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_( # GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
) )
self.v_cache_gpu[layer_id, gpu_slot].copy_( self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
) )
if num_to_load > 0: if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main) self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None: def wait_prefetch_layer(self) -> None:
"""Legacy: Wait for 'prefetch' region loading.""" """Legacy: Wait for 'prefetch' region loading."""
half = max(1, len(self.decode_load_slots) // 2) half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:] slots = self.decode_load_slots[half:]
if slots: if slots:
self.wait_slot_layer(slots[0], layer_id) self.wait_slot_layer(slots[0])
elif self.decode_load_slots: elif self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id) self.wait_slot_layer(self.decode_load_slots[0])
def get_kv_for_compute( def get_kv_for_compute(
self, self,
layer_id: int,
num_blocks: int, num_blocks: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots).""" """Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2) half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half][:num_blocks] slots = self.decode_load_slots[:half][:num_blocks]
return self.get_kv_for_slots(layer_id, slots) return self.get_kv_for_slots(slots)
def get_kv_for_prefetch( def get_kv_for_prefetch(
self, self,
layer_id: int,
num_blocks: int, num_blocks: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots).""" """Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
@@ -1026,7 +945,7 @@ class OffloadEngine:
if not slots: if not slots:
slots = self.decode_load_slots slots = self.decode_load_slots
slots = slots[:num_blocks] slots = slots[:num_blocks]
return self.get_kv_for_slots(layer_id, slots) return self.get_kv_for_slots(slots)
# ========== Debug Hook Interface ========== # ========== Debug Hook Interface ==========
# #
@@ -1082,12 +1001,15 @@ class OffloadEngine:
Call all registered debug hooks with loaded tensor (internal use). Call all registered debug hooks with loaded tensor (internal use).
Called by attention.py after wait_slot_layer completes. Called by attention.py after wait_slot_layer completes.
GPU cache has no layer dimension - slot contains data for the layer
that was just loaded.
""" """
if not self._debug_mode or not self._debug_hooks: if not self._debug_mode or not self._debug_hooks:
return return
k = self.k_cache_gpu[layer_id, slot_idx] # GPU cache has no layer dimension
v = self.v_cache_gpu[layer_id, slot_idx] k = self.k_cache_gpu[slot_idx]
v = self.v_cache_gpu[slot_idx]
for hook in self._debug_hooks: for hook in self._debug_hooks:
try: try:

View File

@@ -201,6 +201,18 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_pop() # ChunkedPrefill torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Per-layer offload: In new GPU cache architecture (no layer dimension),
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0) return final_o.squeeze(0)
@@ -219,11 +231,11 @@ class Attention(nn.Module):
for block_idx, cpu_block_id in enumerate(cpu_block_table): for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot) # Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id) offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0, self.layer_id) offload_engine.wait_slot_layer(0)
# IMPORTANT: Must use compute_stream to match wait_slot_layer # IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id) prev_k, prev_v = offload_engine.get_kv_for_slot(0)
prev_o, prev_lse = flash_attn_with_lse( prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v, q_batched, prev_k, prev_v,
@@ -289,21 +301,21 @@ class Attention(nn.Module):
for block_idx in range(num_blocks): for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx] cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id) offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot, self.layer_id) offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer) # Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode: if offload_engine.debug_mode:
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id) offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id) prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse( prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v, q_batched, prev_k, prev_v,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=False, causal=False,
) )
# Record compute done so next load can safely reuse this slot # Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot, self.layer_id) offload_engine.record_slot_compute_done(slot)
if o_acc is None: if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse o_acc, lse_acc = prev_o, prev_lse
else: else:
@@ -332,7 +344,7 @@ class Attention(nn.Module):
cpu_block_id = cpu_block_table[block_idx] cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete (on compute_stream) # Wait for current slot's transfer to complete (on compute_stream)
offload_engine.wait_slot_layer(current_slot, self.layer_id) offload_engine.wait_slot_layer(current_slot)
# Compute attention on current slot's data # Compute attention on current slot's data
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream # IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
@@ -342,7 +354,7 @@ class Attention(nn.Module):
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id) offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}") torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id) prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse( prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v, q_batched, prev_k, prev_v,
softmax_scale=self.scale, softmax_scale=self.scale,
@@ -351,7 +363,7 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next transfer to safely overwrite this slot # Record compute done - this allows the next transfer to safely overwrite this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id) offload_engine.record_slot_compute_done(current_slot)
# Immediately start loading the NEXT block into this slot (if more blocks remain) # Immediately start loading the NEXT block into this slot (if more blocks remain)
# Key insight: reuse current_slot immediately after compute is done! # Key insight: reuse current_slot immediately after compute is done!
@@ -464,13 +476,9 @@ class Attention(nn.Module):
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
# Get KV from current buffer FIRST, before prefetching overwrites it # Get KV from current buffer FIRST, before prefetching overwrites it
if use_compute: if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute( k_chunk, v_chunk = offload_engine.get_kv_for_compute(num_blocks_in_chunk)
self.layer_id, num_blocks_in_chunk
)
else: else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch( k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk # Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse( o_chunk, lse_chunk = flash_attn_with_lse(
@@ -512,8 +520,9 @@ class Attention(nn.Module):
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
if num_accumulated > 0: if num_accumulated > 0:
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] # GPU cache has no layer dimension
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] decode_k = offload_engine.k_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0) decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0) decode_v = decode_v.unsqueeze(0)

View File

@@ -30,9 +30,6 @@ def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor,
if layer_id != 0: if layer_id != 0:
return return
if layer_id == 0:
__import__('pdb').set_trace()
load_log.append({ load_log.append({
"chunk_idx": current_chunk[0], "chunk_idx": current_chunk[0],
"cpu_block_id": cpu_block_id, "cpu_block_id": cpu_block_id,

View File

@@ -20,7 +20,6 @@ import torch
from random import randint, seed from random import randint, seed
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
from nanovllm.kvcache.debug_utils import dump_block_state
# ============================================================ # ============================================================
@@ -97,9 +96,9 @@ def make_verified_load_to_slot_layer(original_func, offload_engine):
# cpu_block_id == chunk_idx in our sequential test # cpu_block_id == chunk_idx in our sequential test
expected_k, expected_v = get_expected_pattern(cpu_block_id) expected_k, expected_v = get_expected_pattern(cpu_block_id)
# Read GPU slot data # Read GPU slot data (GPU cache has no layer dimension)
gpu_k = offload_engine.k_cache_gpu[layer_id, slot_idx] gpu_k = offload_engine.k_cache_gpu[slot_idx]
gpu_v = offload_engine.v_cache_gpu[layer_id, slot_idx] gpu_v = offload_engine.v_cache_gpu[slot_idx]
actual_k = gpu_k.float().mean().item() actual_k = gpu_k.float().mean().item()
actual_v = gpu_v.float().mean().item() actual_v = gpu_v.float().mean().item()
@@ -306,9 +305,9 @@ def make_gpu_write_verification_post_hook(layer_id: int):
# Get expected pattern for current chunk # Get expected pattern for current chunk
expected_k, expected_v = get_expected_pattern(chunk_idx) expected_k, expected_v = get_expected_pattern(chunk_idx)
# Verify write_slot contains current chunk's data # Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[layer_id, write_slot] gpu_k = oe.k_cache_gpu[write_slot]
gpu_v = oe.v_cache_gpu[layer_id, write_slot] gpu_v = oe.v_cache_gpu[write_slot]
actual_k_mean = gpu_k.float().mean().item() actual_k_mean = gpu_k.float().mean().item()
actual_v_mean = gpu_v.float().mean().item() actual_v_mean = gpu_v.float().mean().item()
@@ -419,9 +418,9 @@ def make_post_chunk_verification_hook(layer_id: int):
expected_k, expected_v = get_expected_pattern(chunk_idx) expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check GPU ring buffer # Check GPU ring buffer (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[layer_id, ring_slot] gpu_k = oe.k_cache_gpu[ring_slot]
gpu_v = oe.v_cache_gpu[layer_id, ring_slot] gpu_v = oe.v_cache_gpu[ring_slot]
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}") k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}") v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")