[feat] Optimized with ASYNC offload.
This commit is contained in:
@@ -38,8 +38,8 @@ def main():
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
max_model_len=128 * 1024,
|
max_model_len=256 * 1024,
|
||||||
max_num_batched_tokens=128 * 1024,
|
max_num_batched_tokens=256 * 1024,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
num_gpu_blocks=120,
|
num_gpu_blocks=120,
|
||||||
num_prefetch_blocks=4,
|
num_prefetch_blocks=4,
|
||||||
@@ -54,12 +54,12 @@ def main():
|
|||||||
# bench_prefill(llm, num_seqs=1, input_len=1024)
|
# bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=4096)
|
# bench_prefill(llm, num_seqs=1, input_len=4096)
|
||||||
bench_prefill(llm, num_seqs=1, input_len=16 * 1024)
|
bench_prefill(llm, num_seqs=1, input_len=128 * 1024)
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Decode Benchmark (CPU Offload)")
|
print("Decode Benchmark (CPU Offload)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
bench_decode(llm, num_seqs=1, input_len=16 * 1024, max_output_len=128)
|
bench_decode(llm, num_seqs=1, input_len=128 * 1024, max_output_len=128)
|
||||||
# bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128)
|
# bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -152,6 +152,14 @@ class OffloadEngine:
|
|||||||
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||||
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||||
|
|
||||||
|
# ========== Per-slot Per-layer compute_done events for async pipeline ==========
|
||||||
|
# ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion
|
||||||
|
# This is used to ensure we don't overwrite data before it's been read by attention
|
||||||
|
self.ring_slot_compute_done = [
|
||||||
|
[torch.cuda.Event() for _ in range(num_layers)]
|
||||||
|
for _ in range(self.num_ring_slots)
|
||||||
|
]
|
||||||
|
|
||||||
# ========== Event tracking for async transfers ==========
|
# ========== Event tracking for async transfers ==========
|
||||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||||
|
|
||||||
@@ -622,11 +630,26 @@ class OffloadEngine:
|
|||||||
|
|
||||||
# ----- Per-slot Per-layer loading methods -----
|
# ----- Per-slot Per-layer loading methods -----
|
||||||
|
|
||||||
|
def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Record that computation using this slot's data is done.
|
||||||
|
|
||||||
|
This event is used by load_to_slot_layer to ensure we don't overwrite
|
||||||
|
data before it's been read by attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slot_idx: GPU slot index that was just used for computation
|
||||||
|
layer_id: Layer index
|
||||||
|
"""
|
||||||
|
self.ring_slot_compute_done[slot_idx][layer_id].record()
|
||||||
|
|
||||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
Async load a single CPU block to a ring buffer slot for one layer.
|
Async load a single CPU block to a ring buffer slot for one layer.
|
||||||
|
|
||||||
This is the core building block for ring buffer pipelining.
|
This is the core building block for ring buffer pipelining.
|
||||||
|
Before starting the transfer, waits for any previous compute on this slot
|
||||||
|
to complete (using compute_done event).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
slot_idx: Target GPU slot index
|
slot_idx: Target GPU slot index
|
||||||
@@ -636,6 +659,10 @@ class OffloadEngine:
|
|||||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
|
# Wait for previous compute on this slot to complete before overwriting
|
||||||
|
# This prevents data race: transfer must not start until attention finishes reading
|
||||||
|
self.transfer_stream_main.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
|
||||||
|
|
||||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -209,13 +209,26 @@ class Attention(nn.Module):
|
|||||||
offload_engine,
|
offload_engine,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ring buffer synchronous loading for previous chunks.
|
Ring buffer async pipeline loading with double buffering.
|
||||||
|
|
||||||
For correctness, we use synchronous loading:
|
Uses compute_done events to ensure safe buffer reuse:
|
||||||
- Load one block at a time
|
- Before loading to slot X, wait for previous compute on slot X to finish
|
||||||
- Wait for transfer, compute attention, then load next
|
- Before computing on slot X, wait for load to slot X to finish
|
||||||
|
|
||||||
This ensures no data races between transfer and computation.
|
Timeline with 2 slots (A, B):
|
||||||
|
┌──────────────┐
|
||||||
|
│ Load B0→A │
|
||||||
|
└──────────────┘
|
||||||
|
┌──────────────┐ ┌──────────────┐
|
||||||
|
│ Load B1→B │ │ Load B2→A │ ...
|
||||||
|
└──────────────┘ └──────────────┘
|
||||||
|
↘ ↘
|
||||||
|
┌──────────────┐ ┌──────────────┐
|
||||||
|
│ Compute(A) │ │ Compute(B) │ ...
|
||||||
|
└──────────────┘ └──────────────┘
|
||||||
|
|
||||||
|
The load_to_slot_layer internally waits for compute_done[slot] before
|
||||||
|
starting the transfer, ensuring no data race.
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
@@ -224,29 +237,62 @@ class Attention(nn.Module):
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
pipeline_depth = len(load_slots)
|
pipeline_depth = len(load_slots)
|
||||||
|
if pipeline_depth == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
o_acc, lse_acc = None, None
|
o_acc, lse_acc = None, None
|
||||||
|
|
||||||
# Process blocks one by one (synchronous)
|
if pipeline_depth == 1:
|
||||||
|
# Only 1 slot available, cannot pipeline - use synchronous mode
|
||||||
|
slot = load_slots[0]
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
|
||||||
|
offload_engine.wait_slot_layer(slot, self.layer_id)
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k, prev_v,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
# Record compute done so next load can safely reuse this slot
|
||||||
|
offload_engine.record_slot_compute_done(slot, self.layer_id)
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
# Double buffering with 2 slots
|
||||||
|
slot_A = load_slots[0]
|
||||||
|
slot_B = load_slots[1]
|
||||||
|
|
||||||
|
# Pre-load first block to slot_A (async)
|
||||||
|
offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0])
|
||||||
|
|
||||||
for block_idx in range(num_blocks):
|
for block_idx in range(num_blocks):
|
||||||
# Determine which slot to use (cycle through load_slots)
|
# Alternate between slot_A and slot_B
|
||||||
slot_idx = load_slots[block_idx % pipeline_depth]
|
current_slot = slot_A if block_idx % 2 == 0 else slot_B
|
||||||
cpu_block_id = cpu_block_table[block_idx]
|
next_slot = slot_B if block_idx % 2 == 0 else slot_A
|
||||||
|
|
||||||
# Load block to slot (async)
|
# Wait for current slot's transfer to complete
|
||||||
offload_engine.load_to_slot_layer(slot_idx, self.layer_id, cpu_block_id)
|
offload_engine.wait_slot_layer(current_slot, self.layer_id)
|
||||||
|
|
||||||
# Wait for transfer to complete
|
# Start async load of next block to the OTHER slot
|
||||||
offload_engine.wait_slot_layer(slot_idx, self.layer_id)
|
# load_to_slot_layer internally waits for next_slot's compute_done
|
||||||
|
if block_idx + 1 < num_blocks:
|
||||||
# Get KV from slot and compute attention
|
offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1])
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot_idx, self.layer_id)
|
|
||||||
|
|
||||||
|
# Compute attention on current slot's data
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
q_batched, prev_k, prev_v,
|
q_batched, prev_k, prev_v,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Record compute done - this allows the next round to safely load into this slot
|
||||||
|
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
|
||||||
|
|
||||||
# Merge with accumulated
|
# Merge with accumulated
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
|||||||
Reference in New Issue
Block a user