[WIP] Before fix bench_offload.py.
This commit is contained in:
@@ -37,7 +37,7 @@ class ModelRunner:
|
||||
self.sampler = GreedySampler()
|
||||
|
||||
#> Disable warmup for debugging
|
||||
# self.warmup_model()
|
||||
self.warmup_model()
|
||||
|
||||
self.allocate_kv_cache()
|
||||
if not self.enforce_eager:
|
||||
@@ -62,7 +62,7 @@ class ModelRunner:
|
||||
self.shm.unlink()
|
||||
if not self.enforce_eager:
|
||||
del self.graphs, self.graph_pool
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def loop(self):
|
||||
|
||||
@@ -35,7 +35,29 @@ class Scheduler:
|
||||
if Observer.ttft_start == 0:
|
||||
Observer.ttft_start = perf_counter_ns()
|
||||
seq = self.waiting[0]
|
||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.kvcache_manager.can_allocate(seq):
|
||||
|
||||
# Check if sequence is too large
|
||||
if not self.running and num_seqs == 0:
|
||||
# First sequence, give clear error if it can't be scheduled
|
||||
if len(seq) > self.max_num_batched_tokens:
|
||||
raise RuntimeError(
|
||||
f"Sequence too long: {len(seq)} tokens exceeds "
|
||||
f"max_num_batched_tokens={self.max_num_batched_tokens}. "
|
||||
f"Increase max_num_batched_tokens (set equal to max_model_len for long sequences)."
|
||||
)
|
||||
if not self.kvcache_manager.can_allocate(seq):
|
||||
blocks_needed = seq.num_blocks
|
||||
blocks_available = self.kvcache_manager.num_free_blocks
|
||||
raise RuntimeError(
|
||||
f"Cannot allocate KV cache for sequence: "
|
||||
f"need {blocks_needed} blocks ({len(seq)} tokens), "
|
||||
f"but only {blocks_available} blocks available. "
|
||||
f"Increase max_model_len to allocate more blocks."
|
||||
)
|
||||
|
||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens:
|
||||
break
|
||||
if not self.kvcache_manager.can_allocate(seq):
|
||||
break
|
||||
num_seqs += 1
|
||||
self.kvcache_manager.allocate(seq)
|
||||
@@ -60,7 +82,7 @@ class Scheduler:
|
||||
num_seqs += 1
|
||||
self.kvcache_manager.may_append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
assert scheduled_seqs
|
||||
assert scheduled_seqs, "No sequences scheduled - this should not happen"
|
||||
self.running.extendleft(reversed(scheduled_seqs))
|
||||
return scheduled_seqs, False
|
||||
|
||||
|
||||
@@ -201,7 +201,7 @@ class OffloadEngine:
|
||||
# This prevents undefined behavior on first load_to_slot_layer call
|
||||
for slot_idx in range(self.num_ring_slots):
|
||||
self.ring_slot_compute_done[slot_idx].record()
|
||||
torch.cuda.synchronize() # Ensure all events are recorded
|
||||
# torch.cuda.synchronize() # Ensure all events are recorded
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
@@ -32,6 +32,10 @@ def store_kvcache(
|
||||
"""
|
||||
# Filter out invalid slots (slot == -1)
|
||||
valid_mask = slot_mapping >= 0
|
||||
|
||||
is_capturing = torch.cuda.is_current_stream_capturing()
|
||||
|
||||
if not is_capturing:
|
||||
if not valid_mask.any():
|
||||
return
|
||||
|
||||
@@ -51,6 +55,7 @@ def store_kvcache(
|
||||
valid_values_flat = valid_values.reshape(-1, D)
|
||||
|
||||
# In-place scatter using index_copy_
|
||||
# 即使 valid_slots 为空张量,index_copy_ 也是安全的(不会修改数据)。
|
||||
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
||||
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
||||
|
||||
@@ -86,7 +91,7 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
#! Ensure synchronization before accessing k_cache/v_cache
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
#! =======================================================
|
||||
|
||||
if is_chunked_offload:
|
||||
|
||||
@@ -123,7 +123,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=36 * 1024,
|
||||
default=128 * 1024,
|
||||
help="Maximum model context length"
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -148,7 +148,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--model", "-m",
|
||||
type=str,
|
||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
||||
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
|
||||
help="Path to model"
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -96,46 +96,41 @@ def generate_needle_prompt(
|
||||
idx = int(len(parts) * needle_position)
|
||||
return max(0, min(idx, len(parts)))
|
||||
|
||||
# Phase 1: Build haystack with full paragraphs until we exceed target
|
||||
# Pre-compute tokens per paragraph for efficiency (avoid O(n²) tokenization)
|
||||
para_tokens = []
|
||||
for para in HAYSTACK_PARAGRAPHS:
|
||||
para_tokens.append(len(tokenizer.encode(para, add_special_tokens=False)))
|
||||
avg_tokens_per_para = sum(para_tokens) / len(para_tokens)
|
||||
|
||||
# Estimate overhead (needle + question + chat template)
|
||||
overhead_prompt = build_prompt([HAYSTACK_PARAGRAPHS[0]], 0)
|
||||
overhead_tokens = count_tokens(overhead_prompt) - para_tokens[0]
|
||||
|
||||
# Phase 1: Estimate number of paragraphs needed
|
||||
estimated_paras = int((target_length - overhead_tokens) / avg_tokens_per_para) + 1
|
||||
|
||||
# Build haystack with estimated paragraphs
|
||||
haystack_parts = []
|
||||
para_idx = 0
|
||||
for i in range(estimated_paras):
|
||||
haystack_parts.append(HAYSTACK_PARAGRAPHS[i % len(HAYSTACK_PARAGRAPHS)])
|
||||
|
||||
while True:
|
||||
# Phase 2: Adjust to get closer to target
|
||||
prompt = build_prompt(haystack_parts, get_needle_idx(haystack_parts))
|
||||
current_tokens = count_tokens(prompt)
|
||||
|
||||
# Add more if under target
|
||||
para_idx = len(haystack_parts)
|
||||
while current_tokens < target_length and para_idx < 100000:
|
||||
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
test_parts = haystack_parts + [para]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
|
||||
if count_tokens(prompt) > target_length:
|
||||
break
|
||||
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
para_idx += 1
|
||||
|
||||
if para_idx > 10000: # Safety limit
|
||||
break
|
||||
|
||||
# Phase 2: Fine-tune by adding words from next paragraph
|
||||
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
words = next_para.split()
|
||||
|
||||
best_parts = haystack_parts.copy()
|
||||
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
|
||||
|
||||
for i in range(1, len(words) + 1):
|
||||
partial = " ".join(words[:i]) + " "
|
||||
test_parts = haystack_parts + [partial]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
token_count = count_tokens(prompt)
|
||||
diff = abs(target_length - token_count)
|
||||
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_parts = test_parts.copy()
|
||||
|
||||
if token_count >= target_length:
|
||||
break
|
||||
|
||||
haystack_parts = best_parts
|
||||
# Remove if over target
|
||||
while current_tokens > target_length + 100 and len(haystack_parts) > 1:
|
||||
removed_para_idx = (len(haystack_parts) - 1) % len(HAYSTACK_PARAGRAPHS)
|
||||
haystack_parts.pop()
|
||||
current_tokens -= para_tokens[removed_para_idx]
|
||||
|
||||
# Final build
|
||||
needle_idx = get_needle_idx(haystack_parts)
|
||||
|
||||
Reference in New Issue
Block a user