[fix] Fixed kvcache offload problem.
This commit is contained in:
@@ -136,36 +136,20 @@ class Attention(nn.Module):
|
||||
# Use Prefetch region to load previous KV (won't conflict with current Compute region)
|
||||
prefetch_size = offload_engine.num_prefetch_blocks
|
||||
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
||||
use_compute = True # Alternate between Compute region and Prefetch region
|
||||
|
||||
# First load previous KV to Prefetch region
|
||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
||||
first_chunk_end = min(prefetch_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
if self.layer_id == 0:
|
||||
offload_engine.load_to_prefetch(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * prefetch_size
|
||||
end = min(start + prefetch_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
chunk_ids = cpu_block_table[start:end]
|
||||
|
||||
# Prefetch next chunk to other buffer (if exists)
|
||||
# Only layer 0 triggers the load
|
||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||||
next_start = end
|
||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if use_compute:
|
||||
# Currently in Prefetch region, next load to Compute region (if space available)
|
||||
# Note: Compute region already has current chunk's KV written, cannot overwrite
|
||||
# So here we use simple sync strategy: wait for current to complete before loading
|
||||
pass # Simplified version: no double buffering, only use Prefetch region
|
||||
else:
|
||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
||||
# Load this chunk to Prefetch region (per-layer loading)
|
||||
# Each layer loads only its own KV, avoiding the bug where layer 0
|
||||
# loads all layers and overwrites data before other layers can read it
|
||||
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
|
||||
|
||||
# Wait for Prefetch region and get KV
|
||||
offload_engine.wait_prefetch()
|
||||
# Wait for this layer's Prefetch region and get KV
|
||||
offload_engine.wait_prefetch_layer(self.layer_id)
|
||||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
@@ -185,13 +169,6 @@ class Attention(nn.Module):
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
# Load next chunk to Prefetch region (if exists)
|
||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||||
next_start = end
|
||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
@@ -262,13 +239,13 @@ class Attention(nn.Module):
|
||||
num_blocks_in_chunk = end - start
|
||||
chunk_ids = cpu_block_table[start:end]
|
||||
|
||||
# Load this chunk to Compute region
|
||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
||||
if self.layer_id == 0:
|
||||
offload_engine.load_to_compute(chunk_ids)
|
||||
# Load this chunk to Compute region (per-layer loading)
|
||||
# Each layer loads only its own KV, avoiding the bug where layer 0
|
||||
# loads all layers and overwrites data before other layers can read it
|
||||
offload_engine.load_to_compute_layer(self.layer_id, chunk_ids)
|
||||
|
||||
# Wait for Compute region to be ready and get KV
|
||||
offload_engine.wait_compute()
|
||||
# Wait for this layer's Compute region to be ready and get KV
|
||||
offload_engine.wait_compute_layer(self.layer_id)
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user