[fix] Fixed kvcache offload problem.

This commit is contained in:
Zijie Tian
2025-12-12 01:35:30 +08:00
parent 60d24f7c12
commit 9b8165af5a
3 changed files with 96 additions and 36 deletions

View File

@@ -136,36 +136,20 @@ class Attention(nn.Module):
# Use Prefetch region to load previous KV (won't conflict with current Compute region)
prefetch_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
use_compute = True # Alternate between Compute region and Prefetch region
# First load previous KV to Prefetch region
# Only layer 0 triggers the load (loads ALL layers at once)
first_chunk_end = min(prefetch_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
if self.layer_id == 0:
offload_engine.load_to_prefetch(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * prefetch_size
end = min(start + prefetch_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Prefetch next chunk to other buffer (if exists)
# Only layer 0 triggers the load
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
next_start = end
next_end = min(next_start + prefetch_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute:
# Currently in Prefetch region, next load to Compute region (if space available)
# Note: Compute region already has current chunk's KV written, cannot overwrite
# So here we use simple sync strategy: wait for current to complete before loading
pass # Simplified version: no double buffering, only use Prefetch region
else:
offload_engine.load_to_prefetch(next_chunk_ids)
# Load this chunk to Prefetch region (per-layer loading)
# Each layer loads only its own KV, avoiding the bug where layer 0
# loads all layers and overwrites data before other layers can read it
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
# Wait for Prefetch region and get KV
offload_engine.wait_prefetch()
# Wait for this layer's Prefetch region and get KV
offload_engine.wait_prefetch_layer(self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
@@ -185,13 +169,6 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Load next chunk to Prefetch region (if exists)
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
next_start = end
next_end = min(next_start + prefetch_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
offload_engine.load_to_prefetch(next_chunk_ids)
# Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse(
q_batched,
@@ -262,13 +239,13 @@ class Attention(nn.Module):
num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Load this chunk to Compute region
# Only layer 0 triggers the load (loads ALL layers at once)
if self.layer_id == 0:
offload_engine.load_to_compute(chunk_ids)
# Load this chunk to Compute region (per-layer loading)
# Each layer loads only its own KV, avoiding the bug where layer 0
# loads all layers and overwrites data before other layers can read it
offload_engine.load_to_compute_layer(self.layer_id, chunk_ids)
# Wait for Compute region to be ready and get KV
offload_engine.wait_compute()
# Wait for this layer's Compute region to be ready and get KV
offload_engine.wait_compute_layer(self.layer_id)
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)