Compare commits

..

10 Commits

Author SHA1 Message Date
Zijie Tian
ff8b09cd35 [test] Added test_needle_ref.py. 2026-01-02 22:03:23 +08:00
Zijie Tian
74ee6d0895 [WIP] need to fix model to normally decode. 2026-01-01 05:18:27 +08:00
Zijie Tian
62b8a63314 [refactor] Refactor the test_chunked_prefill/decode. 2026-01-01 03:32:26 +08:00
Zijie Tian
965c8aff12 [WIP] need change flashattention to debug. 2026-01-01 00:58:22 +08:00
Zijie Tian
30462fe89a [WIP] Before fix needle. 2025-12-31 23:35:25 +08:00
Zijie Tian
ccd1b3d4ab [WIP] Before modify nanovllm CPU-GPU kvcache. 2025-12-31 22:41:07 +08:00
Zijie Tian
31e90a7268 [test] Added offload correct verify. 2025-12-31 20:59:53 +08:00
Zijie Tian
484d0de9f9 [feat] Added debug hook to offload_engine.py. 2025-12-31 19:44:39 +08:00
Zijie Tian
7af721c12c [WIP] Before modify to FlashInfer. 2025-12-30 01:11:13 +08:00
Zijie Tian
89f8020d38 [WIP] fixing attention compute error. 2025-12-30 00:31:48 +08:00
20 changed files with 2986 additions and 896 deletions

1
.gitignore vendored
View File

@@ -195,3 +195,4 @@ cython_debug/
.cursorindexingignore
results/
outputs/

View File

@@ -1,6 +1,7 @@
import os
from dataclasses import dataclass
from transformers import AutoConfig
import torch
@dataclass
@@ -16,6 +17,7 @@ class Config:
eos: int = -1
kvcache_block_size: int = 4096
num_kvcache_blocks: int = -1
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
# CPU Offload configuration
enable_cpu_offload: bool = False
@@ -41,3 +43,17 @@ class Config:
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
assert self.max_num_batched_tokens >= self.max_model_len
# Override torch_dtype if user specified
if self.dtype is not None:
dtype_map = {
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
}
if self.dtype not in dtype_map:
raise ValueError(f"Invalid dtype: {self.dtype}. Choose from: {list(dtype_map.keys())}")
self.hf_config.torch_dtype = dtype_map[self.dtype]

View File

@@ -31,6 +31,8 @@ class LLMEngine:
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
# Set Sequence.block_size to match the KV cache block size
Sequence.block_size = config.kvcache_block_size
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
atexit.register(self.exit)

View File

@@ -489,24 +489,15 @@ class ModelRunner:
logical_id = seq.block_table[block_idx]
self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk's ring buffer slot to CPU (async)
# NOTE: Per-layer offloading is now done in attention.forward
# Each layer offloads its KV to CPU immediately after computing attention.
# We just need to wait for the last offload to complete before reusing the slot.
if block_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[block_idx]
# Call sparse policy hook before offload (to capture metadata)
sparse_policy = self.kvcache_manager.sparse_policy
if sparse_policy is not None:
num_tokens = chunk_end - chunk_start
for layer_id in range(offload_engine.num_layers):
k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens]
sparse_policy.on_block_offloaded(
cpu_block_id=cpu_block_id,
layer_id=layer_id,
k_cache=k_cache,
num_valid_tokens=num_tokens,
)
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
# TODO: Sparse policy hook needs update for new GPU cache architecture
# The GPU cache no longer has layer dimension, so we can't access
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
# in attention.forward after per-layer offload.
pass
# Wait for offload to complete before next chunk
# (slot will be reused after N chunks)
@@ -521,6 +512,7 @@ class ModelRunner:
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
# Sample from last logits
# For chunked prefill, ParallelLMHead automatically selects last position's logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
@@ -627,7 +619,11 @@ class ModelRunner:
if pos_in_block == self.block_size - 1:
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0:
offload_engine.offload_decode_slot(last_cpu_block)
# TODO: In new GPU cache architecture (no layer dimension),
# decode offload should be done per-layer in attention.forward.
# For now, offload all layers sequentially.
for layer_id in range(offload_engine.num_layers):
offload_engine.offload_decode_slot_layer(layer_id, last_cpu_block)
offload_engine.wait_all_offload_done()
# Reset decode start position for next block
self.kvcache_manager.reset_decode_start_pos(seq)

View File

