[opt] optimize nanovllm performance compareable with vllm.

This commit is contained in:
Zijie Tian
2025-12-25 03:47:07 +08:00
parent 16fcf8350b
commit 82ed34fc2d
7 changed files with 450 additions and 208 deletions

View File

@@ -141,11 +141,20 @@ class OffloadEngine:
# ========== Transfer streams for async operations ==========
self.transfer_streams = [torch.cuda.Stream() for _ in range(num_streams)]
self.compute_stream = torch.cuda.current_stream()
# IMPORTANT: Create a dedicated compute stream (not default stream!)
# Default stream has implicit synchronization with other streams,
# which prevents overlap between transfer and compute.
self.compute_stream = torch.cuda.Stream()
self._stream_idx = 0
# ========== Per-slot transfer streams for parallel H2D ==========
# Each slot has its own stream to enable parallel transfers
# This allows multiple slots to load simultaneously
self.slot_transfer_streams = [torch.cuda.Stream() for _ in range(self.num_ring_slots)]
logger.info(f" Created {self.num_ring_slots} per-slot transfer streams")
# ========== Ring Buffer dedicated stream and events ==========
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream (for legacy/batch ops)
# Decode offload event
self.decode_offload_done = torch.cuda.Event()
@@ -174,6 +183,13 @@ class OffloadEngine:
for _ in range(self.num_ring_slots)
]
# Initialize all compute_done events (record them once)
# This prevents undefined behavior on first load_to_slot_layer call
for slot_idx in range(self.num_ring_slots):
for layer_id in range(num_layers):
self.ring_slot_compute_done[slot_idx][layer_id].record()
torch.cuda.synchronize() # Ensure all events are recorded
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
@@ -676,11 +692,14 @@ class OffloadEngine:
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
# Use per-slot stream for parallel transfers across different slots
stream = self.slot_transfer_streams[slot_idx]
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
with torch.cuda.stream(stream):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
self.transfer_stream_main.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
@@ -688,7 +707,7 @@ class OffloadEngine:
self.v_cache_gpu[layer_id, slot_idx].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
self.ring_slot_ready[slot_idx][layer_id].record(stream)
torch.cuda.nvtx.range_pop()
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:

View File

@@ -287,46 +287,56 @@ class Attention(nn.Module):
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
# Double buffering with 2 slots
slot_A = load_slots[0]
slot_B = load_slots[1]
# N-way pipeline: use ALL available slots for maximum overlap
# Pipeline depth = num_slots - 1 (num_slots blocks in flight)
num_slots = len(load_slots)
# Pre-load first block to slot_A (async)
offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0])
# Phase 1: Pre-load up to num_slots blocks to fill the pipeline
# This starts all transfers in parallel, utilizing full PCIe bandwidth
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
# Phase 2: Main loop - compute and immediately reuse slot for next transfer
# Use dedicated compute_stream (not default stream) to enable overlap with transfers
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
# Alternate between slot_A and slot_B
current_slot = slot_A if block_idx % 2 == 0 else slot_B
next_slot = slot_B if block_idx % 2 == 0 else slot_A
# Cycle through slots: slot[block_idx % num_slots]
current_slot = load_slots[block_idx % num_slots]
# Wait for current slot's transfer to complete
# Wait for current slot's transfer to complete (on compute_stream)
offload_engine.wait_slot_layer(current_slot, self.layer_id)
# Start async load of next block to the OTHER slot
# load_to_slot_layer internally waits for next_slot's compute_done
if block_idx + 1 < num_blocks:
offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1])
# Compute attention on current slot's data
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next round to safely load into this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
# Record compute done - this allows the next transfer to safely overwrite this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Immediately start loading the NEXT block into this slot (if more blocks remain)
# Key insight: reuse current_slot immediately after compute is done!
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
# Merge with accumulated (also on compute_stream for consistency)
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop() # PipelineBlock