[WIP] Before fix needle.

This commit is contained in:
Zijie Tian
2025-12-31 23:35:25 +08:00
parent ccd1b3d4ab
commit 30462fe89a
5 changed files with 212 additions and 290 deletions

View File

@@ -201,6 +201,18 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Per-layer offload: In new GPU cache architecture (no layer dimension),
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -219,11 +231,11 @@ class Attention(nn.Module):
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0, self.layer_id)
offload_engine.wait_slot_layer(0)
# IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
@@ -289,21 +301,21 @@ class Attention(nn.Module):
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot, self.layer_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot, self.layer_id)
offload_engine.record_slot_compute_done(slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
@@ -332,7 +344,7 @@ class Attention(nn.Module):
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete (on compute_stream)
offload_engine.wait_slot_layer(current_slot, self.layer_id)
offload_engine.wait_slot_layer(current_slot)
# Compute attention on current slot's data
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
@@ -342,7 +354,7 @@ class Attention(nn.Module):
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
@@ -351,7 +363,7 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next transfer to safely overwrite this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
offload_engine.record_slot_compute_done(current_slot)
# Immediately start loading the NEXT block into this slot (if more blocks remain)
# Key insight: reuse current_slot immediately after compute is done!
@@ -464,13 +476,9 @@ class Attention(nn.Module):
with torch.cuda.stream(compute_stream):
# Get KV from current buffer FIRST, before prefetching overwrites it
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
k_chunk, v_chunk = offload_engine.get_kv_for_compute(num_blocks_in_chunk)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
@@ -512,8 +520,9 @@ class Attention(nn.Module):
with torch.cuda.stream(compute_stream):
if num_accumulated > 0:
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
# GPU cache has no layer dimension
decode_k = offload_engine.k_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)