[fix] Fixed kvcache offload problem.
This commit is contained in:
@@ -155,6 +155,11 @@ class OffloadEngine:
|
|||||||
self.ping_offload_done = torch.cuda.Event()
|
self.ping_offload_done = torch.cuda.Event()
|
||||||
self.pong_offload_done = torch.cuda.Event()
|
self.pong_offload_done = torch.cuda.Event()
|
||||||
|
|
||||||
|
# ========== Per-layer events for chunked attention ==========
|
||||||
|
# Each layer has its own event for synchronization
|
||||||
|
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||||
|
self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||||
|
|
||||||
# ========== Event tracking for async transfers ==========
|
# ========== Event tracking for async transfers ==========
|
||||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||||
|
|
||||||
@@ -836,10 +841,80 @@ class OffloadEngine:
|
|||||||
"""Wait for Compute region loading to complete."""
|
"""Wait for Compute region loading to complete."""
|
||||||
self.compute_stream.wait_event(self.compute_ready)
|
self.compute_stream.wait_event(self.compute_ready)
|
||||||
|
|
||||||
|
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||||
|
"""
|
||||||
|
Load CPU blocks to Compute region for a single layer only.
|
||||||
|
|
||||||
|
This is used for per-layer chunked attention where each layer
|
||||||
|
independently loads its KV data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index to load
|
||||||
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
|
"""
|
||||||
|
if not cpu_block_ids:
|
||||||
|
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||||
|
return
|
||||||
|
|
||||||
|
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
|
||||||
|
logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
|
||||||
|
|
||||||
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
|
for i in range(num_to_load):
|
||||||
|
cpu_id = cpu_block_ids[i]
|
||||||
|
gpu_slot = self.compute_slots[i]
|
||||||
|
# Copy only this layer (not all layers)
|
||||||
|
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||||
|
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||||
|
|
||||||
|
def wait_compute_layer(self, layer_id: int) -> None:
|
||||||
|
"""Wait for specific layer's Compute region loading to complete."""
|
||||||
|
self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id])
|
||||||
|
|
||||||
def wait_prefetch(self) -> None:
|
def wait_prefetch(self) -> None:
|
||||||
"""Wait for Prefetch region loading to complete."""
|
"""Wait for Prefetch region loading to complete."""
|
||||||
self.compute_stream.wait_event(self.prefetch_ready)
|
self.compute_stream.wait_event(self.prefetch_ready)
|
||||||
|
|
||||||
|
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||||
|
"""
|
||||||
|
Load CPU blocks to Prefetch region for a single layer only.
|
||||||
|
|
||||||
|
This is used for per-layer chunked attention where each layer
|
||||||
|
independently loads its KV data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index to load
|
||||||
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
|
"""
|
||||||
|
if not cpu_block_ids:
|
||||||
|
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||||
|
return
|
||||||
|
|
||||||
|
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
|
||||||
|
logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
|
||||||
|
|
||||||
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
|
for i in range(num_to_load):
|
||||||
|
cpu_id = cpu_block_ids[i]
|
||||||
|
gpu_slot = self.prefetch_slots[i]
|
||||||
|
# Copy only this layer (not all layers)
|
||||||
|
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||||
|
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||||
|
|
||||||
|
def wait_prefetch_layer(self, layer_id: int) -> None:
|
||||||
|
"""Wait for specific layer's Prefetch region loading to complete."""
|
||||||
|
self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id])
|
||||||
|
|
||||||
def swap_compute_prefetch(self) -> None:
|
def swap_compute_prefetch(self) -> None:
|
||||||
"""Swap roles of Compute region and Prefetch region."""
|
"""Swap roles of Compute region and Prefetch region."""
|
||||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||||
|
|||||||
@@ -136,36 +136,20 @@ class Attention(nn.Module):
|
|||||||
# Use Prefetch region to load previous KV (won't conflict with current Compute region)
|
# Use Prefetch region to load previous KV (won't conflict with current Compute region)
|
||||||
prefetch_size = offload_engine.num_prefetch_blocks
|
prefetch_size = offload_engine.num_prefetch_blocks
|
||||||
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
||||||
use_compute = True # Alternate between Compute region and Prefetch region
|
|
||||||
|
|
||||||
# First load previous KV to Prefetch region
|
|
||||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
|
||||||
first_chunk_end = min(prefetch_size, len(cpu_block_table))
|
|
||||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
|
||||||
if self.layer_id == 0:
|
|
||||||
offload_engine.load_to_prefetch(first_chunk_ids)
|
|
||||||
|
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
start = chunk_idx * prefetch_size
|
start = chunk_idx * prefetch_size
|
||||||
end = min(start + prefetch_size, len(cpu_block_table))
|
end = min(start + prefetch_size, len(cpu_block_table))
|
||||||
num_blocks_in_chunk = end - start
|
num_blocks_in_chunk = end - start
|
||||||
|
chunk_ids = cpu_block_table[start:end]
|
||||||
|
|
||||||
# Prefetch next chunk to other buffer (if exists)
|
# Load this chunk to Prefetch region (per-layer loading)
|
||||||
# Only layer 0 triggers the load
|
# Each layer loads only its own KV, avoiding the bug where layer 0
|
||||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
# loads all layers and overwrites data before other layers can read it
|
||||||
next_start = end
|
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
|
||||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
|
||||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
|
||||||
if use_compute:
|
|
||||||
# Currently in Prefetch region, next load to Compute region (if space available)
|
|
||||||
# Note: Compute region already has current chunk's KV written, cannot overwrite
|
|
||||||
# So here we use simple sync strategy: wait for current to complete before loading
|
|
||||||
pass # Simplified version: no double buffering, only use Prefetch region
|
|
||||||
else:
|
|
||||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
|
||||||
|
|
||||||
# Wait for Prefetch region and get KV
|
# Wait for this layer's Prefetch region and get KV
|
||||||
offload_engine.wait_prefetch()
|
offload_engine.wait_prefetch_layer(self.layer_id)
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||||||
self.layer_id, num_blocks_in_chunk
|
self.layer_id, num_blocks_in_chunk
|
||||||
)
|
)
|
||||||
@@ -185,13 +169,6 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
# Load next chunk to Prefetch region (if exists)
|
|
||||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
|
||||||
next_start = end
|
|
||||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
|
||||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
|
||||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
|
||||||
|
|
||||||
# Compute attention against current chunk's KV (with causal mask)
|
# Compute attention against current chunk's KV (with causal mask)
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
q_batched,
|
q_batched,
|
||||||
@@ -262,13 +239,13 @@ class Attention(nn.Module):
|
|||||||
num_blocks_in_chunk = end - start
|
num_blocks_in_chunk = end - start
|
||||||
chunk_ids = cpu_block_table[start:end]
|
chunk_ids = cpu_block_table[start:end]
|
||||||
|
|
||||||
# Load this chunk to Compute region
|
# Load this chunk to Compute region (per-layer loading)
|
||||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
# Each layer loads only its own KV, avoiding the bug where layer 0
|
||||||
if self.layer_id == 0:
|
# loads all layers and overwrites data before other layers can read it
|
||||||
offload_engine.load_to_compute(chunk_ids)
|
offload_engine.load_to_compute_layer(self.layer_id, chunk_ids)
|
||||||
|
|
||||||
# Wait for Compute region to be ready and get KV
|
# Wait for this layer's Compute region to be ready and get KV
|
||||||
offload_engine.wait_compute()
|
offload_engine.wait_compute_layer(self.layer_id)
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||||
self.layer_id, num_blocks_in_chunk
|
self.layer_id, num_blocks_in_chunk
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,6 +33,14 @@ class Context:
|
|||||||
# Used when batching decode offloads - we need to attend to all accumulated tokens
|
# Used when batching decode offloads - we need to attend to all accumulated tokens
|
||||||
decode_start_pos_in_block: int = 0
|
decode_start_pos_in_block: int = 0
|
||||||
|
|
||||||
|
# ========== Per-layer chunked attention state ==========
|
||||||
|
# Whether chunked decode/prefill is currently active (for hooks to check)
|
||||||
|
chunked_decode_active: bool = False
|
||||||
|
# CPU block IDs for the current chunk being processed
|
||||||
|
chunked_decode_chunk_ids: List[int] = field(default_factory=list)
|
||||||
|
# Current chunk index being processed
|
||||||
|
chunked_decode_current_chunk: int = 0
|
||||||
|
|
||||||
|
|
||||||
_CONTEXT = Context()
|
_CONTEXT = Context()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user