[refactor] Implement real chunked prefill mechenism.
This commit is contained in:
@@ -429,37 +429,26 @@ class ModelRunner:
|
||||
Run decode with chunked attention when sequence exceeds GPU capacity.
|
||||
|
||||
For decode, we need attention over ALL previous tokens. With CPU offload,
|
||||
we load KV chunks and compute attention incrementally.
|
||||
"""
|
||||
import sys
|
||||
we load KV chunks and compute attention incrementally per-layer.
|
||||
|
||||
Flow:
|
||||
1. Ensure last block is on GPU (for writing new KV token)
|
||||
2. Run model forward - each attention layer:
|
||||
a. Compute attention on GPU blocks
|
||||
b. Load CPU blocks in chunks, compute + merge
|
||||
3. Sample from output
|
||||
"""
|
||||
# Currently only supporting single sequence for chunked decode
|
||||
assert len(seqs) == 1, "Chunked decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
total_blocks = len(seq.block_table)
|
||||
print(f"[Chunked Decode] Sequence has {total_blocks} blocks, "
|
||||
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
|
||||
|
||||
# Prepare inputs
|
||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Compute slot mapping for the new token
|
||||
# Get the last block's GPU slot if it's on GPU, otherwise we need to handle it
|
||||
last_logical_id = seq.block_table[-1]
|
||||
last_block = self.kvcache_manager.logical_blocks[last_logical_id]
|
||||
|
||||
if last_block.location.name == "GPU":
|
||||
slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1
|
||||
else:
|
||||
# Last block is on CPU - we need to bring it to GPU for writing the new token
|
||||
# This is a special case - allocate a temporary GPU slot
|
||||
# For simplicity, use a fixed slot (this might conflict, but for decode
|
||||
# we only write 1 token so it should be ok)
|
||||
print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr)
|
||||
slot = 0 # Use first slot temporarily
|
||||
|
||||
# Ensure last block is on GPU for writing new KV token
|
||||
last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq)
|
||||
slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1
|
||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
@@ -468,12 +457,13 @@ class ModelRunner:
|
||||
is_prefill=False, # Decode mode
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_len,
|
||||
is_chunked_prefill=True, # Use chunked attention
|
||||
is_chunked_prefill=True, # Use chunked attention path
|
||||
offload_engine=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
)
|
||||
|
||||
# Run model forward pass
|
||||
# Each attention layer will handle chunked KV loading internally
|
||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||
reset_context()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user