@@ -281,7 +281,11 @@ def _merge_lse_kernel(
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values."""
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
@@ -289,21 +293,21 @@ def _merge_lse_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@@ -313,7 +317,11 @@ def _merge_output_kernel(
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs."""
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
@@ -322,11 +330,11 @@ def _merge_output_kernel(
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values
lse1 = tl.load(lse1_ptr + lse_idx)
lse2 = tl.load(lse2_ptr + lse_idx)
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
@@ -343,14 +351,14 @@ def _merge_output_kernel(
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)

View File

@@ -69,15 +69,19 @@ class HybridKVCacheManager(KVCacheManager):
Architecture (CPU-primary mode):
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
- GPU buffer: Ring buffer for computation (num_gpu_slots)
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
- GPU buffer: Ring buffer for computation only (num_gpu_slots)
- Logical blocks: What sequences reference (num_cpu_blocks)
Design:
- All KV cache is stored on CPU as primary storage
- GPU is used as a ring buffer for computation only
- GPU is used as a ring buffer for computation only (no persistent data)
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
- During decode: Previous KV is loaded from CPU to GPU for attention
- Ring buffer enables pipelined H2D transfers overlapped with computation
Note:
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
- GPU slots are transient compute buffers, not tracked in logical blocks
"""
def __init__(
@@ -102,20 +106,22 @@ class HybridKVCacheManager(KVCacheManager):
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
# GPU slots are transient compute buffers, not tracked as logical blocks
self.total_blocks = num_cpu_blocks
# Eviction policy
self.policy = policy or LRUPolicy()
# Logical blocks (what sequences reference)
# Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [
LogicalBlock(i) for i in range(self.total_blocks)
]
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
# GPU slot management (slots are fixed, mapping is variable)
# GPU slot management (kept for potential future use, but not used in CPU-primary mode)
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id (unused in CPU-primary mode)
# CPU block management
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
@@ -212,7 +218,9 @@ class HybridKVCacheManager(KVCacheManager):
block.ref_count -= 1
if block.ref_count == 0:
# Free physical block
# Free physical block based on location
# Note: In CPU-primary mode, blocks are always on CPU.
# GPU branch kept for potential future hybrid mode support.
if block.location == BlockLocation.GPU:
self.free_gpu_slots.append(block.gpu_slot)
del self.gpu_slot_to_logical[block.gpu_slot]
@@ -337,10 +345,10 @@ class HybridKVCacheManager(KVCacheManager):
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
logger.debug(
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
f"returned cpu_blocks={cpu_blocks}"
)
# logger.debug(
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
# f"returned cpu_blocks={cpu_blocks}"
# )
return cpu_blocks
# ========== Ring Buffer CPU-primary support ==========

View File

@@ -67,14 +67,19 @@ class OffloadEngine:
self.block_numel = block_size * self.kv_dim
# ========== sgDMA pitch parameters for strided transfers ==========
# CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
# GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dim)
# For CPU-to-GPU transfer (H2D): copy single layer, single block at a time
# For all-layer CPU operations (D2H offload to all layers): use sgDMA
self.dtype_size = dtype.itemsize
# CPU pitch: stride between layers in CPU cache (for all-layer operations)
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
self.gpu_pitch = num_gpu_blocks * self.block_numel * self.dtype_size
self.width = self.block_numel * self.dtype_size
self.height = num_layers
# GPU has no layer dimension, so single block transfer is contiguous
self.gpu_block_bytes = self.block_numel * self.dtype_size
self.height = num_layers # For CPU all-layer operations
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, gpu_pitch={self.gpu_pitch}, "
f"width={self.width}, height={self.height}")
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, "
f"gpu_block_bytes={self.gpu_block_bytes}, height={self.height}")
# ========== Unified Ring Buffer configuration ==========
# Constraint checks
@@ -100,17 +105,37 @@ class OffloadEngine:
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
# Use zeros initialization to avoid uninitialized memory issues
# Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
# NOTE: No num_layers dimension! GPU slots are shared across layers.
# Each layer reuses the same slots (layers execute sequentially).
# This saves 28x GPU memory compared to per-layer allocation.
self.k_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.v_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
# ========== Per-layer decode buffer ==========
# During decode, all layers share decode_slot (no layer dimension in GPU cache).
# This causes accumulated tokens to be overwritten by each layer.
# Solution: Maintain separate per-layer buffers for decode tokens.
# Shape: [num_layers, block_size, kv_heads, head_dim]
# Memory: num_layers * block_size * kv_heads * head_dim * dtype_size
# e.g., 28 * 1024 * 8 * 128 * 2 = 58.7 MB (acceptable)
self.decode_k_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.decode_v_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
@@ -159,40 +184,32 @@ class OffloadEngine:
# Decode offload event
self.decode_offload_done = torch.cuda.Event()
# ========== Per-slot Per-layer events for ring buffer ==========
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
self.ring_slot_ready = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
self.ring_slot_offload_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# ========== Per-slot events for ring buffer ==========
# Since GPU cache has no layer dimension and layers execute sequentially,
# we only need per-slot events (not per-slot per-layer).
# ring_slot_ready[slot_idx] = CUDA Event for H2D completion
# ring_slot_offload_done[slot_idx] = CUDA Event for D2H completion
self.ring_slot_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# Per-slot events for all-layer operations (used in some legacy paths)
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# ========== Per-slot Per-layer compute_done events for async pipeline ==========
# ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion
# This is used to ensure we don't overwrite data before it's been read by attention
self.ring_slot_compute_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# ========== Per-slot compute_done events for async pipeline ==========
# ring_slot_compute_done[slot_idx] = CUDA Event for compute completion
# This ensures we don't overwrite data before it's been read by attention
self.ring_slot_compute_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# Initialize all compute_done events (record them once)
# This prevents undefined behavior on first load_to_slot_layer call
for slot_idx in range(self.num_ring_slots):
for layer_id in range(num_layers):
self.ring_slot_compute_done[slot_idx][layer_id].record()
self.ring_slot_compute_done[slot_idx].record()
torch.cuda.synchronize() # Ensure all events are recorded
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
# ========== Debug hook mode ==========
self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
stream = self.transfer_streams[self._stream_idx]
@@ -200,23 +217,24 @@ class OffloadEngine:
return stream
# ========== CUDA Graph compatible methods ==========
# NOTE: These methods need to be updated for the new GPU cache architecture.
# GPU cache no longer has layer dimension, so gathered copy semantics change.
# For now, these are kept for reference but should not be used without updating.
def gathered_h2d_layer(self, layer_id: int) -> None:
"""
Execute gathered H2D copy for a single layer.
This method is CUDA Graph compatible - can be captured into a graph.
Before calling, update_gather_indices() must be called to set up
which CPU blocks to copy to which GPU slots.
Args:
layer_id: Layer index to transfer
WARNING: This method needs updating for new GPU cache architecture.
GPU cache no longer has layer dimension.
"""
# GPU cache has no layer dimension - use flat indexing
# Source is CPU[layer_id], dest is GPU (shared across layers)
gathered_copy_kv(
k_src=self.k_cache_cpu[layer_id],
v_src=self.v_cache_cpu[layer_id],
k_dst=self.k_cache_gpu[layer_id],
v_dst=self.v_cache_gpu[layer_id],
k_dst=self.k_cache_gpu, # No layer indexing
v_dst=self.v_cache_gpu, # No layer indexing
indices=self.gather_indices_gpu[layer_id],
)
@@ -224,7 +242,8 @@ class OffloadEngine:
"""
Execute gathered H2D copy for all layers.
CUDA Graph compatible - can be captured into a single graph.
WARNING: In new architecture, GPU slots are shared across layers.
This method would overwrite slots multiple times. Not recommended.
"""
for layer_id in range(self.num_layers):
self.gathered_h2d_layer(layer_id)
@@ -293,10 +312,10 @@ class OffloadEngine:
"""
Async prefetch a single block from CPU to GPU.
For use in prefill phase where CUDA graphs are not used.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index
layer_id: Layer index (for CPU cache)
cpu_block_id: Source block in CPU cache
gpu_block_id: Destination slot in GPU cache
@@ -309,13 +328,12 @@ class OffloadEngine:
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
with torch.cuda.stream(stream):
# K cache
self.k_cache_gpu[layer_id, gpu_block_id].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_block_id].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# V cache
self.v_cache_gpu[layer_id, gpu_block_id].copy_(
self.v_cache_gpu[gpu_block_id].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
@@ -352,8 +370,10 @@ class OffloadEngine:
"""
Async offload a block from GPU to CPU.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index
layer_id: Layer index (for CPU cache)
gpu_block_id: Source slot in GPU cache
cpu_block_id: Destination block in CPU cache
@@ -369,14 +389,13 @@ class OffloadEngine:
# Wait for any compute using this block
stream.wait_stream(self.compute_stream)
# K cache
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, gpu_block_id],
self.k_cache_gpu[gpu_block_id],
non_blocking=True
)
# V cache
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, gpu_block_id],
self.v_cache_gpu[gpu_block_id],
non_blocking=True
)
event.record()
@@ -413,11 +432,10 @@ class OffloadEngine:
"""
Load CPU blocks to specific GPU slots for chunked decode.
Uses the main GPU KV cache slots, not a separate temp buffer.
This is the same mechanism as chunked prefill uses.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index
layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
"""
@@ -430,12 +448,12 @@ class OffloadEngine:
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy from pinned CPU memory to GPU KV cache slot
self.k_cache_gpu[layer_id, gpu_slot].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
@@ -452,8 +470,10 @@ class OffloadEngine:
"""
Async version: Load CPU blocks to GPU slots.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index
layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
@@ -470,11 +490,12 @@ class OffloadEngine:
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
self.k_cache_gpu[layer_id, gpu_slot].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
@@ -482,44 +503,8 @@ class OffloadEngine:
return event
def load_cpu_blocks_to_gpu_slots_all_layers(
self,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to GPU slots for ALL layers at once.
More efficient than per-layer loading when we know the mapping upfront.
Args:
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy all layers at once using sgDMA
memcpy_2d_async(
self.k_cache_gpu[:, gpu_slot],
self.k_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=stream
)
memcpy_2d_async(
self.v_cache_gpu[:, gpu_slot],
self.v_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=stream
)
stream.synchronize()
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
# layer dimension. Each GPU slot holds data for ONE layer at a time.
# ========== Synchronization methods ==========
@@ -538,27 +523,33 @@ class OffloadEngine:
def sync_indices(self) -> None:
"""Synchronize to ensure all index updates are complete."""
torch.cuda.current_stream().synchronize()
torch.cuda.default_stream().synchronize()
# ========== Cache access methods ==========
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get GPU K/V cache tensors for a specific layer.
Get GPU K/V cache tensors for attention layer.
NOTE: GPU cache has no layer dimension - all layers share the same slots.
The layer_id parameter is kept for API compatibility but not used.
Returns:
(k_cache, v_cache) tensors for the layer
(k_cache, v_cache) tensors
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id]
# GPU cache is shared across all layers (no layer dimension)
return self.k_cache_gpu, self.v_cache_gpu
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
"""
Get full GPU K/V cache tensors.
NOTE: GPU cache has no layer dimension in the new architecture.
Returns:
(k_cache, v_cache) tensors
Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu, self.v_cache_gpu
@@ -664,7 +655,7 @@ class OffloadEngine:
# ----- Per-slot Per-layer loading methods -----
def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None:
def record_slot_compute_done(self, slot_idx: int) -> None:
"""
Record that computation using this slot's data is done.
@@ -673,21 +664,23 @@ class OffloadEngine:
Args:
slot_idx: GPU slot index that was just used for computation
layer_id: Layer index
"""
self.ring_slot_compute_done[slot_idx][layer_id].record()
self.ring_slot_compute_done[slot_idx].record()
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Before starting the transfer, waits for any previous compute on this slot
to complete (using compute_done event).
GPU cache has no layer dimension - slots are shared across all layers.
CPU cache still has layer dimension for persistent storage.
Before starting the transfer, waits for:
1. Any previous compute on this slot to complete
Args:
slot_idx: Target GPU slot index
layer_id: Layer index to load
layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
@@ -699,140 +692,105 @@ class OffloadEngine:
with torch.cuda.stream(stream):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
stream.wait_event(self.ring_slot_compute_done[slot_idx])
self.k_cache_gpu[layer_id, slot_idx].copy_(
# Also wait for any pending offload of this slot to complete
# This prevents race: load must not write GPU slot while offload is reading from it
stream.wait_event(self.ring_slot_offload_done[slot_idx])
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[layer_id, slot_idx].copy_(
self.v_cache_gpu[slot_idx].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx][layer_id].record(stream)
self.ring_slot_ready[slot_idx].record(stream)
torch.cuda.nvtx.range_pop()
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
def wait_slot_layer(self, slot_idx: int) -> None:
"""
Wait for a slot's loading to complete for a specific layer.
Wait for a slot's loading to complete.
Args:
slot_idx: GPU slot index to wait for
layer_id: Layer index to wait for
"""
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async load a CPU block to a ring buffer slot for ALL layers.
Args:
slot_idx: Target GPU slot index
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
memcpy_2d_async(
self.k_cache_gpu[:, slot_idx],
self.k_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_gpu[:, slot_idx],
self.v_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
def wait_slot_all_layers(self, slot_idx: int) -> None:
"""Wait for a slot's loading to complete for ALL layers."""
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
# NOTE: load_to_slot_all_layers removed - GPU cache no longer has layer dimension.
# Each GPU slot holds data for ONE layer at a time. Layers execute sequentially,
# reusing the same GPU slots.
# ----- Slot offload methods -----
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU (all layers).
Args:
slot_idx: Source GPU slot index
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, slot_idx],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_cpu[:, cpu_block_id],
self.v_cache_gpu[:, slot_idx],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
# NOTE: offload_slot_to_cpu (all-layers) removed - GPU cache no longer has layer dimension.
# Use offload_slot_layer_to_cpu for per-layer offloading.
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
GPU cache has no layer dimension, so we copy from GPU slot to the
specific layer in CPU cache.
Args:
slot_idx: Source GPU slot index
layer_id: Layer index to offload
layer_id: Target layer in CPU cache
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations
# - default_stream: for store_kvcache which runs on default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
self.k_cache_gpu[slot_idx], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
self.v_cache_gpu[slot_idx], non_blocking=True
)
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
"""Wait for slot offload to complete for a specific layer."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
# ----- KV access methods for ring buffer -----
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]:
"""
Get KV for a single ring buffer slot.
GPU cache has no layer dimension - slots contain data for whatever
layer was most recently loaded.
Args:
slot_idx: GPU slot index
layer_id: Layer ID
Returns:
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
k = self.k_cache_gpu[slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[slot_idx].unsqueeze(0)
return k, v
def get_kv_for_slots(
self,
layer_id: int,
slot_indices: List[int],
) -> Tuple[Tensor, Tensor]:
"""
Get KV for multiple ring buffer slots.
GPU cache has no layer dimension - returns data from specified slots.
Args:
layer_id: Layer ID
slot_indices: List of GPU slot indices
Returns:
@@ -840,92 +798,86 @@ class OffloadEngine:
"""
if not slot_indices:
return None, None
k = self.k_cache_gpu[layer_id, slot_indices]
v = self.v_cache_gpu[layer_id, slot_indices]
k = self.k_cache_gpu[slot_indices]
v = self.v_cache_gpu[slot_indices]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
# ----- Decode slot methods (kept for decode phase) -----
# NOTE: For decode with CPU offload, the flow is per-layer:
# 1. Each layer stores to decode_slot (same GPU memory, reused)
# 2. Each layer offloads its data to CPU[layer_id, block_id]
# 3. Each layer loads prev blocks from CPU[layer_id] when needed
def offload_decode_slot(self, cpu_block_id: int) -> None:
def offload_decode_slot_layer(self, layer_id: int, cpu_block_id: int) -> None:
"""
Offload KV from decode slot (slot[0]) to CPU.
Offload KV from decode slot (slot[0]) to CPU for one layer.
Args:
layer_id: Layer ID
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, self.decode_slot],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_cpu[:, cpu_block_id],
self.v_cache_gpu[:, self.decode_slot],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
self.decode_offload_done.record(self.transfer_stream_main)
# Reuse the existing per-layer offload method
self.offload_slot_layer_to_cpu(self.decode_slot, layer_id, cpu_block_id)
def wait_decode_offload(self) -> None:
"""Wait for decode slot offload to complete."""
self.compute_stream.wait_event(self.decode_offload_done)
self.wait_slot_offload(self.decode_slot)
def get_kv_for_decode_slot(
self,
layer_id: int,
pos_in_block: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV at specified position in decode slot.
GPU cache has no layer dimension - decode slot contains data for
whatever layer was most recently stored.
Args:
layer_id: Layer ID
pos_in_block: Token position within block (0 to block_size-1)
Returns:
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
k = self.k_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
def get_kv_for_decode_slot_accumulated(
self,
layer_id: int,
num_tokens: int,
) -> Tuple[Tensor, Tensor]:
"""
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
GPU cache has no layer dimension - decode slot contains data for
whatever layer was most recently stored.
Args:
layer_id: Layer ID
num_tokens: Number of accumulated tokens (1 to block_size)
Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
k = self.k_cache_gpu[self.decode_slot, :num_tokens]
v = self.v_cache_gpu[self.decode_slot, :num_tokens]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
# ----- Legacy compatibility methods (for decode double-buffering) -----
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses first half of decode_load_slots as 'compute' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
"""
if not cpu_block_ids:
return
@@ -938,26 +890,27 @@ class OffloadEngine:
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None:
def wait_compute_layer(self) -> None:
"""Legacy: Wait for 'compute' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
if self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
self.wait_slot_layer(self.decode_load_slots[0])
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses second half of decode_load_slots as 'prefetch' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
"""
if not cpu_block_ids:
return
@@ -972,37 +925,36 @@ class OffloadEngine:
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None:
def wait_prefetch_layer(self) -> None:
"""Legacy: Wait for 'prefetch' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if slots:
self.wait_slot_layer(slots[0], layer_id)
self.wait_slot_layer(slots[0])
elif self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
self.wait_slot_layer(self.decode_load_slots[0])
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half][:num_blocks]
return self.get_kv_for_slots(layer_id, slots)
return self.get_kv_for_slots(slots)
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
@@ -1011,4 +963,76 @@ class OffloadEngine:
if not slots:
slots = self.decode_load_slots
slots = slots[:num_blocks]
return self.get_kv_for_slots(layer_id, slots)
return self.get_kv_for_slots(slots)
# ========== Debug Hook Interface ==========
#
# Minimal generic hook system for debugging.
# Framework only provides hook registration and tensor access.
# All verification logic is external.
def enable_debug_mode(self) -> None:
"""Enable debug mode."""
self._debug_mode = True
logger.info("OffloadEngine debug mode ENABLED")
def disable_debug_mode(self) -> None:
"""Disable debug mode and clear all hooks."""
self._debug_mode = False
self._debug_hooks.clear()
logger.info("OffloadEngine debug mode DISABLED")
@property
def debug_mode(self) -> bool:
"""Check if debug mode is enabled."""
return self._debug_mode
def register_debug_hook(self, hook_fn) -> None:
"""
Register a debug hook.
The hook is called after H2D load completes (after wait_slot_layer),
receiving the loaded tensor for inspection.
Args:
hook_fn: Callable with signature:
(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None
- k, v: GPU tensor views for the loaded slot
Example:
def my_hook(slot_idx, layer_id, cpu_block_id, k, v):
if layer_id == 0:
k_val = k.float().mean().item()
print(f"Loaded block {cpu_block_id}, K mean = {k_val}")
offload_engine.register_debug_hook(my_hook)
"""
self._debug_hooks.append(hook_fn)
def remove_debug_hook(self, hook_fn) -> None:
"""Remove a registered debug hook."""
if hook_fn in self._debug_hooks:
self._debug_hooks.remove(hook_fn)
def _call_debug_hooks(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Call all registered debug hooks with loaded tensor (internal use).
Called by attention.py after wait_slot_layer completes.
GPU cache has no layer dimension - slot contains data for the layer
that was just loaded.
"""
if not self._debug_mode or not self._debug_hooks:
return
# Use get_kv_for_slot for consistency with attention.py
k, v = self.get_kv_for_slot(slot_idx)
for hook in self._debug_hooks:
try:
hook(slot_idx, layer_id, cpu_block_id, k, v)
except Exception as e:
# Allow pdb quit to propagate
if e.__class__.__name__ == 'BdbQuit':
raise
logger.warning(f"Debug hook error: {e}")

View File

@@ -87,6 +87,15 @@ class Attention(nn.Module):
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
# Store current decode token to per-layer decode buffer
# This is needed because GPU cache has no layer dimension,
# so all layers would overwrite each other in decode_slot.
kvcache_manager = context.kvcache_manager
offload_engine = kvcache_manager.offload_engine
pos_in_block = context.decode_pos_in_block
# k, v shape: [1, kv_heads, head_dim]
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
@@ -169,9 +178,11 @@ class Attention(nn.Module):
else:
# Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine
q_batched, cpu_block_table, load_slots, offload_engine,
current_chunk_idx
)
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
@@ -187,11 +198,30 @@ class Attention(nn.Module):
if o_acc is None:
final_o = current_o
else:
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
# reading it on the default stream for the merge operation.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Per-layer offload: In new GPU cache architecture (no layer dimension),
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -205,24 +235,27 @@ class Attention(nn.Module):
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
compute_stream = offload_engine.compute_stream
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0, self.layer_id)
offload_engine.wait_slot_layer(0)
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
# IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
@@ -232,6 +265,7 @@ class Attention(nn.Module):
cpu_block_table: list,
load_slots: list,
offload_engine,
current_chunk_idx: int = -1,
):
"""
Ring buffer async pipeline loading with double buffering.
@@ -269,22 +303,32 @@ class Attention(nn.Module):
if pipeline_depth == 1:
# Only 1 slot available, cannot pipeline - use synchronous mode
# IMPORTANT: Must use compute_stream to match synchronization in
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
slot = load_slots[0]
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
offload_engine.wait_slot_layer(slot, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot, self.layer_id)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
# N-way pipeline: use ALL available slots for maximum overlap
@@ -306,15 +350,20 @@ class Attention(nn.Module):
# Cycle through slots: slot[block_idx % num_slots]
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete (on compute_stream)
offload_engine.wait_slot_layer(current_slot, self.layer_id)
offload_engine.wait_slot_layer(current_slot)
# Compute attention on current slot's data
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
@@ -323,7 +372,7 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next transfer to safely overwrite this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
offload_engine.record_slot_compute_done(current_slot)
# Immediately start loading the NEXT block into this slot (if more blocks remain)
# Key insight: reuse current_slot immediately after compute is done!
@@ -350,25 +399,17 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with double-buffering using decode_load_slots.
Compute decode attention using ring buffer pipeline (same as prefill).
Decode uses:
- decode_slot (slot[0]): writes new token's KV
- decode_load_slots (slots[1:]): load previous chunks from CPU
Uses the same loading mechanism as _chunked_prefill_attention:
- Load one block at a time from CPU to GPU slot
- Compute attention for each block
- Merge results using online softmax
- Finally merge with decode buffer (accumulated decode tokens)
Pipeline design:
- First half of decode_load_slots: 'compute' buffer
- Second half: 'prefetch' buffer
- Double-buffer between them for async overlap
Timeline:
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
└─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │
└─────────────┘ └─────────────┘ └─────────────┘
This approach is simpler and proven correct (prefill tests pass).
The only difference from prefill is the additional decode buffer
that stores new tokens generated during decode.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -378,12 +419,20 @@ class Attention(nn.Module):
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
# Get only PREFILLED CPU blocks (exclude the current decode block)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last block
# Note: For chunked prefill, each block is exactly block_size tokens
# The cpu_block_table only contains full prefill blocks
block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table)
# All prefill blocks are full (block_size tokens each)
last_block_valid_tokens = block_size
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
@@ -391,7 +440,7 @@ class Attention(nn.Module):
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q_batched, # Decode provides query for query-aware selection
query=q_batched,
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
@@ -401,95 +450,122 @@ class Attention(nn.Module):
)
offload_engine = kvcache_manager.offload_engine
load_slots = offload_engine.decode_load_slots # Available slots for loading
# Chunk size = capacity of each double buffer region (compute/prefetch)
# Each region uses half of decode_load_slots
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
o_acc = None
lse_acc = None
# Double buffering state: True = use Compute region, False = use Prefetch region
use_compute = True
# Pre-load first chunk to Compute region (async)
first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))]
offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Wait for current buffer to be ready
if use_compute:
offload_engine.wait_compute_layer(self.layer_id)
else:
offload_engine.wait_prefetch_layer(self.layer_id)
# Trigger async prefetch of next chunk to the OTHER buffer
# This overlaps transfer with current chunk's computation
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Get KV from current buffer
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Swap buffers for next iteration
use_compute = not use_compute
# Now attend to Decode region (contains accumulated decode tokens)
# Now attend to accumulated decode tokens from per-layer decode buffer
pos_in_block = context.decode_pos_in_block
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
if num_accumulated > 0:
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
# Sync compute_stream with default stream before reading decode_buffer
compute_stream = offload_engine.compute_stream
compute_stream.wait_stream(torch.cuda.default_stream())
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
with torch.cuda.stream(compute_stream):
if num_accumulated > 0:
# Read from per-layer decode buffer
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Sync back to default stream before returning
torch.cuda.default_stream().wait_stream(compute_stream)
return o_acc
def _decode_ring_buffer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
block_size: int,
last_block_valid_tokens: int,
):
"""
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
Loads one block at a time, computes attention, and merges results.
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
methods as prefill for proven correctness.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
if not load_slots:
return None, None
o_acc, lse_acc = None, None
num_slots = len(load_slots)
compute_stream = offload_engine.compute_stream
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
# Phase 2: Process blocks with pipeline
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete
offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream):
# Get KV from slot
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
# Handle partial last block
is_last_block = (block_idx == num_blocks - 1)
if is_last_block and last_block_valid_tokens < block_size:
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
# Compute attention
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done for slot reuse
offload_engine.record_slot_compute_done(current_slot)
# Start loading next block (pipeline)
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
# Merge with accumulated
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc

View File

@@ -93,9 +93,9 @@ TEST_CASES = [
(1, 4, 256, 8, 128),
(1, 4, 512, 8, 128),
(1, 8, 512, 8, 128),
(1, 4, 1024, 8, 128),
(1, 4, 1024, 32, 128), # More heads
(1, 8, 256, 8, 64), # Smaller head dim
(1, 32, 1024, 8, 128),
(1, 32, 1024, 32, 128), # More heads
(1, 32, 256, 8, 64), # Smaller head dim
]
DTYPES = [torch.float16, torch.bfloat16]

View File

@@ -0,0 +1,391 @@
"""
Correctness test for chunked decode attention.
Captures Q and output during inference, then computes reference using
CPU KV cache with standard flash attention.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from flash_attn.flash_attn_interface import flash_attn_func
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
NUM_DECODE_TOKENS = 5
BLOCK_SIZE = 1024
# State
prefill_captures: List[Dict] = []
decode_captures: List[Dict] = []
def make_ones_injection_hook():
"""Inject Q=K=V=1.0 for deterministic testing."""
def hook(module, inputs):
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q, K, V, output during inference."""
def hook(module, inputs, output):
ctx = get_context()
q, k, v = inputs
if ctx.is_prefill:
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
prefill_captures.append({
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
else:
decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
decode_captures.append({
'layer_id': layer_id,
'decode_step': decode_step,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
"""
Compute reference decode output using CPU KV cache and standard flash attention.
For decode, query attends to:
1. All prefill KV (from CPU cache)
2. All previous decode tokens (from captured decode k, v)
"""
# Get decode capture for this layer and step
decode_cap = None
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
decode_cap = c
break
if decode_cap is None:
return None
# Query: single decode token
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Collect all K, V: prefill chunks from captures + decode tokens from captures
# NOTE: We use prefill captures directly instead of CPU cache because
# the CPU block ID may not equal the chunk index.
all_k = []
all_v = []
# 1. Prefill chunks from captures (use captured K/V, not CPU cache)
for cidx in range(num_prefill_chunks):
prefill_cap = None
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
prefill_cap = c
break
if prefill_cap is not None:
# Use captured K/V directly (guaranteed to be correct layer data)
all_k.append(prefill_cap['k'].cuda())
all_v.append(prefill_cap['v'].cuda())
# 2. Decode tokens from captures (up to and including current step)
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
all_k.append(c['k'].cuda())
all_v.append(c['v'].cuda())
break
if not all_k:
return None
# Concatenate all K, V
full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
# Run flash attention (non-causal since we explicitly control what KV to include)
output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
# ============================================================
# Main
# ============================================================
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Calculate number of prefill chunks
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
# Debug: Compare decode_buffer with captured K/V
print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===")
decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu()
for step in range(NUM_DECODE_TOKENS):
for layer_id in [0, 17, 35]: # Sample a few layers
# Find captured K for this step and layer
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
captured_k = c['k'].squeeze(0) # [kv_heads, head_dim]
buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim]
diff = (captured_k - buffer_k).abs().max().item()
print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}")
break
# Debug: Verify that decode_buffer slices match concatenated captures
print("\n=== DEBUG: Verifying decode_buffer slices ===")
for layer_id in [0]:
for decode_step in [1, 2]: # Check steps that use multiple tokens
# Build expected slice from captures
expected_k_list = []
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim]
break
if expected_k_list:
expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim]
buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1]
diff = (expected_k - buffer_slice).abs().max().item()
print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}")
# Print first values
print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}")
if decode_step >= 1:
print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}")
# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading)
print("\n=== DEBUG: Expected K values for block 0, layer 0 ===")
for c in prefill_captures:
if c['layer_id'] == 0 and c['chunk_idx'] == 0:
print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}")
break
print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}")
# Debug: Compare CPU cache with captured prefill K/V
print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===")
for chunk_idx in [0, 7, 15]: # Sample a few chunks
for layer_id in [0, 17, 35]: # Sample a few layers
# Find captured K for this chunk and layer
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx:
captured_k = c['k'] # [seq_len, kv_heads, head_dim]
cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]]
diff = (captured_k - cpu_cache_k).abs().max().item()
print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}")
break
# Debug: Get cpu_block_table to check order
kvcache_manager = llm.model_runner.kvcache_manager
# Find the sequence (it should still exist)
from nanovllm.engine.sequence import Sequence
for attr_name in ['sequences', '_sequences', 'active_sequences']:
if hasattr(kvcache_manager, attr_name):
print(f"Found {attr_name}")
break
# Try to get cpu_block_table through a different way
print(f"\n=== DEBUG: CPU block order ===")
# For each prefill capture, check which CPU block it ended up in
for chunk_idx in range(num_prefill_chunks):
for c in prefill_captures:
if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx:
# Check if this chunk's K matches any CPU block
captured_k_first = c['k'][0, 0, 0].item()
for block_id in range(num_prefill_chunks):
cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item()
if abs(captured_k_first - cpu_k_first) < 1e-6:
print(f"Chunk {chunk_idx} -> CPU block {block_id}")
break
break
# Debug: Check reference vs actual for decode steps 0 and 1
# Also compute partial references (prefill only, decode only) to isolate the bug
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
for decode_step in [0, 1]:
print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===")
layer_id = 0
# Find the capture
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
q = c['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Build prefill K/V per-block for block-by-block reference
prefill_k_blocks = []
prefill_v_blocks = []
for cidx in range(num_prefill_chunks):
for pc in prefill_captures:
if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx:
prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim]
prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0))
break
# Build decode K/V
decode_k_list = []
decode_v_list = []
for step in range(decode_step + 1):
for dc in decode_captures:
if dc['layer_id'] == layer_id and dc['decode_step'] == step:
decode_k_list.append(dc['k'].cuda())
decode_v_list.append(dc['v'].cuda())
break
full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0)
full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0)
full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0)
full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0)
full_k = torch.cat([full_prefill_k, full_decode_k], dim=1)
full_v = torch.cat([full_prefill_v, full_decode_v], dim=1)
print(f"Q shape: {q_batched.shape}")
print(f"Prefill K shape: {full_prefill_k.shape}")
print(f"Decode K shape: {full_decode_k.shape}")
print(f"Full K shape: {full_k.shape}")
print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}")
# Reference output (single attention over all)
ref_output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
# Chunked reference: prefill attention + decode attention + merge
prefill_o, prefill_lse = flash_attn_with_lse(
q_batched, full_prefill_k, full_prefill_v,
softmax_scale=scale,
causal=False,
)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, full_decode_k, full_decode_v,
softmax_scale=scale,
causal=False,
)
chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse)
# Block-by-block reference (simulating ring buffer pipeline)
block_o_acc, block_lse_acc = None, None
for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)):
o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False)
if block_o_acc is None:
block_o_acc, block_lse_acc = o_blk, lse_blk
else:
block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk)
# Compare block-by-block vs single
block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item()
print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}")
# Compare full reference vs chunked reference
ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item()
print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}")
ref_output = ref_output.squeeze(0).squeeze(0).cpu()
chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu()
# Actual output
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff_ref = (actual_output - ref_output).abs()
diff_chunked = (actual_output - chunked_output_cpu).abs()
print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}")
print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}")
break
print()
# Verify decode outputs
all_passed = True
for c in decode_captures:
layer_id = c['layer_id']
decode_step = c['decode_step']
ref_output = compute_decode_reference(
layer_id, decode_step, scale,
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
)
if ref_output is None:
continue
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1
all_passed = all_passed and passed
if not passed:
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")

View File

@@ -0,0 +1,196 @@
"""
Correctness test for chunked prefill attention.
Captures Q and output during inference, then computes reference using
CPU KV cache with standard flash attention.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
BLOCK_SIZE = 1024
# State - capture Q and output for each (layer, chunk)
captures: List[Dict] = []
def make_ones_injection_hook():
"""Inject Q=K=V=1.0 for deterministic testing."""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return inputs
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q and output during prefill."""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
q, k, v = inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
captures.append({
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
def compute_reference(layer_id: int, chunk_idx: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int) -> torch.Tensor:
"""
Compute reference output using CPU KV cache and standard flash attention.
Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
then extracts output for the current chunk.
"""
# Get all captures for this layer up to chunk_idx
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
if not layer_captures:
return None
# Collect Q from captures, K/V from CPU cache
all_q = []
all_k = []
all_v = []
chunk_lengths = []
for c in layer_captures:
cidx = c['chunk_idx']
q = c['q'].cuda() # [seqlen, nheads, headdim]
all_q.append(q)
chunk_lengths.append(q.shape[0])
# Get K, V from CPU cache (already offloaded during prefill)
# CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
all_k.append(k)
all_v.append(v)
# Concatenate
full_q = torch.cat(all_q, dim=0)
full_k = torch.cat(all_k, dim=0)
full_v = torch.cat(all_v, dim=0)
total_len = full_q.shape[0]
# Run standard causal flash attention
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
full_o = flash_attn_varlen_func(
full_q, full_k, full_v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_len,
max_seqlen_k=total_len,
softmax_scale=scale,
causal=True,
)
# Extract output for current chunk
start_pos = sum(chunk_lengths[:-1])
end_pos = sum(chunk_lengths)
return full_o[start_pos:end_pos].cpu()
# ============================================================
# Main
# ============================================================
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Verify: compare actual output with reference computed from CPU cache
all_passed = True
num_chunks = INPUT_LEN // BLOCK_SIZE
for idx,c in enumerate(captures):
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
# Skip chunk 0 (no previous KV to load)
if chunk_idx == 0:
continue
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
if ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1 # float16 tolerance
all_passed = all_passed and passed
if not passed:
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
__import__('pdb').set_trace()
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")

View File

@@ -0,0 +1,137 @@
"""
Test KV cache offload correctness using debug hooks.
Injects distinctive K/V values, verifies loaded tensors match expected patterns.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import inspect
from random import randint, seed
from typing import Dict, List
import torch
from torch import Tensor
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024
BLOCK_SIZE = 1024
# State
load_log: List[Dict] = []
current_chunk: List[int] = [0]
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
"""Record loaded tensor values for layer 0."""
if layer_id != 0:
return
# Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots
frame = inspect.currentframe()
try:
caller_frame = frame.f_back
if caller_frame is not None:
local_vars = caller_frame.f_locals
if 'self' in local_vars:
self_obj = local_vars['self']
if hasattr(self_obj, 'k_cache_gpu'):
num_slots = self_obj.k_cache_gpu.shape[0]
vals = []
for i in range(num_slots):
v = self_obj.k_cache_gpu[i][0,0,0].item()
if i == slot_idx:
vals.append(f"[{v}]")
else:
vals.append(str(v))
print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]")
finally:
del frame
load_log.append({
"chunk_idx": current_chunk[0],
"cpu_block_id": cpu_block_id,
"k_value": k.float().mean().item(),
})
def make_pattern_injection_hook(layer_id):
"""Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0."""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill or layer_id != 0:
return inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
current_chunk[0] = chunk_idx
if len(inputs) >= 3:
q, k, v = inputs[0], inputs[1], inputs[2]
k_new = torch.full_like(k, float(chunk_idx + 1))
v_new = torch.full_like(v, float(-(chunk_idx + 1)))
return (q, k_new, v_new) + inputs[3:]
return inputs
return hook
def verify() -> bool:
"""Verify blocks loaded in correct order with correct K values."""
chunk_loads: Dict[int, List[tuple]] = {}
for log in load_log:
chunk = log["chunk_idx"]
if chunk not in chunk_loads:
chunk_loads[chunk] = []
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
for chunk, loads in chunk_loads.items():
expected_blocks = list(range(chunk))
actual_blocks = [b for b, _ in loads]
k_values = [k for _, k in loads]
expected_k = [float(b + 1) for b in expected_blocks]
if actual_blocks != expected_blocks:
return False
if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)):
return False
return True
# Main
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
offload_engine = llm.model_runner.kvcache_manager.offload_engine
offload_engine.enable_debug_mode()
offload_engine.register_debug_hook(debug_load_hook)
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook(
make_pattern_injection_hook(layer_idx)
))
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False)
for hook in hooks:
hook.remove()
offload_engine.remove_debug_hook(debug_load_hook)
offload_engine.disable_debug_mode()
# Verify
num_chunks = INPUT_LEN // BLOCK_SIZE
expected_loads = num_chunks * (num_chunks - 1) // 2
passed = len(load_log) == expected_loads and verify()
print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")

