[WIP] Before fix needle.
This commit is contained in:
@@ -201,6 +201,18 @@ class Attention(nn.Module):
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
|
||||
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
||||
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
if seq is not None:
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
||||
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
@@ -219,11 +231,11 @@ class Attention(nn.Module):
|
||||
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
||||
# Load to slot 0 (single slot)
|
||||
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(0, self.layer_id)
|
||||
offload_engine.wait_slot_layer(0)
|
||||
|
||||
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
@@ -289,21 +301,21 @@ class Attention(nn.Module):
|
||||
for block_idx in range(num_blocks):
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot, self.layer_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||
if offload_engine.debug_mode:
|
||||
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
# Record compute done so next load can safely reuse this slot
|
||||
offload_engine.record_slot_compute_done(slot, self.layer_id)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
@@ -332,7 +344,7 @@ class Attention(nn.Module):
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
|
||||
# Wait for current slot's transfer to complete (on compute_stream)
|
||||
offload_engine.wait_slot_layer(current_slot, self.layer_id)
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
# Compute attention on current slot's data
|
||||
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
||||
@@ -342,7 +354,7 @@ class Attention(nn.Module):
|
||||
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
@@ -351,7 +363,7 @@ class Attention(nn.Module):
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
# Record compute done - this allows the next transfer to safely overwrite this slot
|
||||
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
# Immediately start loading the NEXT block into this slot (if more blocks remain)
|
||||
# Key insight: reuse current_slot immediately after compute is done!
|
||||
@@ -464,13 +476,9 @@ class Attention(nn.Module):
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Get KV from current buffer FIRST, before prefetching overwrites it
|
||||
if use_compute:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(num_blocks_in_chunk)
|
||||
else:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
@@ -512,8 +520,9 @@ class Attention(nn.Module):
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if num_accumulated > 0:
|
||||
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
# GPU cache has no layer dimension
|
||||
decode_k = offload_engine.k_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
decode_v = offload_engine.v_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
decode_k = decode_k.unsqueeze(0)
|
||||
decode_v = decode_v.unsqueeze(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user