[claudesquad] update from 'multi-request-2' on 13 Jan 26 02:01 CST
This commit is contained in:
@@ -851,12 +851,33 @@ class ModelRunner:
|
||||
# Step 4: Compute logits for last token
|
||||
logits = self.model.compute_logits(hidden_states[-1:])
|
||||
|
||||
# DEBUG: Check hidden_states and logits at end of prefill
|
||||
hs_last = hidden_states[-1, :4].tolist()
|
||||
top5_logits, top5_indices = torch.topk(logits[0], 5)
|
||||
logger.debug(
|
||||
f"[DEBUG] PREFILL END: hidden_states[-1, :4]={hs_last}, "
|
||||
f"top5_tokens={top5_indices.tolist()}, top5_logits={top5_logits.tolist()}"
|
||||
)
|
||||
|
||||
# Note: Using sync offload, no wait needed
|
||||
|
||||
# Mark all blocks as prefilled
|
||||
for logical_id in logical_ids:
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# DEBUG: Verify CPU cache content after prefill
|
||||
first_cpu_block = cpu_block_ids[0]
|
||||
last_cpu_block = cpu_block_ids[-1]
|
||||
last_block_valid = total_tokens % self.block_size or self.block_size
|
||||
k_first = offload_engine.k_cache_cpu[0, first_cpu_block, 0, 0, :4].tolist()
|
||||
k_last = offload_engine.k_cache_cpu[0, last_cpu_block, 0, 0, :4].tolist()
|
||||
logger.debug(
|
||||
f"[DEBUG] AFTER PREFILL: first_cpu_block={first_cpu_block}, last_cpu_block={last_cpu_block}, "
|
||||
f"last_block_valid={last_block_valid}, "
|
||||
f"k_cache_cpu[0, {first_cpu_block}, 0, 0, :4]={k_first}, "
|
||||
f"k_cache_cpu[0, {last_cpu_block}, 0, 0, :4]={k_last}"
|
||||
)
|
||||
|
||||
# Step 5: Sample
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
@@ -926,6 +947,24 @@ class ModelRunner:
|
||||
# New token will be stored at this position
|
||||
context_len = total_prefill_tokens + num_prev_decode_tokens
|
||||
|
||||
# DEBUG: Log key values for first decode step
|
||||
if num_prev_decode_tokens == 0:
|
||||
first_cpu_block = cpu_block_table[0] if cpu_block_table else -1
|
||||
last_cpu_block = cpu_block_table[-1] if cpu_block_table else -1
|
||||
k_first = offload_engine.k_cache_cpu[0, first_cpu_block, 0, 0, :4].tolist() if first_cpu_block >= 0 else []
|
||||
k_last = offload_engine.k_cache_cpu[0, last_cpu_block, 0, 0, :4].tolist() if last_cpu_block >= 0 else []
|
||||
logger.debug(
|
||||
f"[DEBUG] FIRST DECODE STEP: len(seq)={len(seq)}, "
|
||||
f"total_prefill_tokens={total_prefill_tokens}, "
|
||||
f"num_prefill_blocks={num_prefill_blocks}, "
|
||||
f"valid_tokens_per_block[-1]={valid_tokens_per_block[-1] if valid_tokens_per_block else 'N/A'}, "
|
||||
f"pos_in_block={pos_in_block}, decode_start_pos={decode_start_pos}, "
|
||||
f"context_len={context_len}, "
|
||||
f"first_cpu_block={first_cpu_block}, last_cpu_block={last_cpu_block}, "
|
||||
f"k_cache_cpu[0, {first_cpu_block}, 0, ...]={k_first}, "
|
||||
f"k_cache_cpu[0, {last_cpu_block}, 0, ...]={k_last}"
|
||||
)
|
||||
|
||||
# Context setup for Attention.forward() - contiguous mode (no block tables)
|
||||
if use_cuda_graph:
|
||||
graph_vars["slot_mapping"][0] = context_len
|
||||
@@ -943,15 +982,40 @@ class ModelRunner:
|
||||
i, i, cpu_block_table, valid_tokens_per_block
|
||||
)
|
||||
|
||||
# DEBUG: Check ring buffer content after preload (first decode step only)
|
||||
if num_prev_decode_tokens == 0:
|
||||
# Wait for all load streams to complete
|
||||
torch.cuda.synchronize()
|
||||
ring_k_0 = offload_engine.layer_k_cache[0, 0, 0, :4].tolist()
|
||||
# Check the actual last valid position based on valid_tokens_per_block
|
||||
sum_valid = sum(valid_tokens_per_block)
|
||||
ring_k_last_valid = offload_engine.layer_k_cache[0, sum_valid - 1, 0, :4].tolist()
|
||||
logger.debug(
|
||||
f"[DEBUG] AFTER PRELOAD L0: sum_valid={sum_valid}, "
|
||||
f"ring_k[0, 0, 0, :4]={ring_k_0}, "
|
||||
f"ring_k[0, {sum_valid-1}, 0, :4]={ring_k_last_valid}"
|
||||
)
|
||||
|
||||
# Step 1: Embedding (on compute stream)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# DEBUG: Log input token for first decode step
|
||||
if num_prev_decode_tokens == 0:
|
||||
embed_weight_sample = self.model.model.embed_tokens.weight[input_ids[0], :4].tolist()
|
||||
logger.debug(f"[DEBUG] EMBEDDING INPUT: input_ids={input_ids.tolist()}, positions={positions.tolist()}, weight[{input_ids[0]},:4]={embed_weight_sample}")
|
||||
|
||||
if use_cuda_graph:
|
||||
# Copy embedding output to graph's hidden_states
|
||||
embedded = self.model.model.embed_tokens(input_ids)
|
||||
# DEBUG: Log embedding output for first decode step
|
||||
if num_prev_decode_tokens == 0:
|
||||
logger.debug(f"[DEBUG] EMBEDDING OUTPUT: embedded[0, :4]={embedded[0, :4].tolist()}")
|
||||
graph_vars["hidden_states"].copy_(embedded)
|
||||
graph_vars["residual"].zero_() # Reset residual for first layer
|
||||
else:
|
||||
hidden_states = self.model.model.embed_tokens(input_ids)
|
||||
# DEBUG: Log embedding output for first decode step
|
||||
if num_prev_decode_tokens == 0:
|
||||
logger.debug(f"[DEBUG] EMBEDDING OUTPUT: hidden_states[0, :4]={hidden_states[0, :4].tolist()}")
|
||||
residual = None
|
||||
|
||||
# Phase 2: Layer-by-layer processing with ring buffer pipeline
|
||||
@@ -963,6 +1027,14 @@ class ModelRunner:
|
||||
# 2a. Wait for current buffer's load to complete
|
||||
offload_engine.wait_buffer_load(current_buffer)
|
||||
|
||||
# DEBUG: Layer outputs (first decode step, layer 0 and last layer)
|
||||
if num_prev_decode_tokens == 0 and (layer_id == 0 or layer_id == num_layers - 1):
|
||||
if not use_cuda_graph:
|
||||
hs_pre = hidden_states[0, :4].tolist()
|
||||
else:
|
||||
hs_pre = graph_vars["hidden_states"][0, :4].tolist()
|
||||
logger.debug(f"[DEBUG] L{layer_id} BEFORE: hidden_states[0, :4]={hs_pre}")
|
||||
|
||||
# 2b. Copy previous decode KV from decode buffer to ring buffer
|
||||
# Ring buffer already has prefill KV at [0:total_prefill_tokens]
|
||||
# We need to add decode KV at [total_prefill_tokens:]
|
||||
@@ -1005,6 +1077,14 @@ class ModelRunner:
|
||||
# - Compute attention via flash_attn_with_kvcache
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
# DEBUG: Layer outputs (first decode step, layer 0 and last layer)
|
||||
if num_prev_decode_tokens == 0 and (layer_id == 0 or layer_id == num_layers - 1):
|
||||
if not use_cuda_graph:
|
||||
hs_post = hidden_states[0, :4].tolist()
|
||||
else:
|
||||
hs_post = graph_vars["layer_outputs"][0, :4].tolist()
|
||||
logger.debug(f"[DEBUG] L{layer_id} AFTER: hidden_states[0, :4]={hs_post}")
|
||||
|
||||
# 2f. Copy new token's KV from ring buffer to decode buffer (for persistence)
|
||||
# The new token was stored at position context_len in ring buffer
|
||||
ring_k = offload_engine.layer_k_cache[current_buffer]
|
||||
@@ -1054,6 +1134,16 @@ class ModelRunner:
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
|
||||
# DEBUG: Log first decode token
|
||||
if num_prev_decode_tokens == 0 and token_ids:
|
||||
# Get top-5 logits for debugging
|
||||
top_logits, top_indices = torch.topk(logits[0], 5)
|
||||
logger.debug(
|
||||
f"[DEBUG] FIRST DECODE TOKEN: token_id={token_ids[0]}, "
|
||||
f"top5_indices={top_indices.tolist()}, "
|
||||
f"top5_logits={top_logits.tolist()}"
|
||||
)
|
||||
|
||||
return token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -244,6 +244,13 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
# Clear decode tracking to prevent state pollution between requests
|
||||
self.clear_decode_tracking(seq)
|
||||
|
||||
# Clear offload engine state (decode buffer, events)
|
||||
if self.offload_engine is not None:
|
||||
self.offload_engine.on_sequence_finished()
|
||||
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
"""Check if we can append a token."""
|
||||
need_new_block = (len(seq) % self._block_size == 1)
|
||||
@@ -342,10 +349,12 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
# logger.debug(
|
||||
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||
# f"returned cpu_blocks={cpu_blocks}"
|
||||
# )
|
||||
# DEBUG: Log on first decode call
|
||||
logger.debug(
|
||||
f"[DEBUG] get_prefilled_cpu_blocks: block_table={list(seq.block_table)}, "
|
||||
f"prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||
f"returned cpu_blocks={cpu_blocks}"
|
||||
)
|
||||
return cpu_blocks
|
||||
|
||||
# ========== CPU Block Allocation ==========
|
||||
@@ -383,6 +392,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
seq.block_table.append(logical_id)
|
||||
|
||||
# DEBUG: Log allocated CPU blocks
|
||||
cpu_blocks = [self.logical_blocks[lid].cpu_block_id for lid in seq.block_table]
|
||||
logger.debug(f"[DEBUG] allocate_cpu_only: allocated cpu_blocks={cpu_blocks}")
|
||||
|
||||
# NOTE: Prefix cache disabled in offload mode
|
||||
# If enabled, would compute hash and update:
|
||||
# h = self.compute_hash(seq.block(i), prefix_hash)
|
||||
@@ -430,6 +443,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_block_ids.append(block.cpu_block_id)
|
||||
logical_ids.append(logical_id)
|
||||
# DEBUG: Log during prefill
|
||||
logger.debug(f"[DEBUG] get_all_cpu_blocks: returned cpu_block_ids={cpu_block_ids}")
|
||||
return cpu_block_ids, logical_ids
|
||||
|
||||
def allocate_next_cpu_block(self, seq: Sequence) -> int:
|
||||
@@ -502,6 +517,12 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Decode starts at the next position
|
||||
prefill_len = len(seq) - 1 # Current len includes the new decode token
|
||||
self._decode_start_pos[seq_id] = prefill_len % self._block_size
|
||||
# DEBUG: Log first access
|
||||
logger.debug(
|
||||
f"[DEBUG] get_decode_start_pos FIRST ACCESS: seq_id={seq_id}, "
|
||||
f"len(seq)={len(seq)}, prefill_len={prefill_len}, "
|
||||
f"stored decode_start_pos={self._decode_start_pos[seq_id]}"
|
||||
)
|
||||
return self._decode_start_pos[seq_id]
|
||||
|
||||
def reset_decode_start_pos(self, seq: Sequence) -> None:
|
||||
@@ -534,6 +555,11 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# First decode step - store the prefill length
|
||||
# len(seq) - 1 because current len includes the first decode token
|
||||
self._prefill_len[seq_id] = len(seq) - 1
|
||||
# DEBUG: Log first access
|
||||
logger.debug(
|
||||
f"[DEBUG] get_prefill_len FIRST ACCESS: seq_id={seq_id}, "
|
||||
f"len(seq)={len(seq)}, stored prefill_len={self._prefill_len[seq_id]}"
|
||||
)
|
||||
return self._prefill_len[seq_id]
|
||||
|
||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||
@@ -546,6 +572,15 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
seq: Sequence
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
# DEBUG: Log clearing and CPU blocks
|
||||
cpu_blocks = [self.logical_blocks[lid].cpu_block_id for lid in seq.block_table
|
||||
if self.logical_blocks[lid].location == BlockLocation.CPU]
|
||||
logger.debug(
|
||||
f"[DEBUG] clear_decode_tracking: seq_id={seq_id}, "
|
||||
f"clearing decode_start_pos={self._decode_start_pos.get(seq_id, 'N/A')}, "
|
||||
f"prefill_len={self._prefill_len.get(seq_id, 'N/A')}, "
|
||||
f"cpu_blocks={cpu_blocks}"
|
||||
)
|
||||
self._decode_start_pos.pop(seq_id, None)
|
||||
self._prefill_len.pop(seq_id, None)
|
||||
|
||||
|
||||
@@ -179,6 +179,24 @@ class OffloadEngine:
|
||||
f")"
|
||||
)
|
||||
|
||||
# ========== State Reset ==========
|
||||
|
||||
def on_sequence_finished(self):
|
||||
"""
|
||||
Clear state after sequence completion to prevent pollution between requests.
|
||||
|
||||
Called by HybridKVCacheManager.deallocate() when a sequence finishes.
|
||||
"""
|
||||
# Clear decode buffer to prevent residual KV from affecting next request
|
||||
self.decode_k_buffer.zero_()
|
||||
self.decode_v_buffer.zero_()
|
||||
|
||||
# Re-record buffer_compute_done_events to mark all buffers as available
|
||||
for event in self.buffer_compute_done_events:
|
||||
event.record()
|
||||
|
||||
logger.debug("OffloadEngine: state cleared for next sequence")
|
||||
|
||||
# ========== Prefill: Async D2H Offload API ==========
|
||||
|
||||
def offload_layer_kv_async(
|
||||
|
||||
Reference in New Issue
Block a user