View File

@@ -0,0 +1,276 @@
"""
Test script for flash_attn_with_kvcache based chunked prefill.
Verifies that chunked prefill produces identical results to full attention.
"""
import torch
from flash_attn import flash_attn_func, flash_attn_with_kvcache
def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size):
"""
Chunked prefill using flash_attn_with_kvcache.
Args:
q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim]
k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim]
cache_seqlens: [batch] - current cache lengths
chunk_size: size of each chunk
Returns:
output: [batch, total_seq_len, heads, head_dim]
"""
total_len = q_full.shape[1]
outputs = []
for start in range(0, total_len, chunk_size):
end = min(start + chunk_size, total_len)
q_chunk = q_full[:, start:end]
k_chunk = k_full[:, start:end]
v_chunk = v_full[:, start:end]
out = flash_attn_with_kvcache(
q_chunk,
k_cache,
v_cache,
k=k_chunk,
v=v_chunk,
cache_seqlens=cache_seqlens,
causal=True,
)
outputs.append(out)
cache_seqlens += (end - start)
return torch.cat(outputs, dim=1)
def reference_attention(q, k, v):
"""Standard flash attention as reference."""
return flash_attn_func(q, k, v, causal=True)
def test_chunked_prefill_correctness():
"""Test that chunked prefill matches full attention."""
batch_size = 1
num_heads = 32
num_kv_heads = 8 # GQA
head_dim = 128
max_seq_len = 131072 # 128K
test_configs = [
(1024, 256), # 1K tokens, 256 chunk
(2048, 512), # 2K tokens, 512 chunk
(4096, 1024), # 4K tokens, 1K chunk
(4096, 2048), # 4K tokens, 2K chunk (2 chunks)
(8192, 2048), # 8K tokens, 2K chunk (4 chunks)
(16384, 4096), # 16K tokens, 4K chunk
(32768, 4096), # 32K tokens, 4K chunk
(65536, 8192), # 64K tokens, 8K chunk
(131072, 8192), # 128K tokens, 8K chunk (16 chunks)
]
for seq_len, chunk_size in test_configs:
print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...")
# Generate random input
torch.manual_seed(42)
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
# Expand K/V for non-GQA reference
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
# Reference: full attention
ref_out = reference_attention(q, k_expanded, v_expanded)
# Chunked prefill with KV cache
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
# Compare
max_diff = (ref_out - chunked_out).abs().max().item()
mean_diff = (ref_out - chunked_out).abs().mean().item()
# Verify cache was filled correctly
assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}"
# Check K/V cache content
k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item()
v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item()
print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}")
print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}")
# Tolerance for fp16
tolerance = 1e-2
if max_diff < tolerance:
print(f" PASSED")
else:
print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})")
return False
return True
def test_incremental_decode():
"""Test that decode after chunked prefill works correctly."""
batch_size = 1
num_heads = 32
num_kv_heads = 8
head_dim = 128
max_seq_len = 8192
prefill_len = 2048
chunk_size = 512
num_decode_steps = 10
print(f"\nTesting incremental decode after chunked prefill...")
print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}")
print(f" Decode: {num_decode_steps} steps")
torch.manual_seed(42)
# Prefill phase
q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
# Run chunked prefill
prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill,
k_cache, v_cache, cache_seqlens, chunk_size)
print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}")
# Decode phase - one token at a time
for step in range(num_decode_steps):
q_decode = torch.randn(batch_size, 1, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
decode_out = flash_attn_with_kvcache(
q_decode,
k_cache,
v_cache,
k=k_decode,
v=v_decode,
cache_seqlens=cache_seqlens,
causal=True,
)
cache_seqlens += 1
assert decode_out.shape == (batch_size, 1, num_heads, head_dim)
expected_len = prefill_len + num_decode_steps
actual_len = cache_seqlens[0].item()
print(f" After decode: cache_seqlens={actual_len}")
if actual_len == expected_len:
print(f" PASSED")
return True
else:
print(f" FAILED: expected {expected_len}, got {actual_len}")
return False
def test_batch_processing():
"""Test chunked prefill with batch > 1."""
batch_size = 4
num_heads = 32
num_kv_heads = 8
head_dim = 128
max_seq_len = 4096
seq_len = 2048
chunk_size = 512
print(f"\nTesting batch processing (batch_size={batch_size})...")
torch.manual_seed(42)
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
# Verify all batches have correct cache length
assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}"
assert out.shape == (batch_size, seq_len, num_heads, head_dim)
# Compare with reference for each batch item
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
ref_out = reference_attention(q, k_expanded, v_expanded)
max_diff = (ref_out - out).abs().max().item()
print(f" Output shape: {out.shape}")
print(f" Max diff vs reference: {max_diff:.6f}")
if max_diff < 1e-2:
print(f" PASSED")
return True
else:
print(f" FAILED")
return False
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Testing flash_attn_with_kvcache chunked prefill")
print("=" * 60)
all_passed = True
all_passed &= test_chunked_prefill_correctness()
all_passed &= test_incremental_decode()
all_passed &= test_batch_processing()
print("\n" + "=" * 60)
if all_passed:
print("test_flash_attn_kvcache: ALL TESTS PASSED")
else:
print("test_flash_attn_kvcache: SOME TESTS FAILED")
print("=" * 60)

View File

@@ -0,0 +1,104 @@
"""
Test FlashInfer chunked attention with CPU offload.
Uses single_prefill_with_kv_cache + merge_state for chunked KV processing.
"""
import torch
import flashinfer
# ============================================================
# Core Functions
# ============================================================
def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size):
"""
Chunked causal attention with KV on CPU.
q: [seq_q, num_heads, head_dim] on GPU
k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU
"""
seq_q = q.shape[0]
seq_kv = k_cpu.shape[0]
final_outputs = []
for q_start in range(0, seq_q, q_chunk_size):
q_end = min(q_start + q_chunk_size, seq_q)
q_chunk = q[q_start:q_end]
merged_output = None
merged_lse = None
for kv_start in range(0, seq_kv, kv_chunk_size):
kv_end = min(kv_start + kv_chunk_size, seq_kv)
if kv_start >= q_end:
continue
k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
causal = not (kv_end <= q_start)
partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache(
q_chunk, k_chunk, v_chunk,
causal=causal,
return_lse=True,
)
if merged_output is None:
merged_output, merged_lse = partial_out, partial_lse
else:
merged_output, merged_lse = flashinfer.merge_state(
merged_output, merged_lse,
partial_out, partial_lse,
)
final_outputs.append(merged_output)
return torch.cat(final_outputs, dim=0)
# ============================================================
# Main Test Script
# ============================================================
print("=" * 60)
print("Testing FlashInfer chunked attention with CPU offload")
print("=" * 60)
num_heads = 32
num_kv_heads = 8
head_dim = 128
test_configs = [
(32768, 8192, 8192), # 32K tokens
(65536, 8192, 8192), # 64K tokens
(131072, 16384, 16384), # 128K tokens
# (262144, 16384, 16384), # 256K tokens (slow)
# (524288, 16384, 16384), # 512K tokens (slow)
]
for seq_len, q_chunk, kv_chunk in test_configs:
torch.manual_seed(42)
q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda')
k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
# Chunked result
chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk)
# Reference
k_gpu = k_cpu.to('cuda')
v_gpu = v_cpu.to('cuda')
ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True)
max_diff = (ref_out - chunked_out).abs().max().item()
mean_diff = (ref_out - chunked_out).abs().mean().item()
num_chunks = (seq_len + q_chunk - 1) // q_chunk
assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}"
print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print("\ntest_flashinfer_merge: PASSED")

