[claudesquad] update from 'multi-request-2' on 13 Jan 26 02:01 CST

This commit is contained in:
Zijie Tian
2026-01-13 02:01:07 +08:00
parent 49519c7ce7
commit 76af506956
7 changed files with 858 additions and 398 deletions

View File

@@ -851,12 +851,33 @@ class ModelRunner:
# Step 4: Compute logits for last token
logits = self.model.compute_logits(hidden_states[-1:])
# DEBUG: Check hidden_states and logits at end of prefill
hs_last = hidden_states[-1, :4].tolist()
top5_logits, top5_indices = torch.topk(logits[0], 5)
logger.debug(
f"[DEBUG] PREFILL END: hidden_states[-1, :4]={hs_last}, "
f"top5_tokens={top5_indices.tolist()}, top5_logits={top5_logits.tolist()}"
)
# Note: Using sync offload, no wait needed
# Mark all blocks as prefilled
for logical_id in logical_ids:
self.kvcache_manager.prefilled_blocks.add(logical_id)
# DEBUG: Verify CPU cache content after prefill
first_cpu_block = cpu_block_ids[0]
last_cpu_block = cpu_block_ids[-1]
last_block_valid = total_tokens % self.block_size or self.block_size
k_first = offload_engine.k_cache_cpu[0, first_cpu_block, 0, 0, :4].tolist()
k_last = offload_engine.k_cache_cpu[0, last_cpu_block, 0, 0, :4].tolist()
logger.debug(
f"[DEBUG] AFTER PREFILL: first_cpu_block={first_cpu_block}, last_cpu_block={last_cpu_block}, "
f"last_block_valid={last_block_valid}, "
f"k_cache_cpu[0, {first_cpu_block}, 0, 0, :4]={k_first}, "
f"k_cache_cpu[0, {last_cpu_block}, 0, 0, :4]={k_last}"
)
# Step 5: Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
@@ -926,6 +947,24 @@ class ModelRunner:
# New token will be stored at this position
context_len = total_prefill_tokens + num_prev_decode_tokens
# DEBUG: Log key values for first decode step
if num_prev_decode_tokens == 0:
first_cpu_block = cpu_block_table[0] if cpu_block_table else -1
last_cpu_block = cpu_block_table[-1] if cpu_block_table else -1
k_first = offload_engine.k_cache_cpu[0, first_cpu_block, 0, 0, :4].tolist() if first_cpu_block >= 0 else []
k_last = offload_engine.k_cache_cpu[0, last_cpu_block, 0, 0, :4].tolist() if last_cpu_block >= 0 else []
logger.debug(
f"[DEBUG] FIRST DECODE STEP: len(seq)={len(seq)}, "
f"total_prefill_tokens={total_prefill_tokens}, "
f"num_prefill_blocks={num_prefill_blocks}, "
f"valid_tokens_per_block[-1]={valid_tokens_per_block[-1] if valid_tokens_per_block else 'N/A'}, "
f"pos_in_block={pos_in_block}, decode_start_pos={decode_start_pos}, "
f"context_len={context_len}, "
f"first_cpu_block={first_cpu_block}, last_cpu_block={last_cpu_block}, "
f"k_cache_cpu[0, {first_cpu_block}, 0, ...]={k_first}, "
f"k_cache_cpu[0, {last_cpu_block}, 0, ...]={k_last}"
)
# Context setup for Attention.forward() - contiguous mode (no block tables)
if use_cuda_graph:
graph_vars["slot_mapping"][0] = context_len
@@ -943,15 +982,40 @@ class ModelRunner:
i, i, cpu_block_table, valid_tokens_per_block
)
# DEBUG: Check ring buffer content after preload (first decode step only)
if num_prev_decode_tokens == 0:
# Wait for all load streams to complete
torch.cuda.synchronize()
ring_k_0 = offload_engine.layer_k_cache[0, 0, 0, :4].tolist()
# Check the actual last valid position based on valid_tokens_per_block
sum_valid = sum(valid_tokens_per_block)
ring_k_last_valid = offload_engine.layer_k_cache[0, sum_valid - 1, 0, :4].tolist()
logger.debug(
f"[DEBUG] AFTER PRELOAD L0: sum_valid={sum_valid}, "
f"ring_k[0, 0, 0, :4]={ring_k_0}, "
f"ring_k[0, {sum_valid-1}, 0, :4]={ring_k_last_valid}"
)
# Step 1: Embedding (on compute stream)
with torch.cuda.stream(compute_stream):
# DEBUG: Log input token for first decode step
if num_prev_decode_tokens == 0:
embed_weight_sample = self.model.model.embed_tokens.weight[input_ids[0], :4].tolist()
logger.debug(f"[DEBUG] EMBEDDING INPUT: input_ids={input_ids.tolist()}, positions={positions.tolist()}, weight[{input_ids[0]},:4]={embed_weight_sample}")
if use_cuda_graph:
# Copy embedding output to graph's hidden_states
embedded = self.model.model.embed_tokens(input_ids)
# DEBUG: Log embedding output for first decode step
if num_prev_decode_tokens == 0:
logger.debug(f"[DEBUG] EMBEDDING OUTPUT: embedded[0, :4]={embedded[0, :4].tolist()}")
graph_vars["hidden_states"].copy_(embedded)
graph_vars["residual"].zero_() # Reset residual for first layer
else:
hidden_states = self.model.model.embed_tokens(input_ids)
# DEBUG: Log embedding output for first decode step
if num_prev_decode_tokens == 0:
logger.debug(f"[DEBUG] EMBEDDING OUTPUT: hidden_states[0, :4]={hidden_states[0, :4].tolist()}")
residual = None
# Phase 2: Layer-by-layer processing with ring buffer pipeline
@@ -963,6 +1027,14 @@ class ModelRunner:
# 2a. Wait for current buffer's load to complete
offload_engine.wait_buffer_load(current_buffer)
# DEBUG: Layer outputs (first decode step, layer 0 and last layer)
if num_prev_decode_tokens == 0 and (layer_id == 0 or layer_id == num_layers - 1):
if not use_cuda_graph:
hs_pre = hidden_states[0, :4].tolist()
else:
hs_pre = graph_vars["hidden_states"][0, :4].tolist()
logger.debug(f"[DEBUG] L{layer_id} BEFORE: hidden_states[0, :4]={hs_pre}")
# 2b. Copy previous decode KV from decode buffer to ring buffer
# Ring buffer already has prefill KV at [0:total_prefill_tokens]
# We need to add decode KV at [total_prefill_tokens:]
@@ -1005,6 +1077,14 @@ class ModelRunner:
# - Compute attention via flash_attn_with_kvcache
hidden_states, residual = layer(positions, hidden_states, residual)
# DEBUG: Layer outputs (first decode step, layer 0 and last layer)
if num_prev_decode_tokens == 0 and (layer_id == 0 or layer_id == num_layers - 1):
if not use_cuda_graph:
hs_post = hidden_states[0, :4].tolist()
else:
hs_post = graph_vars["layer_outputs"][0, :4].tolist()
logger.debug(f"[DEBUG] L{layer_id} AFTER: hidden_states[0, :4]={hs_post}")
# 2f. Copy new token's KV from ring buffer to decode buffer (for persistence)
# The new token was stored at position context_len in ring buffer
ring_k = offload_engine.layer_k_cache[current_buffer]
@@ -1054,6 +1134,16 @@ class ModelRunner:
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
# DEBUG: Log first decode token
if num_prev_decode_tokens == 0 and token_ids:
# Get top-5 logits for debugging
top_logits, top_indices = torch.topk(logits[0], 5)
logger.debug(
f"[DEBUG] FIRST DECODE TOKEN: token_id={token_ids[0]}, "
f"top5_indices={top_indices.tolist()}, "
f"top5_logits={top_logits.tolist()}"
)
return token_ids
@torch.inference_mode()

