[WIP] Before fix bench_offload.py.

This commit is contained in:
Zijie Tian
2026-01-06 18:41:08 +08:00
parent c7ac39dfbd
commit 535f2037ab
7 changed files with 66 additions and 44 deletions

View File

@@ -37,7 +37,7 @@ class ModelRunner:
self.sampler = GreedySampler()
#> Disable warmup for debugging
# self.warmup_model()
self.warmup_model()
self.allocate_kv_cache()
if not self.enforce_eager:
@@ -62,7 +62,7 @@ class ModelRunner:
self.shm.unlink()
if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
# torch.cuda.synchronize()
dist.destroy_process_group()
def loop(self):

View File

@@ -35,7 +35,29 @@ class Scheduler:
if Observer.ttft_start == 0:
Observer.ttft_start = perf_counter_ns()
seq = self.waiting[0]
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.kvcache_manager.can_allocate(seq):
# Check if sequence is too large
if not self.running and num_seqs == 0:
# First sequence, give clear error if it can't be scheduled
if len(seq) > self.max_num_batched_tokens:
raise RuntimeError(
f"Sequence too long: {len(seq)} tokens exceeds "
f"max_num_batched_tokens={self.max_num_batched_tokens}. "
f"Increase max_num_batched_tokens (set equal to max_model_len for long sequences)."
)
if not self.kvcache_manager.can_allocate(seq):
blocks_needed = seq.num_blocks
blocks_available = self.kvcache_manager.num_free_blocks
raise RuntimeError(
f"Cannot allocate KV cache for sequence: "
f"need {blocks_needed} blocks ({len(seq)} tokens), "
f"but only {blocks_available} blocks available. "
f"Increase max_model_len to allocate more blocks."
)
if num_batched_tokens + len(seq) > self.max_num_batched_tokens:
break
if not self.kvcache_manager.can_allocate(seq):
break
num_seqs += 1
self.kvcache_manager.allocate(seq)
@@ -60,7 +82,7 @@ class Scheduler:
num_seqs += 1
self.kvcache_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs
assert scheduled_seqs, "No sequences scheduled - this should not happen"
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False

View File

@@ -201,7 +201,7 @@ class OffloadEngine:
# This prevents undefined behavior on first load_to_slot_layer call
for slot_idx in range(self.num_ring_slots):
self.ring_slot_compute_done[slot_idx].record()
torch.cuda.synchronize() # Ensure all events are recorded
# torch.cuda.synchronize() # Ensure all events are recorded
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}

View File

@@ -32,8 +32,12 @@ def store_kvcache(
"""
# Filter out invalid slots (slot == -1)
valid_mask = slot_mapping >= 0
if not valid_mask.any():
return
is_capturing = torch.cuda.is_current_stream_capturing()
if not is_capturing:
if not valid_mask.any():
return
valid_slots = slot_mapping[valid_mask]
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
@@ -51,6 +55,7 @@ def store_kvcache(
valid_values_flat = valid_values.reshape(-1, D)
# In-place scatter using index_copy_
# 即使 valid_slots 为空张量index_copy_ 也是安全的(不会修改数据)。
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
@@ -86,7 +91,7 @@ class Attention(nn.Module):
)
#! Ensure synchronization before accessing k_cache/v_cache
torch.cuda.synchronize()
# torch.cuda.synchronize()
#! =======================================================
if is_chunked_offload: