[WIP] need to fix model to normally decode.
This commit is contained in:
@@ -118,6 +118,24 @@ class OffloadEngine:
|
|||||||
dtype=dtype, device="cuda"
|
dtype=dtype, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ========== Per-layer decode buffer ==========
|
||||||
|
# During decode, all layers share decode_slot (no layer dimension in GPU cache).
|
||||||
|
# This causes accumulated tokens to be overwritten by each layer.
|
||||||
|
# Solution: Maintain separate per-layer buffers for decode tokens.
|
||||||
|
# Shape: [num_layers, block_size, kv_heads, head_dim]
|
||||||
|
# Memory: num_layers * block_size * kv_heads * head_dim * dtype_size
|
||||||
|
# e.g., 28 * 1024 * 8 * 128 * 2 = 58.7 MB (acceptable)
|
||||||
|
self.decode_k_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.decode_v_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||||
|
|
||||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||||
self.k_cache_cpu = torch.zeros(
|
self.k_cache_cpu = torch.zeros(
|
||||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
|||||||
@@ -87,6 +87,15 @@ class Attention(nn.Module):
|
|||||||
else: # decode
|
else: # decode
|
||||||
if context.is_chunked_prefill:
|
if context.is_chunked_prefill:
|
||||||
# Chunked decode: need to load all KV from CPU+GPU
|
# Chunked decode: need to load all KV from CPU+GPU
|
||||||
|
# Store current decode token to per-layer decode buffer
|
||||||
|
# This is needed because GPU cache has no layer dimension,
|
||||||
|
# so all layers would overwrite each other in decode_slot.
|
||||||
|
kvcache_manager = context.kvcache_manager
|
||||||
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
pos_in_block = context.decode_pos_in_block
|
||||||
|
# k, v shape: [1, kv_heads, head_dim]
|
||||||
|
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
|
||||||
|
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
|
||||||
o = self._chunked_decode_attention(q, k, v, context)
|
o = self._chunked_decode_attention(q, k, v, context)
|
||||||
else:
|
else:
|
||||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||||
@@ -390,25 +399,17 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute decode attention with double-buffering using decode_load_slots.
|
Compute decode attention using ring buffer pipeline (same as prefill).
|
||||||
|
|
||||||
Decode uses:
|
Uses the same loading mechanism as _chunked_prefill_attention:
|
||||||
- decode_slot (slot[0]): writes new token's KV
|
- Load one block at a time from CPU to GPU slot
|
||||||
- decode_load_slots (slots[1:]): load previous chunks from CPU
|
- Compute attention for each block
|
||||||
|
- Merge results using online softmax
|
||||||
|
- Finally merge with decode buffer (accumulated decode tokens)
|
||||||
|
|
||||||
Pipeline design:
|
This approach is simpler and proven correct (prefill tests pass).
|
||||||
- First half of decode_load_slots: 'compute' buffer
|
The only difference from prefill is the additional decode buffer
|
||||||
- Second half: 'prefetch' buffer
|
that stores new tokens generated during decode.
|
||||||
- Double-buffer between them for async overlap
|
|
||||||
|
|
||||||
Timeline:
|
|
||||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
|
||||||
│Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
|
|
||||||
└─────────────┘ └─────────────┘ └─────────────┘
|
|
||||||
↘ ↘ ↘
|
|
||||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
|
||||||
│ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │
|
|
||||||
└─────────────┘ └─────────────┘ └─────────────┘
|
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
@@ -419,7 +420,6 @@ class Attention(nn.Module):
|
|||||||
seq = context.chunked_seq
|
seq = context.chunked_seq
|
||||||
|
|
||||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||||
# The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU
|
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
if self.layer_id == 0:
|
if self.layer_id == 0:
|
||||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||||
@@ -427,12 +427,12 @@ class Attention(nn.Module):
|
|||||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||||
|
|
||||||
# Calculate valid tokens in the last block
|
# Calculate valid tokens in the last block
|
||||||
# prefill_len = total prefilled tokens (current decode token not yet in CPU)
|
# Note: For chunked prefill, each block is exactly block_size tokens
|
||||||
|
# The cpu_block_table only contains full prefill blocks
|
||||||
block_size = kvcache_manager.block_size
|
block_size = kvcache_manager.block_size
|
||||||
prefill_len = len(seq) - 1 # Exclude current decode token
|
num_prefill_blocks = len(cpu_block_table)
|
||||||
last_block_valid_tokens = prefill_len % block_size
|
# All prefill blocks are full (block_size tokens each)
|
||||||
if last_block_valid_tokens == 0 and prefill_len > 0:
|
last_block_valid_tokens = block_size
|
||||||
last_block_valid_tokens = block_size # Last block is full
|
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# Apply sparse policy if enabled
|
||||||
if kvcache_manager.sparse_policy is not None:
|
if kvcache_manager.sparse_policy is not None:
|
||||||
@@ -440,7 +440,7 @@ class Attention(nn.Module):
|
|||||||
query_chunk_idx=0,
|
query_chunk_idx=0,
|
||||||
num_query_chunks=1,
|
num_query_chunks=1,
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
query=q_batched, # Decode provides query for query-aware selection
|
query=q_batched,
|
||||||
is_prefill=False,
|
is_prefill=False,
|
||||||
block_size=kvcache_manager.block_size,
|
block_size=kvcache_manager.block_size,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
@@ -450,104 +450,28 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
compute_stream = offload_engine.compute_stream
|
load_slots = offload_engine.decode_load_slots # Available slots for loading
|
||||||
|
|
||||||
# Chunk size = capacity of each double buffer region (compute/prefetch)
|
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
|
||||||
# Each region uses half of decode_load_slots
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
block_size, last_block_valid_tokens
|
||||||
|
|
||||||
# Check if double buffering is possible (need at least 2 separate regions)
|
|
||||||
# With only 1 load slot, compute and prefetch regions overlap -> can't double buffer
|
|
||||||
can_double_buffer = len(offload_engine.decode_load_slots) >= 2
|
|
||||||
|
|
||||||
o_acc = None
|
|
||||||
lse_acc = None
|
|
||||||
|
|
||||||
# Double buffering state: True = use Compute region, False = use Prefetch region
|
|
||||||
use_compute = True
|
|
||||||
|
|
||||||
# Pre-load first chunk to Compute region (async)
|
|
||||||
first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))]
|
|
||||||
offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids)
|
|
||||||
|
|
||||||
for chunk_idx in range(num_chunks):
|
|
||||||
start = chunk_idx * chunk_size
|
|
||||||
end = min(start + chunk_size, len(cpu_block_table))
|
|
||||||
num_blocks_in_chunk = end - start
|
|
||||||
|
|
||||||
# Wait for current buffer to be ready on compute_stream
|
|
||||||
# The load runs on transfer_stream_main, compute runs on compute_stream
|
|
||||||
compute_stream.wait_stream(offload_engine.transfer_stream_main)
|
|
||||||
|
|
||||||
# All computation on explicit compute_stream
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
# Get KV from current buffer FIRST, before prefetching overwrites it
|
|
||||||
if use_compute:
|
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(num_blocks_in_chunk)
|
|
||||||
else:
|
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
|
|
||||||
|
|
||||||
# Handle partial last block: slice to only include valid tokens
|
|
||||||
# This is critical because the rest of the block contains stale data
|
|
||||||
is_last_chunk = (end == len(cpu_block_table))
|
|
||||||
if is_last_chunk and last_block_valid_tokens < block_size:
|
|
||||||
# Calculate total valid tokens in this chunk
|
|
||||||
# All blocks except the last are full, last block has last_block_valid_tokens
|
|
||||||
full_blocks = num_blocks_in_chunk - 1
|
|
||||||
valid_tokens = full_blocks * block_size + last_block_valid_tokens
|
|
||||||
# Slice KV: [batch, seqlen, heads, dim] -> [batch, valid_tokens, heads, dim]
|
|
||||||
k_chunk = k_chunk[:, :valid_tokens, :, :]
|
|
||||||
v_chunk = v_chunk[:, :valid_tokens, :, :]
|
|
||||||
|
|
||||||
# Compute attention for this chunk
|
|
||||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
|
||||||
q_batched, k_chunk, v_chunk,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge with accumulated
|
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = o_chunk, lse_chunk
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
|
||||||
|
|
||||||
# Trigger async prefetch/load of next chunk to the OTHER buffer
|
|
||||||
# This happens AFTER attention completes, so the data is no longer needed
|
|
||||||
if chunk_idx + 1 < num_chunks:
|
|
||||||
next_start = end
|
|
||||||
next_end = min(next_start + chunk_size, len(cpu_block_table))
|
|
||||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
|
||||||
if can_double_buffer:
|
|
||||||
if use_compute:
|
|
||||||
# Current in Compute, prefetch next to Prefetch region
|
|
||||||
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
|
|
||||||
else:
|
|
||||||
# Current in Prefetch, prefetch next to Compute region
|
|
||||||
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
|
||||||
else:
|
|
||||||
# Sync fallback: load next chunk to same slot (always compute region)
|
|
||||||
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
|
||||||
|
|
||||||
# Swap buffers for next iteration (only matters if can_double_buffer)
|
|
||||||
use_compute = not use_compute
|
|
||||||
|
|
||||||
# Now attend to Decode region (contains accumulated decode tokens)
|
|
||||||
pos_in_block = context.decode_pos_in_block
|
pos_in_block = context.decode_pos_in_block
|
||||||
start_pos = context.decode_start_pos_in_block
|
start_pos = context.decode_start_pos_in_block
|
||||||
num_accumulated = pos_in_block - start_pos + 1
|
num_accumulated = pos_in_block - start_pos + 1
|
||||||
|
|
||||||
# IMPORTANT: Sync compute_stream with default stream before reading decode_slot
|
# Sync compute_stream with default stream before reading decode_buffer
|
||||||
# store_kvcache writes to decode_slot on default stream (before entering this function)
|
compute_stream = offload_engine.compute_stream
|
||||||
# We need to ensure that write is complete before reading on compute_stream
|
|
||||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
if num_accumulated > 0:
|
if num_accumulated > 0:
|
||||||
# GPU cache has no layer dimension
|
# Read from per-layer decode buffer
|
||||||
decode_k = offload_engine.k_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
|
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
|
||||||
decode_v = offload_engine.v_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
|
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
|
||||||
decode_k = decode_k.unsqueeze(0)
|
decode_k = decode_k.unsqueeze(0)
|
||||||
decode_v = decode_v.unsqueeze(0)
|
decode_v = decode_v.unsqueeze(0)
|
||||||
|
|
||||||
@@ -566,7 +490,82 @@ class Attention(nn.Module):
|
|||||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||||
|
|
||||||
# Sync back to default stream before returning
|
# Sync back to default stream before returning
|
||||||
# Caller expects result to be ready on default stream
|
|
||||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||||
|
|
||||||
return o_acc
|
return o_acc
|
||||||
|
|
||||||
|
def _decode_ring_buffer_pipeline(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
load_slots: list,
|
||||||
|
offload_engine,
|
||||||
|
block_size: int,
|
||||||
|
last_block_valid_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
|
||||||
|
|
||||||
|
Loads one block at a time, computes attention, and merges results.
|
||||||
|
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
|
||||||
|
methods as prefill for proven correctness.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not load_slots:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
num_slots = len(load_slots)
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Phase 1: Pre-load up to num_slots blocks
|
||||||
|
num_preload = min(num_slots, num_blocks)
|
||||||
|
for i in range(num_preload):
|
||||||
|
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
||||||
|
|
||||||
|
# Phase 2: Process blocks with pipeline
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
current_slot = load_slots[block_idx % num_slots]
|
||||||
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
|
|
||||||
|
# Wait for current slot's transfer to complete
|
||||||
|
offload_engine.wait_slot_layer(current_slot)
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Get KV from slot
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||||
|
|
||||||
|
# Handle partial last block
|
||||||
|
is_last_block = (block_idx == num_blocks - 1)
|
||||||
|
if is_last_block and last_block_valid_tokens < block_size:
|
||||||
|
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
|
||||||
|
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k, prev_v,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record compute done for slot reuse
|
||||||
|
offload_engine.record_slot_compute_done(current_slot)
|
||||||
|
|
||||||
|
# Start loading next block (pipeline)
|
||||||
|
next_block_idx = block_idx + num_slots
|
||||||
|
if next_block_idx < num_blocks:
|
||||||
|
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
||||||
|
|
||||||
|
# Merge with accumulated
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|||||||
@@ -92,13 +92,14 @@ def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
|
|||||||
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
|
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
|
||||||
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
||||||
|
|
||||||
# Collect all K, V: prefill chunks from CPU cache + decode tokens from captures
|
# Collect all K, V: prefill chunks from captures + decode tokens from captures
|
||||||
|
# NOTE: We use prefill captures directly instead of CPU cache because
|
||||||
|
# the CPU block ID may not equal the chunk index.
|
||||||
all_k = []
|
all_k = []
|
||||||
all_v = []
|
all_v = []
|
||||||
|
|
||||||
# 1. Prefill chunks from CPU cache
|
# 1. Prefill chunks from captures (use captured K/V, not CPU cache)
|
||||||
for cidx in range(num_prefill_chunks):
|
for cidx in range(num_prefill_chunks):
|
||||||
# Get prefill capture to know the sequence length for this chunk
|
|
||||||
prefill_cap = None
|
prefill_cap = None
|
||||||
for c in prefill_captures:
|
for c in prefill_captures:
|
||||||
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
|
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
|
||||||
@@ -106,11 +107,9 @@ def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
|
|||||||
break
|
break
|
||||||
|
|
||||||
if prefill_cap is not None:
|
if prefill_cap is not None:
|
||||||
seq_len = prefill_cap['q'].shape[0]
|
# Use captured K/V directly (guaranteed to be correct layer data)
|
||||||
k = k_cache_cpu[layer_id, cidx, :seq_len].cuda()
|
all_k.append(prefill_cap['k'].cuda())
|
||||||
v = v_cache_cpu[layer_id, cidx, :seq_len].cuda()
|
all_v.append(prefill_cap['v'].cuda())
|
||||||
all_k.append(k)
|
|
||||||
all_v.append(v)
|
|
||||||
|
|
||||||
# 2. Decode tokens from captures (up to and including current step)
|
# 2. Decode tokens from captures (up to and including current step)
|
||||||
for step in range(decode_step + 1):
|
for step in range(decode_step + 1):
|
||||||
@@ -184,6 +183,184 @@ v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
|||||||
# Calculate number of prefill chunks
|
# Calculate number of prefill chunks
|
||||||
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
|
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
|
||||||
|
|
||||||
|
# Debug: Compare decode_buffer with captured K/V
|
||||||
|
print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===")
|
||||||
|
decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu()
|
||||||
|
for step in range(NUM_DECODE_TOKENS):
|
||||||
|
for layer_id in [0, 17, 35]: # Sample a few layers
|
||||||
|
# Find captured K for this step and layer
|
||||||
|
for c in decode_captures:
|
||||||
|
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
||||||
|
captured_k = c['k'].squeeze(0) # [kv_heads, head_dim]
|
||||||
|
buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim]
|
||||||
|
diff = (captured_k - buffer_k).abs().max().item()
|
||||||
|
print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Debug: Verify that decode_buffer slices match concatenated captures
|
||||||
|
print("\n=== DEBUG: Verifying decode_buffer slices ===")
|
||||||
|
for layer_id in [0]:
|
||||||
|
for decode_step in [1, 2]: # Check steps that use multiple tokens
|
||||||
|
# Build expected slice from captures
|
||||||
|
expected_k_list = []
|
||||||
|
for step in range(decode_step + 1):
|
||||||
|
for c in decode_captures:
|
||||||
|
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
||||||
|
expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim]
|
||||||
|
break
|
||||||
|
if expected_k_list:
|
||||||
|
expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim]
|
||||||
|
buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1]
|
||||||
|
diff = (expected_k - buffer_slice).abs().max().item()
|
||||||
|
print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}")
|
||||||
|
# Print first values
|
||||||
|
print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}")
|
||||||
|
if decode_step >= 1:
|
||||||
|
print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}")
|
||||||
|
|
||||||
|
# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading)
|
||||||
|
print("\n=== DEBUG: Expected K values for block 0, layer 0 ===")
|
||||||
|
for c in prefill_captures:
|
||||||
|
if c['layer_id'] == 0 and c['chunk_idx'] == 0:
|
||||||
|
print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}")
|
||||||
|
break
|
||||||
|
print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}")
|
||||||
|
|
||||||
|
# Debug: Compare CPU cache with captured prefill K/V
|
||||||
|
print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===")
|
||||||
|
for chunk_idx in [0, 7, 15]: # Sample a few chunks
|
||||||
|
for layer_id in [0, 17, 35]: # Sample a few layers
|
||||||
|
# Find captured K for this chunk and layer
|
||||||
|
for c in prefill_captures:
|
||||||
|
if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx:
|
||||||
|
captured_k = c['k'] # [seq_len, kv_heads, head_dim]
|
||||||
|
cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]]
|
||||||
|
diff = (captured_k - cpu_cache_k).abs().max().item()
|
||||||
|
print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Debug: Get cpu_block_table to check order
|
||||||
|
kvcache_manager = llm.model_runner.kvcache_manager
|
||||||
|
# Find the sequence (it should still exist)
|
||||||
|
from nanovllm.engine.sequence import Sequence
|
||||||
|
for attr_name in ['sequences', '_sequences', 'active_sequences']:
|
||||||
|
if hasattr(kvcache_manager, attr_name):
|
||||||
|
print(f"Found {attr_name}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Try to get cpu_block_table through a different way
|
||||||
|
print(f"\n=== DEBUG: CPU block order ===")
|
||||||
|
# For each prefill capture, check which CPU block it ended up in
|
||||||
|
for chunk_idx in range(num_prefill_chunks):
|
||||||
|
for c in prefill_captures:
|
||||||
|
if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx:
|
||||||
|
# Check if this chunk's K matches any CPU block
|
||||||
|
captured_k_first = c['k'][0, 0, 0].item()
|
||||||
|
for block_id in range(num_prefill_chunks):
|
||||||
|
cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item()
|
||||||
|
if abs(captured_k_first - cpu_k_first) < 1e-6:
|
||||||
|
print(f"Chunk {chunk_idx} -> CPU block {block_id}")
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
# Debug: Check reference vs actual for decode steps 0 and 1
|
||||||
|
# Also compute partial references (prefill only, decode only) to isolate the bug
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
for decode_step in [0, 1]:
|
||||||
|
print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===")
|
||||||
|
layer_id = 0
|
||||||
|
# Find the capture
|
||||||
|
for c in decode_captures:
|
||||||
|
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
|
||||||
|
q = c['q'].cuda() # [1, num_heads, head_dim]
|
||||||
|
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
||||||
|
|
||||||
|
# Build prefill K/V per-block for block-by-block reference
|
||||||
|
prefill_k_blocks = []
|
||||||
|
prefill_v_blocks = []
|
||||||
|
for cidx in range(num_prefill_chunks):
|
||||||
|
for pc in prefill_captures:
|
||||||
|
if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx:
|
||||||
|
prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim]
|
||||||
|
prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Build decode K/V
|
||||||
|
decode_k_list = []
|
||||||
|
decode_v_list = []
|
||||||
|
for step in range(decode_step + 1):
|
||||||
|
for dc in decode_captures:
|
||||||
|
if dc['layer_id'] == layer_id and dc['decode_step'] == step:
|
||||||
|
decode_k_list.append(dc['k'].cuda())
|
||||||
|
decode_v_list.append(dc['v'].cuda())
|
||||||
|
break
|
||||||
|
|
||||||
|
full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0)
|
||||||
|
full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0)
|
||||||
|
full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0)
|
||||||
|
full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0)
|
||||||
|
|
||||||
|
full_k = torch.cat([full_prefill_k, full_decode_k], dim=1)
|
||||||
|
full_v = torch.cat([full_prefill_v, full_decode_v], dim=1)
|
||||||
|
|
||||||
|
print(f"Q shape: {q_batched.shape}")
|
||||||
|
print(f"Prefill K shape: {full_prefill_k.shape}")
|
||||||
|
print(f"Decode K shape: {full_decode_k.shape}")
|
||||||
|
print(f"Full K shape: {full_k.shape}")
|
||||||
|
print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}")
|
||||||
|
|
||||||
|
# Reference output (single attention over all)
|
||||||
|
ref_output = flash_attn_func(
|
||||||
|
q_batched, full_k, full_v,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chunked reference: prefill attention + decode attention + merge
|
||||||
|
prefill_o, prefill_lse = flash_attn_with_lse(
|
||||||
|
q_batched, full_prefill_k, full_prefill_v,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
decode_o, decode_lse = flash_attn_with_lse(
|
||||||
|
q_batched, full_decode_k, full_decode_v,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse)
|
||||||
|
|
||||||
|
# Block-by-block reference (simulating ring buffer pipeline)
|
||||||
|
block_o_acc, block_lse_acc = None, None
|
||||||
|
for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)):
|
||||||
|
o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False)
|
||||||
|
if block_o_acc is None:
|
||||||
|
block_o_acc, block_lse_acc = o_blk, lse_blk
|
||||||
|
else:
|
||||||
|
block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk)
|
||||||
|
|
||||||
|
# Compare block-by-block vs single
|
||||||
|
block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item()
|
||||||
|
print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}")
|
||||||
|
|
||||||
|
# Compare full reference vs chunked reference
|
||||||
|
ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item()
|
||||||
|
print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}")
|
||||||
|
|
||||||
|
ref_output = ref_output.squeeze(0).squeeze(0).cpu()
|
||||||
|
chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu()
|
||||||
|
|
||||||
|
# Actual output
|
||||||
|
actual_output = c['output'].squeeze(0)
|
||||||
|
if actual_output.dim() == 3:
|
||||||
|
actual_output = actual_output.squeeze(0)
|
||||||
|
|
||||||
|
diff_ref = (actual_output - ref_output).abs()
|
||||||
|
diff_chunked = (actual_output - chunked_output_cpu).abs()
|
||||||
|
print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}")
|
||||||
|
print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}")
|
||||||
|
break
|
||||||
|
print()
|
||||||
|
|
||||||
# Verify decode outputs
|
# Verify decode outputs
|
||||||
all_passed = True
|
all_passed = True
|
||||||
|
|
||||||
@@ -208,7 +385,7 @@ for c in decode_captures:
|
|||||||
passed = max_diff < 1e-1
|
passed = max_diff < 1e-1
|
||||||
all_passed = all_passed and passed
|
all_passed = all_passed and passed
|
||||||
|
|
||||||
# if not passed:
|
if not passed:
|
||||||
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
||||||
|
|
||||||
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")
|
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")
|
||||||
|
|||||||
Reference in New Issue
Block a user