[refactor] Implement real chunked prefill mechenism.
This commit is contained in:
@@ -174,45 +174,76 @@ class Attention(nn.Module):
|
||||
"""
|
||||
Compute decode attention with KV spread across CPU and GPU.
|
||||
|
||||
For decode with chunked KV:
|
||||
1. Load all KV for this layer from CPU+GPU
|
||||
2. Compute attention (1 query token vs all KV)
|
||||
3. Return output
|
||||
Uses chunked attention similar to chunked prefill:
|
||||
1. Process blocks on GPU first (if any)
|
||||
2. Load CPU blocks in chunks to GPU slots (per-layer)
|
||||
3. Compute attention for each chunk, merge with online softmax
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
# We need to attend to ALL previous tokens
|
||||
# Need: [batch, seqlen, heads, dim]
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
# Load all KV for this layer
|
||||
if context.offload_engine is not None and self.layer_id >= 0:
|
||||
kvcache_manager = context.offload_engine
|
||||
kvcache_manager = context.offload_engine
|
||||
seq = context.chunked_seq
|
||||
|
||||
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
|
||||
# Load all KV from both GPU and CPU for this layer
|
||||
k_all, v_all = kvcache_manager.load_all_kv_for_layer(
|
||||
context.chunked_seq,
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
# Step 1: Process blocks already on GPU (if any)
|
||||
gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq)
|
||||
if gpu_slots:
|
||||
k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots)
|
||||
o_gpu, lse_gpu = flash_attn_with_lse(
|
||||
q_batched, k_gpu, v_gpu,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
o_acc, lse_acc = o_gpu, lse_gpu
|
||||
|
||||
# Step 2: Process CPU blocks in chunks
|
||||
# Get chunk info from kvcache_manager
|
||||
cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq)
|
||||
|
||||
if num_chunks > 0:
|
||||
# Use num_gpu_slots - 1 to avoid the reserved slot (used for write block)
|
||||
chunk_size = kvcache_manager.num_gpu_slots - 1
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * chunk_size
|
||||
end = min(start + chunk_size, len(cpu_block_ids))
|
||||
chunk_cpu_ids = cpu_block_ids[start:end]
|
||||
|
||||
# Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER
|
||||
# (slot num_gpu_slots-1 is reserved for write block)
|
||||
gpu_slots_for_chunk = list(range(len(chunk_cpu_ids)))
|
||||
kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots(
|
||||
self.layer_id,
|
||||
chunk_cpu_ids,
|
||||
gpu_slots_for_chunk,
|
||||
)
|
||||
|
||||
if k_all is not None and v_all is not None:
|
||||
# q shape: [batch_size, num_heads, head_dim]
|
||||
# Need: [batch, seqlen, heads, dim]
|
||||
# Insert seqlen dimension at position 1
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
# Get KV for this chunk
|
||||
k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots(
|
||||
self.layer_id, gpu_slots_for_chunk
|
||||
)
|
||||
|
||||
# k_all, v_all shape: [1, total_kv_tokens, kv_heads, head_dim]
|
||||
# Compute attention (no causal mask for decode - we want all KV)
|
||||
out, _ = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_all,
|
||||
v_all,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # No causal mask for decode
|
||||
)
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
|
||||
return out.squeeze(1)
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o_chunk, lse_chunk
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
|
||||
# Fallback: shouldn't reach here
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
|
||||
return o_acc.squeeze(1)
|
||||
|
||||
Reference in New Issue
Block a user