[WIP] fixing attention compute error.
This commit is contained in:
@@ -31,6 +31,8 @@ class LLMEngine:
|
|||||||
self.model_runner = ModelRunner(config, 0, self.events)
|
self.model_runner = ModelRunner(config, 0, self.events)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||||
config.eos = self.tokenizer.eos_token_id
|
config.eos = self.tokenizer.eos_token_id
|
||||||
|
# Set Sequence.block_size to match the KV cache block size
|
||||||
|
Sequence.block_size = config.kvcache_block_size
|
||||||
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
|
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
|
||||||
atexit.register(self.exit)
|
atexit.register(self.exit)
|
||||||
|
|
||||||
|
|||||||
@@ -521,6 +521,7 @@ class ModelRunner:
|
|||||||
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
||||||
|
|
||||||
# Sample from last logits
|
# Sample from last logits
|
||||||
|
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
||||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||||
|
|||||||
@@ -281,7 +281,11 @@ def _merge_lse_kernel(
|
|||||||
num_elements: tl.constexpr,
|
num_elements: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""Fused kernel for merging LSE values."""
|
"""Fused kernel for merging LSE values.
|
||||||
|
|
||||||
|
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
|
||||||
|
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
|
||||||
|
"""
|
||||||
# Each program handles BLOCK_SIZE elements
|
# Each program handles BLOCK_SIZE elements
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
@@ -289,21 +293,21 @@ def _merge_lse_kernel(
|
|||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
mask = offsets < num_elements
|
mask = offsets < num_elements
|
||||||
|
|
||||||
# Load lse values
|
# Load lse values and convert to fp32 for precision
|
||||||
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
|
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
|
||||||
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
|
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
|
||||||
|
|
||||||
# Compute max for numerical stability
|
# Compute max for numerical stability (in fp32)
|
||||||
max_lse = tl.maximum(lse1, lse2)
|
max_lse = tl.maximum(lse1, lse2)
|
||||||
|
|
||||||
# Compute exp(lse - max_lse)
|
# Compute exp(lse - max_lse) in fp32
|
||||||
exp1 = tl.exp(lse1 - max_lse)
|
exp1 = tl.exp(lse1 - max_lse)
|
||||||
exp2 = tl.exp(lse2 - max_lse)
|
exp2 = tl.exp(lse2 - max_lse)
|
||||||
|
|
||||||
# Compute merged LSE: max_lse + log(exp1 + exp2)
|
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
|
||||||
lse_merged = max_lse + tl.log(exp1 + exp2)
|
lse_merged = max_lse + tl.log(exp1 + exp2)
|
||||||
|
|
||||||
# Store result
|
# Store result (convert back to original dtype)
|
||||||
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@@ -313,7 +317,11 @@ def _merge_output_kernel(
|
|||||||
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""Fused kernel for merging attention outputs."""
|
"""Fused kernel for merging attention outputs.
|
||||||
|
|
||||||
|
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
|
||||||
|
This is critical for numerical accuracy in chunked attention.
|
||||||
|
"""
|
||||||
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
||||||
pid_batch = tl.program_id(0)
|
pid_batch = tl.program_id(0)
|
||||||
pid_seq = tl.program_id(1)
|
pid_seq = tl.program_id(1)
|
||||||
@@ -322,11 +330,11 @@ def _merge_output_kernel(
|
|||||||
# Compute LSE index: [batch, nheads, seqlen_q]
|
# Compute LSE index: [batch, nheads, seqlen_q]
|
||||||
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
||||||
|
|
||||||
# Load LSE values
|
# Load LSE values and convert to fp32 for precision
|
||||||
lse1 = tl.load(lse1_ptr + lse_idx)
|
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
|
||||||
lse2 = tl.load(lse2_ptr + lse_idx)
|
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
|
||||||
|
|
||||||
# Compute max and scaling factors
|
# Compute max and scaling factors in fp32
|
||||||
max_lse = tl.maximum(lse1, lse2)
|
max_lse = tl.maximum(lse1, lse2)
|
||||||
exp1 = tl.exp(lse1 - max_lse)
|
exp1 = tl.exp(lse1 - max_lse)
|
||||||
exp2 = tl.exp(lse2 - max_lse)
|
exp2 = tl.exp(lse2 - max_lse)
|
||||||
@@ -343,14 +351,14 @@ def _merge_output_kernel(
|
|||||||
pid_head * headdim)
|
pid_head * headdim)
|
||||||
o_idx = base_idx + d_idx
|
o_idx = base_idx + d_idx
|
||||||
|
|
||||||
# Load o1, o2
|
# Load o1, o2 and convert to fp32 for weighted sum
|
||||||
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
|
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||||
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
|
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
|
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||||
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
||||||
|
|
||||||
# Store result
|
# Store result (Triton will convert back to original dtype)
|
||||||
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -337,10 +337,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
block = self.logical_blocks[logical_id]
|
block = self.logical_blocks[logical_id]
|
||||||
if block.location == BlockLocation.CPU:
|
if block.location == BlockLocation.CPU:
|
||||||
cpu_blocks.append(block.cpu_block_id)
|
cpu_blocks.append(block.cpu_block_id)
|
||||||
logger.debug(
|
# logger.debug(
|
||||||
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||||
f"returned cpu_blocks={cpu_blocks}"
|
# f"returned cpu_blocks={cpu_blocks}"
|
||||||
)
|
# )
|
||||||
return cpu_blocks
|
return cpu_blocks
|
||||||
|
|
||||||
# ========== Ring Buffer CPU-primary support ==========
|
# ========== Ring Buffer CPU-primary support ==========
|
||||||
|
|||||||
@@ -538,7 +538,7 @@ class OffloadEngine:
|
|||||||
|
|
||||||
def sync_indices(self) -> None:
|
def sync_indices(self) -> None:
|
||||||
"""Synchronize to ensure all index updates are complete."""
|
"""Synchronize to ensure all index updates are complete."""
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.cuda.default_stream().synchronize()
|
||||||
|
|
||||||
# ========== Cache access methods ==========
|
# ========== Cache access methods ==========
|
||||||
|
|
||||||
@@ -682,8 +682,9 @@ class OffloadEngine:
|
|||||||
Async load a single CPU block to a ring buffer slot for one layer.
|
Async load a single CPU block to a ring buffer slot for one layer.
|
||||||
|
|
||||||
This is the core building block for ring buffer pipelining.
|
This is the core building block for ring buffer pipelining.
|
||||||
Before starting the transfer, waits for any previous compute on this slot
|
Before starting the transfer, waits for:
|
||||||
to complete (using compute_done event).
|
1. Any previous compute on this slot to complete
|
||||||
|
2. Any pending offload of this slot to complete
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
slot_idx: Target GPU slot index
|
slot_idx: Target GPU slot index
|
||||||
@@ -701,6 +702,10 @@ class OffloadEngine:
|
|||||||
# This prevents data race: transfer must not start until attention finishes reading
|
# This prevents data race: transfer must not start until attention finishes reading
|
||||||
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
|
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
|
||||||
|
|
||||||
|
# Also wait for any pending offload of this slot to complete
|
||||||
|
# This prevents race: load must not write GPU slot while offload is reading from it
|
||||||
|
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
|
||||||
|
|
||||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||||
)
|
)
|
||||||
@@ -763,7 +768,11 @@ class OffloadEngine:
|
|||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
|
# Wait for both compute_stream and default stream
|
||||||
|
# - compute_stream: for flash attention operations
|
||||||
|
# - default_stream: for store_kvcache which runs on default stream
|
||||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||||
|
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
||||||
memcpy_2d_async(
|
memcpy_2d_async(
|
||||||
self.k_cache_cpu[:, cpu_block_id],
|
self.k_cache_cpu[:, cpu_block_id],
|
||||||
self.k_cache_gpu[:, slot_idx],
|
self.k_cache_gpu[:, slot_idx],
|
||||||
@@ -793,7 +802,9 @@ class OffloadEngine:
|
|||||||
cpu_block_id: Target CPU block ID
|
cpu_block_id: Target CPU block ID
|
||||||
"""
|
"""
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
|
# Wait for both compute_stream and default stream
|
||||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||||
|
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
|
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -169,9 +169,11 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Use ring buffer pipeline
|
# Use ring buffer pipeline
|
||||||
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
|
current_chunk_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Compute attention against current chunk's KV (with causal mask)
|
# Compute attention against current chunk's KV (with causal mask)
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
@@ -187,11 +189,18 @@ class Attention(nn.Module):
|
|||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
final_o = current_o
|
final_o = current_o
|
||||||
else:
|
else:
|
||||||
|
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
|
||||||
|
# reading it on the default stream for the merge operation.
|
||||||
|
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||||
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
|
|
||||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||||
return final_o.squeeze(0)
|
return final_o.squeeze(0)
|
||||||
|
|
||||||
@@ -205,24 +214,27 @@ class Attention(nn.Module):
|
|||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
o_acc, lse_acc = None, None
|
o_acc, lse_acc = None, None
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
||||||
# Load to slot 0 (single slot)
|
# Load to slot 0 (single slot)
|
||||||
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
||||||
offload_engine.wait_slot_layer(0, self.layer_id)
|
offload_engine.wait_slot_layer(0, self.layer_id)
|
||||||
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
||||||
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
q_batched, prev_k, prev_v,
|
q_batched, prev_k, prev_v,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
else:
|
else:
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
return o_acc, lse_acc
|
return o_acc, lse_acc
|
||||||
|
|
||||||
@@ -232,6 +244,7 @@ class Attention(nn.Module):
|
|||||||
cpu_block_table: list,
|
cpu_block_table: list,
|
||||||
load_slots: list,
|
load_slots: list,
|
||||||
offload_engine,
|
offload_engine,
|
||||||
|
current_chunk_idx: int = -1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ring buffer async pipeline loading with double buffering.
|
Ring buffer async pipeline loading with double buffering.
|
||||||
@@ -269,22 +282,26 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
if pipeline_depth == 1:
|
if pipeline_depth == 1:
|
||||||
# Only 1 slot available, cannot pipeline - use synchronous mode
|
# Only 1 slot available, cannot pipeline - use synchronous mode
|
||||||
|
# IMPORTANT: Must use compute_stream to match synchronization in
|
||||||
|
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
|
||||||
slot = load_slots[0]
|
slot = load_slots[0]
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
for block_idx in range(num_blocks):
|
for block_idx in range(num_blocks):
|
||||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
|
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
|
||||||
offload_engine.wait_slot_layer(slot, self.layer_id)
|
offload_engine.wait_slot_layer(slot, self.layer_id)
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
with torch.cuda.stream(compute_stream):
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
||||||
q_batched, prev_k, prev_v,
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
softmax_scale=self.scale,
|
q_batched, prev_k, prev_v,
|
||||||
causal=False,
|
softmax_scale=self.scale,
|
||||||
)
|
causal=False,
|
||||||
# Record compute done so next load can safely reuse this slot
|
)
|
||||||
offload_engine.record_slot_compute_done(slot, self.layer_id)
|
# Record compute done so next load can safely reuse this slot
|
||||||
if o_acc is None:
|
offload_engine.record_slot_compute_done(slot, self.layer_id)
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
if o_acc is None:
|
||||||
else:
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
return o_acc, lse_acc
|
return o_acc, lse_acc
|
||||||
|
|
||||||
# N-way pipeline: use ALL available slots for maximum overlap
|
# N-way pipeline: use ALL available slots for maximum overlap
|
||||||
@@ -378,12 +395,13 @@ class Attention(nn.Module):
|
|||||||
kvcache_manager = context.kvcache_manager
|
kvcache_manager = context.kvcache_manager
|
||||||
seq = context.chunked_seq
|
seq = context.chunked_seq
|
||||||
|
|
||||||
# Get all CPU blocks for this sequence
|
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||||
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
# The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU
|
||||||
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
if self.layer_id == 0:
|
if self.layer_id == 0:
|
||||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||||
if not cpu_block_table:
|
if not cpu_block_table:
|
||||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# Apply sparse policy if enabled
|
||||||
if kvcache_manager.sparse_policy is not None:
|
if kvcache_manager.sparse_policy is not None:
|
||||||
@@ -401,12 +419,17 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
# Chunk size = capacity of each double buffer region (compute/prefetch)
|
# Chunk size = capacity of each double buffer region (compute/prefetch)
|
||||||
# Each region uses half of decode_load_slots
|
# Each region uses half of decode_load_slots
|
||||||
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
|
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
|
||||||
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
||||||
|
|
||||||
|
# Check if double buffering is possible (need at least 2 separate regions)
|
||||||
|
# With only 1 load slot, compute and prefetch regions overlap -> can't double buffer
|
||||||
|
can_double_buffer = len(offload_engine.decode_load_slots) >= 2
|
||||||
|
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
|
|
||||||
@@ -422,49 +445,53 @@ class Attention(nn.Module):
|
|||||||
end = min(start + chunk_size, len(cpu_block_table))
|
end = min(start + chunk_size, len(cpu_block_table))
|
||||||
num_blocks_in_chunk = end - start
|
num_blocks_in_chunk = end - start
|
||||||
|
|
||||||
# Wait for current buffer to be ready
|
# Wait for current buffer to be ready on compute_stream
|
||||||
if use_compute:
|
# The load runs on transfer_stream_main, compute runs on compute_stream
|
||||||
offload_engine.wait_compute_layer(self.layer_id)
|
compute_stream.wait_stream(offload_engine.transfer_stream_main)
|
||||||
else:
|
|
||||||
offload_engine.wait_prefetch_layer(self.layer_id)
|
|
||||||
|
|
||||||
# Trigger async prefetch of next chunk to the OTHER buffer
|
# All computation on explicit compute_stream
|
||||||
# This overlaps transfer with current chunk's computation
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Get KV from current buffer FIRST, before prefetching overwrites it
|
||||||
|
if use_compute:
|
||||||
|
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||||
|
self.layer_id, num_blocks_in_chunk
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
|
||||||
|
self.layer_id, num_blocks_in_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute attention for this chunk
|
||||||
|
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||||
|
q_batched, k_chunk, v_chunk,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge with accumulated
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = o_chunk, lse_chunk
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||||
|
|
||||||
|
# Trigger async prefetch/load of next chunk to the OTHER buffer
|
||||||
|
# This happens AFTER attention completes, so the data is no longer needed
|
||||||
if chunk_idx + 1 < num_chunks:
|
if chunk_idx + 1 < num_chunks:
|
||||||
next_start = end
|
next_start = end
|
||||||
next_end = min(next_start + chunk_size, len(cpu_block_table))
|
next_end = min(next_start + chunk_size, len(cpu_block_table))
|
||||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||||
if use_compute:
|
if can_double_buffer:
|
||||||
# Current in Compute, prefetch next to Prefetch region
|
if use_compute:
|
||||||
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
|
# Current in Compute, prefetch next to Prefetch region
|
||||||
|
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
|
||||||
|
else:
|
||||||
|
# Current in Prefetch, prefetch next to Compute region
|
||||||
|
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
||||||
else:
|
else:
|
||||||
# Current in Prefetch, prefetch next to Compute region
|
# Sync fallback: load next chunk to same slot (always compute region)
|
||||||
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
||||||
|
|
||||||
# Get KV from current buffer
|
# Swap buffers for next iteration (only matters if can_double_buffer)
|
||||||
if use_compute:
|
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
|
||||||
self.layer_id, num_blocks_in_chunk
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
|
|
||||||
self.layer_id, num_blocks_in_chunk
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute attention for this chunk
|
|
||||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
|
||||||
q_batched, k_chunk, v_chunk,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge with accumulated
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = o_chunk, lse_chunk
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
|
||||||
|
|
||||||
# Swap buffers for next iteration
|
|
||||||
use_compute = not use_compute
|
use_compute = not use_compute
|
||||||
|
|
||||||
# Now attend to Decode region (contains accumulated decode tokens)
|
# Now attend to Decode region (contains accumulated decode tokens)
|
||||||
@@ -472,24 +499,29 @@ class Attention(nn.Module):
|
|||||||
start_pos = context.decode_start_pos_in_block
|
start_pos = context.decode_start_pos_in_block
|
||||||
num_accumulated = pos_in_block - start_pos + 1
|
num_accumulated = pos_in_block - start_pos + 1
|
||||||
|
|
||||||
if num_accumulated > 0:
|
with torch.cuda.stream(compute_stream):
|
||||||
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
if num_accumulated > 0:
|
||||||
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||||
decode_k = decode_k.unsqueeze(0)
|
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||||
decode_v = decode_v.unsqueeze(0)
|
decode_k = decode_k.unsqueeze(0)
|
||||||
|
decode_v = decode_v.unsqueeze(0)
|
||||||
|
|
||||||
decode_o, decode_lse = flash_attn_with_lse(
|
decode_o, decode_lse = flash_attn_with_lse(
|
||||||
q_batched, decode_k, decode_v,
|
q_batched, decode_k, decode_v,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
o_acc = decode_o
|
o_acc = decode_o
|
||||||
else:
|
else:
|
||||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||||
|
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||||
|
|
||||||
|
# Sync back to default stream before returning
|
||||||
|
# Caller expects result to be ready on default stream
|
||||||
|
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||||
|
|
||||||
return o_acc
|
return o_acc
|
||||||
|
|||||||
@@ -93,9 +93,9 @@ TEST_CASES = [
|
|||||||
(1, 4, 256, 8, 128),
|
(1, 4, 256, 8, 128),
|
||||||
(1, 4, 512, 8, 128),
|
(1, 4, 512, 8, 128),
|
||||||
(1, 8, 512, 8, 128),
|
(1, 8, 512, 8, 128),
|
||||||
(1, 4, 1024, 8, 128),
|
(1, 32, 1024, 8, 128),
|
||||||
(1, 4, 1024, 32, 128), # More heads
|
(1, 32, 1024, 32, 128), # More heads
|
||||||
(1, 8, 256, 8, 64), # Smaller head dim
|
(1, 32, 256, 8, 64), # Smaller head dim
|
||||||
]
|
]
|
||||||
|
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
|||||||
374
tests/test_chunked_decode_hook.py
Normal file
374
tests/test_chunked_decode_hook.py
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
"""
|
||||||
|
Hook-based correctness test for chunked decode attention.
|
||||||
|
|
||||||
|
Uses PyTorch register_forward_hook() to capture real inference I/O,
|
||||||
|
then compares against reference computation to locate bugs.
|
||||||
|
|
||||||
|
This test targets the decode phase with CPU offload - after prefill,
|
||||||
|
the model generates tokens one by one while attending to all previous context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from random import randint, seed
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.utils.context import get_context
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||||
|
MAX_MODEL_LEN = 8 * 1024
|
||||||
|
NUM_GPU_BLOCKS = 2
|
||||||
|
INPUT_LEN = 2 * 1024 # 2K tokens for prefill
|
||||||
|
NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Global capture storage
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
captures = []
|
||||||
|
prefill_kv = {} # Store prefill k,v for reference computation
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Hook Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def make_hook(layer_id):
|
||||||
|
"""Create a forward hook for a specific layer."""
|
||||||
|
def hook(module, inputs, output):
|
||||||
|
q, k, v = inputs
|
||||||
|
ctx = get_context()
|
||||||
|
|
||||||
|
is_prefill = ctx.is_prefill
|
||||||
|
|
||||||
|
capture_entry = {
|
||||||
|
'layer_id': layer_id,
|
||||||
|
'is_prefill': is_prefill,
|
||||||
|
'q': q.clone().cpu(),
|
||||||
|
'k': k.clone().cpu(),
|
||||||
|
'v': v.clone().cpu(),
|
||||||
|
'output': output.clone().cpu(),
|
||||||
|
'is_chunked_prefill': ctx.is_chunked_prefill,
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_prefill:
|
||||||
|
# Store prefill k,v for reference computation
|
||||||
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||||
|
capture_entry['chunk_idx'] = chunk_idx
|
||||||
|
if layer_id not in prefill_kv:
|
||||||
|
prefill_kv[layer_id] = []
|
||||||
|
prefill_kv[layer_id].append({
|
||||||
|
'chunk_idx': chunk_idx,
|
||||||
|
'k': k.clone().cpu(),
|
||||||
|
'v': v.clone().cpu(),
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Decode phase - capture decode token info
|
||||||
|
capture_entry['decode_step'] = len([c for c in captures
|
||||||
|
if c['layer_id'] == layer_id and not c['is_prefill']])
|
||||||
|
|
||||||
|
captures.append(capture_entry)
|
||||||
|
return hook
|
||||||
|
|
||||||
|
|
||||||
|
def register_hooks(llm):
|
||||||
|
"""Register forward hooks on all Attention modules."""
|
||||||
|
hooks = []
|
||||||
|
model = llm.model_runner.model
|
||||||
|
|
||||||
|
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
||||||
|
attn_module = decoder_layer.self_attn.attn
|
||||||
|
hook = attn_module.register_forward_hook(make_hook(layer_idx))
|
||||||
|
hooks.append(hook)
|
||||||
|
|
||||||
|
return hooks
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Reference Computation
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def compute_decode_reference(layer_id, decode_step, scale, debug=False):
|
||||||
|
"""
|
||||||
|
Compute reference decode attention output for a specific layer.
|
||||||
|
|
||||||
|
For decode, the query is a single token that attends to:
|
||||||
|
1. All prefill KV (from CPU cache)
|
||||||
|
2. All previous decode tokens (stored in GPU decode slot)
|
||||||
|
"""
|
||||||
|
# Get the decode capture
|
||||||
|
decode_captures = [c for c in captures
|
||||||
|
if c['layer_id'] == layer_id and not c['is_prefill']]
|
||||||
|
if decode_step >= len(decode_captures):
|
||||||
|
return None
|
||||||
|
|
||||||
|
decode_capture = decode_captures[decode_step]
|
||||||
|
q = decode_capture['q'].cuda() # [1, num_heads, head_dim]
|
||||||
|
q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" Reference for L{layer_id} D{decode_step}:")
|
||||||
|
print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}")
|
||||||
|
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
|
||||||
|
# Attend to all prefill chunks
|
||||||
|
if layer_id in prefill_kv:
|
||||||
|
for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']):
|
||||||
|
k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim]
|
||||||
|
v = chunk_data['v'].cuda().unsqueeze(0)
|
||||||
|
|
||||||
|
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}")
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = o, lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
|
||||||
|
|
||||||
|
# Attend to previous decode tokens (including current)
|
||||||
|
# In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens
|
||||||
|
# For step 0, we just have the current token's k,v
|
||||||
|
# For step 1, we have tokens 0 and 1's k,v
|
||||||
|
# etc.
|
||||||
|
|
||||||
|
# Collect k,v from all decode steps up to and including current
|
||||||
|
decode_kv = []
|
||||||
|
for i in range(decode_step + 1):
|
||||||
|
if i < len(decode_captures):
|
||||||
|
decode_kv.append({
|
||||||
|
'k': decode_captures[i]['k'].cuda(),
|
||||||
|
'v': decode_captures[i]['v'].cuda(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if decode_kv:
|
||||||
|
# Stack decode k,v into a single tensor
|
||||||
|
decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim]
|
||||||
|
decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}")
|
||||||
|
|
||||||
|
# For decode, we use causal=False since we're attending to all decode tokens
|
||||||
|
# (the causal masking was already handled by only including tokens up to current)
|
||||||
|
o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v,
|
||||||
|
softmax_scale=scale, causal=False)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" Decode attention: o.mean={o_decode.mean().item():.6f}")
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = o_decode, lse_decode
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode)
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" Final: o.mean={o_acc.mean().item():.6f}")
|
||||||
|
|
||||||
|
return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Test Runner
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def run_test(verbose=True):
|
||||||
|
"""Run the hook-based chunked decode correctness test."""
|
||||||
|
global captures, prefill_kv
|
||||||
|
captures = []
|
||||||
|
prefill_kv = {}
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("=" * 70)
|
||||||
|
print("Test: Hook-Based Chunked Decode Correctness")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"Model: {MODEL_PATH}")
|
||||||
|
print(f"Input length: {INPUT_LEN} tokens")
|
||||||
|
print(f"Decode tokens: {NUM_DECODE_TOKENS}")
|
||||||
|
print(f"Block size: {BLOCK_SIZE}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Initialize LLM with CPU offload
|
||||||
|
llm = LLM(
|
||||||
|
MODEL_PATH,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||||
|
enable_cpu_offload=True,
|
||||||
|
kvcache_block_size=BLOCK_SIZE,
|
||||||
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get model info
|
||||||
|
num_layers = len(llm.model_runner.model.model.layers)
|
||||||
|
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
||||||
|
scale = head_dim ** -0.5
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Num layers: {num_layers}")
|
||||||
|
print(f"Head dim: {head_dim}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Register hooks
|
||||||
|
hooks = register_hooks(llm)
|
||||||
|
if verbose:
|
||||||
|
print(f"Registered {len(hooks)} hooks")
|
||||||
|
|
||||||
|
# Generate random prompt
|
||||||
|
seed(42)
|
||||||
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||||
|
|
||||||
|
# Run prefill and decode
|
||||||
|
if verbose:
|
||||||
|
print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...")
|
||||||
|
sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS)
|
||||||
|
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
|
|
||||||
|
# Remove hooks
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
# =========== VERIFICATION: Check CPU cache after prefill ===========
|
||||||
|
# Verify that CPU cache data matches captured prefill k,v
|
||||||
|
if verbose:
|
||||||
|
print("\n--- CPU Cache Verification (After Prefill) ---")
|
||||||
|
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
||||||
|
|
||||||
|
# For each prefill capture, check if CPU cache matches
|
||||||
|
for layer_id in [0]: # Only check layer 0 for brevity
|
||||||
|
if layer_id not in prefill_kv:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for chunk_data in prefill_kv[layer_id]:
|
||||||
|
chunk_idx = chunk_data['chunk_idx']
|
||||||
|
captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim]
|
||||||
|
|
||||||
|
# CPU block ID should be chunk_idx (based on allocation order)
|
||||||
|
cpu_block_id = chunk_idx
|
||||||
|
cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu()
|
||||||
|
|
||||||
|
diff = (captured_k - cpu_k).abs().max().item()
|
||||||
|
print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}")
|
||||||
|
if diff > 1e-3:
|
||||||
|
print(f" WARNING: CPU cache doesn't match captured k!")
|
||||||
|
print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}")
|
||||||
|
print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Analyze captures
|
||||||
|
prefill_count = sum(1 for c in captures if c['is_prefill'])
|
||||||
|
decode_count = sum(1 for c in captures if not c['is_prefill'])
|
||||||
|
if verbose:
|
||||||
|
print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls")
|
||||||
|
|
||||||
|
# Count decode steps per layer
|
||||||
|
decode_per_layer = {}
|
||||||
|
for c in captures:
|
||||||
|
if not c['is_prefill']:
|
||||||
|
layer_id = c['layer_id']
|
||||||
|
if layer_id not in decode_per_layer:
|
||||||
|
decode_per_layer[layer_id] = 0
|
||||||
|
decode_per_layer[layer_id] += 1
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Decode calls per layer: {decode_per_layer}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Verify decode correctness
|
||||||
|
all_passed = True
|
||||||
|
results = []
|
||||||
|
first_fail_debug = True
|
||||||
|
|
||||||
|
for c in captures:
|
||||||
|
if c['is_prefill']:
|
||||||
|
continue # Skip prefill (already tested in test_chunked_prefill_hook.py)
|
||||||
|
|
||||||
|
layer_id = c['layer_id']
|
||||||
|
decode_step = c['decode_step']
|
||||||
|
|
||||||
|
# Only test first decode step for now (simpler reference computation)
|
||||||
|
if decode_step > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compute reference (debug first failure)
|
||||||
|
debug_this = (layer_id == 0 and first_fail_debug)
|
||||||
|
ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this)
|
||||||
|
if ref_output is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
actual_output = c['output'].squeeze(0) # Remove seq dim for decode
|
||||||
|
if actual_output.dim() == 3:
|
||||||
|
actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case
|
||||||
|
|
||||||
|
diff = (actual_output - ref_output).abs()
|
||||||
|
max_diff = diff.max().item()
|
||||||
|
mean_diff = diff.mean().item()
|
||||||
|
|
||||||
|
tol = 1e-2
|
||||||
|
passed = max_diff < tol
|
||||||
|
all_passed = all_passed and passed
|
||||||
|
|
||||||
|
status = "PASS" if passed else "FAIL"
|
||||||
|
results.append((layer_id, decode_step, passed, max_diff, mean_diff))
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: "
|
||||||
|
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
||||||
|
|
||||||
|
# Debug first failure
|
||||||
|
if not passed and first_fail_debug:
|
||||||
|
first_fail_debug = False
|
||||||
|
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
|
||||||
|
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
|
||||||
|
# Find where max diff is
|
||||||
|
max_idx = diff.argmax()
|
||||||
|
flat_actual = actual_output.flatten()
|
||||||
|
flat_ref = ref_output.flatten()
|
||||||
|
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
total_tests = len(results)
|
||||||
|
passed_count = sum(1 for r in results if r[2])
|
||||||
|
|
||||||
|
print(f"Results: {passed_count}/{total_tests} tests passed")
|
||||||
|
|
||||||
|
if not all_passed:
|
||||||
|
print("\nFailed tests:")
|
||||||
|
for layer_id, decode_step, passed, max_diff, mean_diff in results:
|
||||||
|
if not passed:
|
||||||
|
print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
return all_passed
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
passed = run_test(verbose=True)
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
print("test_chunked_decode_hook: PASSED")
|
||||||
|
else:
|
||||||
|
print("test_chunked_decode_hook: FAILED")
|
||||||
|
exit(1)
|
||||||
473
tests/test_chunked_prefill_hook.py
Normal file
473
tests/test_chunked_prefill_hook.py
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
"""
|
||||||
|
Hook-based correctness test for chunked prefill attention.
|
||||||
|
|
||||||
|
Uses PyTorch register_forward_hook() to capture real inference I/O,
|
||||||
|
then compares against reference computation to locate bugs.
|
||||||
|
|
||||||
|
This test targets the integration layer (context setup, cpu_block_table management)
|
||||||
|
which is where the needle test fails despite isolated attention tests passing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from random import randint, seed
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.utils.context import get_context
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
|
MAX_MODEL_LEN = 32 * 1024
|
||||||
|
NUM_GPU_BLOCKS = 2
|
||||||
|
INPUT_LEN = 16 * 1024 # 4K tokens = 4 chunks with 1K block size
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Global capture storage
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
captures = []
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Hook Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def make_hook(layer_id):
|
||||||
|
"""Create a forward hook for a specific layer."""
|
||||||
|
def hook(module, inputs, output):
|
||||||
|
q, k, v = inputs
|
||||||
|
ctx = get_context()
|
||||||
|
|
||||||
|
# Only capture prefill phase
|
||||||
|
if not ctx.is_prefill:
|
||||||
|
return
|
||||||
|
|
||||||
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||||
|
|
||||||
|
capture_entry = {
|
||||||
|
'layer_id': layer_id,
|
||||||
|
'chunk_idx': chunk_idx,
|
||||||
|
'q': q.clone().cpu(),
|
||||||
|
'k': k.clone().cpu(),
|
||||||
|
'v': v.clone().cpu(),
|
||||||
|
'output': output.clone().cpu(),
|
||||||
|
'is_chunked_prefill': ctx.is_chunked_prefill,
|
||||||
|
}
|
||||||
|
|
||||||
|
# For debugging: also capture CPU cache state for layer 0
|
||||||
|
if layer_id == 0 and chunk_idx >= 2:
|
||||||
|
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||||
|
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||||
|
oe = kvcache_manager.offload_engine
|
||||||
|
# Get what should have been loaded from CPU
|
||||||
|
cpu_k0 = oe.k_cache_cpu[0, 0].clone().cpu() # Layer 0, CPU block 0
|
||||||
|
cpu_k1 = oe.k_cache_cpu[0, 1].clone().cpu() # Layer 0, CPU block 1
|
||||||
|
capture_entry['cpu_k0'] = cpu_k0
|
||||||
|
capture_entry['cpu_k1'] = cpu_k1
|
||||||
|
|
||||||
|
captures.append(capture_entry)
|
||||||
|
return hook
|
||||||
|
|
||||||
|
|
||||||
|
def register_hooks(llm):
|
||||||
|
"""Register forward hooks on all Attention modules."""
|
||||||
|
hooks = []
|
||||||
|
model = llm.model_runner.model
|
||||||
|
|
||||||
|
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
||||||
|
attn_module = decoder_layer.self_attn.attn
|
||||||
|
hook = attn_module.register_forward_hook(make_hook(layer_idx))
|
||||||
|
hooks.append(hook)
|
||||||
|
|
||||||
|
return hooks
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Reference Computation
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def compute_reference(layer_id, chunk_idx, scale, debug=False):
|
||||||
|
"""
|
||||||
|
Compute reference attention output for a specific layer and chunk.
|
||||||
|
|
||||||
|
Uses the captured k, v from all chunks up to and including chunk_idx.
|
||||||
|
"""
|
||||||
|
# Filter captures for this layer
|
||||||
|
layer_captures = [c for c in captures
|
||||||
|
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
|
||||||
|
|
||||||
|
if not layer_captures:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get current chunk's q
|
||||||
|
current_capture = [c for c in layer_captures if c['chunk_idx'] == chunk_idx][0]
|
||||||
|
q = current_capture['q'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
|
||||||
|
|
||||||
|
# Collect all k, v up to current chunk
|
||||||
|
kv_list = []
|
||||||
|
for c in sorted(layer_captures, key=lambda x: x['chunk_idx']):
|
||||||
|
k = c['k'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
|
||||||
|
v = c['v'].cuda().unsqueeze(0)
|
||||||
|
kv_list.append((k, v, c['chunk_idx']))
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" Reference for L{layer_id} C{chunk_idx}:")
|
||||||
|
print(f" q shape: {q.shape}, mean={q.mean().item():.4f}")
|
||||||
|
print(f" kv_list: {len(kv_list)} chunks")
|
||||||
|
for i, (k, v, cidx) in enumerate(kv_list):
|
||||||
|
print(f" chunk {cidx}: k.mean={k.mean().item():.4f}, v.mean={v.mean().item():.4f}")
|
||||||
|
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
|
||||||
|
# Previous chunks: non-causal attention
|
||||||
|
for i in range(len(kv_list) - 1):
|
||||||
|
k, v, _ = kv_list[i]
|
||||||
|
o, lse = flash_attn_with_lse(q, k, v, softmax_scale=scale, causal=False)
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = o, lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
|
||||||
|
|
||||||
|
# Current chunk: causal attention
|
||||||
|
k_cur, v_cur, _ = kv_list[-1]
|
||||||
|
o_cur, lse_cur = flash_attn_with_lse(q, k_cur, v_cur, softmax_scale=scale, causal=True)
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
return o_cur.squeeze(0).cpu()
|
||||||
|
|
||||||
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, o_cur, lse_cur)
|
||||||
|
return final_o.squeeze(0).cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
|
||||||
|
"""
|
||||||
|
Compute reference using standard flash attention (single pass with all K, V).
|
||||||
|
|
||||||
|
This simulates what standard (non-chunked) prefill would produce.
|
||||||
|
Concatenates all Q, K, V from chunks 0 to chunk_idx and runs a single
|
||||||
|
causal attention pass, then extracts the output for the current chunk.
|
||||||
|
"""
|
||||||
|
# Filter captures for this layer
|
||||||
|
layer_captures = [c for c in captures
|
||||||
|
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
|
||||||
|
|
||||||
|
if not layer_captures:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort by chunk index
|
||||||
|
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
|
||||||
|
|
||||||
|
# Concatenate all Q, K, V
|
||||||
|
all_q = []
|
||||||
|
all_k = []
|
||||||
|
all_v = []
|
||||||
|
chunk_lengths = []
|
||||||
|
|
||||||
|
for c in layer_captures:
|
||||||
|
q = c['q'].cuda() # [seqlen, nheads, headdim]
|
||||||
|
k = c['k'].cuda()
|
||||||
|
v = c['v'].cuda()
|
||||||
|
all_q.append(q)
|
||||||
|
all_k.append(k)
|
||||||
|
all_v.append(v)
|
||||||
|
chunk_lengths.append(q.shape[0])
|
||||||
|
|
||||||
|
# Concatenate along sequence dimension
|
||||||
|
full_q = torch.cat(all_q, dim=0) # [total_seqlen, nheads, headdim]
|
||||||
|
full_k = torch.cat(all_k, dim=0)
|
||||||
|
full_v = torch.cat(all_v, dim=0)
|
||||||
|
|
||||||
|
total_len = full_q.shape[0]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" Standard Reference for L{layer_id} C{chunk_idx}:")
|
||||||
|
print(f" full_q shape: {full_q.shape}, mean={full_q.mean().item():.4f}")
|
||||||
|
print(f" full_k shape: {full_k.shape}, mean={full_k.mean().item():.4f}")
|
||||||
|
print(f" chunk_lengths: {chunk_lengths}")
|
||||||
|
|
||||||
|
# Run standard causal flash attention
|
||||||
|
# flash_attn_varlen_func expects: q, k, v with shape [total_seqlen, nheads, headdim]
|
||||||
|
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
|
full_o = flash_attn_varlen_func(
|
||||||
|
full_q, full_k, full_v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=total_len,
|
||||||
|
max_seqlen_k=total_len,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract output for current chunk only
|
||||||
|
start_pos = sum(chunk_lengths[:-1])
|
||||||
|
end_pos = sum(chunk_lengths)
|
||||||
|
chunk_output = full_o[start_pos:end_pos]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" full_o shape: {full_o.shape}")
|
||||||
|
print(f" extracting positions [{start_pos}:{end_pos}]")
|
||||||
|
print(f" chunk_output shape: {chunk_output.shape}, mean={chunk_output.mean().item():.4f}")
|
||||||
|
|
||||||
|
return chunk_output.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Test Runner
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def run_test(verbose=True):
|
||||||
|
"""Run the hook-based chunked prefill correctness test."""
|
||||||
|
global captures
|
||||||
|
captures = []
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("=" * 70)
|
||||||
|
print("Test: Hook-Based Chunked Prefill Correctness")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"Model: {MODEL_PATH}")
|
||||||
|
print(f"Input length: {INPUT_LEN} tokens")
|
||||||
|
print(f"Block size: {BLOCK_SIZE}")
|
||||||
|
print(f"Expected chunks: {INPUT_LEN // BLOCK_SIZE}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Initialize LLM with CPU offload
|
||||||
|
llm = LLM(
|
||||||
|
MODEL_PATH,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||||
|
enable_cpu_offload=True,
|
||||||
|
kvcache_block_size=BLOCK_SIZE,
|
||||||
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get model info
|
||||||
|
num_layers = len(llm.model_runner.model.model.layers)
|
||||||
|
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
||||||
|
scale = head_dim ** -0.5
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Num layers: {num_layers}")
|
||||||
|
print(f"Head dim: {head_dim}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Register hooks
|
||||||
|
hooks = register_hooks(llm)
|
||||||
|
if verbose:
|
||||||
|
print(f"Registered {len(hooks)} hooks")
|
||||||
|
|
||||||
|
# Generate random prompt
|
||||||
|
seed(42)
|
||||||
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||||
|
|
||||||
|
# Run prefill only (max_tokens=1)
|
||||||
|
if verbose:
|
||||||
|
print("Running inference...")
|
||||||
|
sampling_params = SamplingParams(temperature=0.6, max_tokens=1)
|
||||||
|
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
|
|
||||||
|
# Remove hooks
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
# Analyze captures
|
||||||
|
if verbose:
|
||||||
|
print(f"\nCaptured {len(captures)} attention calls")
|
||||||
|
|
||||||
|
# Group by layer and chunk
|
||||||
|
chunks_per_layer = {}
|
||||||
|
for c in captures:
|
||||||
|
layer_id = c['layer_id']
|
||||||
|
chunk_idx = c['chunk_idx']
|
||||||
|
if layer_id not in chunks_per_layer:
|
||||||
|
chunks_per_layer[layer_id] = set()
|
||||||
|
chunks_per_layer[layer_id].add(chunk_idx)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("Chunks per layer:", {k: sorted(v) for k, v in chunks_per_layer.items()})
|
||||||
|
print()
|
||||||
|
|
||||||
|
# First, verify CPU cache data integrity
|
||||||
|
if verbose:
|
||||||
|
print("\n--- CPU Cache Verification (Layer 0) ---")
|
||||||
|
# Get original k from chunk 0 and chunk 1 captures
|
||||||
|
chunk0_k = None
|
||||||
|
chunk1_k = None
|
||||||
|
chunk2_capture = None
|
||||||
|
for c in captures:
|
||||||
|
if c['layer_id'] == 0:
|
||||||
|
if c['chunk_idx'] == 0:
|
||||||
|
chunk0_k = c['k']
|
||||||
|
elif c['chunk_idx'] == 1:
|
||||||
|
chunk1_k = c['k']
|
||||||
|
elif c['chunk_idx'] == 2:
|
||||||
|
chunk2_capture = c
|
||||||
|
|
||||||
|
if chunk0_k is not None and chunk2_capture is not None and 'cpu_k0' in chunk2_capture:
|
||||||
|
cpu_k0 = chunk2_capture['cpu_k0']
|
||||||
|
diff_k0 = (chunk0_k - cpu_k0).abs().max().item()
|
||||||
|
print(f"Chunk 0 k vs CPU cache block 0: max_diff={diff_k0:.6f}")
|
||||||
|
if diff_k0 > 1e-3:
|
||||||
|
print(f" WARNING: CPU cache block 0 differs from original chunk 0 k!")
|
||||||
|
print(f" Original k[0,0,:5] = {chunk0_k[0,0,:5].tolist()}")
|
||||||
|
print(f" CPU k0[0,0,:5] = {cpu_k0[0,0,:5].tolist()}")
|
||||||
|
|
||||||
|
if chunk1_k is not None and chunk2_capture is not None and 'cpu_k1' in chunk2_capture:
|
||||||
|
cpu_k1 = chunk2_capture['cpu_k1']
|
||||||
|
diff_k1 = (chunk1_k - cpu_k1).abs().max().item()
|
||||||
|
print(f"Chunk 1 k vs CPU cache block 1: max_diff={diff_k1:.6f}")
|
||||||
|
if diff_k1 > 1e-3:
|
||||||
|
print(f" WARNING: CPU cache block 1 differs from original chunk 1 k!")
|
||||||
|
print(f" Original k[0,0,:5] = {chunk1_k[0,0,:5].tolist()}")
|
||||||
|
print(f" CPU k1[0,0,:5] = {cpu_k1[0,0,:5].tolist()}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# Test 1: Verify against merge-based reference (same algorithm)
|
||||||
|
# ================================================================
|
||||||
|
if verbose:
|
||||||
|
print("--- Test 1: Merge-based Reference (verifies merge algorithm) ---")
|
||||||
|
|
||||||
|
all_passed_merge = True
|
||||||
|
results_merge = []
|
||||||
|
first_fail_debug = True
|
||||||
|
|
||||||
|
for c in captures:
|
||||||
|
layer_id = c['layer_id']
|
||||||
|
chunk_idx = c['chunk_idx']
|
||||||
|
|
||||||
|
if chunk_idx == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
|
||||||
|
ref_output = compute_reference(layer_id, chunk_idx, scale, debug=debug_this)
|
||||||
|
if ref_output is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
actual_output = c['output']
|
||||||
|
diff = (actual_output - ref_output).abs()
|
||||||
|
max_diff = diff.max().item()
|
||||||
|
mean_diff = diff.mean().item()
|
||||||
|
|
||||||
|
tol = 1e-2
|
||||||
|
passed = max_diff < tol
|
||||||
|
all_passed_merge = all_passed_merge and passed
|
||||||
|
|
||||||
|
status = "PASS" if passed else "FAIL"
|
||||||
|
results_merge.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
|
||||||
|
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
||||||
|
|
||||||
|
if not passed and first_fail_debug:
|
||||||
|
first_fail_debug = False
|
||||||
|
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
|
||||||
|
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
|
||||||
|
max_idx = diff.argmax()
|
||||||
|
flat_actual = actual_output.flatten()
|
||||||
|
flat_ref = ref_output.flatten()
|
||||||
|
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# Test 2: Verify against standard flash attention (single pass)
|
||||||
|
# ================================================================
|
||||||
|
if verbose:
|
||||||
|
print("--- Test 2: Standard FlashAttn Reference (verifies correctness vs non-chunked) ---")
|
||||||
|
|
||||||
|
all_passed_standard = True
|
||||||
|
results_standard = []
|
||||||
|
first_fail_debug = True
|
||||||
|
|
||||||
|
for c in captures:
|
||||||
|
layer_id = c['layer_id']
|
||||||
|
chunk_idx = c['chunk_idx']
|
||||||
|
|
||||||
|
if chunk_idx == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
|
||||||
|
std_ref_output = compute_standard_reference(layer_id, chunk_idx, scale, debug=debug_this)
|
||||||
|
if std_ref_output is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
actual_output = c['output']
|
||||||
|
diff = (actual_output - std_ref_output).abs()
|
||||||
|
max_diff = diff.max().item()
|
||||||
|
mean_diff = diff.mean().item()
|
||||||
|
|
||||||
|
tol = 1e-2
|
||||||
|
passed = max_diff < tol
|
||||||
|
all_passed_standard = all_passed_standard and passed
|
||||||
|
|
||||||
|
status = "PASS" if passed else "FAIL"
|
||||||
|
results_standard.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
|
||||||
|
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
||||||
|
|
||||||
|
if not passed and first_fail_debug:
|
||||||
|
first_fail_debug = False
|
||||||
|
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
|
||||||
|
print(f" Debug: std_ref_output shape={std_ref_output.shape}, mean={std_ref_output.mean().item():.4f}")
|
||||||
|
max_idx = diff.argmax()
|
||||||
|
flat_actual = actual_output.flatten()
|
||||||
|
flat_ref = std_ref_output.flatten()
|
||||||
|
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
total_merge = len(results_merge)
|
||||||
|
passed_merge = sum(1 for r in results_merge if r[2])
|
||||||
|
total_standard = len(results_standard)
|
||||||
|
passed_standard = sum(1 for r in results_standard if r[2])
|
||||||
|
|
||||||
|
print(f"Merge-based reference: {passed_merge}/{total_merge} tests passed")
|
||||||
|
print(f"Standard FlashAttn ref: {passed_standard}/{total_standard} tests passed")
|
||||||
|
|
||||||
|
all_passed = all_passed_merge and all_passed_standard
|
||||||
|
|
||||||
|
if not all_passed_merge:
|
||||||
|
print("\nFailed merge-based tests:")
|
||||||
|
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_merge:
|
||||||
|
if not passed:
|
||||||
|
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
|
||||||
|
|
||||||
|
if not all_passed_standard:
|
||||||
|
print("\nFailed standard FlashAttn tests:")
|
||||||
|
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_standard:
|
||||||
|
if not passed:
|
||||||
|
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
return all_passed
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
passed = run_test(verbose=True)
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
print("test_chunked_prefill_hook: PASSED")
|
||||||
|
else:
|
||||||
|
print("test_chunked_prefill_hook: FAILED")
|
||||||
|
exit(1)
|
||||||
276
tests/test_flash_attn_kvcache.py
Normal file
276
tests/test_flash_attn_kvcache.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
Test script for flash_attn_with_kvcache based chunked prefill.
|
||||||
|
|
||||||
|
Verifies that chunked prefill produces identical results to full attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size):
|
||||||
|
"""
|
||||||
|
Chunked prefill using flash_attn_with_kvcache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim]
|
||||||
|
k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim]
|
||||||
|
cache_seqlens: [batch] - current cache lengths
|
||||||
|
chunk_size: size of each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: [batch, total_seq_len, heads, head_dim]
|
||||||
|
"""
|
||||||
|
total_len = q_full.shape[1]
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
for start in range(0, total_len, chunk_size):
|
||||||
|
end = min(start + chunk_size, total_len)
|
||||||
|
|
||||||
|
q_chunk = q_full[:, start:end]
|
||||||
|
k_chunk = k_full[:, start:end]
|
||||||
|
v_chunk = v_full[:, start:end]
|
||||||
|
|
||||||
|
out = flash_attn_with_kvcache(
|
||||||
|
q_chunk,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
k=k_chunk,
|
||||||
|
v=v_chunk,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
outputs.append(out)
|
||||||
|
|
||||||
|
cache_seqlens += (end - start)
|
||||||
|
|
||||||
|
return torch.cat(outputs, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def reference_attention(q, k, v):
|
||||||
|
"""Standard flash attention as reference."""
|
||||||
|
return flash_attn_func(q, k, v, causal=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_prefill_correctness():
|
||||||
|
"""Test that chunked prefill matches full attention."""
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
num_heads = 32
|
||||||
|
num_kv_heads = 8 # GQA
|
||||||
|
head_dim = 128
|
||||||
|
max_seq_len = 131072 # 128K
|
||||||
|
|
||||||
|
test_configs = [
|
||||||
|
(1024, 256), # 1K tokens, 256 chunk
|
||||||
|
(2048, 512), # 2K tokens, 512 chunk
|
||||||
|
(4096, 1024), # 4K tokens, 1K chunk
|
||||||
|
(4096, 2048), # 4K tokens, 2K chunk (2 chunks)
|
||||||
|
(8192, 2048), # 8K tokens, 2K chunk (4 chunks)
|
||||||
|
(16384, 4096), # 16K tokens, 4K chunk
|
||||||
|
(32768, 4096), # 32K tokens, 4K chunk
|
||||||
|
(65536, 8192), # 64K tokens, 8K chunk
|
||||||
|
(131072, 8192), # 128K tokens, 8K chunk (16 chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
for seq_len, chunk_size in test_configs:
|
||||||
|
print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...")
|
||||||
|
|
||||||
|
# Generate random input
|
||||||
|
torch.manual_seed(42)
|
||||||
|
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
|
||||||
|
# Expand K/V for non-GQA reference
|
||||||
|
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
||||||
|
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
||||||
|
|
||||||
|
# Reference: full attention
|
||||||
|
ref_out = reference_attention(q, k_expanded, v_expanded)
|
||||||
|
|
||||||
|
# Chunked prefill with KV cache
|
||||||
|
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
|
chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
max_diff = (ref_out - chunked_out).abs().max().item()
|
||||||
|
mean_diff = (ref_out - chunked_out).abs().mean().item()
|
||||||
|
|
||||||
|
# Verify cache was filled correctly
|
||||||
|
assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}"
|
||||||
|
|
||||||
|
# Check K/V cache content
|
||||||
|
k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item()
|
||||||
|
v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item()
|
||||||
|
|
||||||
|
print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}")
|
||||||
|
print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}")
|
||||||
|
|
||||||
|
# Tolerance for fp16
|
||||||
|
tolerance = 1e-2
|
||||||
|
if max_diff < tolerance:
|
||||||
|
print(f" PASSED")
|
||||||
|
else:
|
||||||
|
print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_incremental_decode():
|
||||||
|
"""Test that decode after chunked prefill works correctly."""
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
num_heads = 32
|
||||||
|
num_kv_heads = 8
|
||||||
|
head_dim = 128
|
||||||
|
max_seq_len = 8192
|
||||||
|
|
||||||
|
prefill_len = 2048
|
||||||
|
chunk_size = 512
|
||||||
|
num_decode_steps = 10
|
||||||
|
|
||||||
|
print(f"\nTesting incremental decode after chunked prefill...")
|
||||||
|
print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}")
|
||||||
|
print(f" Decode: {num_decode_steps} steps")
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Prefill phase
|
||||||
|
q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
|
||||||
|
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
|
# Run chunked prefill
|
||||||
|
prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill,
|
||||||
|
k_cache, v_cache, cache_seqlens, chunk_size)
|
||||||
|
|
||||||
|
print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}")
|
||||||
|
|
||||||
|
# Decode phase - one token at a time
|
||||||
|
for step in range(num_decode_steps):
|
||||||
|
q_decode = torch.randn(batch_size, 1, num_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
|
||||||
|
decode_out = flash_attn_with_kvcache(
|
||||||
|
q_decode,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
k=k_decode,
|
||||||
|
v=v_decode,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_seqlens += 1
|
||||||
|
|
||||||
|
assert decode_out.shape == (batch_size, 1, num_heads, head_dim)
|
||||||
|
|
||||||
|
expected_len = prefill_len + num_decode_steps
|
||||||
|
actual_len = cache_seqlens[0].item()
|
||||||
|
|
||||||
|
print(f" After decode: cache_seqlens={actual_len}")
|
||||||
|
|
||||||
|
if actual_len == expected_len:
|
||||||
|
print(f" PASSED")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f" FAILED: expected {expected_len}, got {actual_len}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_processing():
|
||||||
|
"""Test chunked prefill with batch > 1."""
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
num_heads = 32
|
||||||
|
num_kv_heads = 8
|
||||||
|
head_dim = 128
|
||||||
|
max_seq_len = 4096
|
||||||
|
seq_len = 2048
|
||||||
|
chunk_size = 512
|
||||||
|
|
||||||
|
print(f"\nTesting batch processing (batch_size={batch_size})...")
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
|
||||||
|
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||||
|
dtype=torch.float16, device='cuda')
|
||||||
|
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
|
out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
|
||||||
|
|
||||||
|
# Verify all batches have correct cache length
|
||||||
|
assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}"
|
||||||
|
assert out.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||||
|
|
||||||
|
# Compare with reference for each batch item
|
||||||
|
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
||||||
|
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
||||||
|
ref_out = reference_attention(q, k_expanded, v_expanded)
|
||||||
|
|
||||||
|
max_diff = (ref_out - out).abs().max().item()
|
||||||
|
|
||||||
|
print(f" Output shape: {out.shape}")
|
||||||
|
print(f" Max diff vs reference: {max_diff:.6f}")
|
||||||
|
|
||||||
|
if max_diff < 1e-2:
|
||||||
|
print(f" PASSED")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f" FAILED")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test Script
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing flash_attn_with_kvcache chunked prefill")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
|
||||||
|
all_passed &= test_chunked_prefill_correctness()
|
||||||
|
all_passed &= test_incremental_decode()
|
||||||
|
all_passed &= test_batch_processing()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
if all_passed:
|
||||||
|
print("test_flash_attn_kvcache: ALL TESTS PASSED")
|
||||||
|
else:
|
||||||
|
print("test_flash_attn_kvcache: SOME TESTS FAILED")
|
||||||
|
print("=" * 60)
|
||||||
322
tests/test_needle.py
Normal file
322
tests/test_needle.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""
|
||||||
|
Needle-in-a-haystack test for LLM.
|
||||||
|
|
||||||
|
Tests: Long context retrieval capability with configurable sequence length.
|
||||||
|
|
||||||
|
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
|
||||||
|
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Needle Test Generator
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def generate_needle_prompt(
|
||||||
|
tokenizer,
|
||||||
|
target_length: int,
|
||||||
|
needle_position: float = 0.5,
|
||||||
|
needle_value: str = "7492",
|
||||||
|
use_chat_template: bool = True,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: HuggingFace tokenizer for length estimation
|
||||||
|
target_length: Target total sequence length in tokens
|
||||||
|
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||||
|
needle_value: The secret value to hide in the haystack
|
||||||
|
use_chat_template: Whether to use chat template for instruct models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(prompt, expected_answer): The full prompt and the expected needle value
|
||||||
|
"""
|
||||||
|
# Haystack filler paragraphs (various topics to create realistic context)
|
||||||
|
haystack_paragraphs = [
|
||||||
|
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||||
|
"Many people are enjoying outdoor activities in the park. "
|
||||||
|
"Birds are singing in the trees and children are playing on the swings. ",
|
||||||
|
|
||||||
|
"In the world of technology, new innovations continue to emerge every day. "
|
||||||
|
"Researchers are working on advanced algorithms and computing systems. "
|
||||||
|
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||||
|
|
||||||
|
"The history of human civilization spans thousands of years. "
|
||||||
|
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||||
|
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||||
|
|
||||||
|
"Modern cooking combines traditional techniques with new ingredients. "
|
||||||
|
"Chefs around the world experiment with flavors and presentations. "
|
||||||
|
"Food brings people together and creates memorable experiences. ",
|
||||||
|
|
||||||
|
"The ocean covers more than seventy percent of Earth's surface. "
|
||||||
|
"Marine ecosystems support an incredible diversity of life forms. "
|
||||||
|
"Scientists continue to discover new species in the deep sea. ",
|
||||||
|
|
||||||
|
"Music has been a part of human culture since prehistoric times. "
|
||||||
|
"Different genres evolved across various regions and time periods. "
|
||||||
|
"Today, people can access millions of songs through digital platforms. ",
|
||||||
|
|
||||||
|
"Space exploration has revealed many secrets about our universe. "
|
||||||
|
"Telescopes can observe galaxies billions of light years away. "
|
||||||
|
"Future missions aim to establish human presence on other planets. ",
|
||||||
|
|
||||||
|
"The study of languages reveals patterns in human cognition. "
|
||||||
|
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||||
|
"Language continues to evolve with new words and expressions. ",
|
||||||
|
]
|
||||||
|
|
||||||
|
# The needle sentence
|
||||||
|
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||||
|
|
||||||
|
# Question at the end
|
||||||
|
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||||
|
|
||||||
|
# Estimate tokens for fixed parts
|
||||||
|
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||||
|
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||||
|
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||||
|
# Buffer for chat template, special tokens, etc.
|
||||||
|
overhead_tokens = 100 if use_chat_template else 50
|
||||||
|
|
||||||
|
# Available tokens for haystack
|
||||||
|
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||||
|
if haystack_target_tokens < 100:
|
||||||
|
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||||
|
|
||||||
|
# Build haystack by repeating paragraphs
|
||||||
|
haystack_parts = []
|
||||||
|
current_tokens = 0
|
||||||
|
para_idx = 0
|
||||||
|
|
||||||
|
while current_tokens < haystack_target_tokens:
|
||||||
|
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
|
||||||
|
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||||
|
if current_tokens + para_tokens > haystack_target_tokens:
|
||||||
|
break
|
||||||
|
haystack_parts.append(para)
|
||||||
|
current_tokens += para_tokens
|
||||||
|
para_idx += 1
|
||||||
|
|
||||||
|
# Calculate needle insertion point
|
||||||
|
needle_idx = int(len(haystack_parts) * needle_position)
|
||||||
|
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||||
|
|
||||||
|
# Insert needle
|
||||||
|
haystack_parts.insert(needle_idx, needle)
|
||||||
|
|
||||||
|
# Assemble prompt
|
||||||
|
full_text = "".join(haystack_parts)
|
||||||
|
|
||||||
|
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||||
|
# Use chat template for instruct models
|
||||||
|
# For Qwen3, add /no_think to disable thinking mode
|
||||||
|
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||||
|
]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Raw text format for base models
|
||||||
|
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||||
|
prompt = full_text + question
|
||||||
|
|
||||||
|
# Verify length
|
||||||
|
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||||
|
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||||
|
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||||
|
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||||
|
|
||||||
|
return prompt, needle_value
|
||||||
|
|
||||||
|
|
||||||
|
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||||
|
"""Check if the model output contains the expected needle value."""
|
||||||
|
import re
|
||||||
|
# Clean output - remove special tokens and whitespace
|
||||||
|
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||||
|
output_clean = ' '.join(output_clean.split()).lower()
|
||||||
|
expected_clean = expected.strip().lower()
|
||||||
|
|
||||||
|
# Check if expected value appears in output
|
||||||
|
# Also try to find it as a standalone number
|
||||||
|
if expected_clean in output_clean:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Try to extract numbers and check if expected is among them
|
||||||
|
numbers = re.findall(r'\d+', output_clean)
|
||||||
|
return expected_clean in numbers
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def run_needle_test(
|
||||||
|
model_path: str,
|
||||||
|
max_model_len: int,
|
||||||
|
input_len: int,
|
||||||
|
num_gpu_blocks: int = 4,
|
||||||
|
needle_position: float = 0.5,
|
||||||
|
needle_value: str = "7492",
|
||||||
|
max_new_tokens: int = 32,
|
||||||
|
enable_cpu_offload: bool = False,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Run a needle-in-haystack test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model
|
||||||
|
max_model_len: Maximum model context length
|
||||||
|
input_len: Target input sequence length
|
||||||
|
num_gpu_blocks: Number of GPU blocks for offload
|
||||||
|
needle_position: Where to place needle (0.0-1.0)
|
||||||
|
needle_value: The secret value to find
|
||||||
|
max_new_tokens: Maximum tokens to generate
|
||||||
|
enable_cpu_offload: Enable CPU offload mode
|
||||||
|
verbose: Print detailed output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if test passed, False otherwise
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Needle-in-Haystack Test")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Model: {model_path}")
|
||||||
|
print(f"Max model len: {max_model_len}")
|
||||||
|
print(f"Input length: {input_len}")
|
||||||
|
print(f"Needle position: {needle_position:.0%}")
|
||||||
|
print(f"Needle value: {needle_value}")
|
||||||
|
print(f"CPU offload: {enable_cpu_offload}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
# 1. Initialize LLM
|
||||||
|
llm_kwargs = {
|
||||||
|
"enforce_eager": True,
|
||||||
|
"max_model_len": max_model_len,
|
||||||
|
"max_num_batched_tokens": max_model_len,
|
||||||
|
"enable_cpu_offload": enable_cpu_offload,
|
||||||
|
}
|
||||||
|
if enable_cpu_offload:
|
||||||
|
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||||
|
|
||||||
|
llm = LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
|
# 2. Generate needle prompt
|
||||||
|
prompt, expected = generate_needle_prompt(
|
||||||
|
tokenizer=llm.tokenizer,
|
||||||
|
target_length=input_len,
|
||||||
|
needle_position=needle_position,
|
||||||
|
needle_value=needle_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Generate output
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.6, # Moderate temperature
|
||||||
|
max_tokens=max_new_tokens,
|
||||||
|
)
|
||||||
|
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
|
||||||
|
|
||||||
|
# 4. Check result
|
||||||
|
output_text = outputs[0]["text"]
|
||||||
|
output_token_ids = outputs[0]["token_ids"]
|
||||||
|
passed = check_needle_answer(output_text, expected)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Result")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Expected: {expected}")
|
||||||
|
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
|
||||||
|
print(f"Output: {output_text[:200]}...")
|
||||||
|
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
return passed
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# CLI Entry Point
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", "-m",
|
||||||
|
type=str,
|
||||||
|
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
||||||
|
help="Path to model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-model-len",
|
||||||
|
type=int,
|
||||||
|
default=32 * 1024,
|
||||||
|
help="Maximum model context length"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-len",
|
||||||
|
type=int,
|
||||||
|
default=8 * 1024,
|
||||||
|
help="Target input sequence length"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-gpu-blocks",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of GPU blocks for CPU offload"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--needle-position",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--needle-value",
|
||||||
|
type=str,
|
||||||
|
default="7492",
|
||||||
|
help="The secret value to hide"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-new-tokens",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Maximum tokens to generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-offload",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable CPU offload (has known bug for long sequences)"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
passed = run_needle_test(
|
||||||
|
model_path=args.model,
|
||||||
|
max_model_len=args.max_model_len,
|
||||||
|
input_len=args.input_len,
|
||||||
|
num_gpu_blocks=args.num_gpu_blocks,
|
||||||
|
needle_position=args.needle_position,
|
||||||
|
needle_value=args.needle_value,
|
||||||
|
max_new_tokens=args.max_new_tokens,
|
||||||
|
enable_cpu_offload=args.enable_offload,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
print("test_needle: PASSED")
|
||||||
|
else:
|
||||||
|
print("test_needle: FAILED")
|
||||||
|
exit(1)
|
||||||
573
tests/test_offload_correctness.py
Normal file
573
tests/test_offload_correctness.py
Normal file
@@ -0,0 +1,573 @@
|
|||||||
|
"""
|
||||||
|
Correctness test for chunked attention with CPU offload.
|
||||||
|
|
||||||
|
Validates that the offload pipeline (CPU -> GPU transfer + chunked attention)
|
||||||
|
produces the same result as direct GPU computation.
|
||||||
|
|
||||||
|
Test scenario:
|
||||||
|
1. Generate Q, K, V data
|
||||||
|
2. Reference: Compute full causal attention on GPU
|
||||||
|
3. Offload: Store K, V in CPU cache, load via pipeline, compute chunked attention
|
||||||
|
4. Compare results
|
||||||
|
|
||||||
|
This test is designed to identify bugs in:
|
||||||
|
- CPU <-> GPU data transfer (sgDMA)
|
||||||
|
- Ring buffer slot management
|
||||||
|
- N-way pipeline ordering
|
||||||
|
- Triton merge kernel correctness
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
NUM_LAYERS = 4
|
||||||
|
NUM_HEADS = 8
|
||||||
|
NUM_KV_HEADS = 8
|
||||||
|
HEAD_DIM = 64
|
||||||
|
BLOCK_SIZE = 256 # Smaller for faster testing
|
||||||
|
DTYPE = torch.bfloat16
|
||||||
|
DEVICE = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Reference Implementation (GPU only, no offload)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def compute_reference_causal(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute reference causal attention using flash_attn_func.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q, k, v: [batch, seqlen, nheads, headdim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: [batch, seqlen, nheads, headdim]
|
||||||
|
"""
|
||||||
|
return flash_attn_func(q, k, v, causal=True)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_reference_chunked(
|
||||||
|
q_chunks: list,
|
||||||
|
kv_chunks: list,
|
||||||
|
scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute chunked prefill attention directly on GPU (no offload).
|
||||||
|
|
||||||
|
This is the "gold standard" for chunked attention correctness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q_chunks: List of [batch, chunk_size, nheads, headdim]
|
||||||
|
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
|
||||||
|
scale: Softmax scale
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: [batch, total_seqlen, nheads, headdim]
|
||||||
|
"""
|
||||||
|
out_chunks = []
|
||||||
|
|
||||||
|
for chunk_idx, q_chunk in enumerate(q_chunks):
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
|
||||||
|
# Attend to all previous chunks (no causal mask)
|
||||||
|
for i in range(chunk_idx):
|
||||||
|
k_chunk, v_chunk = kv_chunks[i]
|
||||||
|
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||||
|
q_chunk, k_chunk, v_chunk,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = chunk_o, chunk_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)
|
||||||
|
|
||||||
|
# Attend to current chunk (with causal mask)
|
||||||
|
k_chunk, v_chunk = kv_chunks[chunk_idx]
|
||||||
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
|
q_chunk, k_chunk, v_chunk,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
final_o = current_o
|
||||||
|
else:
|
||||||
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||||
|
|
||||||
|
out_chunks.append(final_o)
|
||||||
|
|
||||||
|
return torch.cat(out_chunks, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Offload Implementation
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def create_manager(num_gpu_slots: int, num_cpu_blocks: int):
|
||||||
|
"""Create HybridKVCacheManager with specified configuration."""
|
||||||
|
manager = HybridKVCacheManager(
|
||||||
|
num_gpu_slots=num_gpu_slots,
|
||||||
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
manager.allocate_cache(
|
||||||
|
num_layers=NUM_LAYERS,
|
||||||
|
num_kv_heads=NUM_KV_HEADS,
|
||||||
|
head_dim=HEAD_DIM,
|
||||||
|
dtype=DTYPE,
|
||||||
|
)
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
def store_kv_to_cpu_cache(manager, kv_chunks: list, layer_id: int):
|
||||||
|
"""
|
||||||
|
Store K, V chunks to CPU cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manager: HybridKVCacheManager
|
||||||
|
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
|
||||||
|
layer_id: Layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cpu_block_ids: List of CPU block IDs
|
||||||
|
"""
|
||||||
|
offload_engine = manager.offload_engine
|
||||||
|
cpu_block_ids = []
|
||||||
|
|
||||||
|
for block_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
|
||||||
|
# k_chunk, v_chunk: [batch, chunk_size, nheads, headdim]
|
||||||
|
# CPU cache layout: [num_layers, num_blocks, block_size, nheads, headdim]
|
||||||
|
k_data = k_chunk.squeeze(0) # [chunk_size, nheads, headdim]
|
||||||
|
v_data = v_chunk.squeeze(0)
|
||||||
|
|
||||||
|
offload_engine.k_cache_cpu[layer_id, block_idx, :k_data.shape[0]].copy_(k_data)
|
||||||
|
offload_engine.v_cache_cpu[layer_id, block_idx, :v_data.shape[0]].copy_(v_data)
|
||||||
|
|
||||||
|
cpu_block_ids.append(block_idx)
|
||||||
|
|
||||||
|
return cpu_block_ids
|
||||||
|
|
||||||
|
|
||||||
|
def compute_offload_chunked_single_layer(
|
||||||
|
manager,
|
||||||
|
q_chunks: list,
|
||||||
|
cpu_block_ids: list,
|
||||||
|
layer_id: int,
|
||||||
|
scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute chunked attention for a single layer using offload pipeline.
|
||||||
|
|
||||||
|
This mimics the behavior of Attention._ring_buffer_pipeline_load().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manager: HybridKVCacheManager
|
||||||
|
q_chunks: List of [batch, chunk_size, nheads, headdim]
|
||||||
|
cpu_block_ids: List of CPU block IDs containing K, V data
|
||||||
|
layer_id: Layer index
|
||||||
|
scale: Softmax scale
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: [batch, total_seqlen, nheads, headdim]
|
||||||
|
"""
|
||||||
|
offload_engine = manager.offload_engine
|
||||||
|
out_chunks = []
|
||||||
|
|
||||||
|
for chunk_idx, q_chunk in enumerate(q_chunks):
|
||||||
|
# CPU blocks to load: all blocks before current chunk
|
||||||
|
blocks_to_load = cpu_block_ids[:chunk_idx]
|
||||||
|
|
||||||
|
# Get slots for this chunk
|
||||||
|
write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
|
||||||
|
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
||||||
|
|
||||||
|
# Load and compute attention for previous chunks
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
|
||||||
|
if len(blocks_to_load) > 0 and len(load_slots) > 0:
|
||||||
|
o_acc, lse_acc = _pipeline_load_and_compute(
|
||||||
|
offload_engine,
|
||||||
|
q_chunk,
|
||||||
|
blocks_to_load,
|
||||||
|
load_slots,
|
||||||
|
layer_id,
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Current chunk's K, V (load from CPU to GPU slot)
|
||||||
|
current_cpu_block = cpu_block_ids[chunk_idx]
|
||||||
|
offload_engine.load_to_slot_layer(write_slot, layer_id, current_cpu_block)
|
||||||
|
offload_engine.wait_slot_layer(write_slot, layer_id)
|
||||||
|
|
||||||
|
current_k, current_v = offload_engine.get_kv_for_slot(write_slot, layer_id)
|
||||||
|
|
||||||
|
# Compute attention with causal mask
|
||||||
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
|
q_chunk, current_k, current_v,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge
|
||||||
|
if o_acc is None:
|
||||||
|
final_o = current_o
|
||||||
|
else:
|
||||||
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||||
|
|
||||||
|
out_chunks.append(final_o)
|
||||||
|
|
||||||
|
return torch.cat(out_chunks, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _pipeline_load_and_compute(
|
||||||
|
offload_engine,
|
||||||
|
q_chunk: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
load_slots: list,
|
||||||
|
layer_id: int,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Pipeline loading from CPU and computing attention.
|
||||||
|
|
||||||
|
Mirrors Attention._ring_buffer_pipeline_load() logic.
|
||||||
|
"""
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
num_slots = len(load_slots)
|
||||||
|
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
|
||||||
|
# Phase 1: Pre-load up to num_slots blocks
|
||||||
|
num_preload = min(num_slots, num_blocks)
|
||||||
|
for i in range(num_preload):
|
||||||
|
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||||
|
|
||||||
|
# Phase 2: Main loop
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
current_slot = load_slots[block_idx % num_slots]
|
||||||
|
|
||||||
|
# Wait for transfer
|
||||||
|
offload_engine.wait_slot_layer(current_slot, layer_id)
|
||||||
|
|
||||||
|
# Compute on dedicated stream
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
q_chunk, prev_k, prev_v,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
offload_engine.record_slot_compute_done(current_slot, layer_id)
|
||||||
|
|
||||||
|
# Start next transfer
|
||||||
|
next_block_idx = block_idx + num_slots
|
||||||
|
if next_block_idx < num_blocks:
|
||||||
|
offload_engine.load_to_slot_layer(
|
||||||
|
current_slot, layer_id, cpu_block_table[next_block_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
|
# Sync compute stream
|
||||||
|
compute_stream.synchronize()
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Test Runner
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def run_correctness_test(
|
||||||
|
num_chunks: int,
|
||||||
|
num_gpu_slots: int,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> tuple[bool, float, float]:
|
||||||
|
"""
|
||||||
|
Run a single correctness test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_chunks: Number of chunks (= number of CPU blocks)
|
||||||
|
num_gpu_slots: Number of GPU ring buffer slots
|
||||||
|
verbose: Print detailed info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(passed, max_diff, mean_diff)
|
||||||
|
"""
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
seqlen = num_chunks * BLOCK_SIZE
|
||||||
|
scale = HEAD_DIM ** -0.5
|
||||||
|
|
||||||
|
# Generate Q, K, V
|
||||||
|
q_full = torch.randn(1, seqlen, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
||||||
|
k_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
||||||
|
v_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
# Split into chunks
|
||||||
|
q_chunks = [q_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE] for i in range(num_chunks)]
|
||||||
|
kv_chunks = [
|
||||||
|
(k_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE],
|
||||||
|
v_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE])
|
||||||
|
for i in range(num_chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Reference: chunked attention on GPU (no offload)
|
||||||
|
out_ref = compute_reference_chunked(q_chunks, kv_chunks, scale)
|
||||||
|
|
||||||
|
# Create manager with enough CPU blocks
|
||||||
|
manager = create_manager(num_gpu_slots, num_chunks)
|
||||||
|
|
||||||
|
# Test each layer
|
||||||
|
all_passed = True
|
||||||
|
max_diff_all = 0.0
|
||||||
|
mean_diff_all = 0.0
|
||||||
|
|
||||||
|
for layer_id in range(NUM_LAYERS):
|
||||||
|
# Store K, V to CPU cache
|
||||||
|
cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id)
|
||||||
|
|
||||||
|
# Compute with offload
|
||||||
|
out_offload = compute_offload_chunked_single_layer(
|
||||||
|
manager, q_chunks, cpu_block_ids, layer_id, scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
diff = (out_ref - out_offload).abs()
|
||||||
|
max_diff = diff.max().item()
|
||||||
|
mean_diff = diff.mean().item()
|
||||||
|
|
||||||
|
max_diff_all = max(max_diff_all, max_diff)
|
||||||
|
mean_diff_all = max(mean_diff_all, mean_diff)
|
||||||
|
|
||||||
|
tol = 1e-2
|
||||||
|
passed = max_diff < tol
|
||||||
|
all_passed = all_passed and passed
|
||||||
|
|
||||||
|
if verbose and not passed:
|
||||||
|
print(f" Layer {layer_id}: FAIL max_diff={max_diff:.6f}")
|
||||||
|
|
||||||
|
return all_passed, max_diff_all, mean_diff_all
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Decode Phase Test
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def run_decode_correctness_test(
|
||||||
|
num_prefill_chunks: int,
|
||||||
|
num_gpu_slots: int,
|
||||||
|
num_decode_steps: int = 4,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> tuple[bool, float, float]:
|
||||||
|
"""
|
||||||
|
Test decode phase correctness with CPU offload.
|
||||||
|
|
||||||
|
Simulates:
|
||||||
|
1. Prefill: Store K, V for multiple chunks in CPU cache
|
||||||
|
2. Decode: Single token queries against all prefilled K, V
|
||||||
|
|
||||||
|
This tests the scenario in needle test where decode reads all previous KV.
|
||||||
|
"""
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
scale = HEAD_DIM ** -0.5
|
||||||
|
prefill_len = num_prefill_chunks * BLOCK_SIZE
|
||||||
|
|
||||||
|
# Generate prefill K, V (store in CPU)
|
||||||
|
k_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
||||||
|
v_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
# Split into chunks for CPU storage
|
||||||
|
kv_chunks = [
|
||||||
|
(k_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE],
|
||||||
|
v_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE])
|
||||||
|
for i in range(num_prefill_chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create manager
|
||||||
|
manager = create_manager(num_gpu_slots, num_prefill_chunks)
|
||||||
|
offload_engine = manager.offload_engine
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
max_diff_all = 0.0
|
||||||
|
mean_diff_all = 0.0
|
||||||
|
|
||||||
|
for layer_id in range(NUM_LAYERS):
|
||||||
|
# Store prefilled K, V to CPU cache
|
||||||
|
cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id)
|
||||||
|
|
||||||
|
for decode_step in range(num_decode_steps):
|
||||||
|
# Decode query: single token
|
||||||
|
q_decode = torch.randn(1, 1, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
# Reference: direct attention on GPU
|
||||||
|
# Concat all prefilled K, V and compute attention
|
||||||
|
out_ref = flash_attn_func(
|
||||||
|
q_decode,
|
||||||
|
k_prefill,
|
||||||
|
v_prefill,
|
||||||
|
causal=False, # Decode query can attend to all prefilled tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Offload: load from CPU and compute
|
||||||
|
load_slots = offload_engine.get_load_slots_for_prefill(0) # Use all slots except decode slot
|
||||||
|
|
||||||
|
if len(load_slots) == 0 or len(cpu_block_ids) == 0:
|
||||||
|
# No previous chunks to load
|
||||||
|
out_offload = out_ref # Trivially equal
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = _pipeline_load_and_compute(
|
||||||
|
offload_engine,
|
||||||
|
q_decode,
|
||||||
|
cpu_block_ids,
|
||||||
|
load_slots,
|
||||||
|
layer_id,
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
out_offload = o_acc
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
diff = (out_ref - out_offload).abs()
|
||||||
|
max_diff = diff.max().item()
|
||||||
|
mean_diff = diff.mean().item()
|
||||||
|
|
||||||
|
max_diff_all = max(max_diff_all, max_diff)
|
||||||
|
mean_diff_all = max(mean_diff_all, mean_diff)
|
||||||
|
|
||||||
|
tol = 1e-2
|
||||||
|
passed = max_diff < tol
|
||||||
|
all_passed = all_passed and passed
|
||||||
|
|
||||||
|
if verbose and not passed:
|
||||||
|
print(f" Layer {layer_id} Step {decode_step}: FAIL max_diff={max_diff:.6f}")
|
||||||
|
|
||||||
|
return all_passed, max_diff_all, mean_diff_all
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test Script
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 70)
|
||||||
|
print("Test: Offload Chunked Attention Correctness")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"Config: layers={NUM_LAYERS}, heads={NUM_HEADS}, kv_heads={NUM_KV_HEADS}, "
|
||||||
|
f"head_dim={HEAD_DIM}, block_size={BLOCK_SIZE}, dtype={DTYPE}")
|
||||||
|
print()
|
||||||
|
print("Comparing: Reference (GPU chunked) vs Offload (CPU->GPU pipeline)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test configurations: (num_chunks, num_gpu_slots)
|
||||||
|
TEST_CASES = [
|
||||||
|
# Basic tests
|
||||||
|
(2, 2), # Minimal: 2 chunks, 2 slots (no pipeline)
|
||||||
|
(2, 3), # 2 chunks, 3 slots (1-slot pipeline)
|
||||||
|
(4, 2), # 4 chunks, 2 slots (heavy slot reuse)
|
||||||
|
(4, 3), # 4 chunks, 3 slots
|
||||||
|
(4, 4), # 4 chunks, 4 slots
|
||||||
|
# Stress tests
|
||||||
|
(8, 3), # Many chunks, few slots
|
||||||
|
(8, 4), # Many chunks, moderate slots
|
||||||
|
(8, 6), # Many chunks, many slots (like bench_offload)
|
||||||
|
# Edge cases
|
||||||
|
(1, 2), # Single chunk
|
||||||
|
(3, 5), # Fewer chunks than slots
|
||||||
|
]
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for num_chunks, num_gpu_slots in TEST_CASES:
|
||||||
|
seqlen = num_chunks * BLOCK_SIZE
|
||||||
|
passed, max_diff, mean_diff = run_correctness_test(
|
||||||
|
num_chunks, num_gpu_slots, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
all_passed = all_passed and passed
|
||||||
|
status = "PASS" if passed else "FAIL"
|
||||||
|
|
||||||
|
results.append((num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff))
|
||||||
|
|
||||||
|
print(f"[{status}] chunks={num_chunks:2d} slots={num_gpu_slots:2d} "
|
||||||
|
f"seqlen={seqlen:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# Part 2: Decode Phase Tests
|
||||||
|
# ================================================================
|
||||||
|
print("=" * 70)
|
||||||
|
print("Part 2: Decode Phase Correctness")
|
||||||
|
print("=" * 70)
|
||||||
|
print("Testing: Decode query (single token) against all prefilled K, V")
|
||||||
|
print()
|
||||||
|
|
||||||
|
DECODE_TEST_CASES = [
|
||||||
|
# (num_prefill_chunks, num_gpu_slots)
|
||||||
|
(2, 2),
|
||||||
|
(4, 3),
|
||||||
|
(4, 4),
|
||||||
|
(8, 4),
|
||||||
|
(8, 6),
|
||||||
|
]
|
||||||
|
|
||||||
|
decode_results = []
|
||||||
|
|
||||||
|
for num_prefill_chunks, num_gpu_slots in DECODE_TEST_CASES:
|
||||||
|
prefill_len = num_prefill_chunks * BLOCK_SIZE
|
||||||
|
passed, max_diff, mean_diff = run_decode_correctness_test(
|
||||||
|
num_prefill_chunks, num_gpu_slots, num_decode_steps=4, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
all_passed = all_passed and passed
|
||||||
|
status = "PASS" if passed else "FAIL"
|
||||||
|
|
||||||
|
decode_results.append((num_prefill_chunks, num_gpu_slots, prefill_len, passed, max_diff, mean_diff))
|
||||||
|
|
||||||
|
print(f"[{status}] prefill_chunks={num_prefill_chunks:2d} slots={num_gpu_slots:2d} "
|
||||||
|
f"prefill_len={prefill_len:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
prefill_passed = sum(1 for r in results if r[3])
|
||||||
|
decode_passed = sum(1 for r in decode_results if r[3])
|
||||||
|
total_tests = len(results) + len(decode_results)
|
||||||
|
total_passed = prefill_passed + decode_passed
|
||||||
|
|
||||||
|
print(f"Results: {total_passed}/{total_tests} tests passed")
|
||||||
|
print(f" - Prefill: {prefill_passed}/{len(results)}")
|
||||||
|
print(f" - Decode: {decode_passed}/{len(decode_results)}")
|
||||||
|
|
||||||
|
if not all_passed:
|
||||||
|
print("\nFailed tests:")
|
||||||
|
for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in results:
|
||||||
|
if not passed:
|
||||||
|
print(f" - [Prefill] chunks={num_chunks}, slots={num_gpu_slots}, "
|
||||||
|
f"seqlen={seqlen}, max_diff={max_diff:.6f}")
|
||||||
|
for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in decode_results:
|
||||||
|
if not passed:
|
||||||
|
print(f" - [Decode] prefill_chunks={num_chunks}, slots={num_gpu_slots}, "
|
||||||
|
f"prefill_len={seqlen}, max_diff={max_diff:.6f}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
assert all_passed, "Some correctness tests failed!"
|
||||||
|
print("test_offload_correctness: PASSED")
|
||||||
Reference in New Issue
Block a user