[refactor] Implement real chunked prefill mechenism.

This commit is contained in:
Zijie Tian
2025-12-10 18:34:01 +08:00
parent 0b6f19242d
commit 87055cc5ce
4 changed files with 313 additions and 85 deletions

View File

@@ -429,37 +429,26 @@ class ModelRunner:
Run decode with chunked attention when sequence exceeds GPU capacity.
For decode, we need attention over ALL previous tokens. With CPU offload,
we load KV chunks and compute attention incrementally.
"""
import sys
we load KV chunks and compute attention incrementally per-layer.
Flow:
1. Ensure last block is on GPU (for writing new KV token)
2. Run model forward - each attention layer:
a. Compute attention on GPU blocks
b. Load CPU blocks in chunks, compute + merge
3. Sample from output
"""
# Currently only supporting single sequence for chunked decode
assert len(seqs) == 1, "Chunked decode only supports single sequence"
seq = seqs[0]
total_blocks = len(seq.block_table)
print(f"[Chunked Decode] Sequence has {total_blocks} blocks, "
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# Compute slot mapping for the new token
# Get the last block's GPU slot if it's on GPU, otherwise we need to handle it
last_logical_id = seq.block_table[-1]
last_block = self.kvcache_manager.logical_blocks[last_logical_id]
if last_block.location.name == "GPU":
slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1
else:
# Last block is on CPU - we need to bring it to GPU for writing the new token
# This is a special case - allocate a temporary GPU slot
# For simplicity, use a fixed slot (this might conflict, but for decode
# we only write 1 token so it should be ok)
print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr)
slot = 0 # Use first slot temporarily
# Ensure last block is on GPU for writing new KV token
last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq)
slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -468,12 +457,13 @@ class ModelRunner:
is_prefill=False, # Decode mode
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention
is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager,
chunked_seq=seq,
)
# Run model forward pass
# Each attention layer will handle chunked KV loading internally
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()