View File

@@ -244,6 +244,13 @@ class HybridKVCacheManager(KVCacheManager):
seq.num_cached_tokens = 0
seq.block_table.clear()
# Clear decode tracking to prevent state pollution between requests
self.clear_decode_tracking(seq)
# Clear offload engine state (decode buffer, events)
if self.offload_engine is not None:
self.offload_engine.on_sequence_finished()
def can_append(self, seq: Sequence) -> bool:
"""Check if we can append a token."""
need_new_block = (len(seq) % self._block_size == 1)
@@ -342,10 +349,12 @@ class HybridKVCacheManager(KVCacheManager):
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
# logger.debug(
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
# f"returned cpu_blocks={cpu_blocks}"
# )
# DEBUG: Log on first decode call
logger.debug(
f"[DEBUG] get_prefilled_cpu_blocks: block_table={list(seq.block_table)}, "
f"prefilled_blocks={list(self.prefilled_blocks)}, "
f"returned cpu_blocks={cpu_blocks}"
)
return cpu_blocks
# ========== CPU Block Allocation ==========
@@ -383,6 +392,10 @@ class HybridKVCacheManager(KVCacheManager):
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
# DEBUG: Log allocated CPU blocks
cpu_blocks = [self.logical_blocks[lid].cpu_block_id for lid in seq.block_table]
logger.debug(f"[DEBUG] allocate_cpu_only: allocated cpu_blocks={cpu_blocks}")
# NOTE: Prefix cache disabled in offload mode
# If enabled, would compute hash and update:
# h = self.compute_hash(seq.block(i), prefix_hash)
@@ -430,6 +443,8 @@ class HybridKVCacheManager(KVCacheManager):
if block.location == BlockLocation.CPU:
cpu_block_ids.append(block.cpu_block_id)
logical_ids.append(logical_id)
# DEBUG: Log during prefill
logger.debug(f"[DEBUG] get_all_cpu_blocks: returned cpu_block_ids={cpu_block_ids}")
return cpu_block_ids, logical_ids
def allocate_next_cpu_block(self, seq: Sequence) -> int:
@@ -502,6 +517,12 @@ class HybridKVCacheManager(KVCacheManager):
# Decode starts at the next position
prefill_len = len(seq) - 1 # Current len includes the new decode token
self._decode_start_pos[seq_id] = prefill_len % self._block_size
# DEBUG: Log first access
logger.debug(
f"[DEBUG] get_decode_start_pos FIRST ACCESS: seq_id={seq_id}, "
f"len(seq)={len(seq)}, prefill_len={prefill_len}, "
f"stored decode_start_pos={self._decode_start_pos[seq_id]}"
)
return self._decode_start_pos[seq_id]
def reset_decode_start_pos(self, seq: Sequence) -> None:
@@ -534,6 +555,11 @@ class HybridKVCacheManager(KVCacheManager):
# First decode step - store the prefill length
# len(seq) - 1 because current len includes the first decode token
self._prefill_len[seq_id] = len(seq) - 1
# DEBUG: Log first access
logger.debug(
f"[DEBUG] get_prefill_len FIRST ACCESS: seq_id={seq_id}, "
f"len(seq)={len(seq)}, stored prefill_len={self._prefill_len[seq_id]}"
)
return self._prefill_len[seq_id]
def clear_decode_tracking(self, seq: Sequence) -> None:
@@ -546,6 +572,15 @@ class HybridKVCacheManager(KVCacheManager):
seq: Sequence
"""
seq_id = id(seq)
# DEBUG: Log clearing and CPU blocks
cpu_blocks = [self.logical_blocks[lid].cpu_block_id for lid in seq.block_table
if self.logical_blocks[lid].location == BlockLocation.CPU]
logger.debug(
f"[DEBUG] clear_decode_tracking: seq_id={seq_id}, "
f"clearing decode_start_pos={self._decode_start_pos.get(seq_id, 'N/A')}, "
f"prefill_len={self._prefill_len.get(seq_id, 'N/A')}, "
f"cpu_blocks={cpu_blocks}"
)
self._decode_start_pos.pop(seq_id, None)
self._prefill_len.pop(seq_id, None)

View File

@@ -179,6 +179,24 @@ class OffloadEngine:
f")"
)
# ========== State Reset ==========
def on_sequence_finished(self):
"""
Clear state after sequence completion to prevent pollution between requests.
Called by HybridKVCacheManager.deallocate() when a sequence finishes.
"""
# Clear decode buffer to prevent residual KV from affecting next request
self.decode_k_buffer.zero_()
self.decode_v_buffer.zero_()
# Re-record buffer_compute_done_events to mark all buffers as available
for event in self.buffer_compute_done_events:
event.record()
logger.debug("OffloadEngine: state cleared for next sequence")
# ========== Prefill: Async D2H Offload API ==========
def offload_layer_kv_async(