[feat] Need to optimized with async prefetch.
This commit is contained in:
@@ -630,29 +630,31 @@ class ModelRunner:
|
||||
|
||||
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill with three-region GPU buffer (CPU is primary storage).
|
||||
Run prefill with unified ring buffer (CPU is primary storage).
|
||||
|
||||
Flow:
|
||||
1. All blocks are allocated to CPU (primary storage)
|
||||
2. Process tokens in chunks using Compute region GPU buffer
|
||||
3. After each chunk, offload from Compute region to CPU
|
||||
4. Prefetch region is used to load previous KV (if any)
|
||||
2. Each chunk writes KV to ring buffer slot[chunk_idx % N]
|
||||
3. After each chunk, offload from ring buffer slot to CPU
|
||||
4. All N-1 other slots are used to load previous chunks for attention
|
||||
"""
|
||||
import sys
|
||||
|
||||
assert len(seqs) == 1, "Three-region prefill only supports single sequence"
|
||||
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
compute_size = offload_engine.num_compute_blocks
|
||||
tokens_per_chunk = compute_size * self.block_size
|
||||
# Each chunk uses 1 ring buffer slot = 1 block
|
||||
tokens_per_chunk = self.block_size
|
||||
|
||||
total_tokens = len(seq)
|
||||
print(f"[Three-region Prefill] Starting: {total_tokens} tokens, "
|
||||
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
||||
print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
||||
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
||||
f"total_chunks={num_chunks}",
|
||||
file=sys.stderr)
|
||||
|
||||
chunk_num = 0
|
||||
chunk_idx = 0
|
||||
logits = None
|
||||
processed_tokens = 0
|
||||
|
||||
@@ -660,27 +662,22 @@ class ModelRunner:
|
||||
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||
|
||||
while processed_tokens < total_tokens:
|
||||
chunk_num += 1
|
||||
chunk_start = processed_tokens
|
||||
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
|
||||
chunk_tokens = chunk_end - chunk_start
|
||||
|
||||
# Calculate which CPU blocks this chunk covers
|
||||
start_block_idx = chunk_start // self.block_size
|
||||
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
|
||||
num_blocks = end_block_idx - start_block_idx
|
||||
# Get ring buffer slot for this chunk
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
|
||||
|
||||
print(f"[Three-region Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"blocks {start_block_idx}-{end_block_idx-1}, "
|
||||
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
|
||||
# CPU block index for this chunk
|
||||
block_idx = chunk_idx
|
||||
|
||||
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"write_slot={write_slot}",
|
||||
file=sys.stderr)
|
||||
|
||||
# Get GPU slots for this chunk (using Compute region)
|
||||
gpu_slots = offload_engine.compute_slots[:num_blocks]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, positions = self._prepare_chunked_offload_chunk(
|
||||
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
|
||||
seq, chunk_start, chunk_end, write_slot, block_idx, chunk_idx
|
||||
)
|
||||
|
||||
if input_ids.numel() == 0:
|
||||
@@ -690,24 +687,27 @@ class ModelRunner:
|
||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||
reset_context()
|
||||
|
||||
# Mark blocks as prefilled
|
||||
for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))):
|
||||
logical_id = seq.block_table[i]
|
||||
# Mark block as prefilled
|
||||
if block_idx < len(seq.block_table):
|
||||
logical_id = seq.block_table[block_idx]
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# Offload this chunk from Compute region to CPU (async)
|
||||
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
|
||||
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
|
||||
# Offload this chunk's ring buffer slot to CPU (async)
|
||||
if block_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[block_idx]
|
||||
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
|
||||
|
||||
# Wait for offload to complete before next chunk
|
||||
offload_engine.wait_all_offload_done()
|
||||
# (slot will be reused after N chunks)
|
||||
offload_engine.wait_slot_offload(write_slot)
|
||||
|
||||
processed_tokens = chunk_end
|
||||
chunk_idx += 1
|
||||
|
||||
# Wait for all offloads to complete
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
print(f"[Three-region Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
||||
|
||||
# Sample from last logits
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
@@ -723,34 +723,24 @@ class ModelRunner:
|
||||
seq: Sequence,
|
||||
chunk_start: int,
|
||||
chunk_end: int,
|
||||
gpu_slots: list[int],
|
||||
start_block_idx: int,
|
||||
write_slot: int,
|
||||
block_idx: int,
|
||||
chunk_idx: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare inputs for a chunked offload prefill chunk."""
|
||||
"""Prepare inputs for a chunked offload prefill chunk (ring buffer design)."""
|
||||
# Input tokens for this chunk
|
||||
input_ids = seq[chunk_start:chunk_end]
|
||||
positions = list(range(chunk_start, chunk_end))
|
||||
|
||||
# Create slot mapping pointing to GPU slots
|
||||
# Create slot mapping pointing to the single write_slot
|
||||
slot_mapping = []
|
||||
num_tokens = chunk_end - chunk_start
|
||||
|
||||
token_idx = 0
|
||||
for i, gpu_slot in enumerate(gpu_slots):
|
||||
block_idx = start_block_idx + i
|
||||
block_start = block_idx * self.block_size
|
||||
block_end = min(block_start + self.block_size, len(seq))
|
||||
|
||||
# How many tokens in this block for this chunk
|
||||
overlap_start = max(chunk_start, block_start)
|
||||
overlap_end = min(chunk_end, block_end)
|
||||
|
||||
for pos in range(overlap_start, overlap_end):
|
||||
pos_in_block = pos % self.block_size
|
||||
slot = gpu_slot * self.block_size + pos_in_block
|
||||
slot_mapping.append(slot)
|
||||
for pos in range(chunk_start, chunk_end):
|
||||
pos_in_block = pos % self.block_size
|
||||
slot = write_slot * self.block_size + pos_in_block
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# Convert to tensors
|
||||
num_tokens = chunk_end - chunk_start
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
@@ -770,21 +760,23 @@ class ModelRunner:
|
||||
is_chunked_prefill=True,
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
current_chunk_idx=chunk_idx, # Pass chunk index for ring buffer pipeline
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with three-region GPU buffer.
|
||||
Run decode with ring buffer (CPU is primary storage).
|
||||
|
||||
All KV is on CPU. Uses Decode region to write new KV, Compute/Prefetch region to load KV chunks.
|
||||
New token's KV is written to Decode region (slot 0) then offloaded to CPU only when block is full.
|
||||
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
||||
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
|
||||
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
|
||||
|
||||
Key: Decode region is never overwritten by Compute/Prefetch, dedicated to writing new KV.
|
||||
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
||||
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
|
||||
"""
|
||||
assert len(seqs) == 1, "Three-region decode only supports single sequence"
|
||||
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
|
||||
Reference in New Issue
Block a user