Compare commits
14 Commits
bf4c63c7ec
...
tzj/minfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b52d25866 | ||
|
|
8c3418725b | ||
|
|
b3685c9190 | ||
|
|
6927a75ac3 | ||
|
|
ff8b09cd35 | ||
|
|
74ee6d0895 | ||
|
|
62b8a63314 | ||
|
|
965c8aff12 | ||
|
|
30462fe89a | ||
|
|
ccd1b3d4ab | ||
|
|
31e90a7268 | ||
|
|
484d0de9f9 | ||
|
|
7af721c12c | ||
|
|
89f8020d38 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -195,3 +195,4 @@ cython_debug/
|
||||
.cursorindexingignore
|
||||
|
||||
results/
|
||||
outputs/
|
||||
74
CLAUDE.md
74
CLAUDE.md
@@ -20,6 +20,80 @@ For sparse attention related content (block sparse attention, MInference, FlexPr
|
||||
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
|
||||
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
|
||||
|
||||
## PyTorch Hooks for Debugging
|
||||
|
||||
### Hook Positions in Qwen3
|
||||
|
||||
```
|
||||
decoder_layer
|
||||
├── input_layernorm (RMSNorm)
|
||||
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
|
||||
│ ├── q_proj → q_norm → RoPE
|
||||
│ ├── k_proj → k_norm → RoPE
|
||||
│ ├── v_proj
|
||||
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
|
||||
│ │ └── FlashAttention / SDPA
|
||||
│ └── o_proj
|
||||
├── post_attention_layernorm (RMSNorm)
|
||||
└── mlp (Qwen3MLP)
|
||||
```
|
||||
|
||||
### Hook Types & Data Shapes
|
||||
|
||||
| Hook Position | Type | Captured Data |
|
||||
|---------------|------|---------------|
|
||||
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
|
||||
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
|
||||
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
|
||||
|
||||
### Example: Capture Attention Outputs
|
||||
|
||||
```python
|
||||
storage = {}
|
||||
|
||||
def make_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs, output):
|
||||
if isinstance(output, tuple):
|
||||
attn_output = output[0]
|
||||
else:
|
||||
attn_output = output
|
||||
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
||||
if attn_output.dim() == 2:
|
||||
attn_output = attn_output.unsqueeze(0)
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
# Register hooks
|
||||
hooks = []
|
||||
for layer_idx, layer in enumerate(model.model.layers):
|
||||
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
|
||||
|
||||
# Run inference...
|
||||
|
||||
# Cleanup
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
```
|
||||
|
||||
### Alignment Testing
|
||||
|
||||
Use `tests/test_align.py` to compare nanovllm with reference torch implementation:
|
||||
|
||||
```bash
|
||||
python tests/test_align.py
|
||||
```
|
||||
|
||||
Key files:
|
||||
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
|
||||
- `tests/test_align.py`: Compares attention outputs between nanovllm and reference
|
||||
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
|
||||
|
||||
### Common Pitfalls
|
||||
|
||||
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
|
||||
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
|
||||
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
|
||||
|
||||
## CPU Offload System
|
||||
|
||||
### Ring Buffer Design
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -16,6 +17,7 @@ class Config:
|
||||
eos: int = -1
|
||||
kvcache_block_size: int = 4096
|
||||
num_kvcache_blocks: int = -1
|
||||
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
||||
|
||||
# CPU Offload configuration
|
||||
enable_cpu_offload: bool = False
|
||||
@@ -41,3 +43,17 @@ class Config:
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
|
||||
# Override torch_dtype if user specified
|
||||
if self.dtype is not None:
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"fp16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"bf16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
if self.dtype not in dtype_map:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}. Choose from: {list(dtype_map.keys())}")
|
||||
self.hf_config.torch_dtype = dtype_map[self.dtype]
|
||||
|
||||
@@ -31,6 +31,8 @@ class LLMEngine:
|
||||
self.model_runner = ModelRunner(config, 0, self.events)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
# Set Sequence.block_size to match the KV cache block size
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
|
||||
atexit.register(self.exit)
|
||||
|
||||
|
||||
@@ -489,24 +489,15 @@ class ModelRunner:
|
||||
logical_id = seq.block_table[block_idx]
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# Offload this chunk's ring buffer slot to CPU (async)
|
||||
# NOTE: Per-layer offloading is now done in attention.forward
|
||||
# Each layer offloads its KV to CPU immediately after computing attention.
|
||||
# We just need to wait for the last offload to complete before reusing the slot.
|
||||
if block_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[block_idx]
|
||||
|
||||
# Call sparse policy hook before offload (to capture metadata)
|
||||
sparse_policy = self.kvcache_manager.sparse_policy
|
||||
if sparse_policy is not None:
|
||||
num_tokens = chunk_end - chunk_start
|
||||
for layer_id in range(offload_engine.num_layers):
|
||||
k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens]
|
||||
sparse_policy.on_block_offloaded(
|
||||
cpu_block_id=cpu_block_id,
|
||||
layer_id=layer_id,
|
||||
k_cache=k_cache,
|
||||
num_valid_tokens=num_tokens,
|
||||
)
|
||||
|
||||
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
|
||||
# TODO: Sparse policy hook needs update for new GPU cache architecture
|
||||
# The GPU cache no longer has layer dimension, so we can't access
|
||||
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
|
||||
# in attention.forward after per-layer offload.
|
||||
pass
|
||||
|
||||
# Wait for offload to complete before next chunk
|
||||
# (slot will be reused after N chunks)
|
||||
@@ -521,6 +512,7 @@ class ModelRunner:
|
||||
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
||||
|
||||
# Sample from last logits
|
||||
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
if logits is not None:
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
@@ -627,7 +619,11 @@ class ModelRunner:
|
||||
if pos_in_block == self.block_size - 1:
|
||||
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
||||
if last_cpu_block >= 0:
|
||||
offload_engine.offload_decode_slot(last_cpu_block)
|
||||
# TODO: In new GPU cache architecture (no layer dimension),
|
||||
# decode offload should be done per-layer in attention.forward.
|
||||
# For now, offload all layers sequentially.
|
||||
for layer_id in range(offload_engine.num_layers):
|
||||
offload_engine.offload_decode_slot_layer(layer_id, last_cpu_block)
|
||||
offload_engine.wait_all_offload_done()
|
||||
# Reset decode start position for next block
|
||||
self.kvcache_manager.reset_decode_start_pos(seq)
|
||||
|
||||
@@ -281,7 +281,11 @@ def _merge_lse_kernel(
|
||||
num_elements: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Fused kernel for merging LSE values."""
|
||||
"""Fused kernel for merging LSE values.
|
||||
|
||||
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
|
||||
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
|
||||
"""
|
||||
# Each program handles BLOCK_SIZE elements
|
||||
pid = tl.program_id(0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
@@ -289,21 +293,21 @@ def _merge_lse_kernel(
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < num_elements
|
||||
|
||||
# Load lse values
|
||||
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
|
||||
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
|
||||
# Load lse values and convert to fp32 for precision
|
||||
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
|
||||
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
|
||||
|
||||
# Compute max for numerical stability
|
||||
# Compute max for numerical stability (in fp32)
|
||||
max_lse = tl.maximum(lse1, lse2)
|
||||
|
||||
# Compute exp(lse - max_lse)
|
||||
# Compute exp(lse - max_lse) in fp32
|
||||
exp1 = tl.exp(lse1 - max_lse)
|
||||
exp2 = tl.exp(lse2 - max_lse)
|
||||
|
||||
# Compute merged LSE: max_lse + log(exp1 + exp2)
|
||||
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
|
||||
lse_merged = max_lse + tl.log(exp1 + exp2)
|
||||
|
||||
# Store result
|
||||
# Store result (convert back to original dtype)
|
||||
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
||||
|
||||
|
||||
@@ -313,7 +317,11 @@ def _merge_output_kernel(
|
||||
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Fused kernel for merging attention outputs."""
|
||||
"""Fused kernel for merging attention outputs.
|
||||
|
||||
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
|
||||
This is critical for numerical accuracy in chunked attention.
|
||||
"""
|
||||
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
||||
pid_batch = tl.program_id(0)
|
||||
pid_seq = tl.program_id(1)
|
||||
@@ -322,11 +330,11 @@ def _merge_output_kernel(
|
||||
# Compute LSE index: [batch, nheads, seqlen_q]
|
||||
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
||||
|
||||
# Load LSE values
|
||||
lse1 = tl.load(lse1_ptr + lse_idx)
|
||||
lse2 = tl.load(lse2_ptr + lse_idx)
|
||||
# Load LSE values and convert to fp32 for precision
|
||||
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
|
||||
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
|
||||
|
||||
# Compute max and scaling factors
|
||||
# Compute max and scaling factors in fp32
|
||||
max_lse = tl.maximum(lse1, lse2)
|
||||
exp1 = tl.exp(lse1 - max_lse)
|
||||
exp2 = tl.exp(lse2 - max_lse)
|
||||
@@ -343,14 +351,14 @@ def _merge_output_kernel(
|
||||
pid_head * headdim)
|
||||
o_idx = base_idx + d_idx
|
||||
|
||||
# Load o1, o2
|
||||
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
|
||||
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
|
||||
# Load o1, o2 and convert to fp32 for weighted sum
|
||||
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
||||
|
||||
# Store result
|
||||
# Store result (Triton will convert back to original dtype)
|
||||
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
||||
|
||||
|
||||
|
||||
@@ -69,15 +69,19 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
Architecture (CPU-primary mode):
|
||||
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
|
||||
- GPU buffer: Ring buffer for computation (num_gpu_slots)
|
||||
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
|
||||
- GPU buffer: Ring buffer for computation only (num_gpu_slots)
|
||||
- Logical blocks: What sequences reference (num_cpu_blocks)
|
||||
|
||||
Design:
|
||||
- All KV cache is stored on CPU as primary storage
|
||||
- GPU is used as a ring buffer for computation only
|
||||
- GPU is used as a ring buffer for computation only (no persistent data)
|
||||
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
|
||||
- During decode: Previous KV is loaded from CPU to GPU for attention
|
||||
- Ring buffer enables pipelined H2D transfers overlapped with computation
|
||||
|
||||
Note:
|
||||
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
|
||||
- GPU slots are transient compute buffers, not tracked in logical blocks
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -102,20 +106,22 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
|
||||
# GPU slots are transient compute buffers, not tracked as logical blocks
|
||||
self.total_blocks = num_cpu_blocks
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
|
||||
# Logical blocks (what sequences reference)
|
||||
# Logical blocks (what sequences reference) - one per CPU block
|
||||
self.logical_blocks: List[LogicalBlock] = [
|
||||
LogicalBlock(i) for i in range(self.total_blocks)
|
||||
]
|
||||
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
|
||||
|
||||
# GPU slot management (slots are fixed, mapping is variable)
|
||||
# GPU slot management (kept for potential future use, but not used in CPU-primary mode)
|
||||
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
|
||||
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id
|
||||
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id (unused in CPU-primary mode)
|
||||
|
||||
# CPU block management
|
||||
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
|
||||
@@ -212,7 +218,9 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block.ref_count -= 1
|
||||
|
||||
if block.ref_count == 0:
|
||||
# Free physical block
|
||||
# Free physical block based on location
|
||||
# Note: In CPU-primary mode, blocks are always on CPU.
|
||||
# GPU branch kept for potential future hybrid mode support.
|
||||
if block.location == BlockLocation.GPU:
|
||||
self.free_gpu_slots.append(block.gpu_slot)
|
||||
del self.gpu_slot_to_logical[block.gpu_slot]
|
||||
@@ -337,10 +345,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
logger.debug(
|
||||
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||
f"returned cpu_blocks={cpu_blocks}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||
# f"returned cpu_blocks={cpu_blocks}"
|
||||
# )
|
||||
return cpu_blocks
|
||||
|
||||
# ========== Ring Buffer CPU-primary support ==========
|
||||
|
||||
@@ -67,14 +67,19 @@ class OffloadEngine:
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== sgDMA pitch parameters for strided transfers ==========
|
||||
# CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
|
||||
# GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dim)
|
||||
# For CPU-to-GPU transfer (H2D): copy single layer, single block at a time
|
||||
# For all-layer CPU operations (D2H offload to all layers): use sgDMA
|
||||
self.dtype_size = dtype.itemsize
|
||||
# CPU pitch: stride between layers in CPU cache (for all-layer operations)
|
||||
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
|
||||
self.gpu_pitch = num_gpu_blocks * self.block_numel * self.dtype_size
|
||||
self.width = self.block_numel * self.dtype_size
|
||||
self.height = num_layers
|
||||
# GPU has no layer dimension, so single block transfer is contiguous
|
||||
self.gpu_block_bytes = self.block_numel * self.dtype_size
|
||||
self.height = num_layers # For CPU all-layer operations
|
||||
|
||||
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, gpu_pitch={self.gpu_pitch}, "
|
||||
f"width={self.width}, height={self.height}")
|
||||
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, "
|
||||
f"gpu_block_bytes={self.gpu_block_bytes}, height={self.height}")
|
||||
|
||||
# ========== Unified Ring Buffer configuration ==========
|
||||
# Constraint checks
|
||||
@@ -100,17 +105,37 @@ class OffloadEngine:
|
||||
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
# Use zeros initialization to avoid uninitialized memory issues
|
||||
# Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
# NOTE: No num_layers dimension! GPU slots are shared across layers.
|
||||
# Each layer reuses the same slots (layers execute sequentially).
|
||||
# This saves 28x GPU memory compared to per-layer allocation.
|
||||
self.k_cache_gpu = torch.zeros(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.v_cache_gpu = torch.zeros(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
# ========== Per-layer decode buffer ==========
|
||||
# During decode, all layers share decode_slot (no layer dimension in GPU cache).
|
||||
# This causes accumulated tokens to be overwritten by each layer.
|
||||
# Solution: Maintain separate per-layer buffers for decode tokens.
|
||||
# Shape: [num_layers, block_size, kv_heads, head_dim]
|
||||
# Memory: num_layers * block_size * kv_heads * head_dim * dtype_size
|
||||
# e.g., 28 * 1024 * 8 * 128 * 2 = 58.7 MB (acceptable)
|
||||
self.decode_k_buffer = torch.zeros(
|
||||
num_layers, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.decode_v_buffer = torch.zeros(
|
||||
num_layers, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||
|
||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||
self.k_cache_cpu = torch.zeros(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
@@ -159,40 +184,32 @@ class OffloadEngine:
|
||||
# Decode offload event
|
||||
self.decode_offload_done = torch.cuda.Event()
|
||||
|
||||
# ========== Per-slot Per-layer events for ring buffer ==========
|
||||
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
|
||||
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
|
||||
self.ring_slot_ready = [
|
||||
[torch.cuda.Event() for _ in range(num_layers)]
|
||||
for _ in range(self.num_ring_slots)
|
||||
]
|
||||
self.ring_slot_offload_done = [
|
||||
[torch.cuda.Event() for _ in range(num_layers)]
|
||||
for _ in range(self.num_ring_slots)
|
||||
]
|
||||
# ========== Per-slot events for ring buffer ==========
|
||||
# Since GPU cache has no layer dimension and layers execute sequentially,
|
||||
# we only need per-slot events (not per-slot per-layer).
|
||||
# ring_slot_ready[slot_idx] = CUDA Event for H2D completion
|
||||
# ring_slot_offload_done[slot_idx] = CUDA Event for D2H completion
|
||||
self.ring_slot_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
self.ring_slot_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
|
||||
# Per-slot events for all-layer operations (used in some legacy paths)
|
||||
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
|
||||
# ========== Per-slot Per-layer compute_done events for async pipeline ==========
|
||||
# ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion
|
||||
# This is used to ensure we don't overwrite data before it's been read by attention
|
||||
self.ring_slot_compute_done = [
|
||||
[torch.cuda.Event() for _ in range(num_layers)]
|
||||
for _ in range(self.num_ring_slots)
|
||||
]
|
||||
# ========== Per-slot compute_done events for async pipeline ==========
|
||||
# ring_slot_compute_done[slot_idx] = CUDA Event for compute completion
|
||||
# This ensures we don't overwrite data before it's been read by attention
|
||||
self.ring_slot_compute_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
|
||||
# Initialize all compute_done events (record them once)
|
||||
# This prevents undefined behavior on first load_to_slot_layer call
|
||||
for slot_idx in range(self.num_ring_slots):
|
||||
for layer_id in range(num_layers):
|
||||
self.ring_slot_compute_done[slot_idx][layer_id].record()
|
||||
self.ring_slot_compute_done[slot_idx].record()
|
||||
torch.cuda.synchronize() # Ensure all events are recorded
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
# ========== Debug hook mode ==========
|
||||
self._debug_mode = False
|
||||
self._debug_hooks: List = [] # External hooks for debug events
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
stream = self.transfer_streams[self._stream_idx]
|
||||
@@ -200,23 +217,24 @@ class OffloadEngine:
|
||||
return stream
|
||||
|
||||
# ========== CUDA Graph compatible methods ==========
|
||||
# NOTE: These methods need to be updated for the new GPU cache architecture.
|
||||
# GPU cache no longer has layer dimension, so gathered copy semantics change.
|
||||
# For now, these are kept for reference but should not be used without updating.
|
||||
|
||||
def gathered_h2d_layer(self, layer_id: int) -> None:
|
||||
"""
|
||||
Execute gathered H2D copy for a single layer.
|
||||
|
||||
This method is CUDA Graph compatible - can be captured into a graph.
|
||||
Before calling, update_gather_indices() must be called to set up
|
||||
which CPU blocks to copy to which GPU slots.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to transfer
|
||||
WARNING: This method needs updating for new GPU cache architecture.
|
||||
GPU cache no longer has layer dimension.
|
||||
"""
|
||||
# GPU cache has no layer dimension - use flat indexing
|
||||
# Source is CPU[layer_id], dest is GPU (shared across layers)
|
||||
gathered_copy_kv(
|
||||
k_src=self.k_cache_cpu[layer_id],
|
||||
v_src=self.v_cache_cpu[layer_id],
|
||||
k_dst=self.k_cache_gpu[layer_id],
|
||||
v_dst=self.v_cache_gpu[layer_id],
|
||||
k_dst=self.k_cache_gpu, # No layer indexing
|
||||
v_dst=self.v_cache_gpu, # No layer indexing
|
||||
indices=self.gather_indices_gpu[layer_id],
|
||||
)
|
||||
|
||||
@@ -224,7 +242,8 @@ class OffloadEngine:
|
||||
"""
|
||||
Execute gathered H2D copy for all layers.
|
||||
|
||||
CUDA Graph compatible - can be captured into a single graph.
|
||||
WARNING: In new architecture, GPU slots are shared across layers.
|
||||
This method would overwrite slots multiple times. Not recommended.
|
||||
"""
|
||||
for layer_id in range(self.num_layers):
|
||||
self.gathered_h2d_layer(layer_id)
|
||||
@@ -293,10 +312,10 @@ class OffloadEngine:
|
||||
"""
|
||||
Async prefetch a single block from CPU to GPU.
|
||||
|
||||
For use in prefill phase where CUDA graphs are not used.
|
||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
layer_id: Layer index (for CPU cache)
|
||||
cpu_block_id: Source block in CPU cache
|
||||
gpu_block_id: Destination slot in GPU cache
|
||||
|
||||
@@ -309,13 +328,12 @@ class OffloadEngine:
|
||||
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
# K cache
|
||||
self.k_cache_gpu[layer_id, gpu_block_id].copy_(
|
||||
# GPU: no layer dimension, CPU: has layer dimension
|
||||
self.k_cache_gpu[gpu_block_id].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
# V cache
|
||||
self.v_cache_gpu[layer_id, gpu_block_id].copy_(
|
||||
self.v_cache_gpu[gpu_block_id].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
@@ -352,8 +370,10 @@ class OffloadEngine:
|
||||
"""
|
||||
Async offload a block from GPU to CPU.
|
||||
|
||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
layer_id: Layer index (for CPU cache)
|
||||
gpu_block_id: Source slot in GPU cache
|
||||
cpu_block_id: Destination block in CPU cache
|
||||
|
||||
@@ -369,14 +389,13 @@ class OffloadEngine:
|
||||
# Wait for any compute using this block
|
||||
stream.wait_stream(self.compute_stream)
|
||||
|
||||
# K cache
|
||||
# GPU: no layer dimension, CPU: has layer dimension
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[layer_id, gpu_block_id],
|
||||
self.k_cache_gpu[gpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
# V cache
|
||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[layer_id, gpu_block_id],
|
||||
self.v_cache_gpu[gpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
event.record()
|
||||
@@ -413,11 +432,10 @@ class OffloadEngine:
|
||||
"""
|
||||
Load CPU blocks to specific GPU slots for chunked decode.
|
||||
|
||||
Uses the main GPU KV cache slots, not a separate temp buffer.
|
||||
This is the same mechanism as chunked prefill uses.
|
||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
layer_id: Layer index (for CPU cache)
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
||||
"""
|
||||
@@ -430,12 +448,12 @@ class OffloadEngine:
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
# Copy from pinned CPU memory to GPU KV cache slot
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
# GPU: no layer dimension, CPU: has layer dimension
|
||||
self.k_cache_gpu[gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_gpu[gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
@@ -452,8 +470,10 @@ class OffloadEngine:
|
||||
"""
|
||||
Async version: Load CPU blocks to GPU slots.
|
||||
|
||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
layer_id: Layer index (for CPU cache)
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into
|
||||
|
||||
@@ -470,11 +490,12 @@ class OffloadEngine:
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
# GPU: no layer dimension, CPU: has layer dimension
|
||||
self.k_cache_gpu[gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_gpu[gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
@@ -482,44 +503,8 @@ class OffloadEngine:
|
||||
|
||||
return event
|
||||
|
||||
def load_cpu_blocks_to_gpu_slots_all_layers(
|
||||
self,
|
||||
cpu_block_ids: List[int],
|
||||
gpu_slot_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Load CPU blocks to GPU slots for ALL layers at once.
|
||||
|
||||
More efficient than per-layer loading when we know the mapping upfront.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
if cpu_block_ids:
|
||||
logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
||||
|
||||
stream = self._get_next_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
# Copy all layers at once using sgDMA
|
||||
memcpy_2d_async(
|
||||
self.k_cache_gpu[:, gpu_slot],
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=stream
|
||||
)
|
||||
memcpy_2d_async(
|
||||
self.v_cache_gpu[:, gpu_slot],
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=stream
|
||||
)
|
||||
|
||||
stream.synchronize()
|
||||
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
|
||||
# layer dimension. Each GPU slot holds data for ONE layer at a time.
|
||||
|
||||
# ========== Synchronization methods ==========
|
||||
|
||||
@@ -538,27 +523,33 @@ class OffloadEngine:
|
||||
|
||||
def sync_indices(self) -> None:
|
||||
"""Synchronize to ensure all index updates are complete."""
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.cuda.default_stream().synchronize()
|
||||
|
||||
# ========== Cache access methods ==========
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get GPU K/V cache tensors for a specific layer.
|
||||
Get GPU K/V cache tensors for attention layer.
|
||||
|
||||
NOTE: GPU cache has no layer dimension - all layers share the same slots.
|
||||
The layer_id parameter is kept for API compatibility but not used.
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache) tensors for the layer
|
||||
(k_cache, v_cache) tensors
|
||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id]
|
||||
# GPU cache is shared across all layers (no layer dimension)
|
||||
return self.k_cache_gpu, self.v_cache_gpu
|
||||
|
||||
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get full GPU K/V cache tensors.
|
||||
|
||||
NOTE: GPU cache has no layer dimension in the new architecture.
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache) tensors
|
||||
Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
return self.k_cache_gpu, self.v_cache_gpu
|
||||
|
||||
@@ -664,7 +655,7 @@ class OffloadEngine:
|
||||
|
||||
# ----- Per-slot Per-layer loading methods -----
|
||||
|
||||
def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None:
|
||||
def record_slot_compute_done(self, slot_idx: int) -> None:
|
||||
"""
|
||||
Record that computation using this slot's data is done.
|
||||
|
||||
@@ -673,21 +664,23 @@ class OffloadEngine:
|
||||
|
||||
Args:
|
||||
slot_idx: GPU slot index that was just used for computation
|
||||
layer_id: Layer index
|
||||
"""
|
||||
self.ring_slot_compute_done[slot_idx][layer_id].record()
|
||||
self.ring_slot_compute_done[slot_idx].record()
|
||||
|
||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async load a single CPU block to a ring buffer slot for one layer.
|
||||
|
||||
This is the core building block for ring buffer pipelining.
|
||||
Before starting the transfer, waits for any previous compute on this slot
|
||||
to complete (using compute_done event).
|
||||
GPU cache has no layer dimension - slots are shared across all layers.
|
||||
CPU cache still has layer dimension for persistent storage.
|
||||
|
||||
Before starting the transfer, waits for:
|
||||
1. Any previous compute on this slot to complete
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
layer_id: Layer index to load
|
||||
layer_id: Layer index to load (for CPU cache indexing)
|
||||
cpu_block_id: Source CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
@@ -699,140 +692,105 @@ class OffloadEngine:
|
||||
with torch.cuda.stream(stream):
|
||||
# Wait for previous compute on this slot to complete before overwriting
|
||||
# This prevents data race: transfer must not start until attention finishes reading
|
||||
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
|
||||
stream.wait_event(self.ring_slot_compute_done[slot_idx])
|
||||
|
||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
||||
# Also wait for any pending offload of this slot to complete
|
||||
# This prevents race: load must not write GPU slot while offload is reading from it
|
||||
stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||
|
||||
# GPU: no layer dimension, CPU: has layer dimension
|
||||
self.k_cache_gpu[slot_idx].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, slot_idx].copy_(
|
||||
self.v_cache_gpu[slot_idx].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.ring_slot_ready[slot_idx][layer_id].record(stream)
|
||||
self.ring_slot_ready[slot_idx].record(stream)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
|
||||
def wait_slot_layer(self, slot_idx: int) -> None:
|
||||
"""
|
||||
Wait for a slot's loading to complete for a specific layer.
|
||||
Wait for a slot's loading to complete.
|
||||
|
||||
Args:
|
||||
slot_idx: GPU slot index to wait for
|
||||
layer_id: Layer index to wait for
|
||||
"""
|
||||
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
|
||||
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
|
||||
|
||||
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async load a CPU block to a ring buffer slot for ALL layers.
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
cpu_block_id: Source CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
memcpy_2d_async(
|
||||
self.k_cache_gpu[:, slot_idx],
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=self.transfer_stream_main
|
||||
)
|
||||
memcpy_2d_async(
|
||||
self.v_cache_gpu[:, slot_idx],
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=self.transfer_stream_main
|
||||
)
|
||||
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_all_layers(self, slot_idx: int) -> None:
|
||||
"""Wait for a slot's loading to complete for ALL layers."""
|
||||
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
|
||||
# NOTE: load_to_slot_all_layers removed - GPU cache no longer has layer dimension.
|
||||
# Each GPU slot holds data for ONE layer at a time. Layers execute sequentially,
|
||||
# reusing the same GPU slots.
|
||||
|
||||
# ----- Slot offload methods -----
|
||||
|
||||
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU (all layers).
|
||||
|
||||
Args:
|
||||
slot_idx: Source GPU slot index
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
memcpy_2d_async(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.k_cache_gpu[:, slot_idx],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
memcpy_2d_async(
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.v_cache_gpu[:, slot_idx],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
# NOTE: offload_slot_to_cpu (all-layers) removed - GPU cache no longer has layer dimension.
|
||||
# Use offload_slot_layer_to_cpu for per-layer offloading.
|
||||
|
||||
def wait_slot_offload(self, slot_idx: int) -> None:
|
||||
"""Wait for slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
|
||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||
|
||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU for one layer.
|
||||
|
||||
GPU cache has no layer dimension, so we copy from GPU slot to the
|
||||
specific layer in CPU cache.
|
||||
|
||||
Args:
|
||||
slot_idx: Source GPU slot index
|
||||
layer_id: Layer index to offload
|
||||
layer_id: Target layer in CPU cache
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for both compute_stream and default stream
|
||||
# - compute_stream: for flash attention operations
|
||||
# - default_stream: for store_kvcache which runs on default stream
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
# GPU: no layer dimension, CPU: has layer dimension
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||
self.k_cache_gpu[slot_idx], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||
self.v_cache_gpu[slot_idx], non_blocking=True
|
||||
)
|
||||
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
|
||||
"""Wait for slot offload to complete for a specific layer."""
|
||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
|
||||
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
# ----- KV access methods for ring buffer -----
|
||||
|
||||
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for a single ring buffer slot.
|
||||
|
||||
GPU cache has no layer dimension - slots contain data for whatever
|
||||
layer was most recently loaded.
|
||||
|
||||
Args:
|
||||
slot_idx: GPU slot index
|
||||
layer_id: Layer ID
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
|
||||
k = self.k_cache_gpu[slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[slot_idx].unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
slot_indices: List[int],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for multiple ring buffer slots.
|
||||
|
||||
GPU cache has no layer dimension - returns data from specified slots.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
slot_indices: List of GPU slot indices
|
||||
|
||||
Returns:
|
||||
@@ -840,92 +798,86 @@ class OffloadEngine:
|
||||
"""
|
||||
if not slot_indices:
|
||||
return None, None
|
||||
k = self.k_cache_gpu[layer_id, slot_indices]
|
||||
v = self.v_cache_gpu[layer_id, slot_indices]
|
||||
k = self.k_cache_gpu[slot_indices]
|
||||
v = self.v_cache_gpu[slot_indices]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
# ----- Decode slot methods (kept for decode phase) -----
|
||||
# NOTE: For decode with CPU offload, the flow is per-layer:
|
||||
# 1. Each layer stores to decode_slot (same GPU memory, reused)
|
||||
# 2. Each layer offloads its data to CPU[layer_id, block_id]
|
||||
# 3. Each layer loads prev blocks from CPU[layer_id] when needed
|
||||
|
||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||
def offload_decode_slot_layer(self, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Offload KV from decode slot (slot[0]) to CPU.
|
||||
Offload KV from decode slot (slot[0]) to CPU for one layer.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
memcpy_2d_async(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.k_cache_gpu[:, self.decode_slot],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
memcpy_2d_async(
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.v_cache_gpu[:, self.decode_slot],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.decode_offload_done.record(self.transfer_stream_main)
|
||||
# Reuse the existing per-layer offload method
|
||||
self.offload_slot_layer_to_cpu(self.decode_slot, layer_id, cpu_block_id)
|
||||
|
||||
def wait_decode_offload(self) -> None:
|
||||
"""Wait for decode slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.decode_offload_done)
|
||||
self.wait_slot_offload(self.decode_slot)
|
||||
|
||||
def get_kv_for_decode_slot(
|
||||
self,
|
||||
layer_id: int,
|
||||
pos_in_block: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV at specified position in decode slot.
|
||||
|
||||
GPU cache has no layer dimension - decode slot contains data for
|
||||
whatever layer was most recently stored.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
pos_in_block: Token position within block (0 to block_size-1)
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
k = self.k_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
v = self.v_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_decode_slot_accumulated(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_tokens: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
|
||||
|
||||
GPU cache has no layer dimension - decode slot contains data for
|
||||
whatever layer was most recently stored.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_tokens: Number of accumulated tokens (1 to block_size)
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
||||
k = self.k_cache_gpu[self.decode_slot, :num_tokens]
|
||||
v = self.v_cache_gpu[self.decode_slot, :num_tokens]
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
# ----- Legacy compatibility methods (for decode double-buffering) -----
|
||||
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
|
||||
|
||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
||||
|
||||
Uses first half of decode_load_slots as 'compute' region.
|
||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
@@ -938,26 +890,27 @@ class OffloadEngine:
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = slots[i]
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
# GPU: no layer dimension, CPU: has layer dimension
|
||||
self.k_cache_gpu[gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_gpu[gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
if num_to_load > 0:
|
||||
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
|
||||
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute_layer(self, layer_id: int) -> None:
|
||||
def wait_compute_layer(self) -> None:
|
||||
"""Legacy: Wait for 'compute' region loading."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
if self.decode_load_slots:
|
||||
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
|
||||
self.wait_slot_layer(self.decode_load_slots[0])
|
||||
|
||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
||||
|
||||
Uses second half of decode_load_slots as 'prefetch' region.
|
||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
@@ -972,37 +925,36 @@ class OffloadEngine:
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = slots[i]
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
# GPU: no layer dimension, CPU: has layer dimension
|
||||
self.k_cache_gpu[gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_gpu[gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
if num_to_load > 0:
|
||||
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
|
||||
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
||||
|
||||
def wait_prefetch_layer(self, layer_id: int) -> None:
|
||||
def wait_prefetch_layer(self) -> None:
|
||||
"""Legacy: Wait for 'prefetch' region loading."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[half:]
|
||||
if slots:
|
||||
self.wait_slot_layer(slots[0], layer_id)
|
||||
self.wait_slot_layer(slots[0])
|
||||
elif self.decode_load_slots:
|
||||
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
|
||||
self.wait_slot_layer(self.decode_load_slots[0])
|
||||
|
||||
def get_kv_for_compute(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[:half][:num_blocks]
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
return self.get_kv_for_slots(slots)
|
||||
|
||||
def get_kv_for_prefetch(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
|
||||
@@ -1011,4 +963,76 @@ class OffloadEngine:
|
||||
if not slots:
|
||||
slots = self.decode_load_slots
|
||||
slots = slots[:num_blocks]
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
return self.get_kv_for_slots(slots)
|
||||
|
||||
# ========== Debug Hook Interface ==========
|
||||
#
|
||||
# Minimal generic hook system for debugging.
|
||||
# Framework only provides hook registration and tensor access.
|
||||
# All verification logic is external.
|
||||
|
||||
def enable_debug_mode(self) -> None:
|
||||
"""Enable debug mode."""
|
||||
self._debug_mode = True
|
||||
logger.info("OffloadEngine debug mode ENABLED")
|
||||
|
||||
def disable_debug_mode(self) -> None:
|
||||
"""Disable debug mode and clear all hooks."""
|
||||
self._debug_mode = False
|
||||
self._debug_hooks.clear()
|
||||
logger.info("OffloadEngine debug mode DISABLED")
|
||||
|
||||
@property
|
||||
def debug_mode(self) -> bool:
|
||||
"""Check if debug mode is enabled."""
|
||||
return self._debug_mode
|
||||
|
||||
def register_debug_hook(self, hook_fn) -> None:
|
||||
"""
|
||||
Register a debug hook.
|
||||
|
||||
The hook is called after H2D load completes (after wait_slot_layer),
|
||||
receiving the loaded tensor for inspection.
|
||||
|
||||
Args:
|
||||
hook_fn: Callable with signature:
|
||||
(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None
|
||||
- k, v: GPU tensor views for the loaded slot
|
||||
|
||||
Example:
|
||||
def my_hook(slot_idx, layer_id, cpu_block_id, k, v):
|
||||
if layer_id == 0:
|
||||
k_val = k.float().mean().item()
|
||||
print(f"Loaded block {cpu_block_id}, K mean = {k_val}")
|
||||
|
||||
offload_engine.register_debug_hook(my_hook)
|
||||
"""
|
||||
self._debug_hooks.append(hook_fn)
|
||||
|
||||
def remove_debug_hook(self, hook_fn) -> None:
|
||||
"""Remove a registered debug hook."""
|
||||
if hook_fn in self._debug_hooks:
|
||||
self._debug_hooks.remove(hook_fn)
|
||||
|
||||
def _call_debug_hooks(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Call all registered debug hooks with loaded tensor (internal use).
|
||||
|
||||
Called by attention.py after wait_slot_layer completes.
|
||||
GPU cache has no layer dimension - slot contains data for the layer
|
||||
that was just loaded.
|
||||
"""
|
||||
if not self._debug_mode or not self._debug_hooks:
|
||||
return
|
||||
|
||||
# Use get_kv_for_slot for consistency with attention.py
|
||||
k, v = self.get_kv_for_slot(slot_idx)
|
||||
|
||||
for hook in self._debug_hooks:
|
||||
try:
|
||||
hook(slot_idx, layer_id, cpu_block_id, k, v)
|
||||
except Exception as e:
|
||||
# Allow pdb quit to propagate
|
||||
if e.__class__.__name__ == 'BdbQuit':
|
||||
raise
|
||||
logger.warning(f"Debug hook error: {e}")
|
||||
@@ -87,6 +87,15 @@ class Attention(nn.Module):
|
||||
else: # decode
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked decode: need to load all KV from CPU+GPU
|
||||
# Store current decode token to per-layer decode buffer
|
||||
# This is needed because GPU cache has no layer dimension,
|
||||
# so all layers would overwrite each other in decode_slot.
|
||||
kvcache_manager = context.kvcache_manager
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
pos_in_block = context.decode_pos_in_block
|
||||
# k, v shape: [1, kv_heads, head_dim]
|
||||
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
|
||||
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
|
||||
o = self._chunked_decode_attention(q, k, v, context)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
@@ -169,9 +178,11 @@ class Attention(nn.Module):
|
||||
else:
|
||||
# Use ring buffer pipeline
|
||||
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
current_chunk_idx
|
||||
)
|
||||
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
@@ -187,11 +198,30 @@ class Attention(nn.Module):
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
|
||||
# reading it on the default stream for the merge operation.
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
|
||||
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
||||
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
if seq is not None:
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
||||
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
@@ -205,13 +235,16 @@ class Attention(nn.Module):
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
o_acc, lse_acc = None, None
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
||||
# Load to slot 0 (single slot)
|
||||
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(0, self.layer_id)
|
||||
offload_engine.wait_slot_layer(0)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
||||
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
@@ -232,6 +265,7 @@ class Attention(nn.Module):
|
||||
cpu_block_table: list,
|
||||
load_slots: list,
|
||||
offload_engine,
|
||||
current_chunk_idx: int = -1,
|
||||
):
|
||||
"""
|
||||
Ring buffer async pipeline loading with double buffering.
|
||||
@@ -269,18 +303,28 @@ class Attention(nn.Module):
|
||||
|
||||
if pipeline_depth == 1:
|
||||
# Only 1 slot available, cannot pipeline - use synchronous mode
|
||||
# IMPORTANT: Must use compute_stream to match synchronization in
|
||||
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
|
||||
slot = load_slots[0]
|
||||
compute_stream = offload_engine.compute_stream
|
||||
for block_idx in range(num_blocks):
|
||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
|
||||
offload_engine.wait_slot_layer(slot, self.layer_id)
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||
if offload_engine.debug_mode:
|
||||
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
# Record compute done so next load can safely reuse this slot
|
||||
offload_engine.record_slot_compute_done(slot, self.layer_id)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
@@ -306,15 +350,20 @@ class Attention(nn.Module):
|
||||
|
||||
# Cycle through slots: slot[block_idx % num_slots]
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
|
||||
# Wait for current slot's transfer to complete (on compute_stream)
|
||||
offload_engine.wait_slot_layer(current_slot, self.layer_id)
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
# Compute attention on current slot's data
|
||||
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||
if offload_engine.debug_mode:
|
||||
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
@@ -323,7 +372,7 @@ class Attention(nn.Module):
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
# Record compute done - this allows the next transfer to safely overwrite this slot
|
||||
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
# Immediately start loading the NEXT block into this slot (if more blocks remain)
|
||||
# Key insight: reuse current_slot immediately after compute is done!
|
||||
@@ -350,25 +399,17 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with double-buffering using decode_load_slots.
|
||||
Compute decode attention using ring buffer pipeline (same as prefill).
|
||||
|
||||
Decode uses:
|
||||
- decode_slot (slot[0]): writes new token's KV
|
||||
- decode_load_slots (slots[1:]): load previous chunks from CPU
|
||||
Uses the same loading mechanism as _chunked_prefill_attention:
|
||||
- Load one block at a time from CPU to GPU slot
|
||||
- Compute attention for each block
|
||||
- Merge results using online softmax
|
||||
- Finally merge with decode buffer (accumulated decode tokens)
|
||||
|
||||
Pipeline design:
|
||||
- First half of decode_load_slots: 'compute' buffer
|
||||
- Second half: 'prefetch' buffer
|
||||
- Double-buffer between them for async overlap
|
||||
|
||||
Timeline:
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
↘ ↘ ↘
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
This approach is simpler and proven correct (prefill tests pass).
|
||||
The only difference from prefill is the additional decode buffer
|
||||
that stores new tokens generated during decode.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -378,12 +419,20 @@ class Attention(nn.Module):
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq
|
||||
|
||||
# Get all CPU blocks for this sequence
|
||||
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
if self.layer_id == 0:
|
||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last block
|
||||
# Note: For chunked prefill, each block is exactly block_size tokens
|
||||
# The cpu_block_table only contains full prefill blocks
|
||||
block_size = kvcache_manager.block_size
|
||||
num_prefill_blocks = len(cpu_block_table)
|
||||
# All prefill blocks are full (block_size tokens each)
|
||||
last_block_valid_tokens = block_size
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if kvcache_manager.sparse_policy is not None:
|
||||
@@ -391,7 +440,7 @@ class Attention(nn.Module):
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=self.layer_id,
|
||||
query=q_batched, # Decode provides query for query-aware selection
|
||||
query=q_batched,
|
||||
is_prefill=False,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
@@ -401,80 +450,28 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
load_slots = offload_engine.decode_load_slots # Available slots for loading
|
||||
|
||||
# Chunk size = capacity of each double buffer region (compute/prefetch)
|
||||
# Each region uses half of decode_load_slots
|
||||
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
|
||||
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
||||
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
# Double buffering state: True = use Compute region, False = use Prefetch region
|
||||
use_compute = True
|
||||
|
||||
# Pre-load first chunk to Compute region (async)
|
||||
first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))]
|
||||
offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * chunk_size
|
||||
end = min(start + chunk_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
|
||||
# Wait for current buffer to be ready
|
||||
if use_compute:
|
||||
offload_engine.wait_compute_layer(self.layer_id)
|
||||
else:
|
||||
offload_engine.wait_prefetch_layer(self.layer_id)
|
||||
|
||||
# Trigger async prefetch of next chunk to the OTHER buffer
|
||||
# This overlaps transfer with current chunk's computation
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
next_start = end
|
||||
next_end = min(next_start + chunk_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if use_compute:
|
||||
# Current in Compute, prefetch next to Prefetch region
|
||||
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
|
||||
else:
|
||||
# Current in Prefetch, prefetch next to Compute region
|
||||
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
||||
|
||||
# Get KV from current buffer
|
||||
if use_compute:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
|
||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
block_size, last_block_valid_tokens
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o_chunk, lse_chunk
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
|
||||
# Swap buffers for next iteration
|
||||
use_compute = not use_compute
|
||||
|
||||
# Now attend to Decode region (contains accumulated decode tokens)
|
||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||
pos_in_block = context.decode_pos_in_block
|
||||
start_pos = context.decode_start_pos_in_block
|
||||
num_accumulated = pos_in_block - start_pos + 1
|
||||
|
||||
# Sync compute_stream with default stream before reading decode_buffer
|
||||
compute_stream = offload_engine.compute_stream
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if num_accumulated > 0:
|
||||
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
# Read from per-layer decode buffer
|
||||
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
|
||||
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
|
||||
decode_k = decode_k.unsqueeze(0)
|
||||
decode_v = decode_v.unsqueeze(0)
|
||||
|
||||
@@ -492,4 +489,83 @@ class Attention(nn.Module):
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
# Sync back to default stream before returning
|
||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||
|
||||
return o_acc
|
||||
|
||||
def _decode_ring_buffer_pipeline(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
load_slots: list,
|
||||
offload_engine,
|
||||
block_size: int,
|
||||
last_block_valid_tokens: int,
|
||||
):
|
||||
"""
|
||||
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
|
||||
|
||||
Loads one block at a time, computes attention, and merges results.
|
||||
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
|
||||
methods as prefill for proven correctness.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
return None, None
|
||||
|
||||
if not load_slots:
|
||||
return None, None
|
||||
|
||||
o_acc, lse_acc = None, None
|
||||
num_slots = len(load_slots)
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Phase 1: Pre-load up to num_slots blocks
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
||||
|
||||
# Phase 2: Process blocks with pipeline
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
|
||||
# Wait for current slot's transfer to complete
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Get KV from slot
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
|
||||
# Handle partial last block
|
||||
is_last_block = (block_idx == num_blocks - 1)
|
||||
if is_last_block and last_block_valid_tokens < block_size:
|
||||
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
|
||||
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
|
||||
|
||||
# Compute attention
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Record compute done for slot reuse
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
# Start loading next block (pipeline)
|
||||
next_block_idx = block_idx + num_slots
|
||||
if next_block_idx < num_blocks:
|
||||
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
||||
|
||||
# Merge with accumulated
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
757
tests/modeling_qwen3.py
Normal file
757
tests/modeling_qwen3.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
Custom Qwen3 implementation using only torch and transformers.
|
||||
This file provides a clean reference implementation for understanding the model computation graph.
|
||||
|
||||
Computation Graph:
|
||||
==================
|
||||
|
||||
Input: token_ids [batch, seq_len]
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Decoder Layer (x N) │
|
||||
│ ┌───────────────────────────────────────────────────┐ │
|
||||
│ │ Self Attention Block │ │
|
||||
│ │ │ │
|
||||
│ │ input_layernorm (RMSNorm) │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Qwen3Attention │ │ │
|
||||
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
|
||||
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
|
||||
│ │ │ V = v_proj(x) → reshape │ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ attn_output = attention(Q, K, V) │ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ output = o_proj(attn_output) │ │ │
|
||||
│ │ └─────────────────────────────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ hidden_states = residual + attn_output │ │
|
||||
│ └───────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────────────────────────────────────────────┐ │
|
||||
│ │ MLP Block │ │
|
||||
│ │ │ │
|
||||
│ │ post_attention_layernorm (RMSNorm) │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Qwen3MLP │ │ │
|
||||
│ │ │ gate = gate_proj(x) │ │ │
|
||||
│ │ │ up = up_proj(x) │ │ │
|
||||
│ │ │ output = down_proj(silu(gate) * up) │ │ │
|
||||
│ │ └─────────────────────────────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ hidden_states = residual + mlp_output │ │
|
||||
│ └───────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ norm │ final RMSNorm
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ lm_head │ [hidden_size, vocab_size]
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
logits [batch, seq_len, vocab_size]
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Qwen3RMSNorm(nn.Module):
|
||||
"""RMSNorm implementation."""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
|
||||
class Qwen3RotaryEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding (RoPE)."""
|
||||
|
||||
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
# Compute inverse frequencies
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
|
||||
position_ids: Position indices [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
cos, sin: [batch, seq_len, head_dim]
|
||||
"""
|
||||
# inv_freq: [dim/2]
|
||||
# position_ids: [batch, seq_len]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
|
||||
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
|
||||
|
||||
# freqs: [batch, dim/2, seq_len]
|
||||
freqs = inv_freq_expanded @ position_ids_expanded
|
||||
# freqs: [batch, seq_len, dim/2]
|
||||
freqs = freqs.transpose(1, 2)
|
||||
|
||||
# Duplicate for full head_dim: [batch, seq_len, dim]
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
cos = emb.cos().to(x.dtype)
|
||||
sin = emb.sin().to(x.dtype)
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotate half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary position embeddings to Q and K.
|
||||
|
||||
Args:
|
||||
q: [batch, num_heads, seq_len, head_dim]
|
||||
k: [batch, num_kv_heads, seq_len, head_dim]
|
||||
cos: [batch, seq_len, head_dim]
|
||||
sin: [batch, seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
q_embed, k_embed with same shapes as inputs
|
||||
"""
|
||||
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
"""
|
||||
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
|
||||
|
||||
Data Flow:
|
||||
---------
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
|
||||
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
|
||||
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
|
||||
│
|
||||
▼
|
||||
apply_rotary_pos_emb(Q, K)
|
||||
│
|
||||
▼
|
||||
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
|
||||
│
|
||||
▼
|
||||
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
attention_bias: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_kv_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_kv_groups = num_attention_heads // num_key_value_heads
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Scaling factor
|
||||
self.scaling = head_dim ** -0.5
|
||||
|
||||
# QKV projections
|
||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
|
||||
|
||||
# QK normalization (Qwen3 specific)
|
||||
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
|
||||
# Rotary embeddings
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(
|
||||
head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
|
||||
past_key_value: (k_cache, v_cache) from previous steps
|
||||
use_cache: Whether to return updated cache
|
||||
output_qkv: Whether to output Q, K, V tensors for debugging
|
||||
|
||||
Returns:
|
||||
output: [batch, seq_len, hidden_size]
|
||||
past_key_value: Updated cache (if use_cache=True)
|
||||
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
|
||||
"""
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# === QKV Projections ===
|
||||
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
|
||||
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||
|
||||
# Reshape to [batch, seq_len, num_heads, head_dim]
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# === QK Normalization (Qwen3 specific) ===
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Transpose to [batch, num_heads, seq_len, head_dim]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# === Rotary Position Embeddings ===
|
||||
cos, sin = self.rotary_emb(v, position_ids)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
# === KV Cache Update ===
|
||||
if past_key_value is not None:
|
||||
k_cache, v_cache = past_key_value
|
||||
k = torch.cat([k_cache, k], dim=2)
|
||||
v = torch.cat([v_cache, v], dim=2)
|
||||
|
||||
new_past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# === Grouped Query Attention (expand KV heads if needed) ===
|
||||
if self.num_kv_groups > 1:
|
||||
# Repeat KV for each query group
|
||||
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
||||
|
||||
# === Attention Computation (using SDPA for memory efficiency) ===
|
||||
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
|
||||
# is_causal only works when q_len == kv_len (prefill), not during decode
|
||||
q_len, kv_len = q.shape[2], k.shape[2]
|
||||
is_causal = (q_len == kv_len) and (q_len > 1)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
) # [batch, num_heads, seq_len, head_dim]
|
||||
|
||||
# === Output Projection ===
|
||||
# Transpose back and reshape
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
|
||||
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
|
||||
output = self.o_proj(attn_output)
|
||||
|
||||
# Optional QKV output for debugging
|
||||
qkv_dict = None
|
||||
if output_qkv:
|
||||
qkv_dict = {
|
||||
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
|
||||
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
|
||||
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
|
||||
}
|
||||
|
||||
return output, new_past_key_value, qkv_dict
|
||||
|
||||
|
||||
class Qwen3MLP(nn.Module):
|
||||
"""
|
||||
Qwen3 MLP with SwiGLU activation.
|
||||
|
||||
Data Flow:
|
||||
---------
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
|
||||
│
|
||||
└──► up_proj ──► up [batch, seq_len, intermediate_size]
|
||||
│
|
||||
▼
|
||||
silu(gate) * up
|
||||
│
|
||||
▼
|
||||
down_proj ──► output [batch, seq_len, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate = self.gate_proj(x)
|
||||
up = self.up_proj(x)
|
||||
return self.down_proj(F.silu(gate) * up)
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
"""Single Qwen3 Decoder Layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Pre-attention LayerNorm
|
||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# Self-attention
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
attention_bias=attention_bias,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
# Post-attention LayerNorm
|
||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# MLP
|
||||
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: Causal attention mask
|
||||
past_key_value: KV cache for this layer
|
||||
use_cache: Whether to return updated cache
|
||||
output_qkv: Whether to output Q, K, V for debugging
|
||||
|
||||
Returns:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
past_key_value: Updated cache
|
||||
qkv_dict: QKV tensors (if output_qkv=True)
|
||||
"""
|
||||
# === Self Attention Block ===
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_output, new_past_key_value, qkv_dict = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_qkv=output_qkv,
|
||||
)
|
||||
|
||||
hidden_states = residual + attn_output
|
||||
|
||||
# === MLP Block ===
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, new_past_key_value, qkv_dict
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
"""Qwen3 Transformer Model (without LM head)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_hidden_layers: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
||||
# Token embeddings
|
||||
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
|
||||
|
||||
# Decoder layers
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen3DecoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
mlp_bias=mlp_bias,
|
||||
layer_idx=i,
|
||||
)
|
||||
for i in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final LayerNorm
|
||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv_layers: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: [batch, seq_len] or pre-computed 4D mask
|
||||
past_key_values: List of (k, v) tuples for each layer
|
||||
use_cache: Whether to return new cache
|
||||
output_qkv_layers: List of layer indices to output QKV for
|
||||
|
||||
Returns:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
new_past_key_values: Updated cache
|
||||
qkv_outputs: {layer_idx: qkv_dict}
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Embedding
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Position IDs
|
||||
if position_ids is None:
|
||||
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Attention mask (create causal mask if not provided)
|
||||
if attention_mask is None or attention_mask.dim() == 2:
|
||||
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
|
||||
diagonal=kv_seq_len - seq_len + 1,
|
||||
)
|
||||
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
|
||||
|
||||
# Initialize cache list
|
||||
new_past_key_values = [] if use_cache else None
|
||||
qkv_outputs = {} if output_qkv_layers else None
|
||||
|
||||
# Decoder layers
|
||||
for i, layer in enumerate(self.layers):
|
||||
past_kv = past_key_values[i] if past_key_values else None
|
||||
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
|
||||
|
||||
hidden_states, new_kv, qkv_dict = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_kv,
|
||||
use_cache=use_cache,
|
||||
output_qkv=output_qkv,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
new_past_key_values.append(new_kv)
|
||||
if qkv_dict is not None:
|
||||
qkv_outputs[i] = qkv_dict
|
||||
|
||||
# Final norm
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states, new_past_key_values, qkv_outputs
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
"""Qwen3 Model with Language Modeling head."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_hidden_layers: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
tie_word_embeddings: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
# Transformer model
|
||||
self.model = Qwen3Model(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
mlp_bias=mlp_bias,
|
||||
)
|
||||
|
||||
# LM head
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv_layers: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
... (same as Qwen3Model)
|
||||
|
||||
Returns:
|
||||
logits: [batch, seq_len, vocab_size]
|
||||
past_key_values: Updated KV cache
|
||||
qkv_outputs: QKV tensors for specified layers
|
||||
"""
|
||||
hidden_states, new_past_key_values, qkv_outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_qkv_layers=output_qkv_layers,
|
||||
)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits, new_past_key_values, qkv_outputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
|
||||
"""
|
||||
Load weights from a pretrained Qwen3 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory containing config.json and model weights
|
||||
dtype: Data type for model weights
|
||||
|
||||
Returns:
|
||||
Initialized Qwen3ForCausalLM model
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Create model
|
||||
model = cls(
|
||||
vocab_size=config["vocab_size"],
|
||||
hidden_size=config["hidden_size"],
|
||||
intermediate_size=config["intermediate_size"],
|
||||
num_hidden_layers=config["num_hidden_layers"],
|
||||
num_attention_heads=config["num_attention_heads"],
|
||||
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
|
||||
max_position_embeddings=config.get("max_position_embeddings", 32768),
|
||||
rope_theta=config.get("rope_theta", 10000.0),
|
||||
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
|
||||
attention_bias=config.get("attention_bias", False),
|
||||
mlp_bias=config.get("mlp_bias", False),
|
||||
tie_word_embeddings=config.get("tie_word_embeddings", True),
|
||||
)
|
||||
|
||||
# Load weights
|
||||
weight_files = sorted([
|
||||
f for f in os.listdir(model_path)
|
||||
if f.endswith(".safetensors")
|
||||
])
|
||||
|
||||
state_dict = {}
|
||||
for wf in weight_files:
|
||||
state_dict.update(load_file(os.path.join(model_path, wf)))
|
||||
|
||||
# Load into model
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Tie lm_head weights to embed_tokens if configured
|
||||
if model.tie_word_embeddings:
|
||||
model.lm_head.weight = model.model.embed_tokens.weight
|
||||
|
||||
model = model.to(dtype)
|
||||
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_new_tokens: int = 32,
|
||||
temperature: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Simple autoregressive generation."""
|
||||
device = input_ids.device
|
||||
batch_size, seq_len = input_ids.shape
|
||||
past_key_values = None
|
||||
generated = input_ids.clone()
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if past_key_values is None:
|
||||
current_input = generated
|
||||
else:
|
||||
current_input = generated[:, -1:]
|
||||
|
||||
logits, past_key_values, _ = self(
|
||||
input_ids=current_input,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
next_token_logits = logits[:, -1, :]
|
||||
if temperature > 0 and do_sample:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
probs = torch.softmax(next_token_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
|
||||
|
||||
generated = torch.cat([generated, next_token], dim=1)
|
||||
|
||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
||||
break
|
||||
|
||||
return generated
|
||||
|
||||
|
||||
def print_computation_graph():
|
||||
"""Print the computation graph for reference."""
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_computation_graph()
|
||||
212
tests/test_align.py
Normal file
212
tests/test_align.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Test alignment between nanovllm and custom torch Qwen3 implementation.
|
||||
Compares attention layer outputs to verify correctness.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from utils import generate_needle_prompt
|
||||
|
||||
# Config
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
INPUT_LEN = 512 # Use shorter length for alignment test
|
||||
DTYPE = torch.float16
|
||||
|
||||
# Storage for captured tensors
|
||||
nanovllm_outputs = {}
|
||||
torch_outputs = {}
|
||||
|
||||
|
||||
def make_nanovllm_hook(layer_id: int, storage: dict):
|
||||
"""Capture nanovllm self_attn outputs (after o_proj)."""
|
||||
def hook(module, inputs, output):
|
||||
# Qwen3Attention output is a tuple (attn_output, None)
|
||||
if isinstance(output, tuple):
|
||||
attn_output = output[0]
|
||||
else:
|
||||
attn_output = output
|
||||
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
||||
if attn_output.dim() == 2:
|
||||
attn_output = attn_output.unsqueeze(0)
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
def make_torch_hook(layer_id: int, storage: dict):
|
||||
"""Capture torch model self_attn outputs (after o_proj)."""
|
||||
def hook(module, inputs, output):
|
||||
# Qwen3Attention output is (attn_output, past_kv, qkv_dict)
|
||||
attn_output, _, _ = output
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-2):
|
||||
"""Compare two tensors and print statistics."""
|
||||
# Handle shape differences
|
||||
if t1.shape != t2.shape:
|
||||
print(f"[{name}] Shape mismatch: {t1.shape} vs {t2.shape}")
|
||||
# Try to reshape for comparison if possible
|
||||
if t1.numel() == t2.numel():
|
||||
t2 = t2.view(t1.shape)
|
||||
else:
|
||||
return False
|
||||
|
||||
diff = (t1.float() - t2.float()).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
|
||||
passed = max_diff < atol
|
||||
status = "PASS" if passed else "FAIL"
|
||||
|
||||
print(f"[{name}] {status}")
|
||||
print(f" Shape: {list(t1.shape)}")
|
||||
print(f" t1 mean: {t1.float().mean():.6f}, std: {t1.float().std():.6f}")
|
||||
print(f" t2 mean: {t2.float().mean():.6f}, std: {t2.float().std():.6f}")
|
||||
print(f" Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}")
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Load nanovllm model
|
||||
# ============================================================
|
||||
print("=" * 60)
|
||||
print("Loading nanovllm model...")
|
||||
print("=" * 60)
|
||||
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True,
|
||||
max_model_len=4096,
|
||||
max_num_batched_tokens=4096,
|
||||
enable_cpu_offload=False, # Disable offload for alignment test
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Load torch model
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Loading custom torch model...")
|
||||
print("=" * 60)
|
||||
|
||||
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
|
||||
torch_model = torch_model.to("cuda")
|
||||
torch_model.eval()
|
||||
|
||||
# ============================================================
|
||||
# Generate test input
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Generating test input...")
|
||||
print("=" * 60)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
prompt, _ = generate_needle_prompt(
|
||||
tokenizer=tokenizer,
|
||||
target_length=INPUT_LEN,
|
||||
verbose=True,
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
# ============================================================
|
||||
# Register hooks on both models
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Registering hooks...")
|
||||
print("=" * 60)
|
||||
|
||||
# Hook on nanovllm (self_attn is Qwen3Attention, captures output after o_proj)
|
||||
nanovllm_hooks = []
|
||||
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
|
||||
if layer_idx >= 2: # Only first 2 layers
|
||||
break
|
||||
nanovllm_hooks.append(
|
||||
layer.self_attn.register_forward_hook(
|
||||
make_nanovllm_hook(layer_idx, nanovllm_outputs)
|
||||
)
|
||||
)
|
||||
print(f" Registered nanovllm hook on layer {layer_idx} self_attn")
|
||||
|
||||
# Hook on torch model (self_attn is Qwen3Attention, captures output after o_proj)
|
||||
torch_hooks = []
|
||||
for layer_idx, layer in enumerate(torch_model.model.layers):
|
||||
if layer_idx >= 2: # Only first 2 layers
|
||||
break
|
||||
torch_hooks.append(
|
||||
layer.self_attn.register_forward_hook(
|
||||
make_torch_hook(layer_idx, torch_outputs)
|
||||
)
|
||||
)
|
||||
print(f" Registered torch hook on layer {layer_idx} self_attn")
|
||||
|
||||
# ============================================================
|
||||
# Run nanovllm inference
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Running nanovllm inference...")
|
||||
print("=" * 60)
|
||||
|
||||
# Use prompt_token_ids to ensure same input
|
||||
prompt_token_ids = input_ids[0].tolist()
|
||||
nanovllm_result = llm.generate(
|
||||
[prompt_token_ids],
|
||||
SamplingParams(temperature=0.01, max_tokens=1), # Near-greedy for determinism
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Run torch inference
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Running torch inference...")
|
||||
print("=" * 60)
|
||||
|
||||
with torch.no_grad():
|
||||
torch_logits, _, _ = torch_model(input_ids)
|
||||
|
||||
# ============================================================
|
||||
# Compare outputs
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Comparing attention outputs...")
|
||||
print("=" * 60)
|
||||
|
||||
all_passed = True
|
||||
for layer_idx in sorted(nanovllm_outputs.keys()):
|
||||
if layer_idx not in torch_outputs:
|
||||
print(f"[Layer {layer_idx}] Missing torch output")
|
||||
all_passed = False
|
||||
continue
|
||||
|
||||
nano_out = nanovllm_outputs[layer_idx]
|
||||
torch_out = torch_outputs[layer_idx]
|
||||
|
||||
print(f"\n--- Layer {layer_idx} ---")
|
||||
passed = compare_tensors(f"Layer {layer_idx} attn_output", nano_out, torch_out, atol=0.1)
|
||||
all_passed = all_passed and passed
|
||||
|
||||
# ============================================================
|
||||
# Cleanup
|
||||
# ============================================================
|
||||
for hook in nanovllm_hooks:
|
||||
hook.remove()
|
||||
for hook in torch_hooks:
|
||||
hook.remove()
|
||||
|
||||
# ============================================================
|
||||
# Result
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
if all_passed:
|
||||
print("test_align: PASSED - nanovllm and torch outputs aligned!")
|
||||
else:
|
||||
print("test_align: FAILED - outputs differ!")
|
||||
print("=" * 60)
|
||||
@@ -93,9 +93,9 @@ TEST_CASES = [
|
||||
(1, 4, 256, 8, 128),
|
||||
(1, 4, 512, 8, 128),
|
||||
(1, 8, 512, 8, 128),
|
||||
(1, 4, 1024, 8, 128),
|
||||
(1, 4, 1024, 32, 128), # More heads
|
||||
(1, 8, 256, 8, 64), # Smaller head dim
|
||||
(1, 32, 1024, 8, 128),
|
||||
(1, 32, 1024, 32, 128), # More heads
|
||||
(1, 32, 256, 8, 64), # Smaller head dim
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
391
tests/test_chunked_decode_hook.py
Normal file
391
tests/test_chunked_decode_hook.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Correctness test for chunked decode attention.
|
||||
|
||||
Captures Q and output during inference, then computes reference using
|
||||
CPU KV cache with standard flash attention.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import torch
|
||||
from random import randint, seed
|
||||
from typing import Dict, List
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.utils.context import get_context
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
|
||||
# Config
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
MAX_MODEL_LEN = 128 * 1024
|
||||
NUM_GPU_BLOCKS = 2
|
||||
INPUT_LEN = 16 * 1024
|
||||
NUM_DECODE_TOKENS = 5
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# State
|
||||
prefill_captures: List[Dict] = []
|
||||
decode_captures: List[Dict] = []
|
||||
|
||||
|
||||
def make_ones_injection_hook():
|
||||
"""Inject Q=K=V=1.0 for deterministic testing."""
|
||||
def hook(module, inputs):
|
||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||
q_ones = torch.ones_like(q)
|
||||
k_ones = torch.ones_like(k)
|
||||
v_ones = torch.ones_like(v)
|
||||
return (q_ones, k_ones, v_ones) + inputs[3:]
|
||||
return hook
|
||||
|
||||
|
||||
def make_capture_hook(layer_id: int):
|
||||
"""Capture Q, K, V, output during inference."""
|
||||
def hook(module, inputs, output):
|
||||
ctx = get_context()
|
||||
q, k, v = inputs
|
||||
|
||||
if ctx.is_prefill:
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
prefill_captures.append({
|
||||
'layer_id': layer_id,
|
||||
'chunk_idx': chunk_idx,
|
||||
'q': q.clone().cpu(),
|
||||
'k': k.clone().cpu(),
|
||||
'v': v.clone().cpu(),
|
||||
'output': output.clone().cpu(),
|
||||
})
|
||||
else:
|
||||
decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
|
||||
decode_captures.append({
|
||||
'layer_id': layer_id,
|
||||
'decode_step': decode_step,
|
||||
'q': q.clone().cpu(),
|
||||
'k': k.clone().cpu(),
|
||||
'v': v.clone().cpu(),
|
||||
'output': output.clone().cpu(),
|
||||
})
|
||||
return hook
|
||||
|
||||
|
||||
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
|
||||
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
|
||||
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
|
||||
"""
|
||||
Compute reference decode output using CPU KV cache and standard flash attention.
|
||||
|
||||
For decode, query attends to:
|
||||
1. All prefill KV (from CPU cache)
|
||||
2. All previous decode tokens (from captured decode k, v)
|
||||
"""
|
||||
# Get decode capture for this layer and step
|
||||
decode_cap = None
|
||||
for c in decode_captures:
|
||||
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
|
||||
decode_cap = c
|
||||
break
|
||||
|
||||
if decode_cap is None:
|
||||
return None
|
||||
|
||||
# Query: single decode token
|
||||
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
|
||||
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
||||
|
||||
# Collect all K, V: prefill chunks from captures + decode tokens from captures
|
||||
# NOTE: We use prefill captures directly instead of CPU cache because
|
||||
# the CPU block ID may not equal the chunk index.
|
||||
all_k = []
|
||||
all_v = []
|
||||
|
||||
# 1. Prefill chunks from captures (use captured K/V, not CPU cache)
|
||||
for cidx in range(num_prefill_chunks):
|
||||
prefill_cap = None
|
||||
for c in prefill_captures:
|
||||
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
|
||||
prefill_cap = c
|
||||
break
|
||||
|
||||
if prefill_cap is not None:
|
||||
# Use captured K/V directly (guaranteed to be correct layer data)
|
||||
all_k.append(prefill_cap['k'].cuda())
|
||||
all_v.append(prefill_cap['v'].cuda())
|
||||
|
||||
# 2. Decode tokens from captures (up to and including current step)
|
||||
for step in range(decode_step + 1):
|
||||
for c in decode_captures:
|
||||
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
||||
all_k.append(c['k'].cuda())
|
||||
all_v.append(c['v'].cuda())
|
||||
break
|
||||
|
||||
if not all_k:
|
||||
return None
|
||||
|
||||
# Concatenate all K, V
|
||||
full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
|
||||
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
|
||||
|
||||
# Run flash attention (non-causal since we explicitly control what KV to include)
|
||||
output = flash_attn_func(
|
||||
q_batched, full_k, full_v,
|
||||
softmax_scale=scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main
|
||||
# ============================================================
|
||||
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||
enable_cpu_offload=True,
|
||||
kvcache_block_size=BLOCK_SIZE,
|
||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
# Get model info
|
||||
num_layers = len(llm.model_runner.model.model.layers)
|
||||
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
||||
scale = head_dim ** -0.5
|
||||
|
||||
# Register hooks
|
||||
hooks = []
|
||||
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
||||
# Pre-hook: inject all ones for Q, K, V
|
||||
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
|
||||
# hooks.append(pre_hook)
|
||||
# Post-hook: capture Q, K, V, output
|
||||
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
|
||||
hooks.append(post_hook)
|
||||
|
||||
# Run inference
|
||||
seed(42)
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
|
||||
|
||||
# Remove hooks
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Get CPU cache reference
|
||||
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
||||
k_cache_cpu = offload_engine.k_cache_cpu.clone()
|
||||
v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
||||
|
||||
# Calculate number of prefill chunks
|
||||
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
|
||||
# Debug: Compare decode_buffer with captured K/V
|
||||
print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===")
|
||||
decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu()
|
||||
for step in range(NUM_DECODE_TOKENS):
|
||||
for layer_id in [0, 17, 35]: # Sample a few layers
|
||||
# Find captured K for this step and layer
|
||||
for c in decode_captures:
|
||||
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
||||
captured_k = c['k'].squeeze(0) # [kv_heads, head_dim]
|
||||
buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim]
|
||||
diff = (captured_k - buffer_k).abs().max().item()
|
||||
print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}")
|
||||
break
|
||||
|
||||
# Debug: Verify that decode_buffer slices match concatenated captures
|
||||
print("\n=== DEBUG: Verifying decode_buffer slices ===")
|
||||
for layer_id in [0]:
|
||||
for decode_step in [1, 2]: # Check steps that use multiple tokens
|
||||
# Build expected slice from captures
|
||||
expected_k_list = []
|
||||
for step in range(decode_step + 1):
|
||||
for c in decode_captures:
|
||||
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
||||
expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim]
|
||||
break
|
||||
if expected_k_list:
|
||||
expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim]
|
||||
buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1]
|
||||
diff = (expected_k - buffer_slice).abs().max().item()
|
||||
print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}")
|
||||
# Print first values
|
||||
print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}")
|
||||
if decode_step >= 1:
|
||||
print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}")
|
||||
|
||||
# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading)
|
||||
print("\n=== DEBUG: Expected K values for block 0, layer 0 ===")
|
||||
for c in prefill_captures:
|
||||
if c['layer_id'] == 0 and c['chunk_idx'] == 0:
|
||||
print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}")
|
||||
break
|
||||
print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}")
|
||||
|
||||
# Debug: Compare CPU cache with captured prefill K/V
|
||||
print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===")
|
||||
for chunk_idx in [0, 7, 15]: # Sample a few chunks
|
||||
for layer_id in [0, 17, 35]: # Sample a few layers
|
||||
# Find captured K for this chunk and layer
|
||||
for c in prefill_captures:
|
||||
if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx:
|
||||
captured_k = c['k'] # [seq_len, kv_heads, head_dim]
|
||||
cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]]
|
||||
diff = (captured_k - cpu_cache_k).abs().max().item()
|
||||
print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}")
|
||||
break
|
||||
|
||||
# Debug: Get cpu_block_table to check order
|
||||
kvcache_manager = llm.model_runner.kvcache_manager
|
||||
# Find the sequence (it should still exist)
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
for attr_name in ['sequences', '_sequences', 'active_sequences']:
|
||||
if hasattr(kvcache_manager, attr_name):
|
||||
print(f"Found {attr_name}")
|
||||
break
|
||||
|
||||
# Try to get cpu_block_table through a different way
|
||||
print(f"\n=== DEBUG: CPU block order ===")
|
||||
# For each prefill capture, check which CPU block it ended up in
|
||||
for chunk_idx in range(num_prefill_chunks):
|
||||
for c in prefill_captures:
|
||||
if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx:
|
||||
# Check if this chunk's K matches any CPU block
|
||||
captured_k_first = c['k'][0, 0, 0].item()
|
||||
for block_id in range(num_prefill_chunks):
|
||||
cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item()
|
||||
if abs(captured_k_first - cpu_k_first) < 1e-6:
|
||||
print(f"Chunk {chunk_idx} -> CPU block {block_id}")
|
||||
break
|
||||
break
|
||||
|
||||
# Debug: Check reference vs actual for decode steps 0 and 1
|
||||
# Also compute partial references (prefill only, decode only) to isolate the bug
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
for decode_step in [0, 1]:
|
||||
print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===")
|
||||
layer_id = 0
|
||||
# Find the capture
|
||||
for c in decode_captures:
|
||||
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
|
||||
q = c['q'].cuda() # [1, num_heads, head_dim]
|
||||
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
||||
|
||||
# Build prefill K/V per-block for block-by-block reference
|
||||
prefill_k_blocks = []
|
||||
prefill_v_blocks = []
|
||||
for cidx in range(num_prefill_chunks):
|
||||
for pc in prefill_captures:
|
||||
if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx:
|
||||
prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim]
|
||||
prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0))
|
||||
break
|
||||
|
||||
# Build decode K/V
|
||||
decode_k_list = []
|
||||
decode_v_list = []
|
||||
for step in range(decode_step + 1):
|
||||
for dc in decode_captures:
|
||||
if dc['layer_id'] == layer_id and dc['decode_step'] == step:
|
||||
decode_k_list.append(dc['k'].cuda())
|
||||
decode_v_list.append(dc['v'].cuda())
|
||||
break
|
||||
|
||||
full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0)
|
||||
full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0)
|
||||
full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0)
|
||||
full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0)
|
||||
|
||||
full_k = torch.cat([full_prefill_k, full_decode_k], dim=1)
|
||||
full_v = torch.cat([full_prefill_v, full_decode_v], dim=1)
|
||||
|
||||
print(f"Q shape: {q_batched.shape}")
|
||||
print(f"Prefill K shape: {full_prefill_k.shape}")
|
||||
print(f"Decode K shape: {full_decode_k.shape}")
|
||||
print(f"Full K shape: {full_k.shape}")
|
||||
print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}")
|
||||
|
||||
# Reference output (single attention over all)
|
||||
ref_output = flash_attn_func(
|
||||
q_batched, full_k, full_v,
|
||||
softmax_scale=scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Chunked reference: prefill attention + decode attention + merge
|
||||
prefill_o, prefill_lse = flash_attn_with_lse(
|
||||
q_batched, full_prefill_k, full_prefill_v,
|
||||
softmax_scale=scale,
|
||||
causal=False,
|
||||
)
|
||||
decode_o, decode_lse = flash_attn_with_lse(
|
||||
q_batched, full_decode_k, full_decode_v,
|
||||
softmax_scale=scale,
|
||||
causal=False,
|
||||
)
|
||||
chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse)
|
||||
|
||||
# Block-by-block reference (simulating ring buffer pipeline)
|
||||
block_o_acc, block_lse_acc = None, None
|
||||
for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)):
|
||||
o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False)
|
||||
if block_o_acc is None:
|
||||
block_o_acc, block_lse_acc = o_blk, lse_blk
|
||||
else:
|
||||
block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk)
|
||||
|
||||
# Compare block-by-block vs single
|
||||
block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item()
|
||||
print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}")
|
||||
|
||||
# Compare full reference vs chunked reference
|
||||
ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item()
|
||||
print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}")
|
||||
|
||||
ref_output = ref_output.squeeze(0).squeeze(0).cpu()
|
||||
chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu()
|
||||
|
||||
# Actual output
|
||||
actual_output = c['output'].squeeze(0)
|
||||
if actual_output.dim() == 3:
|
||||
actual_output = actual_output.squeeze(0)
|
||||
|
||||
diff_ref = (actual_output - ref_output).abs()
|
||||
diff_chunked = (actual_output - chunked_output_cpu).abs()
|
||||
print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}")
|
||||
print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}")
|
||||
break
|
||||
print()
|
||||
|
||||
# Verify decode outputs
|
||||
all_passed = True
|
||||
|
||||
for c in decode_captures:
|
||||
layer_id = c['layer_id']
|
||||
decode_step = c['decode_step']
|
||||
|
||||
ref_output = compute_decode_reference(
|
||||
layer_id, decode_step, scale,
|
||||
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
|
||||
)
|
||||
if ref_output is None:
|
||||
continue
|
||||
|
||||
actual_output = c['output'].squeeze(0)
|
||||
if actual_output.dim() == 3:
|
||||
actual_output = actual_output.squeeze(0)
|
||||
|
||||
diff = (actual_output - ref_output).abs()
|
||||
max_diff = diff.max().item()
|
||||
|
||||
passed = max_diff < 1e-1
|
||||
all_passed = all_passed and passed
|
||||
|
||||
if not passed:
|
||||
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
||||
|
||||
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")
|
||||
196
tests/test_chunked_prefill_hook.py
Normal file
196
tests/test_chunked_prefill_hook.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Correctness test for chunked prefill attention.
|
||||
|
||||
Captures Q and output during inference, then computes reference using
|
||||
CPU KV cache with standard flash attention.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import torch
|
||||
from random import randint, seed
|
||||
from typing import Dict, List
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.utils.context import get_context
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
# Config
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
MAX_MODEL_LEN = 128 * 1024
|
||||
NUM_GPU_BLOCKS = 2
|
||||
INPUT_LEN = 16 * 1024
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# State - capture Q and output for each (layer, chunk)
|
||||
captures: List[Dict] = []
|
||||
|
||||
|
||||
def make_ones_injection_hook():
|
||||
"""Inject Q=K=V=1.0 for deterministic testing."""
|
||||
def hook(module, inputs):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return inputs
|
||||
|
||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||
q_ones = torch.ones_like(q)
|
||||
k_ones = torch.ones_like(k)
|
||||
v_ones = torch.ones_like(v)
|
||||
return (q_ones, k_ones, v_ones) + inputs[3:]
|
||||
return hook
|
||||
|
||||
|
||||
def make_capture_hook(layer_id: int):
|
||||
"""Capture Q and output during prefill."""
|
||||
def hook(module, inputs, output):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
q, k, v = inputs
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
|
||||
captures.append({
|
||||
'layer_id': layer_id,
|
||||
'chunk_idx': chunk_idx,
|
||||
'q': q.clone().cpu(),
|
||||
'k': k.clone().cpu(),
|
||||
'v': v.clone().cpu(),
|
||||
'output': output.clone().cpu(),
|
||||
})
|
||||
return hook
|
||||
|
||||
|
||||
def compute_reference(layer_id: int, chunk_idx: int, scale: float,
|
||||
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
|
||||
block_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Compute reference output using CPU KV cache and standard flash attention.
|
||||
|
||||
Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
|
||||
then extracts output for the current chunk.
|
||||
"""
|
||||
# Get all captures for this layer up to chunk_idx
|
||||
layer_captures = [c for c in captures
|
||||
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
|
||||
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
|
||||
|
||||
if not layer_captures:
|
||||
return None
|
||||
|
||||
# Collect Q from captures, K/V from CPU cache
|
||||
all_q = []
|
||||
all_k = []
|
||||
all_v = []
|
||||
chunk_lengths = []
|
||||
|
||||
for c in layer_captures:
|
||||
cidx = c['chunk_idx']
|
||||
q = c['q'].cuda() # [seqlen, nheads, headdim]
|
||||
all_q.append(q)
|
||||
chunk_lengths.append(q.shape[0])
|
||||
|
||||
# Get K, V from CPU cache (already offloaded during prefill)
|
||||
# CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
|
||||
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
|
||||
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
|
||||
all_k.append(k)
|
||||
all_v.append(v)
|
||||
|
||||
# Concatenate
|
||||
full_q = torch.cat(all_q, dim=0)
|
||||
full_k = torch.cat(all_k, dim=0)
|
||||
full_v = torch.cat(all_v, dim=0)
|
||||
total_len = full_q.shape[0]
|
||||
|
||||
# Run standard causal flash attention
|
||||
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
|
||||
full_o = flash_attn_varlen_func(
|
||||
full_q, full_k, full_v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=total_len,
|
||||
max_seqlen_k=total_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Extract output for current chunk
|
||||
start_pos = sum(chunk_lengths[:-1])
|
||||
end_pos = sum(chunk_lengths)
|
||||
return full_o[start_pos:end_pos].cpu()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main
|
||||
# ============================================================
|
||||
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||
enable_cpu_offload=True,
|
||||
kvcache_block_size=BLOCK_SIZE,
|
||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
# Get model info
|
||||
num_layers = len(llm.model_runner.model.model.layers)
|
||||
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
||||
scale = head_dim ** -0.5
|
||||
|
||||
# Register hooks
|
||||
hooks = []
|
||||
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
||||
# Pre-hook: inject all ones for Q, K, V
|
||||
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
|
||||
# hooks.append(pre_hook)
|
||||
# Post-hook: capture Q, K, V, output
|
||||
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
|
||||
hooks.append(post_hook)
|
||||
|
||||
# Run inference
|
||||
seed(42)
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
|
||||
|
||||
# Remove hooks
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Get CPU cache reference
|
||||
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
||||
k_cache_cpu = offload_engine.k_cache_cpu.clone()
|
||||
v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
||||
|
||||
# Verify: compare actual output with reference computed from CPU cache
|
||||
all_passed = True
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
|
||||
for idx,c in enumerate(captures):
|
||||
layer_id = c['layer_id']
|
||||
chunk_idx = c['chunk_idx']
|
||||
|
||||
# Skip chunk 0 (no previous KV to load)
|
||||
if chunk_idx == 0:
|
||||
continue
|
||||
|
||||
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
|
||||
if ref_output is None:
|
||||
continue
|
||||
|
||||
actual_output = c['output']
|
||||
diff = (actual_output - ref_output).abs()
|
||||
max_diff = diff.max().item()
|
||||
|
||||
passed = max_diff < 1e-1 # float16 tolerance
|
||||
all_passed = all_passed and passed
|
||||
|
||||
if not passed:
|
||||
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")
|
||||
137
tests/test_debug_verification.py
Normal file
137
tests/test_debug_verification.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Test KV cache offload correctness using debug hooks.
|
||||
Injects distinctive K/V values, verifies loaded tensors match expected patterns.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import inspect
|
||||
from random import randint, seed
|
||||
from typing import Dict, List
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
# Config
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
MAX_MODEL_LEN = 32 * 1024
|
||||
NUM_GPU_BLOCKS = 4
|
||||
INPUT_LEN = 32 * 1024
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# State
|
||||
load_log: List[Dict] = []
|
||||
current_chunk: List[int] = [0]
|
||||
|
||||
|
||||
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
|
||||
"""Record loaded tensor values for layer 0."""
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
# Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
caller_frame = frame.f_back
|
||||
if caller_frame is not None:
|
||||
local_vars = caller_frame.f_locals
|
||||
if 'self' in local_vars:
|
||||
self_obj = local_vars['self']
|
||||
if hasattr(self_obj, 'k_cache_gpu'):
|
||||
num_slots = self_obj.k_cache_gpu.shape[0]
|
||||
vals = []
|
||||
for i in range(num_slots):
|
||||
v = self_obj.k_cache_gpu[i][0,0,0].item()
|
||||
if i == slot_idx:
|
||||
vals.append(f"[{v}]")
|
||||
else:
|
||||
vals.append(str(v))
|
||||
print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]")
|
||||
finally:
|
||||
del frame
|
||||
|
||||
load_log.append({
|
||||
"chunk_idx": current_chunk[0],
|
||||
"cpu_block_id": cpu_block_id,
|
||||
"k_value": k.float().mean().item(),
|
||||
})
|
||||
|
||||
|
||||
def make_pattern_injection_hook(layer_id):
|
||||
"""Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0."""
|
||||
def hook(module, inputs):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill or layer_id != 0:
|
||||
return inputs
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
current_chunk[0] = chunk_idx
|
||||
if len(inputs) >= 3:
|
||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||
k_new = torch.full_like(k, float(chunk_idx + 1))
|
||||
v_new = torch.full_like(v, float(-(chunk_idx + 1)))
|
||||
return (q, k_new, v_new) + inputs[3:]
|
||||
return inputs
|
||||
return hook
|
||||
|
||||
|
||||
def verify() -> bool:
|
||||
"""Verify blocks loaded in correct order with correct K values."""
|
||||
chunk_loads: Dict[int, List[tuple]] = {}
|
||||
for log in load_log:
|
||||
chunk = log["chunk_idx"]
|
||||
if chunk not in chunk_loads:
|
||||
chunk_loads[chunk] = []
|
||||
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
|
||||
|
||||
for chunk, loads in chunk_loads.items():
|
||||
expected_blocks = list(range(chunk))
|
||||
actual_blocks = [b for b, _ in loads]
|
||||
k_values = [k for _, k in loads]
|
||||
expected_k = [float(b + 1) for b in expected_blocks]
|
||||
|
||||
if actual_blocks != expected_blocks:
|
||||
return False
|
||||
if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Main
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||
enable_cpu_offload=True,
|
||||
kvcache_block_size=BLOCK_SIZE,
|
||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
||||
offload_engine.enable_debug_mode()
|
||||
offload_engine.register_debug_hook(debug_load_hook)
|
||||
|
||||
hooks = []
|
||||
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
||||
hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook(
|
||||
make_pattern_injection_hook(layer_idx)
|
||||
))
|
||||
|
||||
seed(42)
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
offload_engine.remove_debug_hook(debug_load_hook)
|
||||
offload_engine.disable_debug_mode()
|
||||
|
||||
# Verify
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
expected_loads = num_chunks * (num_chunks - 1) // 2
|
||||
passed = len(load_log) == expected_loads and verify()
|
||||
|
||||
print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")
|
||||
276
tests/test_flash_attn_kvcache.py
Normal file
276
tests/test_flash_attn_kvcache.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Test script for flash_attn_with_kvcache based chunked prefill.
|
||||
|
||||
Verifies that chunked prefill produces identical results to full attention.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
|
||||
|
||||
def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size):
|
||||
"""
|
||||
Chunked prefill using flash_attn_with_kvcache.
|
||||
|
||||
Args:
|
||||
q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim]
|
||||
k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim]
|
||||
cache_seqlens: [batch] - current cache lengths
|
||||
chunk_size: size of each chunk
|
||||
|
||||
Returns:
|
||||
output: [batch, total_seq_len, heads, head_dim]
|
||||
"""
|
||||
total_len = q_full.shape[1]
|
||||
outputs = []
|
||||
|
||||
for start in range(0, total_len, chunk_size):
|
||||
end = min(start + chunk_size, total_len)
|
||||
|
||||
q_chunk = q_full[:, start:end]
|
||||
k_chunk = k_full[:, start:end]
|
||||
v_chunk = v_full[:, start:end]
|
||||
|
||||
out = flash_attn_with_kvcache(
|
||||
q_chunk,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k=k_chunk,
|
||||
v=v_chunk,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True,
|
||||
)
|
||||
outputs.append(out)
|
||||
|
||||
cache_seqlens += (end - start)
|
||||
|
||||
return torch.cat(outputs, dim=1)
|
||||
|
||||
|
||||
def reference_attention(q, k, v):
|
||||
"""Standard flash attention as reference."""
|
||||
return flash_attn_func(q, k, v, causal=True)
|
||||
|
||||
|
||||
def test_chunked_prefill_correctness():
|
||||
"""Test that chunked prefill matches full attention."""
|
||||
|
||||
batch_size = 1
|
||||
num_heads = 32
|
||||
num_kv_heads = 8 # GQA
|
||||
head_dim = 128
|
||||
max_seq_len = 131072 # 128K
|
||||
|
||||
test_configs = [
|
||||
(1024, 256), # 1K tokens, 256 chunk
|
||||
(2048, 512), # 2K tokens, 512 chunk
|
||||
(4096, 1024), # 4K tokens, 1K chunk
|
||||
(4096, 2048), # 4K tokens, 2K chunk (2 chunks)
|
||||
(8192, 2048), # 8K tokens, 2K chunk (4 chunks)
|
||||
(16384, 4096), # 16K tokens, 4K chunk
|
||||
(32768, 4096), # 32K tokens, 4K chunk
|
||||
(65536, 8192), # 64K tokens, 8K chunk
|
||||
(131072, 8192), # 128K tokens, 8K chunk (16 chunks)
|
||||
]
|
||||
|
||||
for seq_len, chunk_size in test_configs:
|
||||
print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...")
|
||||
|
||||
# Generate random input
|
||||
torch.manual_seed(42)
|
||||
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
|
||||
# Expand K/V for non-GQA reference
|
||||
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
||||
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
||||
|
||||
# Reference: full attention
|
||||
ref_out = reference_attention(q, k_expanded, v_expanded)
|
||||
|
||||
# Chunked prefill with KV cache
|
||||
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
||||
|
||||
chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
|
||||
|
||||
# Compare
|
||||
max_diff = (ref_out - chunked_out).abs().max().item()
|
||||
mean_diff = (ref_out - chunked_out).abs().mean().item()
|
||||
|
||||
# Verify cache was filled correctly
|
||||
assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}"
|
||||
|
||||
# Check K/V cache content
|
||||
k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item()
|
||||
v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item()
|
||||
|
||||
print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}")
|
||||
print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}")
|
||||
|
||||
# Tolerance for fp16
|
||||
tolerance = 1e-2
|
||||
if max_diff < tolerance:
|
||||
print(f" PASSED")
|
||||
else:
|
||||
print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_incremental_decode():
|
||||
"""Test that decode after chunked prefill works correctly."""
|
||||
|
||||
batch_size = 1
|
||||
num_heads = 32
|
||||
num_kv_heads = 8
|
||||
head_dim = 128
|
||||
max_seq_len = 8192
|
||||
|
||||
prefill_len = 2048
|
||||
chunk_size = 512
|
||||
num_decode_steps = 10
|
||||
|
||||
print(f"\nTesting incremental decode after chunked prefill...")
|
||||
print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}")
|
||||
print(f" Decode: {num_decode_steps} steps")
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Prefill phase
|
||||
q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
|
||||
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
||||
|
||||
# Run chunked prefill
|
||||
prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill,
|
||||
k_cache, v_cache, cache_seqlens, chunk_size)
|
||||
|
||||
print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}")
|
||||
|
||||
# Decode phase - one token at a time
|
||||
for step in range(num_decode_steps):
|
||||
q_decode = torch.randn(batch_size, 1, num_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
|
||||
decode_out = flash_attn_with_kvcache(
|
||||
q_decode,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k=k_decode,
|
||||
v=v_decode,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
cache_seqlens += 1
|
||||
|
||||
assert decode_out.shape == (batch_size, 1, num_heads, head_dim)
|
||||
|
||||
expected_len = prefill_len + num_decode_steps
|
||||
actual_len = cache_seqlens[0].item()
|
||||
|
||||
print(f" After decode: cache_seqlens={actual_len}")
|
||||
|
||||
if actual_len == expected_len:
|
||||
print(f" PASSED")
|
||||
return True
|
||||
else:
|
||||
print(f" FAILED: expected {expected_len}, got {actual_len}")
|
||||
return False
|
||||
|
||||
|
||||
def test_batch_processing():
|
||||
"""Test chunked prefill with batch > 1."""
|
||||
|
||||
batch_size = 4
|
||||
num_heads = 32
|
||||
num_kv_heads = 8
|
||||
head_dim = 128
|
||||
max_seq_len = 4096
|
||||
seq_len = 2048
|
||||
chunk_size = 512
|
||||
|
||||
print(f"\nTesting batch processing (batch_size={batch_size})...")
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
|
||||
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device='cuda')
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
||||
|
||||
out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
|
||||
|
||||
# Verify all batches have correct cache length
|
||||
assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}"
|
||||
assert out.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
# Compare with reference for each batch item
|
||||
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
||||
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
||||
ref_out = reference_attention(q, k_expanded, v_expanded)
|
||||
|
||||
max_diff = (ref_out - out).abs().max().item()
|
||||
|
||||
print(f" Output shape: {out.shape}")
|
||||
print(f" Max diff vs reference: {max_diff:.6f}")
|
||||
|
||||
if max_diff < 1e-2:
|
||||
print(f" PASSED")
|
||||
return True
|
||||
else:
|
||||
print(f" FAILED")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test Script
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Testing flash_attn_with_kvcache chunked prefill")
|
||||
print("=" * 60)
|
||||
|
||||
all_passed = True
|
||||
|
||||
all_passed &= test_chunked_prefill_correctness()
|
||||
all_passed &= test_incremental_decode()
|
||||
all_passed &= test_batch_processing()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if all_passed:
|
||||
print("test_flash_attn_kvcache: ALL TESTS PASSED")
|
||||
else:
|
||||
print("test_flash_attn_kvcache: SOME TESTS FAILED")
|
||||
print("=" * 60)
|
||||
104
tests/test_flashinfer_merge.py
Normal file
104
tests/test_flashinfer_merge.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Test FlashInfer chunked attention with CPU offload.
|
||||
|
||||
Uses single_prefill_with_kv_cache + merge_state for chunked KV processing.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import flashinfer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Core Functions
|
||||
# ============================================================
|
||||
|
||||
def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size):
|
||||
"""
|
||||
Chunked causal attention with KV on CPU.
|
||||
|
||||
q: [seq_q, num_heads, head_dim] on GPU
|
||||
k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU
|
||||
"""
|
||||
seq_q = q.shape[0]
|
||||
seq_kv = k_cpu.shape[0]
|
||||
final_outputs = []
|
||||
|
||||
for q_start in range(0, seq_q, q_chunk_size):
|
||||
q_end = min(q_start + q_chunk_size, seq_q)
|
||||
q_chunk = q[q_start:q_end]
|
||||
|
||||
merged_output = None
|
||||
merged_lse = None
|
||||
|
||||
for kv_start in range(0, seq_kv, kv_chunk_size):
|
||||
kv_end = min(kv_start + kv_chunk_size, seq_kv)
|
||||
|
||||
if kv_start >= q_end:
|
||||
continue
|
||||
|
||||
k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
|
||||
v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
|
||||
|
||||
causal = not (kv_end <= q_start)
|
||||
partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache(
|
||||
q_chunk, k_chunk, v_chunk,
|
||||
causal=causal,
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
if merged_output is None:
|
||||
merged_output, merged_lse = partial_out, partial_lse
|
||||
else:
|
||||
merged_output, merged_lse = flashinfer.merge_state(
|
||||
merged_output, merged_lse,
|
||||
partial_out, partial_lse,
|
||||
)
|
||||
|
||||
final_outputs.append(merged_output)
|
||||
|
||||
return torch.cat(final_outputs, dim=0)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test Script
|
||||
# ============================================================
|
||||
|
||||
print("=" * 60)
|
||||
print("Testing FlashInfer chunked attention with CPU offload")
|
||||
print("=" * 60)
|
||||
|
||||
num_heads = 32
|
||||
num_kv_heads = 8
|
||||
head_dim = 128
|
||||
|
||||
test_configs = [
|
||||
(32768, 8192, 8192), # 32K tokens
|
||||
(65536, 8192, 8192), # 64K tokens
|
||||
(131072, 16384, 16384), # 128K tokens
|
||||
# (262144, 16384, 16384), # 256K tokens (slow)
|
||||
# (524288, 16384, 16384), # 512K tokens (slow)
|
||||
]
|
||||
|
||||
for seq_len, q_chunk, kv_chunk in test_configs:
|
||||
torch.manual_seed(42)
|
||||
|
||||
q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda')
|
||||
k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
|
||||
v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
|
||||
|
||||
# Chunked result
|
||||
chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk)
|
||||
|
||||
# Reference
|
||||
k_gpu = k_cpu.to('cuda')
|
||||
v_gpu = v_cpu.to('cuda')
|
||||
ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True)
|
||||
|
||||
max_diff = (ref_out - chunked_out).abs().max().item()
|
||||
mean_diff = (ref_out - chunked_out).abs().mean().item()
|
||||
|
||||
num_chunks = (seq_len + q_chunk - 1) // q_chunk
|
||||
assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}"
|
||||
print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
print("\ntest_flashinfer_merge: PASSED")
|
||||
178
tests/test_needle.py
Normal file
178
tests/test_needle.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Needle-in-a-haystack test for LLM.
|
||||
|
||||
Tests: Long context retrieval capability with configurable sequence length.
|
||||
|
||||
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
|
||||
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
def run_needle_test(
|
||||
model_path: str,
|
||||
max_model_len: int,
|
||||
input_len: int,
|
||||
num_gpu_blocks: int = 4,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
max_new_tokens: int = 32,
|
||||
enable_cpu_offload: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Run a needle-in-haystack test.
|
||||
|
||||
Args:
|
||||
model_path: Path to model
|
||||
max_model_len: Maximum model context length
|
||||
input_len: Target input sequence length
|
||||
num_gpu_blocks: Number of GPU blocks for offload
|
||||
needle_position: Where to place needle (0.0-1.0)
|
||||
needle_value: The secret value to find
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
enable_cpu_offload: Enable CPU offload mode
|
||||
verbose: Print detailed output
|
||||
|
||||
Returns:
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Needle-in-Haystack Test")
|
||||
print(f"{'='*60}")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Max model len: {max_model_len}")
|
||||
print(f"Input length: {input_len}")
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 1. Initialize LLM
|
||||
llm_kwargs = {
|
||||
"enforce_eager": True,
|
||||
"max_model_len": max_model_len,
|
||||
"max_num_batched_tokens": max_model_len,
|
||||
"enable_cpu_offload": enable_cpu_offload,
|
||||
}
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
|
||||
# 2. Generate needle prompt
|
||||
prompt, expected = generate_needle_prompt(
|
||||
tokenizer=llm.tokenizer,
|
||||
target_length=input_len,
|
||||
needle_position=needle_position,
|
||||
needle_value=needle_value,
|
||||
)
|
||||
|
||||
# 3. Generate output
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.6, # Moderate temperature
|
||||
max_tokens=max_new_tokens,
|
||||
)
|
||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
|
||||
|
||||
# 4. Check result
|
||||
output_text = outputs[0]["text"]
|
||||
output_token_ids = outputs[0]["token_ids"]
|
||||
passed = check_needle_answer(output_text, expected)
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Result")
|
||||
print(f"{'='*60}")
|
||||
print(f"Expected: {expected}")
|
||||
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
|
||||
print(f"Output: {output_text[:200]}...")
|
||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI Entry Point
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
|
||||
parser.add_argument(
|
||||
"--model", "-m",
|
||||
type=str,
|
||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
||||
help="Path to model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=32 * 1024,
|
||||
help="Maximum model context length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=8 * 1024,
|
||||
help="Target input sequence length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-gpu-blocks",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of GPU blocks for CPU offload"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-position",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-value",
|
||||
type=str,
|
||||
default="7492",
|
||||
help="The secret value to hide"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Maximum tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-offload",
|
||||
action="store_true",
|
||||
help="Enable CPU offload (has known bug for long sequences)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
passed = run_needle_test(
|
||||
model_path=args.model,
|
||||
max_model_len=args.max_model_len,
|
||||
input_len=args.input_len,
|
||||
num_gpu_blocks=args.num_gpu_blocks,
|
||||
needle_position=args.needle_position,
|
||||
needle_value=args.needle_value,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
enable_cpu_offload=args.enable_offload,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if passed:
|
||||
print("test_needle: PASSED")
|
||||
else:
|
||||
print("test_needle: FAILED")
|
||||
exit(1)
|
||||
176
tests/test_needle_ref.py
Normal file
176
tests/test_needle_ref.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Needle-in-a-haystack reference test using pure torch + transformers.
|
||||
|
||||
This is a reference implementation for comparison with nanovllm.
|
||||
Uses standard HuggingFace inference (no custom KV cache, no offload).
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
def run_needle_test(
|
||||
model_path: str,
|
||||
input_len: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
max_new_tokens: int = 32,
|
||||
dtype: str = "auto",
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Run a needle-in-haystack test using standard transformers inference.
|
||||
|
||||
Args:
|
||||
model_path: Path to model
|
||||
input_len: Target input sequence length
|
||||
needle_position: Where to place needle (0.0-1.0)
|
||||
needle_value: The secret value to find
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
dtype: Model dtype ("auto", "float16", "bfloat16")
|
||||
verbose: Print detailed output
|
||||
|
||||
Returns:
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Needle-in-Haystack Reference Test (torch + transformers)")
|
||||
print(f"{'='*60}")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Input length: {input_len}")
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"Dtype: {dtype}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 1. Load tokenizer
|
||||
print("[1/4] Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# 2. Generate needle prompt
|
||||
print("[2/4] Generating needle prompt...")
|
||||
prompt, expected = generate_needle_prompt(
|
||||
tokenizer=tokenizer,
|
||||
target_length=input_len,
|
||||
needle_position=needle_position,
|
||||
needle_value=needle_value,
|
||||
)
|
||||
|
||||
# 3. Load model
|
||||
print("[3/4] Loading model...")
|
||||
torch_dtype = {
|
||||
"auto": torch.float16, # default to float16 for custom model
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}.get(dtype, torch.float16)
|
||||
|
||||
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
|
||||
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.eval()
|
||||
|
||||
# 4. Generate output
|
||||
print("[4/4] Running inference...")
|
||||
device = next(model.parameters()).device
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
print(f" Input shape: {input_ids.shape}")
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.6,
|
||||
do_sample=True,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
# Decode only the new tokens
|
||||
new_token_ids = output_ids[0, input_ids.shape[1]:]
|
||||
output_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
|
||||
|
||||
# 5. Check result
|
||||
passed = check_needle_answer(output_text, expected)
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Result")
|
||||
print(f"{'='*60}")
|
||||
print(f"Expected: {expected}")
|
||||
print(f"Output tokens ({len(new_token_ids)}): {new_token_ids[:20].tolist()}")
|
||||
print(f"Output: {output_text[:200]}...")
|
||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI Entry Point
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Needle-in-haystack reference test (torch + transformers)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", "-m",
|
||||
type=str,
|
||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
||||
help="Path to model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=8 * 1024,
|
||||
help="Target input sequence length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-position",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-value",
|
||||
type=str,
|
||||
default="7492",
|
||||
help="The secret value to hide"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Maximum tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "float16", "bfloat16"],
|
||||
help="Model dtype"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
passed = run_needle_test(
|
||||
model_path=args.model,
|
||||
input_len=args.input_len,
|
||||
needle_position=args.needle_position,
|
||||
needle_value=args.needle_value,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
dtype=args.dtype,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if passed:
|
||||
print("test_needle_ref: PASSED")
|
||||
else:
|
||||
print("test_needle_ref: FAILED")
|
||||
exit(1)
|
||||
695
tests/test_offload_correctness.py
Normal file
695
tests/test_offload_correctness.py
Normal file
@@ -0,0 +1,695 @@
|
||||
"""
|
||||
Test script to verify CPU offload correctness using distinctive KV patterns.
|
||||
|
||||
Strategy:
|
||||
1. Hook into attention forward pass
|
||||
2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx)
|
||||
3. After offload to CPU, verify CPU cache contains correct patterns
|
||||
4. On subsequent chunks, verify loaded KV from CPU has correct patterns
|
||||
|
||||
This catches bugs like:
|
||||
- Wrong block being offloaded
|
||||
- Wrong block being loaded
|
||||
- Data corruption during transfer
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import torch
|
||||
from random import randint, seed
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
MAX_MODEL_LEN = 64 * 1024
|
||||
NUM_GPU_BLOCKS = 4
|
||||
INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks)
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# Test state
|
||||
errors = []
|
||||
chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern)
|
||||
block_coverage = {} # chunk_idx -> set of blocks that were actually computed
|
||||
load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples
|
||||
current_chunk_for_load = [0] # Mutable container to track current chunk during loads
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Pattern Helpers
|
||||
# ============================================================
|
||||
|
||||
def get_expected_pattern(chunk_idx: int):
|
||||
"""Get expected K/V pattern for a chunk."""
|
||||
# Use float values that are easy to identify
|
||||
k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ...
|
||||
v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ...
|
||||
return k_val, v_val
|
||||
|
||||
|
||||
def fill_with_pattern(tensor: torch.Tensor, value: float):
|
||||
"""Fill tensor with a constant value."""
|
||||
tensor.fill_(value)
|
||||
|
||||
|
||||
def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3):
|
||||
"""Check if tensor contains expected pattern."""
|
||||
actual_mean = tensor.float().mean().item()
|
||||
if abs(actual_mean - expected) > tolerance:
|
||||
return False, f"{name}: expected mean={expected}, got {actual_mean}"
|
||||
return True, None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Load Verification Instrumentation
|
||||
# ============================================================
|
||||
|
||||
_original_load_to_slot_layer = None
|
||||
_offload_engine_ref = None
|
||||
|
||||
def make_verified_load_to_slot_layer(original_func, offload_engine):
|
||||
"""
|
||||
Create a wrapper around load_to_slot_layer that verifies each load operation.
|
||||
|
||||
After each H2D transfer, checks that the GPU slot contains the expected
|
||||
pattern from the source CPU block.
|
||||
"""
|
||||
def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int):
|
||||
# Call original load
|
||||
original_func(slot_idx, layer_id, cpu_block_id)
|
||||
|
||||
# Only verify layer 0 to reduce overhead
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
# IMPORTANT: Synchronize CUDA to ensure async transfer is complete
|
||||
# The transfer happens on a per-slot stream, and wait_slot_layer only
|
||||
# makes compute_stream wait. We need full sync to read on default stream.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Get the expected pattern for this CPU block
|
||||
# cpu_block_id == chunk_idx in our sequential test
|
||||
expected_k, expected_v = get_expected_pattern(cpu_block_id)
|
||||
|
||||
# Read GPU slot data (GPU cache has no layer dimension)
|
||||
gpu_k = offload_engine.k_cache_gpu[slot_idx]
|
||||
gpu_v = offload_engine.v_cache_gpu[slot_idx]
|
||||
|
||||
actual_k = gpu_k.float().mean().item()
|
||||
actual_v = gpu_v.float().mean().item()
|
||||
|
||||
k_ok = abs(actual_k - expected_k) < 1e-3
|
||||
v_ok = abs(actual_v - expected_v) < 1e-3
|
||||
|
||||
chunk_idx = current_chunk_for_load[0]
|
||||
load_operations.append({
|
||||
'chunk_idx': chunk_idx,
|
||||
'slot_idx': slot_idx,
|
||||
'cpu_block_id': cpu_block_id,
|
||||
'expected_k': expected_k,
|
||||
'expected_v': expected_v,
|
||||
'actual_k': actual_k,
|
||||
'actual_v': actual_v,
|
||||
'k_ok': k_ok,
|
||||
'v_ok': v_ok,
|
||||
})
|
||||
|
||||
if not (k_ok and v_ok):
|
||||
errors.append(f"Load verification failed: chunk {chunk_idx}, "
|
||||
f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: "
|
||||
f"expected K={expected_k:.1f}/V={expected_v:.1f}, "
|
||||
f"got K={actual_k:.4f}/V={actual_v:.4f}")
|
||||
|
||||
return verified_load
|
||||
|
||||
|
||||
def install_load_verification(llm):
|
||||
"""Install verification wrapper on load_to_slot_layer."""
|
||||
global _original_load_to_slot_layer, _offload_engine_ref
|
||||
|
||||
oe = llm.model_runner.kvcache_manager.offload_engine
|
||||
_offload_engine_ref = oe
|
||||
_original_load_to_slot_layer = oe.load_to_slot_layer
|
||||
|
||||
oe.load_to_slot_layer = make_verified_load_to_slot_layer(
|
||||
_original_load_to_slot_layer, oe
|
||||
)
|
||||
print("Installed load verification wrapper on load_to_slot_layer")
|
||||
|
||||
|
||||
def uninstall_load_verification():
|
||||
"""Restore original load_to_slot_layer."""
|
||||
global _original_load_to_slot_layer, _offload_engine_ref
|
||||
|
||||
if _offload_engine_ref is not None and _original_load_to_slot_layer is not None:
|
||||
_offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer
|
||||
print("Restored original load_to_slot_layer")
|
||||
|
||||
_original_load_to_slot_layer = None
|
||||
_offload_engine_ref = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Attention Hook
|
||||
# ============================================================
|
||||
|
||||
def make_kv_pattern_pre_hook(layer_id: int):
|
||||
"""
|
||||
Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE
|
||||
they are stored to cache. This is called before attention.forward().
|
||||
|
||||
register_forward_pre_hook receives (module, inputs) and can modify inputs in-place.
|
||||
"""
|
||||
def hook(module, inputs):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
# Only process layer 0 for cleaner output
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
q, k, v = inputs
|
||||
k_pattern, v_pattern = get_expected_pattern(chunk_idx)
|
||||
|
||||
# === Overwrite current chunk's K/V with distinctive pattern ===
|
||||
# This happens BEFORE forward(), so these values will be stored to cache
|
||||
k.fill_(k_pattern)
|
||||
v.fill_(v_pattern)
|
||||
|
||||
# Only print for first few and last few chunks to reduce noise
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
|
||||
print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}")
|
||||
elif chunk_idx == 3:
|
||||
print(f"... (chunks 3 to {num_chunks - 3} omitted) ...")
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def make_block_coverage_pre_hook(layer_id: int):
|
||||
"""
|
||||
Create a PRE-forward hook to verify that all previous blocks are included
|
||||
in the cpu_block_table for chunked prefill attention.
|
||||
|
||||
This catches bugs where:
|
||||
- Some blocks are missing from the computation
|
||||
- Sparse policy incorrectly filters out blocks (when not intended)
|
||||
- Block table construction has off-by-one errors
|
||||
"""
|
||||
def hook(module, inputs):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
# Only process layer 0 for cleaner output
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
# Update current chunk for load verification tracking
|
||||
current_chunk_for_load[0] = chunk_idx
|
||||
|
||||
# No previous blocks for chunk 0
|
||||
if chunk_idx == 0:
|
||||
return
|
||||
|
||||
# Get the sequence and its block table (same logic as _chunked_prefill_attention)
|
||||
seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None
|
||||
if seq is None:
|
||||
return
|
||||
|
||||
# Get the CPU block table that will be used for attention
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Expected blocks: 0 to chunk_idx-1 (all previous chunks)
|
||||
expected_blocks = set(range(chunk_idx))
|
||||
actual_blocks = set(cpu_block_table) if cpu_block_table else set()
|
||||
|
||||
# Store for later summary
|
||||
block_coverage[chunk_idx] = {
|
||||
'expected': expected_blocks,
|
||||
'actual': actual_blocks,
|
||||
}
|
||||
|
||||
# Check for missing blocks
|
||||
missing_blocks = expected_blocks - actual_blocks
|
||||
extra_blocks = actual_blocks - expected_blocks
|
||||
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks:
|
||||
if not missing_blocks and not extra_blocks:
|
||||
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]")
|
||||
else:
|
||||
status_parts = []
|
||||
if missing_blocks:
|
||||
status_parts.append(f"MISSING {sorted(missing_blocks)}")
|
||||
if extra_blocks:
|
||||
status_parts.append(f"EXTRA {sorted(extra_blocks)}")
|
||||
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]")
|
||||
elif chunk_idx == 4:
|
||||
# Indicate that middle chunks are being verified silently
|
||||
print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...")
|
||||
|
||||
if missing_blocks:
|
||||
errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}")
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def make_gpu_write_verification_post_hook(layer_id: int):
|
||||
"""
|
||||
Create a POST-forward hook to verify the current chunk's KV was correctly
|
||||
written to the GPU ring buffer write_slot.
|
||||
|
||||
This is a more reliable verification than checking load slots, because:
|
||||
1. Post-hook runs AFTER forward() writes to GPU cache
|
||||
2. write_slot mapping is deterministic: chunk_idx % num_ring_slots
|
||||
3. We injected known patterns in pre-hook, now verify they're in GPU cache
|
||||
"""
|
||||
def hook(module, inputs, output):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
# Only process layer 0 for cleaner output
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
oe = kvcache_manager.offload_engine
|
||||
num_ring_slots = oe.num_ring_slots
|
||||
write_slot = chunk_idx % num_ring_slots
|
||||
|
||||
# Get expected pattern for current chunk
|
||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
||||
|
||||
# Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
|
||||
gpu_k = oe.k_cache_gpu[write_slot]
|
||||
gpu_v = oe.v_cache_gpu[write_slot]
|
||||
|
||||
actual_k_mean = gpu_k.float().mean().item()
|
||||
actual_v_mean = gpu_v.float().mean().item()
|
||||
|
||||
k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}")
|
||||
v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}")
|
||||
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
# Print for first/last chunks, or if there's an error
|
||||
if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok):
|
||||
if k_ok and v_ok:
|
||||
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]")
|
||||
else:
|
||||
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, "
|
||||
f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]")
|
||||
elif chunk_idx == 4:
|
||||
print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...")
|
||||
|
||||
if not (k_ok and v_ok):
|
||||
errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: "
|
||||
f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}")
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def make_kv_verification_post_hook(layer_id: int):
|
||||
"""
|
||||
Create a POST-forward hook to verify CPU cache contains correct patterns
|
||||
from previously offloaded blocks.
|
||||
"""
|
||||
def hook(module, inputs, output):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
# Only process layer 0 for cleaner output
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
# === Verify previously offloaded blocks in CPU cache ===
|
||||
if chunk_idx >= 1:
|
||||
oe = kvcache_manager.offload_engine
|
||||
num_ok = 0
|
||||
num_fail = 0
|
||||
|
||||
# Check all previously offloaded blocks
|
||||
for prev_chunk in range(chunk_idx):
|
||||
# CPU block ID = prev_chunk (in simple sequential case)
|
||||
cpu_block_id = prev_chunk
|
||||
|
||||
# Get expected pattern for this block
|
||||
expected_k, expected_v = get_expected_pattern(prev_chunk)
|
||||
|
||||
# Read from CPU cache (layer 0)
|
||||
cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id]
|
||||
cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id]
|
||||
|
||||
# Verify patterns
|
||||
k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}")
|
||||
v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}")
|
||||
|
||||
if k_ok and v_ok:
|
||||
num_ok += 1
|
||||
else:
|
||||
num_fail += 1
|
||||
if k_err:
|
||||
errors.append(k_err)
|
||||
if v_err:
|
||||
errors.append(v_err)
|
||||
|
||||
# Only print summary for each chunk verification
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0:
|
||||
status = "OK" if num_fail == 0 else f"FAIL({num_fail})"
|
||||
print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]")
|
||||
elif chunk_idx == 4:
|
||||
print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...")
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def make_post_chunk_verification_hook(layer_id: int):
|
||||
"""
|
||||
Post-forward hook to verify GPU ring buffer state after attention.
|
||||
"""
|
||||
def hook(module, inputs, output):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill or layer_id != 0:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
oe = kvcache_manager.offload_engine
|
||||
|
||||
# After attention, the current chunk's KV should be in the GPU ring buffer
|
||||
# Ring slot = chunk_idx % num_ring_slots
|
||||
ring_slot = chunk_idx % oe.num_ring_slots
|
||||
|
||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
||||
|
||||
# Check GPU ring buffer (GPU cache has no layer dimension)
|
||||
gpu_k = oe.k_cache_gpu[ring_slot]
|
||||
gpu_v = oe.v_cache_gpu[ring_slot]
|
||||
|
||||
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
|
||||
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
|
||||
|
||||
if k_ok and v_ok:
|
||||
print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}")
|
||||
else:
|
||||
if k_err:
|
||||
print(f" [FAIL] {k_err}")
|
||||
errors.append(k_err)
|
||||
if v_err:
|
||||
print(f" [FAIL] {v_err}")
|
||||
errors.append(v_err)
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def register_hooks(llm):
|
||||
"""Register pre and post forward hooks."""
|
||||
hooks = []
|
||||
model = llm.model_runner.model
|
||||
|
||||
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
||||
attn_module = decoder_layer.self_attn.attn
|
||||
|
||||
# PRE-forward hook 1: Verify all previous blocks are in cpu_block_table
|
||||
coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx))
|
||||
hooks.append(coverage_hook)
|
||||
|
||||
# PRE-forward hook 2: Inject K/V patterns before they're stored to cache
|
||||
pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx))
|
||||
hooks.append(pattern_hook)
|
||||
|
||||
# POST-forward hook 1: Verify GPU write_slot contains current chunk's data
|
||||
gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx))
|
||||
hooks.append(gpu_verify_hook)
|
||||
|
||||
# POST-forward hook 2: Verify CPU cache contains correct patterns after offload
|
||||
cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx))
|
||||
hooks.append(cpu_verify_hook)
|
||||
|
||||
return hooks
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Final Verification
|
||||
# ============================================================
|
||||
|
||||
def verify_final_cpu_state(llm, num_chunks: int):
|
||||
"""Verify all CPU blocks have correct patterns after prefill completes."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Final CPU Cache Verification")
|
||||
print("=" * 60)
|
||||
|
||||
kvcache_manager = llm.model_runner.kvcache_manager
|
||||
oe = kvcache_manager.offload_engine
|
||||
|
||||
num_ok = 0
|
||||
num_fail = 0
|
||||
fail_details = []
|
||||
|
||||
# After prefill, all chunks should be in CPU
|
||||
for chunk_idx in range(num_chunks):
|
||||
cpu_block_id = chunk_idx
|
||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
||||
|
||||
# Check layer 0
|
||||
cpu_k = oe.k_cache_cpu[0, cpu_block_id]
|
||||
cpu_v = oe.v_cache_cpu[0, cpu_block_id]
|
||||
|
||||
k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}")
|
||||
v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}")
|
||||
|
||||
if k_ok and v_ok:
|
||||
num_ok += 1
|
||||
# Only print first few and last few
|
||||
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
|
||||
actual_k_mean = cpu_k.float().mean().item()
|
||||
actual_v_mean = cpu_v.float().mean().item()
|
||||
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
|
||||
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]")
|
||||
elif chunk_idx == 3:
|
||||
print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...")
|
||||
else:
|
||||
num_fail += 1
|
||||
actual_k_mean = cpu_k.float().mean().item()
|
||||
actual_v_mean = cpu_v.float().mean().item()
|
||||
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
|
||||
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]")
|
||||
if k_err:
|
||||
errors.append(k_err)
|
||||
if v_err:
|
||||
errors.append(v_err)
|
||||
|
||||
print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks")
|
||||
|
||||
|
||||
def verify_block_coverage_summary(num_chunks: int):
|
||||
"""Verify that all chunks had complete block coverage during prefill."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Block Coverage Verification Summary")
|
||||
print("=" * 60)
|
||||
|
||||
num_ok = 0
|
||||
num_fail = 0
|
||||
total_blocks_expected = 0
|
||||
total_blocks_computed = 0
|
||||
|
||||
for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous)
|
||||
if chunk_idx not in block_coverage:
|
||||
print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]")
|
||||
errors.append(f"Chunk {chunk_idx} has no block coverage data")
|
||||
num_fail += 1
|
||||
continue
|
||||
|
||||
coverage = block_coverage[chunk_idx]
|
||||
expected = coverage['expected']
|
||||
actual = coverage['actual']
|
||||
missing = expected - actual
|
||||
|
||||
total_blocks_expected += len(expected)
|
||||
total_blocks_computed += len(actual)
|
||||
|
||||
if not missing:
|
||||
num_ok += 1
|
||||
else:
|
||||
num_fail += 1
|
||||
|
||||
# Print summary
|
||||
if num_fail == 0:
|
||||
print(f" All {num_ok} chunks had complete block coverage [OK]")
|
||||
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
|
||||
else:
|
||||
print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]")
|
||||
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
|
||||
|
||||
# Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2
|
||||
expected_total = num_chunks * (num_chunks - 1) // 2
|
||||
if total_blocks_expected == expected_total:
|
||||
print(f" Expected total blocks matches formula: {expected_total} [OK]")
|
||||
else:
|
||||
print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]")
|
||||
errors.append(f"Block coverage total mismatch")
|
||||
|
||||
|
||||
def verify_load_operations_summary(num_chunks: int):
|
||||
"""Verify all H2D load operations transferred correct data."""
|
||||
print("\n" + "=" * 60)
|
||||
print("H2D Load Operations Verification Summary")
|
||||
print("=" * 60)
|
||||
|
||||
if not load_operations:
|
||||
print(" WARNING: No load operations recorded!")
|
||||
print(" (This may indicate load verification was not installed)")
|
||||
return
|
||||
|
||||
num_ok = 0
|
||||
num_fail = 0
|
||||
loads_per_chunk = {}
|
||||
|
||||
for op in load_operations:
|
||||
chunk_idx = op['chunk_idx']
|
||||
if chunk_idx not in loads_per_chunk:
|
||||
loads_per_chunk[chunk_idx] = []
|
||||
loads_per_chunk[chunk_idx].append(op)
|
||||
|
||||
if op['k_ok'] and op['v_ok']:
|
||||
num_ok += 1
|
||||
else:
|
||||
num_fail += 1
|
||||
|
||||
# Print per-chunk summary for first/last chunks
|
||||
for chunk_idx in sorted(loads_per_chunk.keys()):
|
||||
ops = loads_per_chunk[chunk_idx]
|
||||
chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok'])
|
||||
chunk_fail = len(ops) - chunk_ok
|
||||
|
||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0:
|
||||
# Show loaded block IDs in order
|
||||
block_ids = [op['cpu_block_id'] for op in ops]
|
||||
if chunk_fail == 0:
|
||||
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]")
|
||||
else:
|
||||
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]")
|
||||
for op in ops:
|
||||
if not (op['k_ok'] and op['v_ok']):
|
||||
print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: "
|
||||
f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, "
|
||||
f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}")
|
||||
elif chunk_idx == 4:
|
||||
print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...")
|
||||
|
||||
# Print overall summary
|
||||
print(f"\n Total load operations: {len(load_operations)}")
|
||||
print(f" Successful: {num_ok}, Failed: {num_fail}")
|
||||
|
||||
if num_fail == 0:
|
||||
print(f" All H2D transfers verified correct [OK]")
|
||||
else:
|
||||
print(f" {num_fail} H2D transfers had incorrect data [FAIL]")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test Script
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Test: CPU Offload Correctness with Distinctive KV Patterns")
|
||||
print("=" * 60)
|
||||
print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks")
|
||||
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
|
||||
print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)")
|
||||
print()
|
||||
|
||||
# 1. Initialize LLM
|
||||
print("Initializing LLM...")
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||
enable_cpu_offload=True,
|
||||
kvcache_block_size=BLOCK_SIZE,
|
||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
# 2. Register hooks
|
||||
hooks = register_hooks(llm)
|
||||
print(f"Registered {len(hooks)} hooks")
|
||||
|
||||
# 3. Install load verification (instrument load_to_slot_layer)
|
||||
install_load_verification(llm)
|
||||
|
||||
# 4. Generate prompt
|
||||
seed(42)
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
|
||||
# 5. Run prefill
|
||||
print("\n" + "=" * 60)
|
||||
print("Running Prefill with KV Pattern Injection...")
|
||||
print("=" * 60)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
||||
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||
|
||||
# 6. Remove hooks and uninstall load verification
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
uninstall_load_verification()
|
||||
|
||||
# 7. Final verification
|
||||
verify_final_cpu_state(llm, num_chunks)
|
||||
|
||||
# 8. Block coverage summary
|
||||
verify_block_coverage_summary(num_chunks)
|
||||
|
||||
# 9. H2D load operations summary
|
||||
verify_load_operations_summary(num_chunks)
|
||||
|
||||
# 10. Report results
|
||||
print("\n" + "=" * 60)
|
||||
if errors:
|
||||
print(f"test_offload_correctness: FAILED ({len(errors)} errors)")
|
||||
for err in errors[:10]: # Show first 10 errors
|
||||
print(f" - {err}")
|
||||
exit(1)
|
||||
else:
|
||||
print("test_offload_correctness: PASSED")
|
||||
print("=" * 60)
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
Test if slicing maintains pinned memory property.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
print("=" * 60)
|
||||
print("Test: Pinned Memory Property with Slicing")
|
||||
print("=" * 60)
|
||||
|
||||
# Create a pinned tensor with shape similar to k_cache_cpu
|
||||
# [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]
|
||||
tensor = torch.zeros(8, 16, 1024, 8, 64, dtype=torch.float16, device="cpu", pin_memory=True)
|
||||
|
||||
print(f"\n1. Original tensor:")
|
||||
print(f" - Shape: {tensor.shape}")
|
||||
print(f" - is_pinned(): {tensor.is_pinned()}")
|
||||
print(f" - is_contiguous(): {tensor.is_contiguous()}")
|
||||
|
||||
# Test slicing operation (what we do in offload_slot_to_cpu)
|
||||
slice_view = tensor[:, 0] # Same as k_cache_cpu[:, cpu_block_id]
|
||||
|
||||
print(f"\n2. Sliced tensor [:, 0]:")
|
||||
print(f" - Shape: {slice_view.shape}")
|
||||
print(f" - is_pinned(): {slice_view.is_pinned()}")
|
||||
print(f" - is_contiguous(): {slice_view.is_contiguous()}")
|
||||
|
||||
# Test if contiguous() helps
|
||||
contiguous_slice = tensor[:, 0].contiguous()
|
||||
|
||||
print(f"\n3. Contiguous slice [:, 0].contiguous():")
|
||||
print(f" - Shape: {contiguous_slice.shape}")
|
||||
print(f" - is_pinned(): {contiguous_slice.is_pinned()}")
|
||||
print(f" - is_contiguous(): {contiguous_slice.is_contiguous()}")
|
||||
|
||||
# Test copy behavior
|
||||
gpu_tensor = torch.zeros(8, 4, 1024, 8, 64, dtype=torch.float16, device="cuda")
|
||||
gpu_slice = gpu_tensor[:, 0]
|
||||
|
||||
print(f"\n4. GPU tensor slice:")
|
||||
print(f" - Shape: {gpu_slice.shape}")
|
||||
print(f" - is_contiguous(): {gpu_slice.is_contiguous()}")
|
||||
|
||||
# Simulate the problematic copy operation
|
||||
print(f"\n5. Testing copy operations:")
|
||||
|
||||
# Method 1: Direct slice copy (current approach - SLOW)
|
||||
slice_dst = tensor[:, 1]
|
||||
print(f" Method 1 (slice view): dst.is_pinned()={slice_dst.is_pinned()}")
|
||||
|
||||
# Method 2: Use contiguous destination
|
||||
contiguous_dst = tensor[:, 2].contiguous()
|
||||
print(f" Method 2 (contiguous): dst.is_pinned()={contiguous_dst.is_pinned()}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Conclusion:")
|
||||
print("=" * 60)
|
||||
|
||||
if not slice_view.is_pinned():
|
||||
print("❌ Slicing LOSES pinned memory property!")
|
||||
print(" This causes Device-to-Pageable transfers (SLOW)")
|
||||
else:
|
||||
print("✓ Slicing maintains pinned memory property")
|
||||
|
||||
if contiguous_slice.is_pinned():
|
||||
print("✓ .contiguous() maintains pinned memory property")
|
||||
else:
|
||||
print("❌ .contiguous() also loses pinned memory property")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
@@ -1,124 +0,0 @@
|
||||
"""
|
||||
Test D2H transfer performance with pinned vs non-contiguous memory.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import time
|
||||
|
||||
print("=" * 60)
|
||||
print("Test: D2H Transfer Performance (for nsys profiling)")
|
||||
print("=" * 60)
|
||||
|
||||
# Setup
|
||||
num_layers = 8
|
||||
num_blocks = 16
|
||||
block_size = 1024
|
||||
num_kv_heads = 8
|
||||
head_dim = 64
|
||||
|
||||
# Allocate CPU cache (pinned)
|
||||
k_cache_cpu = torch.zeros(
|
||||
num_layers, num_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device="cpu", pin_memory=True
|
||||
)
|
||||
|
||||
# Allocate GPU cache
|
||||
k_cache_gpu = torch.randn(
|
||||
num_layers, 4, block_size, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device="cuda"
|
||||
)
|
||||
|
||||
# Warmup
|
||||
print("\nWarmup...")
|
||||
for _ in range(10):
|
||||
k_cache_cpu[:, 0].copy_(k_cache_gpu[:, 0], non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
print(f"\nTensor info:")
|
||||
print(f" k_cache_cpu.is_pinned(): {k_cache_cpu.is_pinned()}")
|
||||
print(f" k_cache_cpu.is_contiguous(): {k_cache_cpu.is_contiguous()}")
|
||||
print(f" k_cache_cpu[:, 0].is_pinned(): {k_cache_cpu[:, 0].is_pinned()}")
|
||||
print(f" k_cache_cpu[:, 0].is_contiguous(): {k_cache_cpu[:, 0].is_contiguous()}")
|
||||
|
||||
# Test 1: Non-contiguous slice (current approach)
|
||||
print(f"\n" + "=" * 60)
|
||||
print("Test 1: Non-contiguous slice copy (current approach)")
|
||||
print("=" * 60)
|
||||
|
||||
NUM_ITERS = 50 # Reduced for profiling
|
||||
|
||||
torch.cuda.nvtx.range_push("Test1_NonContiguous")
|
||||
times = []
|
||||
for i in range(NUM_ITERS):
|
||||
torch.cuda.nvtx.range_push(f"D2H_NonContig_{i}")
|
||||
start = time.perf_counter()
|
||||
k_cache_cpu[:, i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
times.append(time.perf_counter() - start)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"Average time: {avg_time * 1000:.3f} ms")
|
||||
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
|
||||
|
||||
# Test 2: Transpose to make dimension contiguous
|
||||
print(f"\n" + "=" * 60)
|
||||
print("Test 2: Transpose to contiguous dimension")
|
||||
print("=" * 60)
|
||||
|
||||
# Reshape to [num_blocks, num_layers, block_size, num_kv_heads, head_dim]
|
||||
k_cache_cpu_transposed = torch.zeros(
|
||||
num_blocks, num_layers, block_size, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device="cpu", pin_memory=True
|
||||
)
|
||||
|
||||
print(f" k_cache_cpu_transposed[0].is_pinned(): {k_cache_cpu_transposed[0].is_pinned()}")
|
||||
print(f" k_cache_cpu_transposed[0].is_contiguous(): {k_cache_cpu_transposed[0].is_contiguous()}")
|
||||
|
||||
torch.cuda.nvtx.range_push("Test2_Contiguous")
|
||||
times = []
|
||||
for i in range(NUM_ITERS):
|
||||
torch.cuda.nvtx.range_push(f"D2H_Contig_{i}")
|
||||
start = time.perf_counter()
|
||||
k_cache_cpu_transposed[i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
times.append(time.perf_counter() - start)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"Average time: {avg_time * 1000:.3f} ms")
|
||||
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
|
||||
|
||||
# Test 3: Fully contiguous buffer
|
||||
print(f"\n" + "=" * 60)
|
||||
print("Test 3: Fully contiguous buffer")
|
||||
print("=" * 60)
|
||||
|
||||
k_cache_cpu_flat = torch.zeros(
|
||||
num_layers * block_size * num_kv_heads * head_dim,
|
||||
dtype=torch.float16, device="cpu", pin_memory=True
|
||||
)
|
||||
|
||||
print(f" k_cache_cpu_flat.is_pinned(): {k_cache_cpu_flat.is_pinned()}")
|
||||
print(f" k_cache_cpu_flat.is_contiguous(): {k_cache_cpu_flat.is_contiguous()}")
|
||||
|
||||
torch.cuda.nvtx.range_push("Test3_FlatContiguous")
|
||||
times = []
|
||||
for i in range(NUM_ITERS):
|
||||
torch.cuda.nvtx.range_push(f"D2H_Flat_{i}")
|
||||
start = time.perf_counter()
|
||||
k_cache_cpu_flat.copy_(k_cache_gpu[:, 0].flatten(), non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
times.append(time.perf_counter() - start)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"Average time: {avg_time * 1000:.3f} ms")
|
||||
print(f"Bandwidth: {k_cache_cpu_flat.numel() * 2 / avg_time / 1e9:.2f} GB/s")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("test_pinned_transfer: PASSED")
|
||||
print("=" * 60)
|
||||
@@ -1,286 +0,0 @@
|
||||
"""
|
||||
Chunked Prefill + KV Cache Offload Simulation v2
|
||||
|
||||
改进:
|
||||
1. 简化日志输出
|
||||
2. 添加reduce时间
|
||||
3. 计算必须等待KV load完成
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, Future
|
||||
|
||||
# ============== 配置参数 ==============
|
||||
NUM_CHUNKS = 8
|
||||
GPU_SLOTS = 4
|
||||
|
||||
# 模拟时间 (秒)
|
||||
TIME_COMPUTE_BLOCK = 0.10 # 计算一个attention block
|
||||
TIME_REDUCE = 0.03 # 两个partial result做一次reduce
|
||||
TIME_TRANSFER = 0.08 # 传输一个KV cache
|
||||
TIME_PROJ = 0.02 # projection生成KV
|
||||
|
||||
# ============== 全局时间基准 ==============
|
||||
START_TIME = None
|
||||
|
||||
def now() -> float:
|
||||
"""返回相对开始的时间(ms)"""
|
||||
return (time.time() - START_TIME) * 1000
|
||||
|
||||
def log_compute(msg: str):
|
||||
"""计算队列日志(无缩进)"""
|
||||
print(f"[{now():7.1f}ms] [COMPUTE] {msg}")
|
||||
|
||||
def log_transfer(msg: str):
|
||||
"""传输队列日志(缩进)"""
|
||||
print(f"[{now():7.1f}ms] [TRANSFER] {msg}")
|
||||
|
||||
def log_info(msg: str):
|
||||
"""一般信息"""
|
||||
print(f"[{now():7.1f}ms] {msg}")
|
||||
|
||||
# ============== GPU Slot管理 ==============
|
||||
class GPUSlots:
|
||||
def __init__(self, num_slots: int):
|
||||
self.slots = [None] * num_slots # slot_id -> kv_idx
|
||||
self.kv_to_slot = {} # kv_idx -> slot_id
|
||||
self.lock = threading.Lock()
|
||||
# KV ready events: kv_idx -> Event
|
||||
self.kv_ready = {}
|
||||
|
||||
def alloc(self, kv_idx: int) -> int:
|
||||
with self.lock:
|
||||
for sid, val in enumerate(self.slots):
|
||||
if val is None:
|
||||
self.slots[sid] = kv_idx
|
||||
self.kv_to_slot[kv_idx] = sid
|
||||
# 创建ready event
|
||||
if kv_idx not in self.kv_ready:
|
||||
self.kv_ready[kv_idx] = threading.Event()
|
||||
return sid
|
||||
raise RuntimeError(f"No free slot for KV{kv_idx}")
|
||||
|
||||
def free(self, slot_id: int):
|
||||
with self.lock:
|
||||
kv_idx = self.slots[slot_id]
|
||||
if kv_idx is not None:
|
||||
del self.kv_to_slot[kv_idx]
|
||||
# 清除event
|
||||
if kv_idx in self.kv_ready:
|
||||
del self.kv_ready[kv_idx]
|
||||
self.slots[slot_id] = None
|
||||
|
||||
def free_kv(self, kv_idx: int):
|
||||
with self.lock:
|
||||
if kv_idx in self.kv_to_slot:
|
||||
sid = self.kv_to_slot[kv_idx]
|
||||
self.slots[sid] = None
|
||||
del self.kv_to_slot[kv_idx]
|
||||
if kv_idx in self.kv_ready:
|
||||
del self.kv_ready[kv_idx]
|
||||
|
||||
def mark_ready(self, kv_idx: int):
|
||||
"""标记KV已就绪(load完成或proj完成)"""
|
||||
with self.lock:
|
||||
if kv_idx in self.kv_ready:
|
||||
self.kv_ready[kv_idx].set()
|
||||
|
||||
def wait_ready(self, kv_idx: int):
|
||||
"""等待KV就绪"""
|
||||
with self.lock:
|
||||
event = self.kv_ready.get(kv_idx)
|
||||
if event:
|
||||
event.wait()
|
||||
|
||||
def has_kv(self, kv_idx: int) -> bool:
|
||||
with self.lock:
|
||||
return kv_idx in self.kv_to_slot
|
||||
|
||||
def state(self) -> str:
|
||||
with self.lock:
|
||||
return "[" + "][".join(
|
||||
f"KV{v}" if v is not None else "----"
|
||||
for v in self.slots
|
||||
) + "]"
|
||||
|
||||
# ============== 操作执行 ==============
|
||||
class Executor:
|
||||
def __init__(self, gpu: GPUSlots):
|
||||
self.gpu = gpu
|
||||
self.compute_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Compute")
|
||||
self.transfer_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Transfer")
|
||||
|
||||
def proj_kv(self, q_idx: int) -> Future:
|
||||
"""Projection生成KV,返回Future"""
|
||||
def task():
|
||||
log_compute(f"PROJ Q{q_idx}->KV{q_idx} START")
|
||||
time.sleep(TIME_PROJ)
|
||||
slot_id = self.gpu.alloc(q_idx)
|
||||
self.gpu.mark_ready(q_idx) # proj完成,KV立即可用
|
||||
log_compute(f"PROJ Q{q_idx}->KV{q_idx} END slot={slot_id} | {self.gpu.state()}")
|
||||
return slot_id
|
||||
return self.compute_pool.submit(task)
|
||||
|
||||
def compute_attn(self, q_idx: int, kv_indices: list) -> Future:
|
||||
"""计算attention block,会等待所有KV就绪"""
|
||||
def task():
|
||||
# 等待所有需要的KV就绪
|
||||
for kv_idx in kv_indices:
|
||||
self.gpu.wait_ready(kv_idx)
|
||||
|
||||
kv_str = ",".join(map(str, kv_indices))
|
||||
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] START")
|
||||
time.sleep(TIME_COMPUTE_BLOCK * len(kv_indices))
|
||||
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] END")
|
||||
return (q_idx, kv_indices)
|
||||
return self.compute_pool.submit(task)
|
||||
|
||||
def reduce(self, q_idx: int, num_partials: int) -> Future:
|
||||
"""Online softmax reduce多个partial结果"""
|
||||
def task():
|
||||
if num_partials <= 1:
|
||||
return
|
||||
# n个partial需要n-1次两两reduce
|
||||
num_reduces = num_partials - 1
|
||||
log_compute(f"REDUCE Q{q_idx} ({num_partials} partials) START")
|
||||
time.sleep(TIME_REDUCE * num_reduces)
|
||||
log_compute(f"REDUCE Q{q_idx} END")
|
||||
return self.compute_pool.submit(task)
|
||||
|
||||
def load_kv(self, kv_idx: int) -> Future:
|
||||
"""从CPU load KV到GPU"""
|
||||
def task():
|
||||
if self.gpu.has_kv(kv_idx):
|
||||
log_transfer(f"LOAD KV{kv_idx} SKIP (already on GPU)")
|
||||
return kv_idx
|
||||
|
||||
slot_id = self.gpu.alloc(kv_idx)
|
||||
log_transfer(f"LOAD KV{kv_idx} START -> slot{slot_id}")
|
||||
time.sleep(TIME_TRANSFER)
|
||||
self.gpu.mark_ready(kv_idx) # load完成,标记就绪
|
||||
log_transfer(f"LOAD KV{kv_idx} END | {self.gpu.state()}")
|
||||
return kv_idx
|
||||
return self.transfer_pool.submit(task)
|
||||
|
||||
def offload_kv(self, kv_idx: int) -> Future:
|
||||
"""从GPU offload KV到CPU"""
|
||||
def task():
|
||||
log_transfer(f"OFFLOAD KV{kv_idx} START")
|
||||
time.sleep(TIME_TRANSFER)
|
||||
self.gpu.free_kv(kv_idx)
|
||||
log_transfer(f"OFFLOAD KV{kv_idx} END | {self.gpu.state()}")
|
||||
return kv_idx
|
||||
return self.transfer_pool.submit(task)
|
||||
|
||||
def shutdown(self):
|
||||
self.compute_pool.shutdown(wait=True)
|
||||
self.transfer_pool.shutdown(wait=True)
|
||||
|
||||
# ============== 调度器 ==============
|
||||
def schedule_query(exe: Executor, q_idx: int):
|
||||
"""调度单个Query的处理"""
|
||||
print(f"\n{'='*50}")
|
||||
log_info(f"===== Query {q_idx} START =====")
|
||||
|
||||
hist_kv = list(range(q_idx)) # 历史KV: 0 ~ q_idx-1
|
||||
num_partials = 0
|
||||
|
||||
# Phase 1: Projection生成当前KV
|
||||
proj_fut = exe.proj_kv(q_idx)
|
||||
proj_fut.result() # 等待完成
|
||||
|
||||
# Phase 2: 对角块计算 + 同时prefetch历史KV
|
||||
# 启动对角块计算
|
||||
diag_fut = exe.compute_attn(q_idx, [q_idx])
|
||||
num_partials += 1
|
||||
|
||||
# 同时prefetch历史KV (最多3个slot可用)
|
||||
prefetch_slots = min(len(hist_kv), GPU_SLOTS - 1)
|
||||
prefetch_kv = hist_kv[:prefetch_slots]
|
||||
prefetch_futs = [exe.load_kv(kv) for kv in prefetch_kv]
|
||||
|
||||
# 等待对角块完成
|
||||
diag_fut.result()
|
||||
|
||||
# Phase 3: Offload当前KV释放slot
|
||||
offload_fut = exe.offload_kv(q_idx)
|
||||
|
||||
# 等待prefetch完成,然后计算这批历史KV
|
||||
for f in prefetch_futs:
|
||||
f.result()
|
||||
|
||||
if prefetch_kv:
|
||||
hist_fut = exe.compute_attn(q_idx, prefetch_kv)
|
||||
num_partials += 1
|
||||
else:
|
||||
hist_fut = None
|
||||
|
||||
# 等待offload完成
|
||||
offload_fut.result()
|
||||
|
||||
# Phase 4: 处理剩余历史KV
|
||||
remaining_kv = hist_kv[prefetch_slots:]
|
||||
computed_kv = prefetch_kv.copy()
|
||||
|
||||
while remaining_kv:
|
||||
# 等待上一批计算完成
|
||||
if hist_fut:
|
||||
hist_fut.result()
|
||||
|
||||
# 释放已计算的KV
|
||||
for kv in computed_kv:
|
||||
exe.gpu.free_kv(kv)
|
||||
|
||||
# Load下一批
|
||||
batch_size = min(len(remaining_kv), GPU_SLOTS)
|
||||
batch_kv = remaining_kv[:batch_size]
|
||||
remaining_kv = remaining_kv[batch_size:]
|
||||
|
||||
load_futs = [exe.load_kv(kv) for kv in batch_kv]
|
||||
for f in load_futs:
|
||||
f.result()
|
||||
|
||||
# 计算这批
|
||||
hist_fut = exe.compute_attn(q_idx, batch_kv)
|
||||
num_partials += 1
|
||||
computed_kv = batch_kv
|
||||
|
||||
# 等待最后一批计算完成
|
||||
if hist_fut:
|
||||
hist_fut.result()
|
||||
|
||||
# 清理GPU
|
||||
for kv in computed_kv:
|
||||
exe.gpu.free_kv(kv)
|
||||
|
||||
# Phase 5: Reduce所有partial results
|
||||
reduce_fut = exe.reduce(q_idx, num_partials)
|
||||
reduce_fut.result()
|
||||
|
||||
log_info(f"===== Query {q_idx} END =====")
|
||||
|
||||
def main():
|
||||
global START_TIME
|
||||
START_TIME = time.time()
|
||||
|
||||
print("Chunked Prefill + KV Cache Offload Simulation v2")
|
||||
print(f"Config: {NUM_CHUNKS} chunks, {GPU_SLOTS} GPU slots")
|
||||
print(f"Time: compute={TIME_COMPUTE_BLOCK}s, transfer={TIME_TRANSFER}s, reduce={TIME_REDUCE}s")
|
||||
|
||||
gpu = GPUSlots(GPU_SLOTS)
|
||||
exe = Executor(gpu)
|
||||
|
||||
try:
|
||||
for q_idx in range(NUM_CHUNKS):
|
||||
schedule_query(exe, q_idx)
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
log_info(f"ALL DONE! Total: {now():.1f}ms")
|
||||
finally:
|
||||
exe.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
175
tests/utils.py
Normal file
175
tests/utils.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Test utilities for nano-vllm.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Needle-in-Haystack Test Utilities
|
||||
# ============================================================
|
||||
|
||||
# Haystack filler paragraphs (various topics to create realistic context)
|
||||
HAYSTACK_PARAGRAPHS = [
|
||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||
"Many people are enjoying outdoor activities in the park. "
|
||||
"Birds are singing in the trees and children are playing on the swings. ",
|
||||
|
||||
"In the world of technology, new innovations continue to emerge every day. "
|
||||
"Researchers are working on advanced algorithms and computing systems. "
|
||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||
|
||||
"The history of human civilization spans thousands of years. "
|
||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||
|
||||
"Modern cooking combines traditional techniques with new ingredients. "
|
||||
"Chefs around the world experiment with flavors and presentations. "
|
||||
"Food brings people together and creates memorable experiences. ",
|
||||
|
||||
"The ocean covers more than seventy percent of Earth's surface. "
|
||||
"Marine ecosystems support an incredible diversity of life forms. "
|
||||
"Scientists continue to discover new species in the deep sea. ",
|
||||
|
||||
"Music has been a part of human culture since prehistoric times. "
|
||||
"Different genres evolved across various regions and time periods. "
|
||||
"Today, people can access millions of songs through digital platforms. ",
|
||||
|
||||
"Space exploration has revealed many secrets about our universe. "
|
||||
"Telescopes can observe galaxies billions of light years away. "
|
||||
"Future missions aim to establish human presence on other planets. ",
|
||||
|
||||
"The study of languages reveals patterns in human cognition. "
|
||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||
"Language continues to evolve with new words and expressions. ",
|
||||
]
|
||||
|
||||
|
||||
def generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
use_chat_template: bool = True,
|
||||
verbose: bool = True,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
target_length: Target total sequence length in tokens
|
||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||
needle_value: The secret value to hide in the haystack
|
||||
use_chat_template: Whether to use chat template for instruct models
|
||||
verbose: Whether to print generation info
|
||||
|
||||
Returns:
|
||||
(prompt, expected_answer): The full prompt and the expected needle value
|
||||
"""
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Question at the end
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
|
||||
# Estimate tokens for fixed parts
|
||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||
# Buffer for chat template, special tokens, etc.
|
||||
overhead_tokens = 100 if use_chat_template else 50
|
||||
|
||||
# Available tokens for haystack
|
||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||
if haystack_target_tokens < 100:
|
||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||
|
||||
# Build haystack by repeating paragraphs
|
||||
haystack_parts = []
|
||||
current_tokens = 0
|
||||
para_idx = 0
|
||||
|
||||
while current_tokens < haystack_target_tokens:
|
||||
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||
if current_tokens + para_tokens > haystack_target_tokens:
|
||||
break
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
para_idx += 1
|
||||
|
||||
# Calculate needle insertion point
|
||||
needle_idx = int(len(haystack_parts) * needle_position)
|
||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||
|
||||
# Insert needle
|
||||
haystack_parts.insert(needle_idx, needle)
|
||||
|
||||
# Assemble prompt
|
||||
full_text = "".join(haystack_parts)
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
# Use chat template for instruct models
|
||||
# For Qwen3, add /no_think to disable thinking mode
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
messages = [
|
||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
# Raw text format for base models
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
prompt = full_text + question
|
||||
|
||||
# Verify length
|
||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
if verbose:
|
||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||
"""Check if the model output contains the expected needle value."""
|
||||
# Clean output - remove special tokens and whitespace
|
||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||
output_clean = ' '.join(output_clean.split()).lower()
|
||||
expected_clean = expected.strip().lower()
|
||||
|
||||
# Check if expected value appears in output
|
||||
# Also try to find it as a standalone number
|
||||
if expected_clean in output_clean:
|
||||
return True
|
||||
|
||||
# Try to extract numbers and check if expected is among them
|
||||
numbers = re.findall(r'\d+', output_clean)
|
||||
return expected_clean in numbers
|
||||
|
||||
|
||||
def generate_random_token_ids(
|
||||
length: int,
|
||||
vocab_size: int = 10000,
|
||||
seed: int = 42,
|
||||
) -> list:
|
||||
"""
|
||||
Generate random token IDs for testing.
|
||||
|
||||
Args:
|
||||
length: Number of tokens to generate
|
||||
vocab_size: Maximum token ID (exclusive)
|
||||
seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
List of random token IDs
|
||||
"""
|
||||
from random import randint, seed as set_seed
|
||||
set_seed(seed)
|
||||
return [randint(0, vocab_size - 1) for _ in range(length)]
|
||||
Reference in New Issue
Block a user