[feat] Optimized with ASYNC offload.

This commit is contained in:
Zijie Tian
2025-12-15 07:21:35 +08:00
parent b8b6478506
commit 91a0f09a24
3 changed files with 93 additions and 20 deletions

View File

@@ -209,13 +209,26 @@ class Attention(nn.Module):
offload_engine,
):
"""
Ring buffer synchronous loading for previous chunks.
Ring buffer async pipeline loading with double buffering.
For correctness, we use synchronous loading:
- Load one block at a time
- Wait for transfer, compute attention, then load next
Uses compute_done events to ensure safe buffer reuse:
- Before loading to slot X, wait for previous compute on slot X to finish
- Before computing on slot X, wait for load to slot X to finish
This ensures no data races between transfer and computation.
Timeline with 2 slots (A, B):
┌──────────────┐
│ Load B0→A │
└──────────────┘
┌──────────────┐ ┌──────────────┐
│ Load B1→B │ │ Load B2→A │ ...
└──────────────┘ └──────────────┘
↘ ↘
┌──────────────┐ ┌──────────────┐
│ Compute(A) │ │ Compute(B) │ ...
└──────────────┘ └──────────────┘
The load_to_slot_layer internally waits for compute_done[slot] before
starting the transfer, ensuring no data race.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -224,29 +237,62 @@ class Attention(nn.Module):
return None, None
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
return None, None
o_acc, lse_acc = None, None
# Process blocks one by one (synchronous)
if pipeline_depth == 1:
# Only 1 slot available, cannot pipeline - use synchronous mode
slot = load_slots[0]
for block_idx in range(num_blocks):
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
offload_engine.wait_slot_layer(slot, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot, self.layer_id)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
# Double buffering with 2 slots
slot_A = load_slots[0]
slot_B = load_slots[1]
# Pre-load first block to slot_A (async)
offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0])
for block_idx in range(num_blocks):
# Determine which slot to use (cycle through load_slots)
slot_idx = load_slots[block_idx % pipeline_depth]
cpu_block_id = cpu_block_table[block_idx]
# Alternate between slot_A and slot_B
current_slot = slot_A if block_idx % 2 == 0 else slot_B
next_slot = slot_B if block_idx % 2 == 0 else slot_A
# Load block to slot (async)
offload_engine.load_to_slot_layer(slot_idx, self.layer_id, cpu_block_id)
# Wait for current slot's transfer to complete
offload_engine.wait_slot_layer(current_slot, self.layer_id)
# Wait for transfer to complete
offload_engine.wait_slot_layer(slot_idx, self.layer_id)
# Get KV from slot and compute attention
prev_k, prev_v = offload_engine.get_kv_for_slot(slot_idx, self.layer_id)
# Start async load of next block to the OTHER slot
# load_to_slot_layer internally waits for next_slot's compute_done
if block_idx + 1 < num_blocks:
offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1])
# Compute attention on current slot's data
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done - this allows the next round to safely load into this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse