[feat] Need to optimized with async prefetch.

This commit is contained in:
Zijie Tian
2025-12-15 06:58:40 +08:00
parent 1081ab51ea
commit b8b6478506
9 changed files with 556 additions and 404 deletions

View File

@@ -100,16 +100,19 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute attention with three-region GPU buffer for chunked prefill.
Compute attention with unified ring buffer for chunked prefill.
For chunked prefill:
1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks)
2. Compute attention against previous KV chunks (no causal mask)
3. Compute attention against current chunk's KV (causal)
4. Merge all results using online softmax
Ring buffer design:
- Current chunk's KV is written to ring_slot[chunk_idx % N]
- Previous chunks' KV are loaded from CPU using N-1 available slots
- Pipeline: pre-fill slots, then process with overlapped load/compute
Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded
from CPU to Prefetch region, so write and load regions never overlap.
For each layer:
1. Current chunk's KV is in k_batched, v_batched (just written by model)
2. Load previous chunks from CPU using available slots (pipeline)
3. Compute attention against previous KV (no causal mask)
4. Compute attention against current KV (causal)
5. Merge all results using online softmax
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -122,51 +125,33 @@ class Attention(nn.Module):
o_acc = None
lse_acc = None
# Load previous KV from CPU using Compute/Prefetch region
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
current_chunk_idx = context.current_chunk_idx
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks already written in previous chunks)
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
# For prefill: ONLY use Prefetch region to avoid conflict with
# current chunk's KV being written to Compute region slots
# Use synchronous per-layer loading (async would conflict with writes)
chunk_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Get write slot for current chunk and available load slots
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
pipeline_depth = len(load_slots)
# Load to Prefetch region (per-layer, sync)
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
offload_engine.wait_prefetch_layer(self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
if pipeline_depth == 0:
# Only 1 slot total, cannot pipeline - use sync loading
o_acc, lse_acc = self._sync_load_previous_chunks(
q_batched, cpu_block_table, offload_engine
)
# Compute attention against this chunk (no causal mask)
prev_o, prev_lse = flash_attn_with_lse(
q_batched,
prev_k,
prev_v,
softmax_scale=self.scale,
causal=False,
else:
# Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse(
q_batched,
@@ -185,6 +170,91 @@ class Attention(nn.Module):
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
def _sync_load_previous_chunks(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
):
"""Synchronous loading fallback when pipeline_depth=0."""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _ring_buffer_pipeline_load(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
):
"""
Ring buffer synchronous loading for previous chunks.
For correctness, we use synchronous loading:
- Load one block at a time
- Wait for transfer, compute attention, then load next
This ensures no data races between transfer and computation.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
pipeline_depth = len(load_slots)
o_acc, lse_acc = None, None
# Process blocks one by one (synchronous)
for block_idx in range(num_blocks):
# Determine which slot to use (cycle through load_slots)
slot_idx = load_slots[block_idx % pipeline_depth]
cpu_block_id = cpu_block_table[block_idx]
# Load block to slot (async)
offload_engine.load_to_slot_layer(slot_idx, self.layer_id, cpu_block_id)
# Wait for transfer to complete
offload_engine.wait_slot_layer(slot_idx, self.layer_id)
# Get KV from slot and compute attention
prev_k, prev_v = offload_engine.get_kv_for_slot(slot_idx, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _chunked_decode_attention(
self,
q: torch.Tensor,
@@ -193,20 +263,24 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with async double-buffering using Compute and Prefetch regions.
Compute decode attention with double-buffering using decode_load_slots.
Decode uses:
- decode_slot (slot[0]): writes new token's KV
- decode_load_slots (slots[1:]): load previous chunks from CPU
Pipeline design:
- Compute region: holds current chunk being computed
- Prefetch region: async loads next chunk while current is computing
- After computation, swap roles of the two regions
- First half of decode_load_slots: 'compute' buffer
- Second half: 'prefetch' buffer
- Double-buffer between them for async overlap
Timeline:
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ...
│Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
└─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
Compute C0 │ │ Compute C1 │ │ Compute C2
Attn(C0) │ │ Attn(C1) │ │ Attn(C2)
└─────────────┘ └─────────────┘ └─────────────┘
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs