[refactor] Implement real chunked prefill mechenism.

This commit is contained in:
Zijie Tian
2025-12-10 18:34:01 +08:00
parent 0b6f19242d
commit 87055cc5ce
4 changed files with 313 additions and 85 deletions

View File

@@ -429,37 +429,26 @@ class ModelRunner:
Run decode with chunked attention when sequence exceeds GPU capacity. Run decode with chunked attention when sequence exceeds GPU capacity.
For decode, we need attention over ALL previous tokens. With CPU offload, For decode, we need attention over ALL previous tokens. With CPU offload,
we load KV chunks and compute attention incrementally. we load KV chunks and compute attention incrementally per-layer.
"""
import sys
Flow:
1. Ensure last block is on GPU (for writing new KV token)
2. Run model forward - each attention layer:
a. Compute attention on GPU blocks
b. Load CPU blocks in chunks, compute + merge
3. Sample from output
"""
# Currently only supporting single sequence for chunked decode # Currently only supporting single sequence for chunked decode
assert len(seqs) == 1, "Chunked decode only supports single sequence" assert len(seqs) == 1, "Chunked decode only supports single sequence"
seq = seqs[0] seq = seqs[0]
total_blocks = len(seq.block_table)
print(f"[Chunked Decode] Sequence has {total_blocks} blocks, "
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
# Prepare inputs # Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# Compute slot mapping for the new token # Ensure last block is on GPU for writing new KV token
# Get the last block's GPU slot if it's on GPU, otherwise we need to handle it last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq)
last_logical_id = seq.block_table[-1] slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1
last_block = self.kvcache_manager.logical_blocks[last_logical_id]
if last_block.location.name == "GPU":
slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1
else:
# Last block is on CPU - we need to bring it to GPU for writing the new token
# This is a special case - allocate a temporary GPU slot
# For simplicity, use a fixed slot (this might conflict, but for decode
# we only write 1 token so it should be ok)
print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr)
slot = 0 # Use first slot temporarily
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -468,12 +457,13 @@ class ModelRunner:
is_prefill=False, # Decode mode is_prefill=False, # Decode mode
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
context_lens=context_len, context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager, offload_engine=self.kvcache_manager,
chunked_seq=seq, chunked_seq=seq,
) )
# Run model forward pass # Run model forward pass
# Each attention layer will handle chunked KV loading internally
logits = self.run_model(input_ids, positions, is_prefill=False) logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context() reset_context()

View File

@@ -566,50 +566,151 @@ class HybridKVCacheManager(KVCacheManager):
cpu_blocks += 1 cpu_blocks += 1
return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots
def load_all_kv_for_layer( # ========== Chunked Decode Support ==========
self,
seq: Sequence,
layer_id: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load ALL KV for a sequence from both GPU and CPU for a layer.
Used during chunked decode to compute full attention. def get_decode_chunk_info(self, seq: Sequence) -> Tuple[List[int], List[int], int]:
"""
Get information for chunked decode.
Returns: Returns:
(k, v) tensors with shape [1, total_tokens, kv_heads, head_dim] (cpu_block_ids, cpu_logical_ids, num_chunks)
- cpu_block_ids: List of CPU block IDs in sequence order
- cpu_logical_ids: Corresponding logical block IDs
- num_chunks: Number of chunks needed
""" """
k_chunks = [] cpu_block_ids = []
v_chunks = [] cpu_logical_ids = []
for logical_id in seq.block_table: for logical_id in seq.block_table:
block = self.logical_blocks[logical_id] block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_block_ids.append(block.cpu_block_id)
cpu_logical_ids.append(logical_id)
# Each chunk uses available GPU slots minus 1 (reserved for write block)
usable_slots = self.num_gpu_slots - 1
num_chunks = (len(cpu_block_ids) + usable_slots - 1) // usable_slots if usable_slots > 0 else 0
return cpu_block_ids, cpu_logical_ids, num_chunks
def load_decode_chunk(
self,
seq: Sequence,
cpu_block_ids: List[int],
cpu_logical_ids: List[int],
chunk_idx: int,
) -> List[int]:
"""
Load one chunk of CPU blocks to GPU for chunked decode.
Similar to chunked prefill: uses GPU slots to hold a batch of blocks.
Args:
seq: Sequence being decoded
cpu_block_ids: All CPU block IDs for this sequence
cpu_logical_ids: Corresponding logical block IDs
chunk_idx: Which chunk to load (0-indexed)
Returns:
List of GPU slot IDs where the chunk was loaded
"""
chunk_size = self.num_gpu_slots
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_ids))
chunk_cpu_ids = cpu_block_ids[start:end]
chunk_logical_ids = cpu_logical_ids[start:end]
# Use GPU slots 0, 1, 2, ... for this chunk
gpu_slots = list(range(len(chunk_cpu_ids)))
# Load all layers at once using offload_engine
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
chunk_cpu_ids, gpu_slots
)
return gpu_slots
def get_gpu_blocks_for_decode(self, seq: Sequence) -> Tuple[List[int], List[int]]:
"""
Get blocks currently on GPU for this sequence.
Returns:
(gpu_slots, logical_ids) - GPU slot IDs and corresponding logical block IDs
"""
gpu_slots = []
logical_ids = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.GPU: if block.location == BlockLocation.GPU:
# Get from GPU cache gpu_slots.append(block.gpu_slot)
k, v = self.offload_engine.get_layer_cache(layer_id) logical_ids.append(logical_id)
# k, v shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
k_block = k[block.gpu_slot] # [block_size, kv_heads, head_dim]
v_block = v[block.gpu_slot]
k_chunks.append(k_block)
v_chunks.append(v_block)
elif block.location == BlockLocation.CPU: return gpu_slots, logical_ids
# Get from CPU cache
k_block, v_block = self.offload_engine.get_cpu_block(layer_id, block.cpu_block_id)
# Already [block_size, kv_heads, head_dim]
k_chunks.append(k_block.to("cuda", non_blocking=True))
v_chunks.append(v_block.to("cuda", non_blocking=True))
# Concatenate all chunks def get_kv_for_gpu_slots(
k_all = torch.cat(k_chunks, dim=0) # [total_tokens, kv_heads, head_dim] self,
v_all = torch.cat(v_chunks, dim=0) layer_id: int,
gpu_slots: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get KV tensors for specific GPU slots.
# Add batch dimension Args:
k_all = k_all.unsqueeze(0) # [1, total_tokens, kv_heads, head_dim] layer_id: Layer index
v_all = v_all.unsqueeze(0) gpu_slots: List of GPU slot IDs
return k_all, v_all Returns:
(k, v) tensors with shape [1, num_tokens, kv_heads, head_dim]
"""
k_cache, v_cache = self.offload_engine.get_layer_cache(layer_id)
# k_cache, v_cache shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
k_chunks = [k_cache[slot] for slot in gpu_slots]
v_chunks = [v_cache[slot] for slot in gpu_slots]
# Concatenate and add batch dimension
k = torch.cat(k_chunks, dim=0).unsqueeze(0) # [1, tokens, heads, dim]
v = torch.cat(v_chunks, dim=0).unsqueeze(0)
return k, v
def ensure_last_block_on_gpu(self, seq: Sequence) -> int:
"""
Ensure the last block is on GPU for writing new KV.
Uses a RESERVED slot (last slot) to avoid conflicts with chunked decode
which uses slots 0, 1, 2, ... for loading CPU blocks.
Returns:
GPU slot ID for the last block
"""
last_logical_id = seq.block_table[-1]
block = self.logical_blocks[last_logical_id]
if block.location == BlockLocation.GPU:
return block.gpu_slot
# Use last slot as reserved slot for write block
# This avoids conflicts with chunked decode which uses slots 0, 1, 2...
reserved_slot = self.num_gpu_slots - 1
# Load this block to GPU for all layers
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
[block.cpu_block_id], [reserved_slot]
)
# Update block state
self.free_cpu_blocks.append(block.cpu_block_id)
del self.cpu_block_to_logical[block.cpu_block_id]
self.gpu_slot_to_logical[reserved_slot] = last_logical_id
block.location = BlockLocation.GPU
block.gpu_slot = reserved_slot
block.cpu_block_id = -1
return reserved_slot
def get_gpu_block_tables( def get_gpu_block_tables(
self, self,

View File

@@ -308,6 +308,112 @@ class OffloadEngine:
events.append(event) events.append(event)
return events return events
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
def load_cpu_blocks_to_gpu_slots(
self,
layer_id: int,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to specific GPU slots for chunked decode.
Uses the main GPU KV cache slots, not a separate temp buffer.
This is the same mechanism as chunked prefill uses.
Args:
layer_id: Layer index
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy from pinned CPU memory to GPU KV cache slot
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# Wait for transfer to complete
stream.synchronize()
def load_cpu_blocks_to_gpu_slots_async(
self,
layer_id: int,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> torch.cuda.Event:
"""
Async version: Load CPU blocks to GPU slots.
Args:
layer_id: Layer index
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
Returns:
CUDA event to wait on
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
stream = self._get_next_stream()
event = torch.cuda.Event()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
event.record()
return event
def load_cpu_blocks_to_gpu_slots_all_layers(
self,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to GPU slots for ALL layers at once.
More efficient than per-layer loading when we know the mapping upfront.
Args:
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy all layers at once
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_block_id],
non_blocking=True
)
stream.synchronize()
# ========== Synchronization methods ========== # ========== Synchronization methods ==========
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None: def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:

View File

@@ -174,45 +174,76 @@ class Attention(nn.Module):
""" """
Compute decode attention with KV spread across CPU and GPU. Compute decode attention with KV spread across CPU and GPU.
For decode with chunked KV: Uses chunked attention similar to chunked prefill:
1. Load all KV for this layer from CPU+GPU 1. Process blocks on GPU first (if any)
2. Compute attention (1 query token vs all KV) 2. Load CPU blocks in chunks to GPU slots (per-layer)
3. Return output 3. Compute attention for each chunk, merge with online softmax
""" """
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
# We need to attend to ALL previous tokens # Need: [batch, seqlen, heads, dim]
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Load all KV for this layer kvcache_manager = context.offload_engine
if context.offload_engine is not None and self.layer_id >= 0: seq = context.chunked_seq
kvcache_manager = context.offload_engine
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None: o_acc = None
# Load all KV from both GPU and CPU for this layer lse_acc = None
k_all, v_all = kvcache_manager.load_all_kv_for_layer(
context.chunked_seq, # Step 1: Process blocks already on GPU (if any)
gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq)
if gpu_slots:
k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots)
o_gpu, lse_gpu = flash_attn_with_lse(
q_batched, k_gpu, v_gpu,
softmax_scale=self.scale,
causal=False,
)
o_acc, lse_acc = o_gpu, lse_gpu
# Step 2: Process CPU blocks in chunks
# Get chunk info from kvcache_manager
cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq)
if num_chunks > 0:
# Use num_gpu_slots - 1 to avoid the reserved slot (used for write block)
chunk_size = kvcache_manager.num_gpu_slots - 1
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_ids))
chunk_cpu_ids = cpu_block_ids[start:end]
# Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER
# (slot num_gpu_slots-1 is reserved for write block)
gpu_slots_for_chunk = list(range(len(chunk_cpu_ids)))
kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots(
self.layer_id, self.layer_id,
chunk_cpu_ids,
gpu_slots_for_chunk,
) )
if k_all is not None and v_all is not None: # Get KV for this chunk
# q shape: [batch_size, num_heads, head_dim] k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots(
# Need: [batch, seqlen, heads, dim] self.layer_id, gpu_slots_for_chunk
# Insert seqlen dimension at position 1 )
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# k_all, v_all shape: [1, total_kv_tokens, kv_heads, head_dim] # Compute attention for this chunk
# Compute attention (no causal mask for decode - we want all KV) o_chunk, lse_chunk = flash_attn_with_lse(
out, _ = flash_attn_with_lse( q_batched, k_chunk, v_chunk,
q_batched, softmax_scale=self.scale,
k_all, causal=False,
v_all, )
softmax_scale=self.scale,
causal=False, # No causal mask for decode
)
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim] # Merge with accumulated
return out.squeeze(1) if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Fallback: shouldn't reach here if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available") raise RuntimeError("Chunked decode attention failed: no KV available")
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
return o_acc.squeeze(1)