[feat] Need to optimized with async prefetch.
This commit is contained in:
@@ -100,16 +100,19 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with three-region GPU buffer for chunked prefill.
|
||||
Compute attention with unified ring buffer for chunked prefill.
|
||||
|
||||
For chunked prefill:
|
||||
1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks)
|
||||
2. Compute attention against previous KV chunks (no causal mask)
|
||||
3. Compute attention against current chunk's KV (causal)
|
||||
4. Merge all results using online softmax
|
||||
Ring buffer design:
|
||||
- Current chunk's KV is written to ring_slot[chunk_idx % N]
|
||||
- Previous chunks' KV are loaded from CPU using N-1 available slots
|
||||
- Pipeline: pre-fill slots, then process with overlapped load/compute
|
||||
|
||||
Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded
|
||||
from CPU to Prefetch region, so write and load regions never overlap.
|
||||
For each layer:
|
||||
1. Current chunk's KV is in k_batched, v_batched (just written by model)
|
||||
2. Load previous chunks from CPU using available slots (pipeline)
|
||||
3. Compute attention against previous KV (no causal mask)
|
||||
4. Compute attention against current KV (causal)
|
||||
5. Merge all results using online softmax
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -122,51 +125,33 @@ class Attention(nn.Module):
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
# Load previous KV from CPU using Compute/Prefetch region
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
current_chunk_idx = context.current_chunk_idx
|
||||
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
# Get prefilled CPU blocks (blocks already written in previous chunks)
|
||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
# For prefill: ONLY use Prefetch region to avoid conflict with
|
||||
# current chunk's KV being written to Compute region slots
|
||||
# Use synchronous per-layer loading (async would conflict with writes)
|
||||
chunk_size = offload_engine.num_prefetch_blocks
|
||||
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * chunk_size
|
||||
end = min(start + chunk_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
chunk_ids = cpu_block_table[start:end]
|
||||
# Get write slot for current chunk and available load slots
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
||||
pipeline_depth = len(load_slots)
|
||||
|
||||
# Load to Prefetch region (per-layer, sync)
|
||||
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
|
||||
offload_engine.wait_prefetch_layer(self.layer_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
if pipeline_depth == 0:
|
||||
# Only 1 slot total, cannot pipeline - use sync loading
|
||||
o_acc, lse_acc = self._sync_load_previous_chunks(
|
||||
q_batched, cpu_block_table, offload_engine
|
||||
)
|
||||
|
||||
# Compute attention against this chunk (no causal mask)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
prev_k,
|
||||
prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
else:
|
||||
# Use ring buffer pipeline
|
||||
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
@@ -185,6 +170,91 @@ class Attention(nn.Module):
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
def _sync_load_previous_chunks(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
offload_engine,
|
||||
):
|
||||
"""Synchronous loading fallback when pipeline_depth=0."""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
o_acc, lse_acc = None, None
|
||||
|
||||
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
||||
# Load to slot 0 (single slot)
|
||||
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(0, self.layer_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
def _ring_buffer_pipeline_load(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
load_slots: list,
|
||||
offload_engine,
|
||||
):
|
||||
"""
|
||||
Ring buffer synchronous loading for previous chunks.
|
||||
|
||||
For correctness, we use synchronous loading:
|
||||
- Load one block at a time
|
||||
- Wait for transfer, compute attention, then load next
|
||||
|
||||
This ensures no data races between transfer and computation.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
return None, None
|
||||
|
||||
pipeline_depth = len(load_slots)
|
||||
o_acc, lse_acc = None, None
|
||||
|
||||
# Process blocks one by one (synchronous)
|
||||
for block_idx in range(num_blocks):
|
||||
# Determine which slot to use (cycle through load_slots)
|
||||
slot_idx = load_slots[block_idx % pipeline_depth]
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
|
||||
# Load block to slot (async)
|
||||
offload_engine.load_to_slot_layer(slot_idx, self.layer_id, cpu_block_id)
|
||||
|
||||
# Wait for transfer to complete
|
||||
offload_engine.wait_slot_layer(slot_idx, self.layer_id)
|
||||
|
||||
# Get KV from slot and compute attention
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot_idx, self.layer_id)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
def _chunked_decode_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -193,20 +263,24 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with async double-buffering using Compute and Prefetch regions.
|
||||
Compute decode attention with double-buffering using decode_load_slots.
|
||||
|
||||
Decode uses:
|
||||
- decode_slot (slot[0]): writes new token's KV
|
||||
- decode_load_slots (slots[1:]): load previous chunks from CPU
|
||||
|
||||
Pipeline design:
|
||||
- Compute region: holds current chunk being computed
|
||||
- Prefetch region: async loads next chunk while current is computing
|
||||
- After computation, swap roles of the two regions
|
||||
- First half of decode_load_slots: 'compute' buffer
|
||||
- Second half: 'prefetch' buffer
|
||||
- Double-buffer between them for async overlap
|
||||
|
||||
Timeline:
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ...
|
||||
│Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
↘ ↘ ↘
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│ Compute C0 │ │ Compute C1 │ │ Compute C2 │
|
||||
│ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
Reference in New Issue
Block a user