[WIP] need to fix model to normally decode.

This commit is contained in:
Zijie Tian
2026-01-01 05:18:27 +08:00
parent 62b8a63314
commit 74ee6d0895
3 changed files with 317 additions and 123 deletions

View File

@@ -92,13 +92,14 @@ def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Collect all K, V: prefill chunks from CPU cache + decode tokens from captures
# Collect all K, V: prefill chunks from captures + decode tokens from captures
# NOTE: We use prefill captures directly instead of CPU cache because
# the CPU block ID may not equal the chunk index.
all_k = []
all_v = []
# 1. Prefill chunks from CPU cache
# 1. Prefill chunks from captures (use captured K/V, not CPU cache)
for cidx in range(num_prefill_chunks):
# Get prefill capture to know the sequence length for this chunk
prefill_cap = None
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
@@ -106,11 +107,9 @@ def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
break
if prefill_cap is not None:
seq_len = prefill_cap['q'].shape[0]
k = k_cache_cpu[layer_id, cidx, :seq_len].cuda()
v = v_cache_cpu[layer_id, cidx, :seq_len].cuda()
all_k.append(k)
all_v.append(v)
# Use captured K/V directly (guaranteed to be correct layer data)
all_k.append(prefill_cap['k'].cuda())
all_v.append(prefill_cap['v'].cuda())
# 2. Decode tokens from captures (up to and including current step)
for step in range(decode_step + 1):
@@ -184,6 +183,184 @@ v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Calculate number of prefill chunks
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
# Debug: Compare decode_buffer with captured K/V
print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===")
decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu()
for step in range(NUM_DECODE_TOKENS):
for layer_id in [0, 17, 35]: # Sample a few layers
# Find captured K for this step and layer
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
captured_k = c['k'].squeeze(0) # [kv_heads, head_dim]
buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim]
diff = (captured_k - buffer_k).abs().max().item()
print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}")
break
# Debug: Verify that decode_buffer slices match concatenated captures
print("\n=== DEBUG: Verifying decode_buffer slices ===")
for layer_id in [0]:
for decode_step in [1, 2]: # Check steps that use multiple tokens
# Build expected slice from captures
expected_k_list = []
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim]
break
if expected_k_list:
expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim]
buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1]
diff = (expected_k - buffer_slice).abs().max().item()
print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}")
# Print first values
print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}")
if decode_step >= 1:
print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}")
# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading)
print("\n=== DEBUG: Expected K values for block 0, layer 0 ===")
for c in prefill_captures:
if c['layer_id'] == 0 and c['chunk_idx'] == 0:
print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}")
break
print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}")
# Debug: Compare CPU cache with captured prefill K/V
print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===")
for chunk_idx in [0, 7, 15]: # Sample a few chunks
for layer_id in [0, 17, 35]: # Sample a few layers
# Find captured K for this chunk and layer
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx:
captured_k = c['k'] # [seq_len, kv_heads, head_dim]
cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]]
diff = (captured_k - cpu_cache_k).abs().max().item()
print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}")
break
# Debug: Get cpu_block_table to check order
kvcache_manager = llm.model_runner.kvcache_manager
# Find the sequence (it should still exist)
from nanovllm.engine.sequence import Sequence
for attr_name in ['sequences', '_sequences', 'active_sequences']:
if hasattr(kvcache_manager, attr_name):
print(f"Found {attr_name}")
break
# Try to get cpu_block_table through a different way
print(f"\n=== DEBUG: CPU block order ===")
# For each prefill capture, check which CPU block it ended up in
for chunk_idx in range(num_prefill_chunks):
for c in prefill_captures:
if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx:
# Check if this chunk's K matches any CPU block
captured_k_first = c['k'][0, 0, 0].item()
for block_id in range(num_prefill_chunks):
cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item()
if abs(captured_k_first - cpu_k_first) < 1e-6:
print(f"Chunk {chunk_idx} -> CPU block {block_id}")
break
break
# Debug: Check reference vs actual for decode steps 0 and 1
# Also compute partial references (prefill only, decode only) to isolate the bug
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
for decode_step in [0, 1]:
print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===")
layer_id = 0
# Find the capture
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
q = c['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Build prefill K/V per-block for block-by-block reference
prefill_k_blocks = []
prefill_v_blocks = []
for cidx in range(num_prefill_chunks):
for pc in prefill_captures:
if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx:
prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim]
prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0))
break
# Build decode K/V
decode_k_list = []
decode_v_list = []
for step in range(decode_step + 1):
for dc in decode_captures:
if dc['layer_id'] == layer_id and dc['decode_step'] == step:
decode_k_list.append(dc['k'].cuda())
decode_v_list.append(dc['v'].cuda())
break
full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0)
full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0)
full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0)
full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0)
full_k = torch.cat([full_prefill_k, full_decode_k], dim=1)
full_v = torch.cat([full_prefill_v, full_decode_v], dim=1)
print(f"Q shape: {q_batched.shape}")
print(f"Prefill K shape: {full_prefill_k.shape}")
print(f"Decode K shape: {full_decode_k.shape}")
print(f"Full K shape: {full_k.shape}")
print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}")
# Reference output (single attention over all)
ref_output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
# Chunked reference: prefill attention + decode attention + merge
prefill_o, prefill_lse = flash_attn_with_lse(
q_batched, full_prefill_k, full_prefill_v,
softmax_scale=scale,
causal=False,
)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, full_decode_k, full_decode_v,
softmax_scale=scale,
causal=False,
)
chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse)
# Block-by-block reference (simulating ring buffer pipeline)
block_o_acc, block_lse_acc = None, None
for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)):
o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False)
if block_o_acc is None:
block_o_acc, block_lse_acc = o_blk, lse_blk
else:
block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk)
# Compare block-by-block vs single
block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item()
print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}")
# Compare full reference vs chunked reference
ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item()
print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}")
ref_output = ref_output.squeeze(0).squeeze(0).cpu()
chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu()
# Actual output
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff_ref = (actual_output - ref_output).abs()
diff_chunked = (actual_output - chunked_output_cpu).abs()
print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}")
print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}")
break
print()
# Verify decode outputs
all_passed = True
@@ -208,7 +385,7 @@ for c in decode_captures:
passed = max_diff < 1e-1
all_passed = all_passed and passed
# if not passed:
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
if not passed:
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")