[claudesquad] update from 'multi-request-2' on 13 Jan 26 02:01 CST

This commit is contained in:
Zijie Tian
2026-01-13 02:01:07 +08:00
parent 49519c7ce7
commit 76af506956
7 changed files with 858 additions and 398 deletions

View File

@@ -851,12 +851,33 @@ class ModelRunner:
# Step 4: Compute logits for last token
logits = self.model.compute_logits(hidden_states[-1:])
# DEBUG: Check hidden_states and logits at end of prefill
hs_last = hidden_states[-1, :4].tolist()
top5_logits, top5_indices = torch.topk(logits[0], 5)
logger.debug(
f"[DEBUG] PREFILL END: hidden_states[-1, :4]={hs_last}, "
f"top5_tokens={top5_indices.tolist()}, top5_logits={top5_logits.tolist()}"
)
# Note: Using sync offload, no wait needed
# Mark all blocks as prefilled
for logical_id in logical_ids:
self.kvcache_manager.prefilled_blocks.add(logical_id)
# DEBUG: Verify CPU cache content after prefill
first_cpu_block = cpu_block_ids[0]
last_cpu_block = cpu_block_ids[-1]
last_block_valid = total_tokens % self.block_size or self.block_size
k_first = offload_engine.k_cache_cpu[0, first_cpu_block, 0, 0, :4].tolist()
k_last = offload_engine.k_cache_cpu[0, last_cpu_block, 0, 0, :4].tolist()
logger.debug(
f"[DEBUG] AFTER PREFILL: first_cpu_block={first_cpu_block}, last_cpu_block={last_cpu_block}, "
f"last_block_valid={last_block_valid}, "
f"k_cache_cpu[0, {first_cpu_block}, 0, 0, :4]={k_first}, "
f"k_cache_cpu[0, {last_cpu_block}, 0, 0, :4]={k_last}"
)
# Step 5: Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
@@ -926,6 +947,24 @@ class ModelRunner:
# New token will be stored at this position
context_len = total_prefill_tokens + num_prev_decode_tokens
# DEBUG: Log key values for first decode step
if num_prev_decode_tokens == 0:
first_cpu_block = cpu_block_table[0] if cpu_block_table else -1
last_cpu_block = cpu_block_table[-1] if cpu_block_table else -1
k_first = offload_engine.k_cache_cpu[0, first_cpu_block, 0, 0, :4].tolist() if first_cpu_block >= 0 else []
k_last = offload_engine.k_cache_cpu[0, last_cpu_block, 0, 0, :4].tolist() if last_cpu_block >= 0 else []
logger.debug(
f"[DEBUG] FIRST DECODE STEP: len(seq)={len(seq)}, "
f"total_prefill_tokens={total_prefill_tokens}, "
f"num_prefill_blocks={num_prefill_blocks}, "
f"valid_tokens_per_block[-1]={valid_tokens_per_block[-1] if valid_tokens_per_block else 'N/A'}, "
f"pos_in_block={pos_in_block}, decode_start_pos={decode_start_pos}, "
f"context_len={context_len}, "
f"first_cpu_block={first_cpu_block}, last_cpu_block={last_cpu_block}, "
f"k_cache_cpu[0, {first_cpu_block}, 0, ...]={k_first}, "
f"k_cache_cpu[0, {last_cpu_block}, 0, ...]={k_last}"
)
# Context setup for Attention.forward() - contiguous mode (no block tables)
if use_cuda_graph:
graph_vars["slot_mapping"][0] = context_len
@@ -943,15 +982,40 @@ class ModelRunner:
i, i, cpu_block_table, valid_tokens_per_block
)
# DEBUG: Check ring buffer content after preload (first decode step only)
if num_prev_decode_tokens == 0:
# Wait for all load streams to complete
torch.cuda.synchronize()
ring_k_0 = offload_engine.layer_k_cache[0, 0, 0, :4].tolist()
# Check the actual last valid position based on valid_tokens_per_block
sum_valid = sum(valid_tokens_per_block)
ring_k_last_valid = offload_engine.layer_k_cache[0, sum_valid - 1, 0, :4].tolist()
logger.debug(
f"[DEBUG] AFTER PRELOAD L0: sum_valid={sum_valid}, "
f"ring_k[0, 0, 0, :4]={ring_k_0}, "
f"ring_k[0, {sum_valid-1}, 0, :4]={ring_k_last_valid}"
)
# Step 1: Embedding (on compute stream)
with torch.cuda.stream(compute_stream):
# DEBUG: Log input token for first decode step
if num_prev_decode_tokens == 0:
embed_weight_sample = self.model.model.embed_tokens.weight[input_ids[0], :4].tolist()
logger.debug(f"[DEBUG] EMBEDDING INPUT: input_ids={input_ids.tolist()}, positions={positions.tolist()}, weight[{input_ids[0]},:4]={embed_weight_sample}")
if use_cuda_graph:
# Copy embedding output to graph's hidden_states
embedded = self.model.model.embed_tokens(input_ids)
# DEBUG: Log embedding output for first decode step
if num_prev_decode_tokens == 0:
logger.debug(f"[DEBUG] EMBEDDING OUTPUT: embedded[0, :4]={embedded[0, :4].tolist()}")
graph_vars["hidden_states"].copy_(embedded)
graph_vars["residual"].zero_() # Reset residual for first layer
else:
hidden_states = self.model.model.embed_tokens(input_ids)
# DEBUG: Log embedding output for first decode step
if num_prev_decode_tokens == 0:
logger.debug(f"[DEBUG] EMBEDDING OUTPUT: hidden_states[0, :4]={hidden_states[0, :4].tolist()}")
residual = None
# Phase 2: Layer-by-layer processing with ring buffer pipeline
@@ -963,6 +1027,14 @@ class ModelRunner:
# 2a. Wait for current buffer's load to complete
offload_engine.wait_buffer_load(current_buffer)
# DEBUG: Layer outputs (first decode step, layer 0 and last layer)
if num_prev_decode_tokens == 0 and (layer_id == 0 or layer_id == num_layers - 1):
if not use_cuda_graph:
hs_pre = hidden_states[0, :4].tolist()
else:
hs_pre = graph_vars["hidden_states"][0, :4].tolist()
logger.debug(f"[DEBUG] L{layer_id} BEFORE: hidden_states[0, :4]={hs_pre}")
# 2b. Copy previous decode KV from decode buffer to ring buffer
# Ring buffer already has prefill KV at [0:total_prefill_tokens]
# We need to add decode KV at [total_prefill_tokens:]
@@ -1005,6 +1077,14 @@ class ModelRunner:
# - Compute attention via flash_attn_with_kvcache
hidden_states, residual = layer(positions, hidden_states, residual)
# DEBUG: Layer outputs (first decode step, layer 0 and last layer)
if num_prev_decode_tokens == 0 and (layer_id == 0 or layer_id == num_layers - 1):
if not use_cuda_graph:
hs_post = hidden_states[0, :4].tolist()
else:
hs_post = graph_vars["layer_outputs"][0, :4].tolist()
logger.debug(f"[DEBUG] L{layer_id} AFTER: hidden_states[0, :4]={hs_post}")
# 2f. Copy new token's KV from ring buffer to decode buffer (for persistence)
# The new token was stored at position context_len in ring buffer
ring_k = offload_engine.layer_k_cache[current_buffer]
@@ -1054,6 +1134,16 @@ class ModelRunner:
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
# DEBUG: Log first decode token
if num_prev_decode_tokens == 0 and token_ids:
# Get top-5 logits for debugging
top_logits, top_indices = torch.topk(logits[0], 5)
logger.debug(
f"[DEBUG] FIRST DECODE TOKEN: token_id={token_ids[0]}, "
f"top5_indices={top_indices.tolist()}, "
f"top5_logits={top_logits.tolist()}"
)
return token_ids
@torch.inference_mode()