[refactor] Remove legacy mode path.

This commit is contained in:
Zijie Tian
2025-12-22 20:17:56 +08:00
parent 08d83185ce
commit 1907b625b6
4 changed files with 49 additions and 958 deletions

View File

@@ -388,26 +388,6 @@ class ModelRunner:
else:
return self.run_chunked_offload_decode(seqs)
# Check if chunked prefill is needed (legacy path)
if is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
self.kvcache_manager.needs_chunked_prefill(seq)
for seq in seqs if seq.block_table
)
if needs_chunked:
return self.run_chunked_prefill(seqs)
# Check if chunked decode is needed (legacy path)
if not is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
self.kvcache_manager.needs_chunked_decode(seq)
for seq in seqs if seq.block_table
)
if needs_chunked:
return self.run_chunked_decode(seqs)
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
@@ -445,194 +425,6 @@ class ModelRunner:
return False
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill in chunks when sequences exceed GPU capacity.
For each chunk:
1. Process tokens through model forward pass
2. At each attention layer:
- Load previous KV from CPU (handled by attention layer)
- Compute attention with online softmax merging
- Store current KV to GPU cache
3. After chunk completes, offload KV to CPU
4. Load next chunk's blocks to GPU
"""
import sys
# Currently only supporting single sequence for chunked prefill
assert len(seqs) == 1, "Chunked prefill only supports single sequence"
seq = seqs[0]
total_blocks = seq.num_blocks
print(f"[Chunked Prefill] Starting: {total_blocks} total blocks, "
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
chunk_num = 0
logits = None
while True:
# Get chunk info (which blocks are on GPU and not yet prefilled)
chunk_info = self.kvcache_manager.get_gpu_block_tables_partial(seqs)
gpu_blocks, start_block_idx, end_block_idx = chunk_info[0]
if not gpu_blocks:
# No more blocks to process
break
chunk_num += 1
chunk_tokens = (end_block_idx - start_block_idx) * self.block_size
if end_block_idx == seq.num_blocks:
# Last block may be partial
chunk_tokens = len(seq) - start_block_idx * self.block_size
print(f"[Chunked Prefill] Chunk {chunk_num}: blocks {start_block_idx}-{end_block_idx-1}, "
f"~{chunk_tokens} tokens", file=sys.stderr)
# Prepare inputs for this chunk
input_ids, positions = self._prepare_chunked_prefill(seq, gpu_blocks, start_block_idx, end_block_idx)
if input_ids.numel() == 0:
print(f"[Chunked Prefill] No input tokens, breaking", file=sys.stderr)
break
print(f"[Chunked Prefill] Running model with {input_ids.numel()} tokens...", file=sys.stderr)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
print(f"[Chunked Prefill] Model forward complete", file=sys.stderr)
# Check if this is the last chunk
# Mark current chunk as prefilled and offload to CPU
self.kvcache_manager.complete_prefill_chunk(seq)
# Check if more chunks needed
if not self.kvcache_manager.needs_chunked_prefill(seq):
print(f"[Chunked Prefill] All chunks done, sampling", file=sys.stderr)
break
print(f"[Chunked Prefill] Chunk transfer complete, loading next...", file=sys.stderr)
# Sample from the last chunk's logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
else:
token_ids = [0] if self.rank == 0 else None
return token_ids
def run_chunked_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with chunked attention when sequence exceeds GPU capacity.
For decode, we need attention over ALL previous tokens. With CPU offload,
we load KV chunks and compute attention incrementally per-layer.
Flow:
1. Ensure last block is on GPU (for writing new KV token)
2. Run model forward - each attention layer:
a. Compute attention on GPU blocks
b. Load CPU blocks in chunks, compute + merge
3. Sample from output
"""
# Currently only supporting single sequence for chunked decode
assert len(seqs) == 1, "Chunked decode only supports single sequence"
seq = seqs[0]
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# Ensure last block is on GPU for writing new KV token
last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq)
slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked decode
set_context(
is_prefill=False, # Decode mode
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention path
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
)
# Run model forward pass
# Each attention layer will handle chunked KV loading internally
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
def _prepare_chunked_prefill(
self,
seq: Sequence,
gpu_blocks: list[int],
start_block_idx: int,
end_block_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare inputs for a single chunk in chunked prefill.
Sets up context with is_chunked_prefill=True so attention layers
know to load previous KV from CPU.
"""
# Calculate token range for this chunk
start_token = start_block_idx * self.block_size
end_token = min(end_block_idx * self.block_size, len(seq))
# Input tokens for this chunk
input_ids = seq[start_token:end_token]
positions = list(range(start_token, end_token))
# Slot mapping for storing KV cache
slot_mapping = []
for i, gpu_block_id in enumerate(gpu_blocks):
block_idx = start_block_idx + i
start = gpu_block_id * self.block_size
if block_idx != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
# Trim slot_mapping to match actual token count
actual_tokens = end_token - start_token
slot_mapping = slot_mapping[:actual_tokens]
# Convert to tensors
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked prefill
seqlen = actual_tokens
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=seqlen,
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
kvcache_manager=self.kvcache_manager, # Pass manager for loading previous KV
chunked_seq=seq, # Pass sequence for loading previous KV
)
return input_ids, positions
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with unified ring buffer (CPU is primary storage).