[feat] Need to optimized with async prefetch.

This commit is contained in:
Zijie Tian
2025-12-15 06:58:40 +08:00
parent 1081ab51ea
commit b8b6478506
9 changed files with 556 additions and 404 deletions

View File

@@ -630,29 +630,31 @@ class ModelRunner:
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with three-region GPU buffer (CPU is primary storage).
Run prefill with unified ring buffer (CPU is primary storage).
Flow:
1. All blocks are allocated to CPU (primary storage)
2. Process tokens in chunks using Compute region GPU buffer
3. After each chunk, offload from Compute region to CPU
4. Prefetch region is used to load previous KV (if any)
2. Each chunk writes KV to ring buffer slot[chunk_idx % N]
3. After each chunk, offload from ring buffer slot to CPU
4. All N-1 other slots are used to load previous chunks for attention
"""
import sys
assert len(seqs) == 1, "Three-region prefill only supports single sequence"
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
compute_size = offload_engine.num_compute_blocks
tokens_per_chunk = compute_size * self.block_size
# Each chunk uses 1 ring buffer slot = 1 block
tokens_per_chunk = self.block_size
total_tokens = len(seq)
print(f"[Three-region Prefill] Starting: {total_tokens} tokens, "
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
f"total_chunks={num_chunks}",
file=sys.stderr)
chunk_num = 0
chunk_idx = 0
logits = None
processed_tokens = 0
@@ -660,27 +662,22 @@ class ModelRunner:
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
while processed_tokens < total_tokens:
chunk_num += 1
chunk_start = processed_tokens
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
chunk_tokens = chunk_end - chunk_start
# Calculate which CPU blocks this chunk covers
start_block_idx = chunk_start // self.block_size
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
num_blocks = end_block_idx - start_block_idx
# Get ring buffer slot for this chunk
write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
print(f"[Three-region Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
f"blocks {start_block_idx}-{end_block_idx-1}, "
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
# CPU block index for this chunk
block_idx = chunk_idx
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
f"write_slot={write_slot}",
file=sys.stderr)
# Get GPU slots for this chunk (using Compute region)
gpu_slots = offload_engine.compute_slots[:num_blocks]
# Prepare inputs
input_ids, positions = self._prepare_chunked_offload_chunk(
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
seq, chunk_start, chunk_end, write_slot, block_idx, chunk_idx
)
if input_ids.numel() == 0:
@@ -690,24 +687,27 @@ class ModelRunner:
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
# Mark blocks as prefilled
for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))):
logical_id = seq.block_table[i]
# Mark block as prefilled
if block_idx < len(seq.block_table):
logical_id = seq.block_table[block_idx]
self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk from Compute region to CPU (async)
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
# Offload this chunk's ring buffer slot to CPU (async)
if block_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[block_idx]
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
# Wait for offload to complete before next chunk
offload_engine.wait_all_offload_done()
# (slot will be reused after N chunks)
offload_engine.wait_slot_offload(write_slot)
processed_tokens = chunk_end
chunk_idx += 1
# Wait for all offloads to complete
offload_engine.wait_all_offload_done()
print(f"[Three-region Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
# Sample from last logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
@@ -723,34 +723,24 @@ class ModelRunner:
seq: Sequence,
chunk_start: int,
chunk_end: int,
gpu_slots: list[int],
start_block_idx: int,
write_slot: int,
block_idx: int,
chunk_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare inputs for a chunked offload prefill chunk."""
"""Prepare inputs for a chunked offload prefill chunk (ring buffer design)."""
# Input tokens for this chunk
input_ids = seq[chunk_start:chunk_end]
positions = list(range(chunk_start, chunk_end))
# Create slot mapping pointing to GPU slots
# Create slot mapping pointing to the single write_slot
slot_mapping = []
num_tokens = chunk_end - chunk_start
token_idx = 0
for i, gpu_slot in enumerate(gpu_slots):
block_idx = start_block_idx + i
block_start = block_idx * self.block_size
block_end = min(block_start + self.block_size, len(seq))
# How many tokens in this block for this chunk
overlap_start = max(chunk_start, block_start)
overlap_end = min(chunk_end, block_end)
for pos in range(overlap_start, overlap_end):
pos_in_block = pos % self.block_size
slot = gpu_slot * self.block_size + pos_in_block
slot_mapping.append(slot)
for pos in range(chunk_start, chunk_end):
pos_in_block = pos % self.block_size
slot = write_slot * self.block_size + pos_in_block
slot_mapping.append(slot)
# Convert to tensors
num_tokens = chunk_end - chunk_start
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -770,21 +760,23 @@ class ModelRunner:
is_chunked_prefill=True,
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
current_chunk_idx=chunk_idx, # Pass chunk index for ring buffer pipeline
)
return input_ids, positions
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with three-region GPU buffer.
Run decode with ring buffer (CPU is primary storage).
All KV is on CPU. Uses Decode region to write new KV, Compute/Prefetch region to load KV chunks.
New token's KV is written to Decode region (slot 0) then offloaded to CPU only when block is full.
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
Key: Decode region is never overwritten by Compute/Prefetch, dedicated to writing new KV.
Key: decode_slot is dedicated to writing new KV, never used for loading.
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
"""
assert len(seqs) == 1, "Three-region decode only supports single sequence"
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine

View File

@@ -12,7 +12,7 @@ class SequenceStatus(Enum):
class Sequence:
block_size = 256
block_size = 4096
counter = count()
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):