[claudesquad] update from 'perf_opt-1' on 07 Jan 26 05:58 CST
This commit is contained in:
@@ -590,14 +590,15 @@ class ModelRunner:
|
||||
|
||||
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with ring buffer (CPU is primary storage).
|
||||
Run decode with cross-layer pipeline (CPU is primary storage).
|
||||
|
||||
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
||||
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
|
||||
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
|
||||
Optimized with cross-layer pipeline: Layer N's data is loaded while
|
||||
Layer N-1 computes, achieving transfer/compute overlap.
|
||||
|
||||
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
||||
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
|
||||
Optimization: Cross-layer pipeline reduces effective latency by overlapping
|
||||
H2D transfers with attention computation across layers.
|
||||
"""
|
||||
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
@@ -618,6 +619,12 @@ class ModelRunner:
|
||||
# Get decode start position for accumulated token tracking
|
||||
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
||||
|
||||
# Get prefilled CPU blocks for pipeline initialization
|
||||
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Start cross-layer pipeline (preloads Layer 0's data)
|
||||
offload_engine.start_decode_pipeline(cpu_block_table)
|
||||
|
||||
# Set up context for chunked decode
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
@@ -634,6 +641,9 @@ class ModelRunner:
|
||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||
reset_context()
|
||||
|
||||
# End cross-layer pipeline
|
||||
offload_engine.end_decode_pipeline()
|
||||
|
||||
# Only offload when block is full (pos_in_block == block_size - 1)
|
||||
# This avoids unnecessary offloading on every decode step
|
||||
if pos_in_block == self.block_size - 1:
|
||||
|
||||
Reference in New Issue
Block a user