322
tests/test_needle.py Normal file
View File

@@ -0,0 +1,322 @@
"""
Needle-in-a-haystack test for LLM.
Tests: Long context retrieval capability with configurable sequence length.
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
import argparse
from nanovllm import LLM, SamplingParams
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question at the end
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
# Also try to find it as a standalone number
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test.
Args:
model_path: Path to model
max_model_len: Maximum model context length
input_len: Target input sequence length
num_gpu_blocks: Number of GPU blocks for offload
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
print(f"{'='*60}\n")
# 1. Initialize LLM
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt
prompt, expected = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Generate output
sampling_params = SamplingParams(
temperature=0.6, # Moderate temperature
max_tokens=max_new_tokens,
)
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
# 4. Check result
output_text = outputs[0]["text"]
output_token_ids = outputs[0]["token_ids"]
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=32 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload (has known bug for long sequences)"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
verbose=True,
)
if passed:
print("test_needle: PASSED")
else:
print("test_needle: FAILED")
exit(1)

318
tests/test_needle_ref.py Normal file
View File

@@ -0,0 +1,318 @@
"""
Needle-in-a-haystack reference test using pure torch + transformers.
This is a reference implementation for comparison with nanovllm.
Uses standard HuggingFace inference (no custom KV cache, no offload).
"""
import os
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
input_len: int,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
dtype: str = "auto",
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test using standard transformers inference.
Args:
model_path: Path to model
input_len: Target input sequence length
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
dtype: Model dtype ("auto", "float16", "bfloat16")
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Reference Test (torch + transformers)")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Input length: {input_len}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"Dtype: {dtype}")
print(f"{'='*60}\n")
# 1. Load tokenizer
print("[1/4] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# 2. Generate needle prompt
print("[2/4] Generating needle prompt...")
prompt, expected = generate_needle_prompt(
tokenizer=tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Load model
print("[3/4] Loading model...")
torch_dtype = {
"auto": "auto",
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}.get(dtype, "auto")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map="auto",
trust_remote_code=True,
)
model.eval()
# 4. Generate output
print("[4/4] Running inference...")
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
print(f" Input shape: {input_ids.shape}")
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=0.6,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
# Decode only the new tokens
new_token_ids = output_ids[0, input_ids.shape[1]:]
output_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
# 5. Check result
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(new_token_ids)}): {new_token_ids[:20].tolist()}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Needle-in-haystack reference test (torch + transformers)"
)
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "float16", "bfloat16"],
help="Model dtype"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
input_len=args.input_len,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
dtype=args.dtype,
verbose=True,
)
if passed:
print("test_needle_ref: PASSED")
else:
print("test_needle_ref: FAILED")
exit(1)

View File

@@ -0,0 +1,695 @@
"""
Test script to verify CPU offload correctness using distinctive KV patterns.
Strategy:
1. Hook into attention forward pass
2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx)
3. After offload to CPU, verify CPU cache contains correct patterns
4. On subsequent chunks, verify loaded KV from CPU has correct patterns
This catches bugs like:
- Wrong block being offloaded
- Wrong block being loaded
- Data corruption during transfer
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import torch
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 64 * 1024
NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks)
BLOCK_SIZE = 1024
# Test state
errors = []
chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern)
block_coverage = {} # chunk_idx -> set of blocks that were actually computed
load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples
current_chunk_for_load = [0] # Mutable container to track current chunk during loads
# ============================================================
# Pattern Helpers
# ============================================================
def get_expected_pattern(chunk_idx: int):
"""Get expected K/V pattern for a chunk."""
# Use float values that are easy to identify
k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ...
v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ...
return k_val, v_val
def fill_with_pattern(tensor: torch.Tensor, value: float):
"""Fill tensor with a constant value."""
tensor.fill_(value)
def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3):
"""Check if tensor contains expected pattern."""
actual_mean = tensor.float().mean().item()
if abs(actual_mean - expected) > tolerance:
return False, f"{name}: expected mean={expected}, got {actual_mean}"
return True, None
# ============================================================
# Load Verification Instrumentation
# ============================================================
_original_load_to_slot_layer = None
_offload_engine_ref = None
def make_verified_load_to_slot_layer(original_func, offload_engine):
"""
Create a wrapper around load_to_slot_layer that verifies each load operation.
After each H2D transfer, checks that the GPU slot contains the expected
pattern from the source CPU block.
"""
def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int):
# Call original load
original_func(slot_idx, layer_id, cpu_block_id)
# Only verify layer 0 to reduce overhead
if layer_id != 0:
return
# IMPORTANT: Synchronize CUDA to ensure async transfer is complete
# The transfer happens on a per-slot stream, and wait_slot_layer only
# makes compute_stream wait. We need full sync to read on default stream.
torch.cuda.synchronize()
# Get the expected pattern for this CPU block
# cpu_block_id == chunk_idx in our sequential test
expected_k, expected_v = get_expected_pattern(cpu_block_id)
# Read GPU slot data (GPU cache has no layer dimension)
gpu_k = offload_engine.k_cache_gpu[slot_idx]
gpu_v = offload_engine.v_cache_gpu[slot_idx]
actual_k = gpu_k.float().mean().item()
actual_v = gpu_v.float().mean().item()
k_ok = abs(actual_k - expected_k) < 1e-3
v_ok = abs(actual_v - expected_v) < 1e-3
chunk_idx = current_chunk_for_load[0]
load_operations.append({
'chunk_idx': chunk_idx,
'slot_idx': slot_idx,
'cpu_block_id': cpu_block_id,
'expected_k': expected_k,
'expected_v': expected_v,
'actual_k': actual_k,
'actual_v': actual_v,
'k_ok': k_ok,
'v_ok': v_ok,
})
if not (k_ok and v_ok):
errors.append(f"Load verification failed: chunk {chunk_idx}, "
f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: "
f"expected K={expected_k:.1f}/V={expected_v:.1f}, "
f"got K={actual_k:.4f}/V={actual_v:.4f}")
return verified_load
def install_load_verification(llm):
"""Install verification wrapper on load_to_slot_layer."""
global _original_load_to_slot_layer, _offload_engine_ref
oe = llm.model_runner.kvcache_manager.offload_engine
_offload_engine_ref = oe
_original_load_to_slot_layer = oe.load_to_slot_layer
oe.load_to_slot_layer = make_verified_load_to_slot_layer(
_original_load_to_slot_layer, oe
)
print("Installed load verification wrapper on load_to_slot_layer")
def uninstall_load_verification():
"""Restore original load_to_slot_layer."""
global _original_load_to_slot_layer, _offload_engine_ref
if _offload_engine_ref is not None and _original_load_to_slot_layer is not None:
_offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer
print("Restored original load_to_slot_layer")
_original_load_to_slot_layer = None
_offload_engine_ref = None
# ============================================================
# Attention Hook
# ============================================================
def make_kv_pattern_pre_hook(layer_id: int):
"""
Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE
they are stored to cache. This is called before attention.forward().
register_forward_pre_hook receives (module, inputs) and can modify inputs in-place.
"""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
q, k, v = inputs
k_pattern, v_pattern = get_expected_pattern(chunk_idx)
# === Overwrite current chunk's K/V with distinctive pattern ===
# This happens BEFORE forward(), so these values will be stored to cache
k.fill_(k_pattern)
v.fill_(v_pattern)
# Only print for first few and last few chunks to reduce noise
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}")
elif chunk_idx == 3:
print(f"... (chunks 3 to {num_chunks - 3} omitted) ...")
return hook
def make_block_coverage_pre_hook(layer_id: int):
"""
Create a PRE-forward hook to verify that all previous blocks are included
in the cpu_block_table for chunked prefill attention.
This catches bugs where:
- Some blocks are missing from the computation
- Sparse policy incorrectly filters out blocks (when not intended)
- Block table construction has off-by-one errors
"""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
# Update current chunk for load verification tracking
current_chunk_for_load[0] = chunk_idx
# No previous blocks for chunk 0
if chunk_idx == 0:
return
# Get the sequence and its block table (same logic as _chunked_prefill_attention)
seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None
if seq is None:
return
# Get the CPU block table that will be used for attention
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Expected blocks: 0 to chunk_idx-1 (all previous chunks)
expected_blocks = set(range(chunk_idx))
actual_blocks = set(cpu_block_table) if cpu_block_table else set()
# Store for later summary
block_coverage[chunk_idx] = {
'expected': expected_blocks,
'actual': actual_blocks,
}
# Check for missing blocks
missing_blocks = expected_blocks - actual_blocks
extra_blocks = actual_blocks - expected_blocks
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks:
if not missing_blocks and not extra_blocks:
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]")
else:
status_parts = []
if missing_blocks:
status_parts.append(f"MISSING {sorted(missing_blocks)}")
if extra_blocks:
status_parts.append(f"EXTRA {sorted(extra_blocks)}")
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]")
elif chunk_idx == 4:
# Indicate that middle chunks are being verified silently
print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...")
if missing_blocks:
errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}")
return hook
def make_gpu_write_verification_post_hook(layer_id: int):
"""
Create a POST-forward hook to verify the current chunk's KV was correctly
written to the GPU ring buffer write_slot.
This is a more reliable verification than checking load slots, because:
1. Post-hook runs AFTER forward() writes to GPU cache
2. write_slot mapping is deterministic: chunk_idx % num_ring_slots
3. We injected known patterns in pre-hook, now verify they're in GPU cache
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
oe = kvcache_manager.offload_engine
num_ring_slots = oe.num_ring_slots
write_slot = chunk_idx % num_ring_slots
# Get expected pattern for current chunk
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[write_slot]
gpu_v = oe.v_cache_gpu[write_slot]
actual_k_mean = gpu_k.float().mean().item()
actual_v_mean = gpu_v.float().mean().item()
k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}")
v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}")
num_chunks = INPUT_LEN // BLOCK_SIZE
# Print for first/last chunks, or if there's an error
if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok):
if k_ok and v_ok:
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]")
else:
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, "
f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]")
elif chunk_idx == 4:
print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...")
if not (k_ok and v_ok):
errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: "
f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}")
return hook
def make_kv_verification_post_hook(layer_id: int):
"""
Create a POST-forward hook to verify CPU cache contains correct patterns
from previously offloaded blocks.
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
# === Verify previously offloaded blocks in CPU cache ===
if chunk_idx >= 1:
oe = kvcache_manager.offload_engine
num_ok = 0
num_fail = 0
# Check all previously offloaded blocks
for prev_chunk in range(chunk_idx):
# CPU block ID = prev_chunk (in simple sequential case)
cpu_block_id = prev_chunk
# Get expected pattern for this block
expected_k, expected_v = get_expected_pattern(prev_chunk)
# Read from CPU cache (layer 0)
cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id]
cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id]
# Verify patterns
k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}")
v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}")
if k_ok and v_ok:
num_ok += 1
else:
num_fail += 1
if k_err:
errors.append(k_err)
if v_err:
errors.append(v_err)
# Only print summary for each chunk verification
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0:
status = "OK" if num_fail == 0 else f"FAIL({num_fail})"
print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]")
elif chunk_idx == 4:
print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...")
return hook
def make_post_chunk_verification_hook(layer_id: int):
"""
Post-forward hook to verify GPU ring buffer state after attention.
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill or layer_id != 0:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
oe = kvcache_manager.offload_engine
# After attention, the current chunk's KV should be in the GPU ring buffer
# Ring slot = chunk_idx % num_ring_slots
ring_slot = chunk_idx % oe.num_ring_slots
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check GPU ring buffer (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[ring_slot]
gpu_v = oe.v_cache_gpu[ring_slot]
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
if k_ok and v_ok:
print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}")
else:
if k_err:
print(f" [FAIL] {k_err}")
errors.append(k_err)
if v_err:
print(f" [FAIL] {v_err}")
errors.append(v_err)
return hook
def register_hooks(llm):
"""Register pre and post forward hooks."""
hooks = []
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
# PRE-forward hook 1: Verify all previous blocks are in cpu_block_table
coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx))
hooks.append(coverage_hook)
# PRE-forward hook 2: Inject K/V patterns before they're stored to cache
pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx))
hooks.append(pattern_hook)
# POST-forward hook 1: Verify GPU write_slot contains current chunk's data
gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx))
hooks.append(gpu_verify_hook)
# POST-forward hook 2: Verify CPU cache contains correct patterns after offload
cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx))
hooks.append(cpu_verify_hook)
return hooks
# ============================================================
# Final Verification
# ============================================================
def verify_final_cpu_state(llm, num_chunks: int):
"""Verify all CPU blocks have correct patterns after prefill completes."""
print("\n" + "=" * 60)
print("Final CPU Cache Verification")
print("=" * 60)
kvcache_manager = llm.model_runner.kvcache_manager
oe = kvcache_manager.offload_engine
num_ok = 0
num_fail = 0
fail_details = []
# After prefill, all chunks should be in CPU
for chunk_idx in range(num_chunks):
cpu_block_id = chunk_idx
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check layer 0
cpu_k = oe.k_cache_cpu[0, cpu_block_id]
cpu_v = oe.v_cache_cpu[0, cpu_block_id]
k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}")
v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}")
if k_ok and v_ok:
num_ok += 1
# Only print first few and last few
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
actual_k_mean = cpu_k.float().mean().item()
actual_v_mean = cpu_v.float().mean().item()
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]")
elif chunk_idx == 3:
print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...")
else:
num_fail += 1
actual_k_mean = cpu_k.float().mean().item()
actual_v_mean = cpu_v.float().mean().item()
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]")
if k_err:
errors.append(k_err)
if v_err:
errors.append(v_err)
print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks")
def verify_block_coverage_summary(num_chunks: int):
"""Verify that all chunks had complete block coverage during prefill."""
print("\n" + "=" * 60)
print("Block Coverage Verification Summary")
print("=" * 60)
num_ok = 0
num_fail = 0
total_blocks_expected = 0
total_blocks_computed = 0
for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous)
if chunk_idx not in block_coverage:
print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]")
errors.append(f"Chunk {chunk_idx} has no block coverage data")
num_fail += 1
continue
coverage = block_coverage[chunk_idx]
expected = coverage['expected']
actual = coverage['actual']
missing = expected - actual
total_blocks_expected += len(expected)
total_blocks_computed += len(actual)
if not missing:
num_ok += 1
else:
num_fail += 1
# Print summary
if num_fail == 0:
print(f" All {num_ok} chunks had complete block coverage [OK]")
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
else:
print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]")
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
# Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2
expected_total = num_chunks * (num_chunks - 1) // 2
if total_blocks_expected == expected_total:
print(f" Expected total blocks matches formula: {expected_total} [OK]")
else:
print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]")
errors.append(f"Block coverage total mismatch")
def verify_load_operations_summary(num_chunks: int):
"""Verify all H2D load operations transferred correct data."""
print("\n" + "=" * 60)
print("H2D Load Operations Verification Summary")
print("=" * 60)
if not load_operations:
print(" WARNING: No load operations recorded!")
print(" (This may indicate load verification was not installed)")
return
num_ok = 0
num_fail = 0
loads_per_chunk = {}
for op in load_operations:
chunk_idx = op['chunk_idx']
if chunk_idx not in loads_per_chunk:
loads_per_chunk[chunk_idx] = []
loads_per_chunk[chunk_idx].append(op)
if op['k_ok'] and op['v_ok']:
num_ok += 1
else:
num_fail += 1
# Print per-chunk summary for first/last chunks
for chunk_idx in sorted(loads_per_chunk.keys()):
ops = loads_per_chunk[chunk_idx]
chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok'])
chunk_fail = len(ops) - chunk_ok
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0:
# Show loaded block IDs in order
block_ids = [op['cpu_block_id'] for op in ops]
if chunk_fail == 0:
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]")
else:
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]")
for op in ops:
if not (op['k_ok'] and op['v_ok']):
print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: "
f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, "
f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}")
elif chunk_idx == 4:
print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...")
# Print overall summary
print(f"\n Total load operations: {len(load_operations)}")
print(f" Successful: {num_ok}, Failed: {num_fail}")
if num_fail == 0:
print(f" All H2D transfers verified correct [OK]")
else:
print(f" {num_fail} H2D transfers had incorrect data [FAIL]")
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Test: CPU Offload Correctness with Distinctive KV Patterns")
print("=" * 60)
print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks")
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)")
print()
# 1. Initialize LLM
print("Initializing LLM...")
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# 2. Register hooks
hooks = register_hooks(llm)
print(f"Registered {len(hooks)} hooks")
# 3. Install load verification (instrument load_to_slot_layer)
install_load_verification(llm)
# 4. Generate prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
num_chunks = INPUT_LEN // BLOCK_SIZE
# 5. Run prefill
print("\n" + "=" * 60)
print("Running Prefill with KV Pattern Injection...")
print("=" * 60)
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# 6. Remove hooks and uninstall load verification
for hook in hooks:
hook.remove()
uninstall_load_verification()
# 7. Final verification
verify_final_cpu_state(llm, num_chunks)
# 8. Block coverage summary
verify_block_coverage_summary(num_chunks)
# 9. H2D load operations summary
verify_load_operations_summary(num_chunks)
# 10. Report results
print("\n" + "=" * 60)
if errors:
print(f"test_offload_correctness: FAILED ({len(errors)} errors)")
for err in errors[:10]: # Show first 10 errors
print(f" - {err}")
exit(1)
else:
print("test_offload_correctness: PASSED")
print("=" * 60)

View File

@@ -1,70 +0,0 @@
"""
Test if slicing maintains pinned memory property.
"""
import torch
print("=" * 60)
print("Test: Pinned Memory Property with Slicing")
print("=" * 60)
# Create a pinned tensor with shape similar to k_cache_cpu
# [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]
tensor = torch.zeros(8, 16, 1024, 8, 64, dtype=torch.float16, device="cpu", pin_memory=True)
print(f"\n1. Original tensor:")
print(f" - Shape: {tensor.shape}")
print(f" - is_pinned(): {tensor.is_pinned()}")
print(f" - is_contiguous(): {tensor.is_contiguous()}")
# Test slicing operation (what we do in offload_slot_to_cpu)
slice_view = tensor[:, 0] # Same as k_cache_cpu[:, cpu_block_id]
print(f"\n2. Sliced tensor [:, 0]:")
print(f" - Shape: {slice_view.shape}")
print(f" - is_pinned(): {slice_view.is_pinned()}")
print(f" - is_contiguous(): {slice_view.is_contiguous()}")
# Test if contiguous() helps
contiguous_slice = tensor[:, 0].contiguous()
print(f"\n3. Contiguous slice [:, 0].contiguous():")
print(f" - Shape: {contiguous_slice.shape}")
print(f" - is_pinned(): {contiguous_slice.is_pinned()}")
print(f" - is_contiguous(): {contiguous_slice.is_contiguous()}")
# Test copy behavior
gpu_tensor = torch.zeros(8, 4, 1024, 8, 64, dtype=torch.float16, device="cuda")
gpu_slice = gpu_tensor[:, 0]
print(f"\n4. GPU tensor slice:")
print(f" - Shape: {gpu_slice.shape}")
print(f" - is_contiguous(): {gpu_slice.is_contiguous()}")
# Simulate the problematic copy operation
print(f"\n5. Testing copy operations:")
# Method 1: Direct slice copy (current approach - SLOW)
slice_dst = tensor[:, 1]
print(f" Method 1 (slice view): dst.is_pinned()={slice_dst.is_pinned()}")
# Method 2: Use contiguous destination
contiguous_dst = tensor[:, 2].contiguous()
print(f" Method 2 (contiguous): dst.is_pinned()={contiguous_dst.is_pinned()}")
print("\n" + "=" * 60)
print("Conclusion:")
print("=" * 60)
if not slice_view.is_pinned():
print("❌ Slicing LOSES pinned memory property!")
print(" This causes Device-to-Pageable transfers (SLOW)")
else:
print("✓ Slicing maintains pinned memory property")
if contiguous_slice.is_pinned():
print("✓ .contiguous() maintains pinned memory property")
else:
print("❌ .contiguous() also loses pinned memory property")
print("\n" + "=" * 60)

View File

@@ -1,124 +0,0 @@
"""
Test D2H transfer performance with pinned vs non-contiguous memory.
"""
import torch
import time
print("=" * 60)
print("Test: D2H Transfer Performance (for nsys profiling)")
print("=" * 60)
# Setup
num_layers = 8
num_blocks = 16
block_size = 1024
num_kv_heads = 8
head_dim = 64
# Allocate CPU cache (pinned)
k_cache_cpu = torch.zeros(
num_layers, num_blocks, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
# Allocate GPU cache
k_cache_gpu = torch.randn(
num_layers, 4, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cuda"
)
# Warmup
print("\nWarmup...")
for _ in range(10):
k_cache_cpu[:, 0].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
print(f"\nTensor info:")
print(f" k_cache_cpu.is_pinned(): {k_cache_cpu.is_pinned()}")
print(f" k_cache_cpu.is_contiguous(): {k_cache_cpu.is_contiguous()}")
print(f" k_cache_cpu[:, 0].is_pinned(): {k_cache_cpu[:, 0].is_pinned()}")
print(f" k_cache_cpu[:, 0].is_contiguous(): {k_cache_cpu[:, 0].is_contiguous()}")
# Test 1: Non-contiguous slice (current approach)
print(f"\n" + "=" * 60)
print("Test 1: Non-contiguous slice copy (current approach)")
print("=" * 60)
NUM_ITERS = 50 # Reduced for profiling
torch.cuda.nvtx.range_push("Test1_NonContiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_NonContig_{i}")
start = time.perf_counter()
k_cache_cpu[:, i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
# Test 2: Transpose to make dimension contiguous
print(f"\n" + "=" * 60)
print("Test 2: Transpose to contiguous dimension")
print("=" * 60)
# Reshape to [num_blocks, num_layers, block_size, num_kv_heads, head_dim]
k_cache_cpu_transposed = torch.zeros(
num_blocks, num_layers, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
print(f" k_cache_cpu_transposed[0].is_pinned(): {k_cache_cpu_transposed[0].is_pinned()}")
print(f" k_cache_cpu_transposed[0].is_contiguous(): {k_cache_cpu_transposed[0].is_contiguous()}")
torch.cuda.nvtx.range_push("Test2_Contiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_Contig_{i}")
start = time.perf_counter()
k_cache_cpu_transposed[i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
# Test 3: Fully contiguous buffer
print(f"\n" + "=" * 60)
print("Test 3: Fully contiguous buffer")
print("=" * 60)
k_cache_cpu_flat = torch.zeros(
num_layers * block_size * num_kv_heads * head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
print(f" k_cache_cpu_flat.is_pinned(): {k_cache_cpu_flat.is_pinned()}")
print(f" k_cache_cpu_flat.is_contiguous(): {k_cache_cpu_flat.is_contiguous()}")
torch.cuda.nvtx.range_push("Test3_FlatContiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_Flat_{i}")
start = time.perf_counter()
k_cache_cpu_flat.copy_(k_cache_gpu[:, 0].flatten(), non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_cpu_flat.numel() * 2 / avg_time / 1e9:.2f} GB/s")
print("\n" + "=" * 60)
print("test_pinned_transfer: PASSED")
print("=" * 60)

View File

@@ -1,286 +0,0 @@
"""
Chunked Prefill + KV Cache Offload Simulation v2
改进:
1. 简化日志输出
2. 添加reduce时间
3. 计算必须等待KV load完成
"""
import threading
import time
from dataclasses import dataclass
from typing import Optional
from concurrent.futures import ThreadPoolExecutor, Future
# ============== 配置参数 ==============
NUM_CHUNKS = 8
GPU_SLOTS = 4
# 模拟时间 (秒)
TIME_COMPUTE_BLOCK = 0.10 # 计算一个attention block
TIME_REDUCE = 0.03 # 两个partial result做一次reduce
TIME_TRANSFER = 0.08 # 传输一个KV cache
TIME_PROJ = 0.02 # projection生成KV
# ============== 全局时间基准 ==============
START_TIME = None
def now() -> float:
"""返回相对开始的时间(ms)"""
return (time.time() - START_TIME) * 1000
def log_compute(msg: str):
"""计算队列日志(无缩进)"""
print(f"[{now():7.1f}ms] [COMPUTE] {msg}")
def log_transfer(msg: str):
"""传输队列日志(缩进)"""
print(f"[{now():7.1f}ms] [TRANSFER] {msg}")
def log_info(msg: str):
"""一般信息"""
print(f"[{now():7.1f}ms] {msg}")
# ============== GPU Slot管理 ==============
class GPUSlots:
def __init__(self, num_slots: int):
self.slots = [None] * num_slots # slot_id -> kv_idx
self.kv_to_slot = {} # kv_idx -> slot_id
self.lock = threading.Lock()
# KV ready events: kv_idx -> Event
self.kv_ready = {}
def alloc(self, kv_idx: int) -> int:
with self.lock:
for sid, val in enumerate(self.slots):
if val is None:
self.slots[sid] = kv_idx
self.kv_to_slot[kv_idx] = sid
# 创建ready event
if kv_idx not in self.kv_ready:
self.kv_ready[kv_idx] = threading.Event()
return sid
raise RuntimeError(f"No free slot for KV{kv_idx}")
def free(self, slot_id: int):
with self.lock:
kv_idx = self.slots[slot_id]
if kv_idx is not None:
del self.kv_to_slot[kv_idx]
# 清除event
if kv_idx in self.kv_ready:
del self.kv_ready[kv_idx]
self.slots[slot_id] = None
def free_kv(self, kv_idx: int):
with self.lock:
if kv_idx in self.kv_to_slot:
sid = self.kv_to_slot[kv_idx]
self.slots[sid] = None
del self.kv_to_slot[kv_idx]
if kv_idx in self.kv_ready:
del self.kv_ready[kv_idx]
def mark_ready(self, kv_idx: int):
"""标记KV已就绪load完成或proj完成"""
with self.lock:
if kv_idx in self.kv_ready:
self.kv_ready[kv_idx].set()
def wait_ready(self, kv_idx: int):
"""等待KV就绪"""
with self.lock:
event = self.kv_ready.get(kv_idx)
if event:
event.wait()
def has_kv(self, kv_idx: int) -> bool:
with self.lock:
return kv_idx in self.kv_to_slot
def state(self) -> str:
with self.lock:
return "[" + "][".join(
f"KV{v}" if v is not None else "----"
for v in self.slots
) + "]"
# ============== 操作执行 ==============
class Executor:
def __init__(self, gpu: GPUSlots):
self.gpu = gpu
self.compute_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Compute")
self.transfer_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Transfer")
def proj_kv(self, q_idx: int) -> Future:
"""Projection生成KV返回Future"""
def task():
log_compute(f"PROJ Q{q_idx}->KV{q_idx} START")
time.sleep(TIME_PROJ)
slot_id = self.gpu.alloc(q_idx)
self.gpu.mark_ready(q_idx) # proj完成KV立即可用
log_compute(f"PROJ Q{q_idx}->KV{q_idx} END slot={slot_id} | {self.gpu.state()}")
return slot_id
return self.compute_pool.submit(task)
def compute_attn(self, q_idx: int, kv_indices: list) -> Future:
"""计算attention block会等待所有KV就绪"""
def task():
# 等待所有需要的KV就绪
for kv_idx in kv_indices:
self.gpu.wait_ready(kv_idx)
kv_str = ",".join(map(str, kv_indices))
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] START")
time.sleep(TIME_COMPUTE_BLOCK * len(kv_indices))
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] END")
return (q_idx, kv_indices)
return self.compute_pool.submit(task)
def reduce(self, q_idx: int, num_partials: int) -> Future:
"""Online softmax reduce多个partial结果"""
def task():
if num_partials <= 1:
return
# n个partial需要n-1次两两reduce
num_reduces = num_partials - 1
log_compute(f"REDUCE Q{q_idx} ({num_partials} partials) START")
time.sleep(TIME_REDUCE * num_reduces)
log_compute(f"REDUCE Q{q_idx} END")
return self.compute_pool.submit(task)
def load_kv(self, kv_idx: int) -> Future:
"""从CPU load KV到GPU"""
def task():
if self.gpu.has_kv(kv_idx):
log_transfer(f"LOAD KV{kv_idx} SKIP (already on GPU)")
return kv_idx
slot_id = self.gpu.alloc(kv_idx)
log_transfer(f"LOAD KV{kv_idx} START -> slot{slot_id}")
time.sleep(TIME_TRANSFER)
self.gpu.mark_ready(kv_idx) # load完成标记就绪
log_transfer(f"LOAD KV{kv_idx} END | {self.gpu.state()}")
return kv_idx
return self.transfer_pool.submit(task)
def offload_kv(self, kv_idx: int) -> Future:
"""从GPU offload KV到CPU"""
def task():
log_transfer(f"OFFLOAD KV{kv_idx} START")
time.sleep(TIME_TRANSFER)
self.gpu.free_kv(kv_idx)
log_transfer(f"OFFLOAD KV{kv_idx} END | {self.gpu.state()}")
return kv_idx
return self.transfer_pool.submit(task)
def shutdown(self):
self.compute_pool.shutdown(wait=True)
self.transfer_pool.shutdown(wait=True)
# ============== 调度器 ==============
def schedule_query(exe: Executor, q_idx: int):
"""调度单个Query的处理"""
print(f"\n{'='*50}")
log_info(f"===== Query {q_idx} START =====")
hist_kv = list(range(q_idx)) # 历史KV: 0 ~ q_idx-1
num_partials = 0
# Phase 1: Projection生成当前KV
proj_fut = exe.proj_kv(q_idx)
proj_fut.result() # 等待完成
# Phase 2: 对角块计算 + 同时prefetch历史KV
# 启动对角块计算
diag_fut = exe.compute_attn(q_idx, [q_idx])
num_partials += 1
# 同时prefetch历史KV (最多3个slot可用)
prefetch_slots = min(len(hist_kv), GPU_SLOTS - 1)
prefetch_kv = hist_kv[:prefetch_slots]
prefetch_futs = [exe.load_kv(kv) for kv in prefetch_kv]
# 等待对角块完成
diag_fut.result()
# Phase 3: Offload当前KV释放slot
offload_fut = exe.offload_kv(q_idx)
# 等待prefetch完成然后计算这批历史KV
for f in prefetch_futs:
f.result()
if prefetch_kv:
hist_fut = exe.compute_attn(q_idx, prefetch_kv)
num_partials += 1
else:
hist_fut = None
# 等待offload完成
offload_fut.result()
# Phase 4: 处理剩余历史KV
remaining_kv = hist_kv[prefetch_slots:]
computed_kv = prefetch_kv.copy()
while remaining_kv:
# 等待上一批计算完成
if hist_fut:
hist_fut.result()
# 释放已计算的KV
for kv in computed_kv:
exe.gpu.free_kv(kv)
# Load下一批
batch_size = min(len(remaining_kv), GPU_SLOTS)
batch_kv = remaining_kv[:batch_size]
remaining_kv = remaining_kv[batch_size:]
load_futs = [exe.load_kv(kv) for kv in batch_kv]
for f in load_futs:
f.result()
# 计算这批
hist_fut = exe.compute_attn(q_idx, batch_kv)
num_partials += 1
computed_kv = batch_kv
# 等待最后一批计算完成
if hist_fut:
hist_fut.result()
# 清理GPU
for kv in computed_kv:
exe.gpu.free_kv(kv)
# Phase 5: Reduce所有partial results
reduce_fut = exe.reduce(q_idx, num_partials)
reduce_fut.result()
log_info(f"===== Query {q_idx} END =====")
def main():
global START_TIME
START_TIME = time.time()
print("Chunked Prefill + KV Cache Offload Simulation v2")
print(f"Config: {NUM_CHUNKS} chunks, {GPU_SLOTS} GPU slots")
print(f"Time: compute={TIME_COMPUTE_BLOCK}s, transfer={TIME_TRANSFER}s, reduce={TIME_REDUCE}s")
gpu = GPUSlots(GPU_SLOTS)
exe = Executor(gpu)
try:
for q_idx in range(NUM_CHUNKS):
schedule_query(exe, q_idx)
print(f"\n{'='*50}")
log_info(f"ALL DONE! Total: {now():.1f}ms")
finally:
exe.shutdown()
if __name__ == "__main__":
main()