Compare commits

..

29 Commits

Author SHA1 Message Date
Zijie Tian
6575099a06 [refactor] Cleanup unused code after perf_opt merge
Removed ~460 lines of unused/redundant code from offload_engine.py:
- CUDA gather methods (gathered_h2d_*, update_gather_indices)
- Legacy async transfer methods (prefetch_block_async, offload_block_async)
- Legacy sync/wait methods (wait_for_block, wait_all_transfers, sync_indices)
- Legacy compatibility methods (load_to_compute_layer, wait_compute_layer)
- Unused gather_indices tensors and memory calculations

Updated class docstring to reflect current architecture.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 06:25:21 +08:00
Zijie Tian
8fd25d72d7 Merge perf_opt-1 and perf_opt-2 branches
Combines two performance optimization features:
- perf_opt-1: Cross-layer pipeline for decode (double-buffered layer cache)
- perf_opt-2: Per-layer prefill buffer for async offload

Both features are complementary and improve CPU offload performance.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 06:03:44 +08:00
Zijie Tian
ccf27d3a74 [claudesquad] update from 'perf_opt-1' on 07 Jan 26 05:58 CST 2026-01-07 05:58:23 +08:00
Zijie Tian
0ad86eb449 [claudesquad] update from 'perf_opt-2' on 07 Jan 26 05:58 CST 2026-01-07 05:58:10 +08:00
Zijie Tian
aa953ecb59 [refactor] Aligned the bench. 2026-01-07 04:25:06 +08:00
Zijie Tian
362f5e575f [fix] Fixed .gitignores . 2026-01-07 03:32:14 +08:00
Zijie Tian
58a06501c1 Merge branch 'zijie/debug_chunk-2' into tzj/minference 2026-01-07 03:30:38 +08:00
Zijie Tian
2a6e0a2c02 [feat] Added Quest Sparsity Policy. 2026-01-07 03:29:21 +08:00
Zijie Tian
2fe50bab50 [claudesquad] update from 'debug_chunk-2' on 07 Jan 26 03:27 CST 2026-01-07 03:27:27 +08:00
Zijie Tian
c99a6f3d3f [WIP] Before add Quest policy. 2026-01-07 02:32:30 +08:00
Zijie Tian
f240903013 [docs] Add GPU mutex instructions for multi-instance debugging
Add instructions for Claude instances to check GPU availability before
running CUDA operations, preventing conflicts when multiple instances
debug in parallel on a single GPU.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 01:42:59 +08:00
Zijie Tian
0e691f2d85 [WIP] move metadata to GPU. 2026-01-06 23:32:32 +08:00
Zijie Tian
edb5273e34 [WIP] Added basic test for quest. 2026-01-06 22:30:31 +08:00
Zijie Tian
690492e074 [WIP] Before refactor policies. 2026-01-06 20:47:55 +08:00
Zijie Tian
7cc8a394a5 [fix] Fixed bench_offload.py, BUT performance DEGRAD. 2026-01-06 18:46:48 +08:00
Zijie Tian
535f2037ab [WIP] Before fix bench_offload.py. 2026-01-06 18:41:08 +08:00
Zijie Tian
c7ac39dfbd [refactor] Before add sprae policy. 2026-01-05 21:19:24 +08:00
Zijie Tian
e554d5482b [refactor] Delete unnesscessory test, and refacrtor the offload prefix cache. 2026-01-05 20:31:42 +08:00
Zijie Tian
247c5312d9 [fix] Fixed decode misalign. 2026-01-05 19:00:44 +08:00
Zijie Tian
054aaff403 [fix] Fixed needle test bug. 2026-01-05 18:34:09 +08:00
Zijie Tian
d623043a3c [WIP] FIXED decode and prefill NEEDLE test. 2026-01-05 01:51:46 +08:00
Zijie Tian
e897380127 [test] Added test_align.py and Before change nanovllm attention. 2026-01-04 22:48:01 +08:00
Zijie Tian
24096431ed [refactor] refactor test_align.py. 2026-01-04 20:55:40 +08:00
Zijie Tian
772313db8f [refactor] Refactor the kvcache offload. 2026-01-04 19:37:03 +08:00
Zijie Tian
00ed17c640 [feat] Added debug tools. 2026-01-03 22:36:40 +08:00
Zijie Tian
9b52d25866 [docs] Update CLAUDE.md. 2026-01-03 20:46:00 +08:00
Zijie Tian
8c3418725b [refactor] Refactor needle test. 2026-01-03 19:19:37 +08:00
Zijie Tian
b3685c9190 [test] Added test_align.py 2026-01-03 18:55:58 +08:00
Zijie Tian
6927a75ac3 [refactor] refactor needle.py. 2026-01-03 18:33:48 +08:00
50 changed files with 3540 additions and 4407 deletions

1
.gitignore vendored
View File

@@ -196,3 +196,4 @@ cython_debug/
results/
outputs/
.local/

234
CLAUDE.md
View File

@@ -6,10 +6,119 @@ This file provides guidance to Claude Code when working with this repository.
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
## GPU Mutex for Multi-Instance Debugging
**IMPORTANT**: When running multiple Claude instances for parallel debugging, only one GPU (cuda:0) is available. Before executing ANY command that uses the GPU (python scripts, benchmarks, tests), Claude MUST:
1. **Check GPU availability** by running:
```bash
nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader
```
2. **If processes are running on GPU**:
- Wait and retry every 10 seconds until GPU is free
- Use this polling loop:
```bash
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do
echo "GPU busy, waiting 10s..."
sleep 10
done
```
3. **Only proceed** when `nvidia-smi --query-compute-apps=pid --format=csv,noheader` returns empty output
**Example workflow**:
```bash
# First check if GPU is in use
nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader
# If output is empty, proceed with your command
python bench_offload.py
# If output shows processes, wait until they finish
```
**Note**: This applies to ALL GPU operations including:
- Running tests (`python tests/test_*.py`)
- Running benchmarks (`python bench*.py`)
- Running examples (`python example.py`)
- Any script that imports torch/cuda
## Local Package Installation for Multi-Instance
**CRITICAL**: After ANY code modification in the `nanovllm/` directory, you MUST reinstall the package before running tests or benchmarks:
```bash
pip install -e . --prefix=./.local --no-deps
```
Then run with PYTHONPATH:
```bash
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
```
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
1. **Install to worktree-local directory**:
```bash
pip install -e . --prefix=./.local --no-deps
```
2. **Set PYTHONPATH before running any Python command**:
```bash
export PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH
```
3. **Combined example**:
```bash
# One-liner for running tests with local package
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python tests/test_needle.py
```
**Note**: The Python version in the path (python3.10) should match your environment.
**CRITICAL**: After making code changes to `nanovllm/` source files, you MUST reinstall the package for changes to take effect:
```bash
pip install -e . --prefix=./.local --no-deps
```
Without reinstallation, Python will use the old cached version and your changes will NOT be reflected!
## Sparse Attention
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
### Quest Sparse Policy
**Files**: `nanovllm/kvcache/sparse/quest.py`, `nanovllm/kvcache/sparse/policy.py`
Quest policy selects Top-K blocks based on query-key similarity bounds using min/max key metadata.
**Scoring Mechanism**:
```python
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
score_max = torch.einsum('hd,bhd->bh', q, key_max) # [num_blocks, kv_heads]
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] ← averaged!
```
**Critical Limitation - No Per-Head Scheduling**:
The `.mean(dim=-1)` averages scores across all heads, making a **unified** block selection for all heads:
```
Block A: head0 needs (+4), head1 doesn't (-4) → avg = 0 → NOT selected
Block B: head0 doesn't (-4), head1 needs (+4) → avg = 0 → NOT selected
Block C: both heads moderately need (+2, +2) → avg = +2 → selected
```
**Why Per-Head Scheduling is Infeasible**:
1. **Memory Layout**: GPU cache stores all heads together `[block_size, kv_heads, head_dim]`
2. **FlashAttention**: Requires complete heads - partial heads cause dimension mismatch
3. **Block Granularity**: If any head needs a block, the entire block (all heads) must be loaded
**Policy Types**:
- `FullAttentionPolicy`: `supports_prefill=True, supports_decode=True` - loads all blocks
- `QuestPolicy`: `supports_prefill=False, supports_decode=True` - decode-only Top-K selection
## Architecture
### Core Components
@@ -20,6 +129,74 @@ For sparse attention related content (block sparse attention, MInference, FlexPr
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
## PyTorch Hooks for Debugging
### Hook Positions in Qwen3
```
decoder_layer
├── input_layernorm (RMSNorm)
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
│ ├── q_proj → q_norm → RoPE
│ ├── k_proj → k_norm → RoPE
│ ├── v_proj
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
│ │ └── FlashAttention / SDPA
│ └── o_proj
├── post_attention_layernorm (RMSNorm)
└── mlp (Qwen3MLP)
```
### Hook Types & Data Shapes
| Hook Position | Type | Captured Data |
|---------------|------|---------------|
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
### Example: Capture Attention Outputs
```python
storage = {}
def make_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
if isinstance(output, tuple):
attn_output = output[0]
else:
attn_output = output
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
# Register hooks
hooks = []
for layer_idx, layer in enumerate(model.model.layers):
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
# Run inference...
# Cleanup
for hook in hooks:
hook.remove()
```
### Reference Implementation
Key files:
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
- `tests/test_needle.py`: Needle-in-haystack test for nanovllm
### Common Pitfalls
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
## CPU Offload System
### Ring Buffer Design
@@ -105,7 +282,6 @@ memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height,
**Files**:
- `csrc/sgdma_kernel.cu`, `csrc/sgdma.cpp`: CUDA extension
- `nanovllm/comm/sgdma.py`: Python API
- `tests/test_sgdma.py`: Standalone benchmark
- `kvcache/offload_engine.py`: Integration (4 methods updated)
### Integration Details
@@ -210,25 +386,59 @@ def _merge_output_kernel(...):
- Total GPU time: ~1,343 ms
- **Overall speedup with Triton**: 1.67x
### Correctness Verification
**Test**: `tests/test_chunked_attention.py`
- 12 test cases (6 configs × 2 dtypes)
- All tests PASS with max error < 0.01
- float16: max_diff=0.000488, mean_diff~0.00001
- bfloat16: max_diff=0.003906, mean_diff~0.0001
### Key Files
- `nanovllm/kvcache/chunked_attention.py`: Triton kernels + merge function
- `tests/test_chunked_attention.py`: Correctness tests
- `tests/test_attention_offload.py`: Performance profiling
## Known Issues and Fixes
### Partial Last Block Bug (FIXED ✓)
**Problem**: When prefill token count is not an exact multiple of `block_size`, decode outputs garbage.
**Root Cause**: `_chunked_decode_attention` calculated `last_block_valid_tokens` using `len(seq) - 1`, which increases during decode. But CPU blocks are fixed after prefill!
```python
# BUG: len(seq) increases each decode step
total_prefill_tokens = len(seq) - 1 # Wrong!
last_block_valid_tokens = total_prefill_tokens % block_size # Reads garbage from CPU
```
**Fix**: Cache original prefill length in `HybridKVCacheManager.get_prefill_len()`:
```python
# CORRECT: Use cached prefill length
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Fixed value
```
**Files Modified**:
- `nanovllm/kvcache/hybrid_manager.py`: Added `_prefill_len` dict and `get_prefill_len()` method
- `nanovllm/layers/attention.py`: Use `get_prefill_len()` instead of `len(seq) - 1`
### Block Size 4096 Race Condition (FIXED ✓)
**Problem**: `block_size=4096` with multiple chunks produced `index_copy_(): index out of bounds` CUDA error during Chunk 2 processing.
**Root Cause**: Race condition between default stream and compute stream. In `_prepare_chunked_offload_chunk()`, `slot_mapping` tensor was created with `non_blocking=True` H2D transfer on the default stream. However, `store_kvcache` runs on `compute_stream`. Without synchronization, `compute_stream` could use `slot_mapping` before its transfer completed, causing corrupted indices.
**Fix** (in `attention.py`):
```python
if is_chunked_offload:
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
```
**Tested block sizes**: 512, 1024, 4096, 8192 - all pass.
## Configuration
| Parameter | Default | Notes |
|-----------|---------|-------|
| `kvcache_block_size` | 4096 | Tokens per block |
| `kvcache_block_size` | 1024 | Tokens per block (4096 now works after race condition fix) |
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
| `enable_cpu_offload` | False | Enable for long context |

103
DEBUG_SUMMARY.md Normal file
View File

@@ -0,0 +1,103 @@
# Chunked Prefill Bug Debug Summary
## Problem
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
The model generates completely wrong tokens instead of the expected "7492".
## Investigation Progress
### 1. Stream Synchronization Fix (Completed)
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
- Added sync: `default_stream.wait_stream(compute_stream)` before return
### 2. KV Cache Alignment Verification (Completed)
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
**RoPE Alignment:**
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
- Confirmed RoPE is NOT the cause of the bug
**K/V Cache Alignment (Chunk 0):**
- Cosine similarity: ~1.0 for all layers
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
- Mean diff: < 0.001
- **Conclusion: K/V cache offload is working correctly**
### 3. Layer Output Divergence Analysis (Completed)
Created per-chunk layer output comparison:
**Chunk 0 (tokens 0-4096):**
- All layers pass with excellent cosine similarity (0.999+)
- Max diff grows in later layers but within acceptable range
**Chunk 1 (tokens 4096-8192):**
- Layers 0-19: OK (cosine ~1.0)
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
- Divergence correlates with later transformer layers
### 4. Critical Discovery: Single-Chunk Offload Also Fails
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
```
# Without offload: PASSES
python tests/test_needle.py --input-len 2048
# Output: "7492" (correct)
# With offload: FAILS
python tests/test_needle.py --enable-offload --input-len 2048
# Output: "The Ble White Th G Lopsiswin..." (garbage)
```
**This proves the bug is NOT in:**
- Chunked attention logic (merge_attention_outputs)
- Multi-chunk KV loading
- Ring buffer pipeline
**The bug IS in:**
- The decode path when CPU offload is enabled
- How prefilled KV is loaded/used during decode
### 5. Decode Path Analysis (In Progress)
The decode path in CPU offload mode:
1. Prefill writes KV to GPU, offloads to CPU
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
3. Attend to prefilled KV + accumulated decode tokens
4. Merge results
**Observations:**
- `prefilled_blocks` set is empty after decode (should contain block IDs)
- CPU cache has valid data (reasonable mean/std values)
- Decode buffer has zeros (decode tokens not being stored correctly?)
## Current Status
### Working
- Stream synchronization fixes
- K/V cache offload to CPU (verified alignment)
- RoPE implementation
- Chunked prefill attention for first chunk
### Not Working
- Decode with CPU offload (even for single-chunk inputs)
- Multi-chunk attention (divergence in later layers for chunk 1)
## Next Steps
1. Debug why `prefilled_blocks` is empty after decode
2. Check if decode path correctly loads KV from CPU
3. Verify decode buffer is being written correctly
4. Compare decode attention outputs between offload and non-offload modes
## Key Files
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
## Hypothesis
The decode path fails because:
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
3. OR there's a stream synchronization issue specific to decode path

View File

@@ -5,7 +5,7 @@ from nanovllm import LLM, SamplingParams
def bench_decode(llm, num_seqs, input_len, output_len):
"""Benchmark decode performance (original test)"""
"""Benchmark decode performance"""
seed(0)
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
@@ -13,9 +13,14 @@ def bench_decode(llm, num_seqs, input_len, output_len):
t = time.time()
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
t = time.time() - t
total_output_tokens = num_seqs * output_len
throughput = total_output_tokens / t
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
# Calculate metrics
prefill_tokens = num_seqs * input_len
decode_tokens = num_seqs * output_len
decode_throughput = decode_tokens / t
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
def bench_prefill(llm, num_seqs, input_len):
@@ -35,32 +40,49 @@ def bench_prefill(llm, num_seqs, input_len):
def main():
import argparse
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance")
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
parser.add_argument("--output-len", type=int, default=128, help="Output length in tokens")
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
args = parser.parse_args()
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
max_len = 131072 # 128K tokens
llm = LLM(path, enforce_eager=False, max_model_len=max_len, max_num_batched_tokens=max_len)
max_len = args.max_len
print(f"\n[nanovllm GPU] max_len={max_len}")
llm = LLM(
path,
enforce_eager=False,
max_model_len=max_len,
max_num_batched_tokens=max_len,
)
# Warmup
llm.generate(["Benchmark: "], SamplingParams())
print("\nWarming up...")
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
# Default input lengths based on max_len
# Default input lengths
prefill_input_len = args.input_len if args.input_len else max_len - 1
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
print("=" * 60)
print("Prefill Benchmark (GPU)")
print("=" * 60)
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
# Determine which benchmarks to run
run_prefill = not args.bench_decode or args.bench_all
run_decode = args.bench_decode or args.bench_all
# print("=" * 60)
# print("Decode Benchmark (GPU)")
# print("=" * 60)
# bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
if run_prefill:
print("\n" + "=" * 60)
print("Prefill Benchmark (nanovllm GPU)")
print("=" * 60)
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
if run_decode:
print("\n" + "=" * 60)
print("Decode Benchmark (nanovllm GPU)")
print("=" * 60)
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
if __name__ == "__main__":

View File

@@ -3,14 +3,9 @@ import time
from random import randint, seed
from nanovllm import LLM, SamplingParams
# Import sparse policy classes
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
def bench_decode(llm, num_seqs, input_len, output_len):
"""Benchmark decode performance (original test)"""
"""Benchmark decode performance"""
seed(0)
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
@@ -18,9 +13,17 @@ def bench_decode(llm, num_seqs, input_len, output_len):
t = time.time()
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
t = time.time() - t
total_output_tokens = num_seqs * output_len
throughput = total_output_tokens / t
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
# Calculate metrics
prefill_tokens = num_seqs * input_len
decode_tokens = num_seqs * output_len
# Approximate: assume prefill takes ~input_len/prefill_speed, rest is decode
# For more accurate measurement, we'd need internal timing
decode_throughput = decode_tokens / t # This includes prefill time, so it's a lower bound
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
def bench_prefill(llm, num_seqs, input_len):
@@ -38,102 +41,70 @@ def bench_prefill(llm, num_seqs, input_len):
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
def setup_quest_policy(llm, topk_blocks=8, threshold_blocks=4):
"""
Setup Quest sparse policy for decode phase.
Uses HybridPolicy: Full attention for prefill, Quest Top-K for decode.
"""
import torch
kvcache_manager = llm.model_runner.kvcache_manager
offload_engine = kvcache_manager.offload_engine
# Get model parameters from offload engine
num_layers = offload_engine.num_layers
num_kv_heads = offload_engine.num_kv_heads
head_dim = offload_engine.head_dim
num_cpu_blocks = kvcache_manager.num_cpu_blocks
dtype = offload_engine.k_cache_cpu.dtype
print(f"Setting up Quest policy:")
print(f" num_layers={num_layers}, num_kv_heads={num_kv_heads}, head_dim={head_dim}")
print(f" num_cpu_blocks={num_cpu_blocks}, dtype={dtype}")
print(f" topk_blocks={topk_blocks}, threshold_blocks={threshold_blocks}")
# Create BlockMetadataManager for storing min/max keys
metadata = BlockMetadataManager(
num_blocks=num_cpu_blocks,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
# Create Quest policy for decode
quest_config = QuestConfig(
topk_blocks=topk_blocks,
threshold_blocks=threshold_blocks,
)
quest_policy = QuestPolicy(quest_config, metadata)
# Create Hybrid policy: Full for prefill, Quest for decode
hybrid_policy = HybridPolicy(
prefill_policy=FullAttentionPolicy(),
decode_policy=quest_policy,
)
# Set the policy
kvcache_manager.set_sparse_policy(hybrid_policy)
print(f" Policy set: HybridPolicy(prefill=Full, decode=Quest)")
return hybrid_policy
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--no-sparse", action="store_true", help="Disable sparse attention (baseline)")
parser.add_argument("--topk", type=int, default=8, help="Top-K blocks for Quest")
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens (default: max_len - 1 for prefill, max_len - output_len for decode)")
parser.add_argument("--output-len", type=int, default=128, help="Output length in tokens")
from nanovllm.config import SparsePolicyType
parser = argparse.ArgumentParser(description="Benchmark CPU offload performance")
parser.add_argument("--enable-quest", action="store_true", help="Enable Quest sparse attention for decode")
parser.add_argument("--topk", type=int, default=16, help="Top-K blocks for Quest (default: 16)")
parser.add_argument("--threshold", type=int, default=4, help="Apply sparse only when blocks > threshold (default: 4)")
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (default: 6)")
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
args = parser.parse_args()
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
max_len = 131072 # 128K tokens
max_len = args.max_len
# Setup policy configuration
if args.enable_quest:
sparse_policy = SparsePolicyType.QUEST
print(f"\n[Quest Sparse Attention] topk={args.topk}, threshold={args.threshold}")
else:
sparse_policy = SparsePolicyType.FULL
print("\n[Full Attention] baseline (no sparse)")
print(f"[Config] max_len={max_len}, num_gpu_blocks={args.num_gpu_blocks}")
llm = LLM(
path,
enforce_eager=False,
max_model_len=max_len,
max_num_batched_tokens=max_len,
enable_cpu_offload=True,
num_gpu_blocks=6, # Small GPU buffer for offload testing
num_gpu_blocks=args.num_gpu_blocks,
sparse_policy=sparse_policy,
sparse_topk_blocks=args.topk,
sparse_threshold_blocks=args.threshold,
)
if not args.no_sparse:
# Setup Quest policy for decode (Top-K blocks, apply when > 4 blocks)
setup_quest_policy(llm, topk_blocks=args.topk, threshold_blocks=4)
print(f"\n[Quest Sparse Attention] topk={args.topk}")
else:
print("\n[Full Attention] No sparse policy (baseline)")
# Warmup
llm.generate(["Benchmark: "], SamplingParams())
print("\nWarming up...")
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
# Default input lengths based on max_len
# Default input lengths
prefill_input_len = args.input_len if args.input_len else max_len - 1
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
print("=" * 60)
print("Prefill Benchmark (CPU Offload)")
print("=" * 60)
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
# Determine which benchmarks to run
run_prefill = not args.bench_decode or args.bench_all
run_decode = args.bench_decode or args.bench_all
# print("=" * 60)
# print("Decode Benchmark (CPU Offload)")
# print("=" * 60)
# bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
if run_prefill:
print("\n" + "=" * 60)
print("Prefill Benchmark (CPU Offload)")
print("=" * 60)
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
if run_decode:
print("\n" + "=" * 60)
print("Decode Benchmark (CPU Offload)")
print("=" * 60)
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
if __name__ == "__main__":

View File

@@ -6,7 +6,7 @@ from vllm import LLM, SamplingParams
def bench_decode(llm, num_seqs, input_len, output_len):
"""Benchmark decode performance (original test)"""
"""Benchmark decode performance"""
seed(0)
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
@@ -15,9 +15,14 @@ def bench_decode(llm, num_seqs, input_len, output_len):
t = time.time()
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
t = time.time() - t
total_output_tokens = num_seqs * output_len
throughput = total_output_tokens / t
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
# Calculate metrics
prefill_tokens = num_seqs * input_len
decode_tokens = num_seqs * output_len
decode_throughput = decode_tokens / t
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
def bench_prefill(llm, num_seqs, input_len):
@@ -38,32 +43,50 @@ def bench_prefill(llm, num_seqs, input_len):
def main():
import argparse
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="Benchmark vLLM performance (for comparison)")
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
parser.add_argument("--output-len", type=int, default=128, help="Output length in tokens")
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
args = parser.parse_args()
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
max_len = 131072 # 128K tokens
llm = LLM(path, enforce_eager=False, max_model_len=max_len, max_num_seqs=128, gpu_memory_utilization=0.9)
max_len = args.max_len
print(f"\n[vLLM] max_len={max_len}")
llm = LLM(
path,
enforce_eager=False,
max_model_len=max_len,
max_num_seqs=128,
gpu_memory_utilization=0.9,
)
# Warmup
llm.generate([dict(prompt_token_ids=[0])], SamplingParams())
print("\nWarming up...")
llm.generate([dict(prompt_token_ids=[0, 1, 2])], SamplingParams(max_tokens=10))
# Default input lengths based on max_len
# Default input lengths
prefill_input_len = args.input_len if args.input_len else max_len - 1
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
print("=" * 60)
print("Prefill Benchmark (vLLM)")
print("=" * 60)
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
# Determine which benchmarks to run
run_prefill = not args.bench_decode or args.bench_all
run_decode = args.bench_decode or args.bench_all
# print("=" * 60)
# print("Decode Benchmark (vLLM)")
# print("=" * 60)
# bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
if run_prefill:
print("\n" + "=" * 60)
print("Prefill Benchmark (vLLM)")
print("=" * 60)
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
if run_decode:
print("\n" + "=" * 60)
print("Decode Benchmark (vLLM)")
print("=" * 60)
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
if __name__ == "__main__":

View File

@@ -1,9 +1,16 @@
import os
from dataclasses import dataclass
from enum import Enum, auto
from transformers import AutoConfig
import torch
class SparsePolicyType(Enum):
"""Sparse attention policy types."""
FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only)
@dataclass
class Config:
model: str
@@ -15,7 +22,7 @@ class Config:
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
kvcache_block_size: int = 4096
kvcache_block_size: int = 1024
num_kvcache_blocks: int = -1
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
@@ -30,9 +37,9 @@ class Config:
num_cpu_kvcache_blocks: int = -1
# Sparse attention configuration
sparse_policy: str | None = None # "vertical_slash", "quest", "streaming_llm", or None
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
# Quest: decode-only sparse attention with Top-K block selection
# FULL: no sparse attention (load all blocks)
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold

View File

@@ -0,0 +1,49 @@
"""
Breakpoint debugging tools for aligning nanovllm with reference implementations.
This module provides a generator-based breakpoint aligner that enables step-by-step
comparison between nanovllm and torch reference model outputs.
Example:
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
>>> from tests.modeling_qwen3 import Qwen3ForCausalLM
>>>
>>> # Load models
>>> torch_model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch.float16)
>>> nanovllm_model = ... # Your nanovllm model
>>>
>>> # Create adapters
>>> ref = TorchSteppable(torch_model)
>>> test = NanovllmSteppable(nanovllm_model)
>>>
>>> # Run alignment
>>> aligner = BreakpointAligner(ref, test)
>>> result = aligner.align(input_ids)
>>> print(result)
"""
from .breakpoints import BreakpointType, Breakpoint
from .comparator import TensorComparator, ComparisonResult
from .aligner import BreakpointAligner, AlignmentResult
from .adapters import SteppableModel, TorchSteppable, NanovllmSteppable
from .utils import setup_prefill_context, setup_decode_context, cleanup_context
__all__ = [
# Core classes
"BreakpointAligner",
"AlignmentResult",
# Breakpoints
"BreakpointType",
"Breakpoint",
# Comparator
"TensorComparator",
"ComparisonResult",
# Adapters
"SteppableModel",
"TorchSteppable",
"NanovllmSteppable",
# Utils
"setup_prefill_context",
"setup_decode_context",
"cleanup_context",
]

View File

@@ -0,0 +1,11 @@
"""Model adapters for breakpoint alignment."""
from .base import SteppableModel
from .torch_adapter import TorchSteppable
from .nanovllm_adapter import NanovllmSteppable
__all__ = [
"SteppableModel",
"TorchSteppable",
"NanovllmSteppable",
]

View File

@@ -0,0 +1,59 @@
"""Base class for steppable model adapters."""
from abc import ABC, abstractmethod
from typing import Generator, Set, Optional
import torch
from ..breakpoints import Breakpoint, BreakpointType
class SteppableModel(ABC):
"""
Abstract base class for models that can yield at breakpoints.
Subclasses implement the step() method as a generator that yields
Breakpoint objects at each enabled breakpoint during forward pass.
"""
def __init__(self, enabled_breakpoints: Optional[Set[BreakpointType]] = None):
"""
Args:
enabled_breakpoints: Set of breakpoint types to yield at.
If None, yields at all breakpoints.
"""
self.enabled_breakpoints = enabled_breakpoints
def is_enabled(self, bp_type: BreakpointType) -> bool:
"""Check if a breakpoint type is enabled."""
if self.enabled_breakpoints is None:
return True
return bp_type in self.enabled_breakpoints
@abstractmethod
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Generator that yields Breakpoint objects at enabled breakpoints.
Args:
input_ids: Input token IDs
positions: Position IDs (optional, auto-generated if None)
is_prefill: True for prefill phase, False for decode
Yields:
Breakpoint objects at each enabled checkpoint
Returns:
Final output tensor (logits)
"""
pass
@property
@abstractmethod
def num_layers(self) -> int:
"""Return the number of decoder layers."""
pass

View File

@@ -0,0 +1,235 @@
"""Nanovllm model adapter for breakpoint alignment."""
from typing import Generator, Set, Optional, Dict, Any, List
import torch
from nanovllm.utils.context import set_context, reset_context
from ..breakpoints import Breakpoint, BreakpointType
from .base import SteppableModel
class NanovllmSteppable(SteppableModel):
"""
Steppable adapter for nanovllm Qwen3 implementation.
Uses PyTorch hooks to capture intermediate values during forward pass,
then yields them as breakpoints after execution completes.
Key challenges handled:
1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden]
2. Context-based attention: must call set_context() before forward
3. Fused operations: decoder layer returns (hidden_states, residual) tuple
"""
def __init__(
self,
model,
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
):
"""
Args:
model: Qwen3ForCausalLM from nanovllm
enabled_breakpoints: Set of breakpoint types to yield at
"""
super().__init__(enabled_breakpoints)
self.model = model
self.model.eval()
self._hooks: List[Any] = []
self._captured: Dict[str, torch.Tensor] = {}
@property
def num_layers(self) -> int:
return len(self.model.model.layers)
def _register_hooks(self):
"""Register forward hooks on all relevant modules."""
self._hooks = []
self._captured = {}
# Hook for embedding output
def embed_hook(module, input, output):
self._captured["embed"] = output.detach().clone()
self._hooks.append(
self.model.model.embed_tokens.register_forward_hook(embed_hook)
)
# Hooks for each decoder layer
for layer_idx in range(self.num_layers):
layer = self.model.model.layers[layer_idx]
def make_layer_hook(idx):
def hook(module, input, output):
# Decoder layer returns (hidden_states, residual)
# hidden_states is MLP output, residual is accumulated residual
# To match torch reference, we need hidden_states + residual
if isinstance(output, tuple) and len(output) >= 2:
hidden_states, residual = output[0], output[1]
full_output = hidden_states + residual
else:
full_output = output
self._captured[f"layer_{idx}"] = full_output.detach().clone()
return hook
self._hooks.append(
layer.register_forward_hook(make_layer_hook(layer_idx))
)
# Hook for final norm
def final_norm_hook(module, input, output):
# Final norm returns (hidden_states, _) for fused add
hidden_states = output[0] if isinstance(output, tuple) else output
self._captured["final_norm"] = hidden_states.detach().clone()
self._hooks.append(
self.model.model.norm.register_forward_hook(final_norm_hook)
)
# Hook for lm_head
def lm_head_hook(module, input, output):
self._captured["lm_head"] = output.detach().clone()
self._hooks.append(
self.model.lm_head.register_forward_hook(lm_head_hook)
)
def _remove_hooks(self):
"""Remove all registered hooks."""
for hook in self._hooks:
hook.remove()
self._hooks = []
def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool):
"""
Set up nanovllm context for forward pass.
For alignment testing, we use simple context without real KV cache.
"""
if is_prefill:
# Prefill: process all tokens at once
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
# Use -1 for slot_mapping to skip KV cache writes
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
slot_mapping=slot_mapping,
is_chunked_prefill=False,
)
else:
# Decode: single token generation
# For decode, we need context_lens and block_tables
# For alignment testing without real KV cache, we use minimal setup
context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device)
# Single token slot
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
# Empty block tables (no KV cache)
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Normalize nanovllm tensor shape to [batch, seq_len, ...].
nanovllm uses [num_tokens, ...] format without batch dimension.
We add batch dimension for comparison with torch model.
"""
if tensor.dim() == 2: # [num_tokens, hidden_size]
return tensor.unsqueeze(0)
return tensor
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Execute nanovllm forward pass with hooks to capture breakpoints.
Unlike the torch adapter which manually steps through each component,
we run the full forward pass and collect captured values afterward.
"""
# Ensure 1D for nanovllm (it expects [num_tokens])
if input_ids.dim() == 2:
input_ids = input_ids.squeeze(0)
seq_len = input_ids.numel()
device = input_ids.device
# Generate position IDs if not provided
if positions is None:
positions = torch.arange(seq_len, device=device)
elif positions.dim() == 2:
positions = positions.squeeze(0)
# Register hooks
self._register_hooks()
try:
# Setup context for attention
self._setup_context(seq_len, device, is_prefill)
# Run forward pass (hooks capture everything)
with torch.no_grad():
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
reset_context()
# Yield breakpoints in order from captured data
# EMBEDDING
if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.EMBEDDING,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["embed"]),
name="Embedding",
)
# LAYER_OUTPUT for each layer
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
for layer_idx in range(self.num_layers):
key = f"layer_{layer_idx}"
if key in self._captured:
yield Breakpoint(
bp_type=BreakpointType.LAYER_OUTPUT,
layer_idx=layer_idx,
tensor=self._normalize_tensor(self._captured[key]),
name=f"Layer {layer_idx}",
)
# FINAL_NORM
if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.FINAL_NORM,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["final_norm"]),
name="Final Norm",
)
# LM_HEAD
if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.LM_HEAD,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["lm_head"]),
name="LM Head",
)
return logits
finally:
self._remove_hooks()
self._captured = {}

View File

@@ -0,0 +1,119 @@
"""Torch reference model adapter for breakpoint alignment."""
from typing import Generator, Set, Optional
import torch
from ..breakpoints import Breakpoint, BreakpointType
from .base import SteppableModel
class TorchSteppable(SteppableModel):
"""
Steppable adapter for the torch reference Qwen3 implementation.
Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model.
"""
def __init__(
self,
model,
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
):
"""
Args:
model: Qwen3ForCausalLM from tests/modeling_qwen3.py
enabled_breakpoints: Set of breakpoint types to yield at
"""
super().__init__(enabled_breakpoints)
self.model = model
self.model.eval()
@property
def num_layers(self) -> int:
return len(self.model.model.layers)
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Generator that manually steps through the torch model.
The torch model uses [batch, seq_len, hidden_size] shapes.
"""
# Ensure batch dimension
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Generate position IDs if not provided
if positions is None:
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
elif positions.dim() == 1:
positions = positions.unsqueeze(0)
with torch.no_grad():
# === EMBEDDING ===
hidden_states = self.model.model.embed_tokens(input_ids)
if self.is_enabled(BreakpointType.EMBEDDING):
yield Breakpoint(
bp_type=BreakpointType.EMBEDDING,
layer_idx=None,
tensor=hidden_states.detach().clone(),
name="Embedding",
)
# Create causal attention mask
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=device),
diagonal=1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# === DECODER LAYERS ===
for layer_idx, layer in enumerate(self.model.model.layers):
hidden_states, _, _ = layer(
hidden_states=hidden_states,
position_ids=positions,
attention_mask=attention_mask,
past_key_value=None,
use_cache=False,
output_qkv=False,
)
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
yield Breakpoint(
bp_type=BreakpointType.LAYER_OUTPUT,
layer_idx=layer_idx,
tensor=hidden_states.detach().clone(),
name=f"Layer {layer_idx}",
)
# === FINAL NORM ===
hidden_states = self.model.model.norm(hidden_states)
if self.is_enabled(BreakpointType.FINAL_NORM):
yield Breakpoint(
bp_type=BreakpointType.FINAL_NORM,
layer_idx=None,
tensor=hidden_states.detach().clone(),
name="Final Norm",
)
# === LM HEAD ===
logits = self.model.lm_head(hidden_states)
if self.is_enabled(BreakpointType.LM_HEAD):
yield Breakpoint(
bp_type=BreakpointType.LM_HEAD,
layer_idx=None,
tensor=logits.detach().clone(),
name="LM Head",
)
return logits

211
nanovllm/debug/aligner.py Normal file
View File

@@ -0,0 +1,211 @@
"""Breakpoint aligner for comparing model outputs."""
from dataclasses import dataclass, field
from typing import List, Tuple, Optional
import torch
from .breakpoints import Breakpoint
from .comparator import TensorComparator, ComparisonResult
from .adapters.base import SteppableModel
@dataclass
class AlignmentResult:
"""Result of an alignment test."""
passed: bool
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = field(default_factory=list)
failed_at: Optional[Breakpoint] = None
message: str = ""
def __repr__(self) -> str:
passed_count = sum(1 for _, _, c in self.all_comparisons if c.passed)
total = len(self.all_comparisons)
status = "PASSED" if self.passed else "FAILED"
return f"AlignmentResult({status}, {passed_count}/{total} breakpoints passed)"
class BreakpointAligner:
"""
Orchestrates alternating execution of reference and test models,
comparing outputs at each breakpoint.
Example:
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
>>> ref = TorchSteppable(torch_model)
>>> test = NanovllmSteppable(nanovllm_model)
>>> aligner = BreakpointAligner(ref, test)
>>> result = aligner.align(input_ids)
>>> print(result)
"""
def __init__(
self,
ref_model: SteppableModel,
test_model: SteppableModel,
comparator: Optional[TensorComparator] = None,
stop_on_error: bool = True,
verbose: bool = True,
):
"""
Args:
ref_model: Reference (torch) steppable model
test_model: Test (nanovllm) steppable model
comparator: Tensor comparator instance (uses default if None)
stop_on_error: If True, stop at first mismatch
verbose: If True, print comparison results
"""
self.ref_model = ref_model
self.test_model = test_model
self.comparator = comparator or TensorComparator()
self.stop_on_error = stop_on_error
self.verbose = verbose
def align(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> AlignmentResult:
"""
Run both models with same input, comparing at each breakpoint.
Args:
input_ids: Input token IDs
positions: Position IDs (optional)
is_prefill: True for prefill phase, False for decode
Returns:
AlignmentResult with pass/fail status and details
"""
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = []
if self.verbose:
phase = "prefill" if is_prefill else "decode"
print(f"\n{'='*60}")
print(f"Alignment Test ({phase})")
print(f"{'='*60}")
# Start both generators
ref_gen = self.ref_model.step(input_ids, positions, is_prefill)
test_gen = self.test_model.step(input_ids, positions, is_prefill)
try:
while True:
# Get next breakpoint from reference
try:
ref_bp = next(ref_gen)
except StopIteration:
break
# Get corresponding breakpoint from test
try:
test_bp = next(test_gen)
except StopIteration:
if self.verbose:
print(f"Test model ended early at {ref_bp.name}")
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=f"Test model ended early at {ref_bp.name}",
)
# Verify breakpoints match
if ref_bp.bp_type != test_bp.bp_type:
msg = f"Breakpoint type mismatch: {ref_bp.bp_type} vs {test_bp.bp_type}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
if ref_bp.layer_idx != test_bp.layer_idx:
msg = f"Layer index mismatch: {ref_bp.layer_idx} vs {test_bp.layer_idx}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
# Normalize shapes for comparison
ref_t = ref_bp.normalize_shape()
test_t = test_bp.normalize_shape()
# Handle shape mismatches
if ref_t.shape != test_t.shape:
if self.verbose:
print(f"[{ref_bp.name}] Shape mismatch: ref={ref_t.shape} vs test={test_t.shape}")
# Try to reshape if element count matches
if ref_t.numel() == test_t.numel():
test_t = test_t.view(ref_t.shape)
else:
msg = f"Shape mismatch at {ref_bp.name}: {ref_t.shape} vs {test_t.shape}"
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
# Compare tensors
result = self.comparator.compare(ref_t, test_t, ref_bp.name)
all_comparisons.append((ref_bp, test_bp, result))
if self.verbose:
status = "\u2713" if result.passed else "\u2717"
print(f"{status} [{ref_bp.name}] cos={result.cosine_similarity:.6f}, max_diff={result.max_abs_diff:.2e}")
if not result.passed and self.stop_on_error:
if self.verbose:
print(f"\nStopped at {ref_bp.name} (stop_on_error=True)")
print(result.message)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=f"Alignment failed at {ref_bp.name}",
)
# Check for extra test breakpoints
try:
extra_bp = next(test_gen)
msg = f"Test model has extra breakpoints starting at {extra_bp.name}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
message=msg,
)
except StopIteration:
pass
except Exception as e:
msg = f"Exception during alignment: {str(e)}"
if self.verbose:
print(msg)
raise
# Summary
all_passed = all(comp[2].passed for comp in all_comparisons)
passed_count = sum(1 for _, _, c in all_comparisons if c.passed)
total = len(all_comparisons)
if self.verbose:
print(f"{'='*60}")
status = "PASSED" if all_passed else "FAILED"
print(f"Result: {status} ({passed_count}/{total} breakpoints)")
print(f"{'='*60}\n")
return AlignmentResult(
passed=all_passed,
all_comparisons=all_comparisons,
message="All breakpoints aligned" if all_passed else "Some breakpoints failed",
)

View File

@@ -0,0 +1,39 @@
"""Breakpoint types and data structures for alignment debugging."""
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional
import torch
class BreakpointType(Enum):
"""Types of breakpoints in the model forward pass."""
EMBEDDING = auto() # After embed_tokens
LAYER_OUTPUT = auto() # After each decoder layer
FINAL_NORM = auto() # After final RMSNorm
LM_HEAD = auto() # After lm_head (logits)
@dataclass
class Breakpoint:
"""A captured breakpoint with tensor data."""
bp_type: BreakpointType
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
tensor: torch.Tensor
name: str
def normalize_shape(self) -> torch.Tensor:
"""
Normalize tensor shape for comparison.
nanovllm uses [num_tokens, hidden_size] while torch uses
[batch, seq_len, hidden_size]. This adds a batch dimension
to 2D tensors for comparison.
"""
if self.tensor.dim() == 2:
return self.tensor.unsqueeze(0)
return self.tensor
def __repr__(self) -> str:
shape_str = "x".join(str(d) for d in self.tensor.shape)
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"

View File

@@ -0,0 +1,94 @@
"""Tensor comparison utilities for alignment debugging."""
from dataclasses import dataclass
import torch
import torch.nn.functional as F
@dataclass
class ComparisonResult:
"""Result of comparing two tensors."""
passed: bool
cosine_similarity: float
max_abs_diff: float
mean_abs_diff: float
message: str
def __repr__(self) -> str:
status = "\u2713" if self.passed else "\u2717"
return f"{status} cos={self.cosine_similarity:.6f}, max_diff={self.max_abs_diff:.2e}"
class TensorComparator:
"""Compares tensors using cosine similarity and absolute differences."""
def __init__(
self,
cosine_threshold: float = 0.999,
max_diff_threshold: float = 0.1,
mean_diff_threshold: float = 0.01,
):
"""
Args:
cosine_threshold: Minimum cosine similarity to pass (0-1)
max_diff_threshold: Maximum allowed absolute difference
mean_diff_threshold: Maximum allowed mean absolute difference
"""
self.cosine_threshold = cosine_threshold
self.max_diff_threshold = max_diff_threshold
self.mean_diff_threshold = mean_diff_threshold
def compare(
self,
ref: torch.Tensor,
test: torch.Tensor,
name: str = "",
) -> ComparisonResult:
"""
Compare two tensors and return detailed result.
Args:
ref: Reference tensor
test: Test tensor
name: Name for the comparison (used in message)
Returns:
ComparisonResult with pass/fail status and metrics
"""
# Convert to float32 for comparison
ref_f = ref.float().flatten()
test_f = test.float().flatten()
# Cosine similarity
cos_sim = F.cosine_similarity(
ref_f.unsqueeze(0),
test_f.unsqueeze(0)
).item()
# Absolute differences
diff = (ref.float() - test.float()).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
# Check thresholds
passed = (
cos_sim >= self.cosine_threshold and
max_diff <= self.max_diff_threshold and
mean_diff <= self.mean_diff_threshold
)
status = "PASS" if passed else "FAIL"
message = (
f"[{name}] {status}\n"
f" Cosine Similarity: {cos_sim:.6f} (threshold: {self.cosine_threshold})\n"
f" Max Abs Diff: {max_diff:.6f} (threshold: {self.max_diff_threshold})\n"
f" Mean Abs Diff: {mean_diff:.6f} (threshold: {self.mean_diff_threshold})"
)
return ComparisonResult(
passed=passed,
cosine_similarity=cos_sim,
max_abs_diff=max_diff,
mean_abs_diff=mean_diff,
message=message,
)

51
nanovllm/debug/utils.py Normal file
View File

@@ -0,0 +1,51 @@
"""Utility functions for breakpoint alignment debugging."""
import torch
from nanovllm.utils.context import set_context, reset_context
def setup_prefill_context(seq_len: int, device: torch.device):
"""
Set up nanovllm context for prefill alignment testing.
Args:
seq_len: Sequence length
device: Target device
"""
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
slot_mapping=slot_mapping,
is_chunked_prefill=False,
)
def setup_decode_context(context_len: int, device: torch.device):
"""
Set up nanovllm context for decode alignment testing.
Args:
context_len: Context length (number of previous tokens)
device: Target device
"""
context_lens = torch.tensor([context_len], dtype=torch.int32, device=device)
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
def cleanup_context():
"""Reset nanovllm context after alignment testing."""
reset_context()

View File

@@ -62,6 +62,8 @@ class LLMEngine:
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
#> Calculate number of tokens processed
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
return outputs, num_tokens

View File

@@ -35,7 +35,10 @@ class ModelRunner:
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = GreedySampler()
#> Disable warmup for debugging
self.warmup_model()
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
@@ -59,7 +62,7 @@ class ModelRunner:
self.shm.unlink()
if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
# torch.cuda.synchronize()
dist.destroy_process_group()
def loop(self):
@@ -153,6 +156,22 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Initialize sparse policy if manager has one (CPU offload mode)
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
self.kvcache_manager.sparse_policy.initialize(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
logger.info(
f"Sparse policy initialized: {config.sparse_policy.name} "
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
)
# Log KV cache allocation info with detailed per-token breakdown
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
@@ -194,7 +213,7 @@ class ModelRunner:
f"block_size={self.block_size}"
)
# Bind layer caches to attention modules and set layer_id
#> Bind layer caches to attention modules and set layer_id
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
@@ -379,7 +398,7 @@ class ModelRunner:
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
# Check if Chunked Offload mode should be used (all blocks on CPU)
#> Check if Chunked Offload mode should be used (all blocks on CPU)
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill)
if use_chunked_offload:
@@ -388,6 +407,7 @@ class ModelRunner:
else:
return self.run_chunked_offload_decode(seqs)
#> Following Code will not use Chunked Offload mode
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
@@ -435,8 +455,6 @@ class ModelRunner:
3. After each chunk, offload from ring buffer slot to CPU
4. All N-1 other slots are used to load previous chunks for attention
"""
import sys
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
seq = seqs[0]
@@ -446,10 +464,9 @@ class ModelRunner:
total_tokens = len(seq)
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
logger.debug(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
f"total_chunks={num_chunks}",
file=sys.stderr)
f"total_chunks={num_chunks}")
chunk_idx = 0
logits = None
@@ -468,9 +485,8 @@ class ModelRunner:
# CPU block index for this chunk
block_idx = chunk_idx
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
f"write_slot={write_slot}",
file=sys.stderr)
logger.debug(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
f"write_slot={write_slot}")
# Prepare inputs
input_ids, positions = self._prepare_chunked_offload_chunk(
@@ -480,7 +496,7 @@ class ModelRunner:
if input_ids.numel() == 0:
break
# Run model forward
#> Run model forward
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
@@ -489,27 +505,17 @@ class ModelRunner:
logical_id = seq.block_table[block_idx]
self.kvcache_manager.prefilled_blocks.add(logical_id)
# NOTE: Per-layer offloading is now done in attention.forward
# Each layer offloads its KV to CPU immediately after computing attention.
# We just need to wait for the last offload to complete before reusing the slot.
if block_idx < len(cpu_block_ids):
# TODO: Sparse policy hook needs update for new GPU cache architecture
# The GPU cache no longer has layer dimension, so we can't access
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
# in attention.forward after per-layer offload.
pass
# Wait for offload to complete before next chunk
# (slot will be reused after N chunks)
offload_engine.wait_slot_offload(write_slot)
# NOTE: Per-layer async offloading is now done in attention.forward
# Each layer offloads from its own prefill buffer - no waiting required!
# The sparse policy hook is called in offload_prefill_buffer_async.
processed_tokens = chunk_end
chunk_idx += 1
# Wait for all offloads to complete
offload_engine.wait_all_offload_done()
# Wait for all async prefill offloads to complete
offload_engine.wait_all_prefill_offloads()
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
logger.debug(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks")
# Sample from last logits
# For chunked prefill, ParallelLMHead automatically selects last position's logits
@@ -570,14 +576,15 @@ class ModelRunner:
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with ring buffer (CPU is primary storage).
Run decode with cross-layer pipeline (CPU is primary storage).
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
Optimized with cross-layer pipeline: Layer N's data is loaded while
Layer N-1 computes, achieving transfer/compute overlap.
Key: decode_slot is dedicated to writing new KV, never used for loading.
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
Optimization: Cross-layer pipeline reduces effective latency by overlapping
H2D transfers with attention computation across layers.
"""
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
seq = seqs[0]
@@ -598,6 +605,12 @@ class ModelRunner:
# Get decode start position for accumulated token tracking
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
# Get prefilled CPU blocks for pipeline initialization
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
# Start cross-layer pipeline (preloads Layer 0's data)
offload_engine.start_decode_pipeline(cpu_block_table)
# Set up context for chunked decode
set_context(
is_prefill=False,
@@ -614,6 +627,9 @@ class ModelRunner:
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# End cross-layer pipeline
offload_engine.end_decode_pipeline()
# Only offload when block is full (pos_in_block == block_size - 1)
# This avoids unnecessary offloading on every decode step
if pos_in_block == self.block_size - 1:

View File

@@ -35,7 +35,29 @@ class Scheduler:
if Observer.ttft_start == 0:
Observer.ttft_start = perf_counter_ns()
seq = self.waiting[0]
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.kvcache_manager.can_allocate(seq):
# Check if sequence is too large
if not self.running and num_seqs == 0:
# First sequence, give clear error if it can't be scheduled
if len(seq) > self.max_num_batched_tokens:
raise RuntimeError(
f"Sequence too long: {len(seq)} tokens exceeds "
f"max_num_batched_tokens={self.max_num_batched_tokens}. "
f"Increase max_num_batched_tokens (set equal to max_model_len for long sequences)."
)
if not self.kvcache_manager.can_allocate(seq):
blocks_needed = seq.num_blocks
blocks_available = self.kvcache_manager.num_free_blocks
raise RuntimeError(
f"Cannot allocate KV cache for sequence: "
f"need {blocks_needed} blocks ({len(seq)} tokens), "
f"but only {blocks_available} blocks available. "
f"Increase max_model_len to allocate more blocks."
)
if num_batched_tokens + len(seq) > self.max_num_batched_tokens:
break
if not self.kvcache_manager.can_allocate(seq):
break
num_seqs += 1
self.kvcache_manager.allocate(seq)
@@ -60,7 +82,7 @@ class Scheduler:
num_seqs += 1
self.kvcache_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs
assert scheduled_seqs, "No sequences scheduled - this should not happen"
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False

View File

@@ -12,7 +12,7 @@ class SequenceStatus(Enum):
class Sequence:
block_size = 4096
block_size = 1024
counter = count()
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
@@ -34,6 +34,14 @@ class Sequence:
def __getitem__(self, key):
return self.token_ids[key]
def __repr__(self):
ids = self.token_ids
if len(ids) > 20:
ids_str = "[" + ", ".join(map(str, ids[:10])) + ", ..., " + ", ".join(map(str, ids[-5:])) + "]"
else:
ids_str = str(ids)
return f"Seq(id={self.seq_id}, status={self.status.name}, tokens={self.num_tokens}, ids={ids_str})"
@property
def is_finished(self):
return self.status == SequenceStatus.FINISHED

View File

@@ -56,14 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
# Need CPU offload: use hybrid manager
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.policies import get_policy
from nanovllm.kvcache.sparse import create_sparse_policy
from nanovllm.config import SparsePolicyType
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
# Create sparse policy from config enum
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
sparse_policy = create_sparse_policy(
sparse_policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
)
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=config.kvcache_block_size,
policy=policy,
policy=eviction_policy,
sparse_policy=sparse_policy,
)

View File

@@ -90,6 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
sparse_policy: "SparsePolicy" = None,
):
"""
Initialize hybrid manager with CPU-primary ring buffer design.
@@ -102,6 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU, used for prefix cache management)
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
@@ -113,6 +115,9 @@ class HybridKVCacheManager(KVCacheManager):
# Eviction policy
self.policy = policy or LRUPolicy()
# Sparse attention policy (set at construction time, immutable)
self.sparse_policy = sparse_policy
# Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [
LogicalBlock(i) for i in range(self.total_blocks)
@@ -128,6 +133,9 @@ class HybridKVCacheManager(KVCacheManager):
self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id
# Prefix cache (uses logical block IDs)
# NOTE: Currently WRITE-ONLY in offload mode - hashes are stored but never
#> used for cache hit detection. This is intentional: offload mode always
#> allocates new blocks and doesn't reuse existing ones.
self.hash_to_logical_id: Dict[int, int] = {}
# Step counter for policy
@@ -146,8 +154,9 @@ class HybridKVCacheManager(KVCacheManager):
# Key: sequence id, Value: starting position where decode began in current block
self._decode_start_pos: Dict[int, int] = {}
# Sparse attention policy (optional)
self.sparse_policy: Optional["SparsePolicy"] = None
# Track original prefill length (for correct last_block_valid_tokens calculation)
# Key: sequence id, Value: number of tokens from prefill (before decode started)
self._prefill_len: Dict[int, int] = {}
@property
def block_size(self) -> int:
@@ -173,6 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
sparse_policy=self.sparse_policy,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
@@ -180,24 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id)
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
"""
Set sparse attention policy for block selection.
The sparse policy determines which KV blocks to load from CPU
for each query chunk during chunked attention computation.
Args:
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
Example:
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
manager.set_sparse_policy(policy)
"""
self.sparse_policy = policy
logger.info(f"Sparse attention policy set: {policy}")
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence."""
return len(self.free_logical_ids) >= seq.num_blocks
@@ -254,14 +246,10 @@ class HybridKVCacheManager(KVCacheManager):
pos_in_block = seq_len % self._block_size
if pos_in_block == 1:
# Need new block
assert last_block.hash != -1
# Need new block (previous block is full)
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.hash = -1
block.token_ids = []
# Allocate new block to CPU (ring buffer mode)
if not self.free_cpu_blocks:
@@ -275,17 +263,13 @@ class HybridKVCacheManager(KVCacheManager):
block_table.append(logical_id)
elif pos_in_block == 0:
# Block is full, update hash for prefix cache
assert last_block.hash == -1
token_ids = seq.block(seq.num_blocks - 1)
prefix_hash = (
self.logical_blocks[block_table[-2]].hash
if len(block_table) > 1 else -1
)
h = self.compute_hash(token_ids, prefix_hash)
last_block.hash = h
last_block.token_ids = token_ids.copy()
self.hash_to_logical_id[h] = last_logical_id
# Block is full
# NOTE: Prefix cache disabled in offload mode
# If enabled, would compute hash and update:
# h = self.compute_hash(seq.block(seq.num_blocks - 1), prefix_hash)
# last_block.hash = h
# self.hash_to_logical_id[h] = last_logical_id
pass
def prepare_for_attention(
self,
@@ -365,8 +349,6 @@ class HybridKVCacheManager(KVCacheManager):
"""
assert not seq.block_table, "Sequence already has blocks"
h = -1 # Running hash for prefix cache
for i in range(seq.num_blocks):
# Allocate CPU block
if not self.free_cpu_blocks:
@@ -377,19 +359,10 @@ class HybridKVCacheManager(KVCacheManager):
cpu_block_id = self.free_cpu_blocks.popleft()
# Get token IDs for this block and compute hash
token_ids = seq.block(i)
if len(token_ids) == self._block_size:
h = self.compute_hash(token_ids, h)
else:
h = -1 # Incomplete block
# Allocate logical block
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.hash = h
block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else []
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
@@ -397,9 +370,11 @@ class HybridKVCacheManager(KVCacheManager):
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
# Update prefix cache
if h != -1:
self.hash_to_logical_id[h] = logical_id
# NOTE: Prefix cache disabled in offload mode
# If enabled, would compute hash and update:
# h = self.compute_hash(seq.block(i), prefix_hash)
# block.hash = h
# self.hash_to_logical_id[h] = logical_id
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
"""
@@ -542,6 +517,26 @@ class HybridKVCacheManager(KVCacheManager):
seq_id = id(seq)
self._decode_start_pos[seq_id] = 0
def get_prefill_len(self, seq: Sequence) -> int:
"""
Get the original prefill length for a sequence.
This is cached on first call to ensure correct last_block_valid_tokens
calculation during decode (the CPU blocks don't change after prefill).
Args:
seq: Sequence
Returns:
Number of tokens from prefill (before decode started)
"""
seq_id = id(seq)
if seq_id not in self._prefill_len:
# First decode step - store the prefill length
# len(seq) - 1 because current len includes the first decode token
self._prefill_len[seq_id] = len(seq) - 1
return self._prefill_len[seq_id]
def clear_decode_tracking(self, seq: Sequence) -> None:
"""
Clear decode position tracking for sequence.
@@ -553,6 +548,7 @@ class HybridKVCacheManager(KVCacheManager):
"""
seq_id = id(seq)
self._decode_start_pos.pop(seq_id, None)
self._prefill_len.pop(seq_id, None)
def __repr__(self) -> str:
return (

View File

@@ -17,6 +17,11 @@ from nanovllm.kvcache.kernels import gathered_copy_kv
from nanovllm.comm import memcpy_2d_async
from nanovllm.utils.logger import get_logger
# Import for type hints only (avoid circular import)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from nanovllm.kvcache.sparse import SparsePolicy
logger = get_logger("offload_engine")
@@ -35,14 +40,13 @@ class OffloadEngine:
High-performance CPU-GPU async transfer engine for KV cache offloading.
Memory layout:
- GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
- GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dimension)
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
- Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content)
CUDA Graph compatibility:
- gathered_h2d_layer() can be captured into CUDA graphs
- update_gather_indices() is called outside graphs to prepare indices
- All tensor addresses remain fixed across graph replays
Features:
- Unified ring buffer for chunked prefill/decode
- Per-layer prefill buffer for async offload
- Cross-layer pipeline for decode with double-buffering
"""
def __init__(
@@ -55,6 +59,7 @@ class OffloadEngine:
head_dim: int,
dtype: torch.dtype = torch.float16,
num_streams: int = 4,
sparse_policy: "SparsePolicy" = None,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
@@ -136,6 +141,64 @@ class OffloadEngine:
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
# ========== Cross-layer pipeline buffers for decode ==========
# Double-buffered layer cache for pipelined decode:
# - Buffer A: Current layer's prefilled KV being computed
# - Buffer B: Next layer's prefilled KV being loaded
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
self.layer_k_buffer_a = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_buffer_a = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_k_buffer_b = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_buffer_b = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
# Pipeline state tracking
self._pipeline_active = False
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
self._pipeline_next_layer_event = torch.cuda.Event()
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
self._pipeline_num_blocks = 0
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
# ========== Per-layer prefill buffer for async offload ==========
# During chunked prefill, all layers share the same GPU slot. This means
# each layer must wait for offload to complete before the next layer can
# write to the same slot. This serializes offloads and hurts performance.
# Solution: Maintain separate per-layer buffers for prefill.
# Each layer writes to its own buffer, enabling fully async offloads.
# Shape: [num_layers, block_size, kv_heads, head_dim]
self.prefill_k_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.prefill_v_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
prefill_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer prefill buffer: {prefill_buf_mb:.1f} MB")
# Per-layer offload events for async prefill offload
# Each layer has its own event to track offload completion
self.prefill_offload_events = [torch.cuda.Event() for _ in range(num_layers)]
# Per-layer transfer streams for parallel offloads
self.prefill_offload_streams = [torch.cuda.Stream() for _ in range(num_layers)]
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
@@ -146,19 +209,6 @@ class OffloadEngine:
dtype=dtype, device="cpu", pin_memory=True
)
# ========== Fixed-address gather indices (content is variable) ==========
# gather_indices[layer][i] = CPU block id to copy to GPU slot i
# -1 means no-op (skip this slot)
self.gather_indices_cpu = torch.empty(
num_layers, num_gpu_blocks,
dtype=torch.int64, device="cpu", pin_memory=True
)
self.gather_indices_cpu.fill_(-1)
self.gather_indices_gpu = torch.full(
(num_layers, num_gpu_blocks), -1,
dtype=torch.int64, device="cuda"
)
# Log memory allocation
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
@@ -201,7 +251,7 @@ class OffloadEngine:
# This prevents undefined behavior on first load_to_slot_layer call
for slot_idx in range(self.num_ring_slots):
self.ring_slot_compute_done[slot_idx].record()
torch.cuda.synchronize() # Ensure all events are recorded
# torch.cuda.synchronize() # Ensure all events are recorded
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
@@ -210,320 +260,8 @@ class OffloadEngine:
self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
stream = self.transfer_streams[self._stream_idx]
self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams)
return stream
# ========== CUDA Graph compatible methods ==========
# NOTE: These methods need to be updated for the new GPU cache architecture.
# GPU cache no longer has layer dimension, so gathered copy semantics change.
# For now, these are kept for reference but should not be used without updating.
def gathered_h2d_layer(self, layer_id: int) -> None:
"""
Execute gathered H2D copy for a single layer.
WARNING: This method needs updating for new GPU cache architecture.
GPU cache no longer has layer dimension.
"""
# GPU cache has no layer dimension - use flat indexing
# Source is CPU[layer_id], dest is GPU (shared across layers)
gathered_copy_kv(
k_src=self.k_cache_cpu[layer_id],
v_src=self.v_cache_cpu[layer_id],
k_dst=self.k_cache_gpu, # No layer indexing
v_dst=self.v_cache_gpu, # No layer indexing
indices=self.gather_indices_gpu[layer_id],
)
def gathered_h2d_all_layers(self) -> None:
"""
Execute gathered H2D copy for all layers.
WARNING: In new architecture, GPU slots are shared across layers.
This method would overwrite slots multiple times. Not recommended.
"""
for layer_id in range(self.num_layers):
self.gathered_h2d_layer(layer_id)
def update_gather_indices(
self,
layer_id: int,
mappings: List[Tuple[int, int]],
) -> None:
"""
Update gather indices for a layer (call OUTSIDE CUDA graph).
Args:
layer_id: Layer index
mappings: List of (cpu_block_id, gpu_slot) tuples
Only these slots will be updated; others keep their values
"""
for cpu_block_id, gpu_slot in mappings:
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
# Async copy to GPU
self.gather_indices_gpu[layer_id].copy_(
self.gather_indices_cpu[layer_id],
non_blocking=True
)
def update_gather_indices_all_layers(
self,
mappings_per_layer: List[List[Tuple[int, int]]],
) -> None:
"""
Update gather indices for all layers.
Args:
mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...]
"""
for layer_id, mappings in enumerate(mappings_per_layer):
for cpu_block_id, gpu_slot in mappings:
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
# Batch copy all layers
self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True)
def clear_gather_indices(self, layer_id: Optional[int] = None) -> None:
"""
Clear gather indices (set all to -1, meaning no-op).
Args:
layer_id: If provided, clear only this layer; otherwise clear all
"""
if layer_id is not None:
self.gather_indices_cpu[layer_id].fill_(-1)
self.gather_indices_gpu[layer_id].fill_(-1)
else:
self.gather_indices_cpu.fill_(-1)
self.gather_indices_gpu.fill_(-1)
# ========== Async transfer methods (for prefill, outside CUDA graph) ==========
def prefetch_block_async(
self,
layer_id: int,
cpu_block_id: int,
gpu_block_id: int,
) -> torch.cuda.Event:
"""
Async prefetch a single block from CPU to GPU.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index (for CPU cache)
cpu_block_id: Source block in CPU cache
gpu_block_id: Destination slot in GPU cache
Returns:
CUDA event that signals completion
"""
stream = self._get_next_stream()
event = torch.cuda.Event()
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
with torch.cuda.stream(stream):
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_block_id].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[gpu_block_id].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
event.record()
self.pending_events[(layer_id, gpu_block_id)] = event
return event
def prefetch_blocks_batch_async(
self,
transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...]
) -> List[torch.cuda.Event]:
"""
Batch async prefetch multiple blocks.
Args:
transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples
Returns:
List of CUDA events for each transfer
"""
events = []
for layer_id, cpu_block_id, gpu_block_id in transfers:
event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id)
events.append(event)
return events
def offload_block_async(
self,
layer_id: int,
gpu_block_id: int,
cpu_block_id: int,
) -> torch.cuda.Event:
"""
Async offload a block from GPU to CPU.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index (for CPU cache)
gpu_block_id: Source slot in GPU cache
cpu_block_id: Destination block in CPU cache
Returns:
CUDA event that signals completion
"""
stream = self._get_next_stream()
event = torch.cuda.Event()
logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(stream):
# Wait for any compute using this block
stream.wait_stream(self.compute_stream)
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[gpu_block_id],
non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[gpu_block_id],
non_blocking=True
)
event.record()
return event
def offload_blocks_batch_async(
self,
transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...]
) -> List[torch.cuda.Event]:
"""
Batch async offload multiple blocks.
Args:
transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples
Returns:
List of CUDA events
"""
events = []
for layer_id, gpu_block_id, cpu_block_id in transfers:
event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id)
events.append(event)
return events
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
def load_cpu_blocks_to_gpu_slots(
self,
layer_id: int,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to specific GPU slots for chunked decode.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# Wait for transfer to complete
stream.synchronize()
def load_cpu_blocks_to_gpu_slots_async(
self,
layer_id: int,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> torch.cuda.Event:
"""
Async version: Load CPU blocks to GPU slots.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
Returns:
CUDA event to wait on
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
event = torch.cuda.Event()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
event.record()
return event
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
# layer dimension. Each GPU slot holds data for ONE layer at a time.
# ========== Synchronization methods ==========
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
"""Wait for a specific block's transfer to complete."""
key = (layer_id, gpu_block_id)
if key in self.pending_events:
self.pending_events[key].synchronize()
del self.pending_events[key]
def wait_all_transfers(self) -> None:
"""Wait for all pending transfers to complete."""
for stream in self.transfer_streams:
stream.synchronize()
self.pending_events.clear()
def sync_indices(self) -> None:
"""Synchronize to ensure all index updates are complete."""
torch.cuda.default_stream().synchronize()
# ========== Sparse attention policy (set at construction time) ==========
self.sparse_policy = sparse_policy
# ========== Cache access methods ==========
@@ -538,54 +276,22 @@ class OffloadEngine:
(k_cache, v_cache) tensors
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
# GPU cache is shared across all layers (no layer dimension)
return self.k_cache_gpu, self.v_cache_gpu
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
"""
Get full GPU K/V cache tensors.
NOTE: GPU cache has no layer dimension in the new architecture.
Returns:
(k_cache, v_cache) tensors
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu, self.v_cache_gpu
def get_cpu_block(
self,
layer_id: int,
cpu_block_id: int,
) -> Tuple[Tensor, Tensor]:
"""
Get a specific CPU block's K/V cache.
Returns:
(k_cache, v_cache) for the block
Shape: [block_size, kv_heads, head_dim]
"""
return (
self.k_cache_cpu[layer_id, cpu_block_id],
self.v_cache_cpu[layer_id, cpu_block_id],
)
# ========== Memory info ==========
def gpu_memory_bytes(self) -> int:
"""Total GPU memory used by KV caches."""
return (
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() +
self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size()
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size()
)
def cpu_memory_bytes(self) -> int:
"""Total CPU memory used by KV caches."""
return (
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() +
self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size()
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size()
)
def __repr__(self) -> str:
@@ -730,7 +436,14 @@ class OffloadEngine:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
def offload_slot_layer_to_cpu(
self,
slot_idx: int,
layer_id: int,
cpu_block_id: int,
num_valid_tokens: int = -1,
is_prefill: bool = True,
) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
@@ -741,9 +454,21 @@ class OffloadEngine:
slot_idx: Source GPU slot index
layer_id: Target layer in CPU cache
cpu_block_id: Target CPU block ID
num_valid_tokens: Number of valid tokens in this block (-1 = use block_size)
is_prefill: True if in prefill phase, False if in decode phase
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
# Collect metadata BEFORE offload (while k_cache is still on GPU)
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
k_cache = self.k_cache_gpu[slot_idx]
if self.sparse_policy is not None:
if is_prefill:
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else:
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
@@ -869,102 +594,6 @@ class OffloadEngine:
v = v.unsqueeze(0)
return k, v
# ----- Legacy compatibility methods (for decode double-buffering) -----
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses first half of decode_load_slots as 'compute' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
"""
if not cpu_block_ids:
return
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half]
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_compute_layer(self) -> None:
"""Legacy: Wait for 'compute' region loading."""
if self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0])
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses second half of decode_load_slots as 'prefetch' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
"""
if not cpu_block_ids:
return
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots # Fallback if only 1-2 slots
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_prefetch_layer(self) -> None:
"""Legacy: Wait for 'prefetch' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if slots:
self.wait_slot_layer(slots[0])
elif self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0])
def get_kv_for_compute(
self,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half][:num_blocks]
return self.get_kv_for_slots(slots)
def get_kv_for_prefetch(
self,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots
slots = slots[:num_blocks]
return self.get_kv_for_slots(slots)
# ========== Debug Hook Interface ==========
#
# Minimal generic hook system for debugging.
@@ -1036,3 +665,207 @@ class OffloadEngine:
if e.__class__.__name__ == 'BdbQuit':
raise
logger.warning(f"Debug hook error: {e}")
# ========== Cross-layer Pipeline Methods for Decode ==========
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
"""
Start cross-layer pipeline for decode.
Called at the beginning of a decode step to initialize the pipeline.
Preloads Layer 0's data into buffer A.
Args:
cpu_block_ids: List of CPU block IDs for prefilled blocks
"""
if not cpu_block_ids:
self._pipeline_active = False
return
self._pipeline_active = True
self._pipeline_cpu_blocks = cpu_block_ids
self._pipeline_num_blocks = len(cpu_block_ids)
self._pipeline_current_buffer = 0
# Preload Layer 0 into buffer A
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
"""
Get KV cache for a layer during decode.
If pipeline is active, returns data from the current buffer.
Also triggers preloading of the next layer (if not last layer).
Args:
layer_id: Current layer ID
num_blocks: Number of blocks to return
Returns:
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
"""
if not self._pipeline_active:
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
# Wait for current layer's data to be ready
self.compute_stream.wait_event(self._pipeline_next_layer_event)
# Get current buffer
if self._pipeline_current_buffer == 0:
k = self.layer_k_buffer_a[:num_blocks]
v = self.layer_v_buffer_a[:num_blocks]
else:
k = self.layer_k_buffer_b[:num_blocks]
v = self.layer_v_buffer_b[:num_blocks]
# Trigger preloading of next layer (if not last layer)
next_layer_id = layer_id + 1
if next_layer_id < self.num_layers:
# Use the other buffer for next layer
next_buffer_idx = 1 - self._pipeline_current_buffer
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
# Switch to next buffer for next layer
self._pipeline_current_buffer = next_buffer_idx
return k, v
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
"""
Async load a layer's prefilled blocks to the specified buffer.
Uses sgDMA for efficient strided transfer from CPU cache.
Args:
layer_id: Layer index to load
buffer_idx: 0 for buffer A, 1 for buffer B
"""
num_blocks = self._pipeline_num_blocks
cpu_block_ids = self._pipeline_cpu_blocks
# Select target buffer
if buffer_idx == 0:
k_buffer = self.layer_k_buffer_a
v_buffer = self.layer_v_buffer_a
else:
k_buffer = self.layer_k_buffer_b
v_buffer = self.layer_v_buffer_b
# Load all blocks for this layer using dedicated stream
with torch.cuda.stream(self._pipeline_layer_stream):
for i, cpu_block_id in enumerate(cpu_block_ids):
# Copy from CPU cache (has layer dimension) to GPU buffer
k_buffer[i].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
v_buffer[i].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# Record event when all transfers complete
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
def end_decode_pipeline(self) -> None:
"""
End the cross-layer pipeline.
Called at the end of a decode step to clean up pipeline state.
"""
if self._pipeline_active:
# Ensure all transfers complete before ending
self._pipeline_layer_stream.synchronize()
self._pipeline_active = False
self._pipeline_cpu_blocks = []
self._pipeline_num_blocks = 0
def is_pipeline_active(self) -> bool:
"""Check if decode pipeline is currently active."""
return self._pipeline_active
# ========== Per-layer Prefill Buffer Methods ==========
# These methods enable async offload during chunked prefill by using
# per-layer buffers instead of shared GPU slots.
def get_prefill_buffer(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get prefill buffer for a layer.
Args:
layer_id: Layer index
Returns:
(k_buffer, v_buffer), shape: [block_size, kv_heads, head_dim]
"""
return self.prefill_k_buffer[layer_id], self.prefill_v_buffer[layer_id]
def get_prefill_buffer_slice(
self,
layer_id: int,
num_tokens: int,
) -> Tuple[Tensor, Tensor]:
"""
Get a slice of prefill buffer for attention computation.
Args:
layer_id: Layer index
num_tokens: Number of valid tokens in current chunk
Returns:
(k, v) with shape [1, num_tokens, kv_heads, head_dim]
"""
k = self.prefill_k_buffer[layer_id, :num_tokens].unsqueeze(0)
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
return k, v
def offload_prefill_buffer_async(
self,
layer_id: int,
cpu_block_id: int,
num_valid_tokens: int = -1,
) -> None:
"""
Async offload prefill buffer to CPU (no waiting required).
This uses per-layer streams and events to enable fully async offloads.
Each layer can offload independently without blocking other layers.
Args:
layer_id: Layer index
cpu_block_id: Target CPU block ID
num_valid_tokens: Number of valid tokens (-1 = use block_size)
"""
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
# Collect sparse policy metadata before offload
if self.sparse_policy is not None:
k_cache = self.prefill_k_buffer[layer_id]
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
# Use per-layer stream for parallel offloads
stream = self.prefill_offload_streams[layer_id]
torch.cuda.nvtx.range_push(f"AsyncPrefillOffload: L{layer_id}->CPU[{cpu_block_id}]")
with torch.cuda.stream(stream):
# Wait for compute to finish writing to prefill buffer
stream.wait_stream(self.compute_stream)
# Copy from prefill buffer to CPU
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.prefill_k_buffer[layer_id], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.prefill_v_buffer[layer_id], non_blocking=True
)
# Record completion event
self.prefill_offload_events[layer_id].record(stream)
torch.cuda.nvtx.range_pop()
def wait_all_prefill_offloads(self) -> None:
"""Wait for all prefill buffer offloads to complete."""
for stream in self.prefill_offload_streams:
stream.synchronize()
def wait_prefill_offload(self, layer_id: int) -> None:
"""Wait for a specific layer's prefill offload to complete."""
self.prefill_offload_events[layer_id].synchronize()

View File

@@ -5,86 +5,67 @@ Provides pluggable policies for selecting which KV blocks to load
during chunked attention with CPU offload.
Usage:
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
# Use built-in policy
policy = VerticalSlashPolicy(VerticalSlashConfig())
# Create policy using factory function
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
# Or create custom policy
class MyPolicy(SparsePolicy):
supports_prefill = True
supports_decode = True
def select_blocks(self, available_blocks, ctx):
return available_blocks[:5] # Just first 5 blocks
"""
from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
# Built-in policy registry
BUILTIN_SPARSE_POLICIES = {
"full": FullAttentionPolicy,
"vertical_slash": VerticalSlashPolicy,
"streaming_llm": StreamingLLMPolicy,
}
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
"""
Get a sparse attention policy instance by name.
Create a sparse policy instance from an enum type.
The returned policy is not yet initialized. Call policy.initialize()
or let the framework call it during KV cache allocation.
Args:
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
**kwargs: Policy-specific configuration
policy_type: SparsePolicyType enum value
**kwargs: Policy-specific configuration options
Returns:
SparsePolicy instance
"""
policy_name = policy_name.lower()
SparsePolicy instance (not initialized)
if policy_name == "full":
Example:
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
policy.initialize(num_layers=28, num_kv_heads=8, ...)
"""
if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy()
elif policy_name == "vertical_slash":
config = VerticalSlashConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
local_window_blocks=kwargs.get("local_window_blocks", 2),
elif policy_type == SparsePolicyType.QUEST:
config = QuestConfig(
topk_blocks=kwargs.get("topk_blocks", 8),
threshold_blocks=kwargs.get("threshold_blocks", 4),
include_sink_blocks=kwargs.get("include_sink_blocks", 0),
include_recent_blocks=kwargs.get("include_recent_blocks", 0),
)
return VerticalSlashPolicy(config)
elif policy_name == "streaming_llm":
config = StreamingLLMConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
)
return StreamingLLMPolicy(config)
elif policy_name == "quest":
# Quest requires metadata_manager to be passed separately
raise ValueError(
"Quest policy requires BlockMetadataManager. "
"Use QuestPolicy(config, metadata_manager) directly."
)
return QuestPolicy(config)
else:
raise ValueError(
f"Unknown sparse policy '{policy_name}'. "
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
)
raise ValueError(f"Unknown policy type: {policy_type}")
__all__ = [
"SparsePolicy",
"PolicyContext",
"SparsePolicyType",
"FullAttentionPolicy",
"VerticalSlashPolicy",
"VerticalSlashConfig",
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"StreamingLLMPolicy",
"StreamingLLMConfig",
"HybridPolicy",
"get_sparse_policy",
"BUILTIN_SPARSE_POLICIES",
"create_sparse_policy",
]

View File

@@ -22,6 +22,10 @@ class FullAttentionPolicy(SparsePolicy):
- For short sequences where sparsity isn't beneficial
"""
# Full attention supports both prefill and decode
supports_prefill = True
supports_decode = True
def select_blocks(
self,
available_blocks: List[int],

View File

@@ -1,93 +0,0 @@
"""
Hybrid sparse attention policy.
Allows using different policies for prefill vs decode phases.
This is useful because optimal sparsity patterns often differ:
- Prefill: fixed patterns work well (e.g., VerticalSlash)
- Decode: query-aware selection helps (e.g., Quest)
"""
from typing import List
import torch
from .policy import SparsePolicy, PolicyContext
class HybridPolicy(SparsePolicy):
"""
Hybrid policy that uses different policies for prefill and decode.
Example usage:
```python
from nanovllm.kvcache.sparse import (
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
VerticalSlashConfig, QuestConfig, BlockMetadataManager
)
# Prefill: use fast fixed pattern
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
num_sink_blocks=1,
local_window_blocks=3,
))
# Decode: use query-aware selection
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
# Combine
policy = HybridPolicy(prefill_policy, decode_policy)
```
"""
def __init__(
self,
prefill_policy: SparsePolicy,
decode_policy: SparsePolicy,
):
"""
Initialize hybrid policy.
Args:
prefill_policy: Policy to use during prefill phase
decode_policy: Policy to use during decode phase
"""
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""Delegate to appropriate policy based on phase."""
if ctx.is_prefill:
return self.prefill_policy.select_blocks(available_blocks, ctx)
else:
return self.decode_policy.select_blocks(available_blocks, ctx)
def on_block_offloaded(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Forward to both policies (both may need metadata updates)."""
self.prefill_policy.on_block_offloaded(
cpu_block_id, layer_id, k_cache, num_valid_tokens
)
self.decode_policy.on_block_offloaded(
cpu_block_id, layer_id, k_cache, num_valid_tokens
)
def reset(self) -> None:
"""Reset both policies."""
self.prefill_policy.reset()
self.decode_policy.reset()
def __repr__(self) -> str:
return (
f"HybridPolicy(\n"
f" prefill={self.prefill_policy},\n"
f" decode={self.decode_policy}\n"
f")"
)

View File

@@ -10,6 +10,9 @@ from dataclasses import dataclass
from typing import List, Optional, Any
import torch
# Import SparsePolicyType from config to avoid circular imports
from nanovllm.config import SparsePolicyType
@dataclass
class PolicyContext:
@@ -39,7 +42,7 @@ class PolicyContext:
is_prefill: bool
"""True if in prefill phase, False if in decode phase."""
block_size: int = 4096
block_size: int = 1024
"""Number of tokens per block."""
total_kv_len: int = 0
@@ -54,8 +57,15 @@ class SparsePolicy(ABC):
sparse attention patterns. The policy receives context about
the current query chunk and returns which KV blocks to load.
Attributes:
supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase.
Example:
class MySparsePolicy(SparsePolicy):
supports_prefill = False # decode-only policy
supports_decode = True
def select_blocks(self, available_blocks, ctx):
# Load first block and last 2 blocks
if len(available_blocks) <= 3:
@@ -63,6 +73,36 @@ class SparsePolicy(ABC):
return [available_blocks[0]] + available_blocks[-2:]
"""
# Compatibility flags - override in subclasses
supports_prefill: bool = True
supports_decode: bool = True
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
device: torch.device = None,
) -> None:
"""
Initialize policy resources.
Called by the framework after KV cache is allocated. Override this
to create metadata structures (e.g., BlockMetadataManager for Quest).
Default implementation does nothing.
Args:
num_layers: Number of transformer layers
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
num_cpu_blocks: Number of CPU blocks allocated
dtype: Data type for tensors
device: Device for metadata storage (GPU recommended for performance)
"""
pass
@abstractmethod
def select_blocks(
self,
@@ -90,7 +130,7 @@ class SparsePolicy(ABC):
"""
pass
def on_block_offloaded(
def on_prefill_offload(
self,
cpu_block_id: int,
layer_id: int,
@@ -98,15 +138,38 @@ class SparsePolicy(ABC):
num_valid_tokens: int,
) -> None:
"""
Hook called when a block is offloaded from GPU to CPU.
Hook called when a block is offloaded during prefill phase.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to collect metadata about blocks (e.g., min/max keys
for Quest-style selection). Default implementation does nothing.
Args:
cpu_block_id: The CPU block ID that was written
cpu_block_id: The CPU block ID that will be written
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
"""
pass
def on_decode_offload(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""
Hook called when a block is offloaded during decode phase.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to update metadata about blocks. Default implementation
does nothing.
Args:
cpu_block_id: The CPU block ID that will be written
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
"""
pass

View File

@@ -35,6 +35,7 @@ class BlockMetadataManager:
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = None,
):
"""
Initialize metadata storage.
@@ -45,20 +46,23 @@ class BlockMetadataManager:
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
dtype: Data type for metadata storage
device: Device for metadata storage (default: CUDA if available)
"""
self.num_blocks = num_blocks
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
# Stored on GPU for efficient score computation during decode
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True)
self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True)
self.key_min = torch.zeros(shape, dtype=dtype, device=self.device)
self.key_max = torch.zeros(shape, dtype=dtype, device=self.device)
# Track which blocks have valid metadata
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool)
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool, device=self.device)
def update_metadata(
self,
@@ -70,21 +74,21 @@ class BlockMetadataManager:
"""
Update min/max key bounds for a block.
Called when a block is offloaded to CPU.
Called BEFORE offload to CPU, while k_cache is still on GPU.
Args:
block_id: CPU block ID
layer_id: Layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
"""
if num_valid_tokens == 0:
return
# Get valid keys only
k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim]
# Get valid keys only (k_cache is on GPU, metadata is on GPU)
k_valid = k_cache[:num_valid_tokens] # [num_tokens, heads, dim]
# Compute min/max across token dimension
# Compute min/max across token dimension (all on GPU)
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
self.valid_blocks[block_id] = True
@@ -147,22 +151,42 @@ class QuestPolicy(SparsePolicy):
This upper bound is derived from the fact that for any key k in
the block: min_k <= k <= max_k (element-wise), so the actual
attention score is bounded by the maximum of the two extremes.
Note: This is a decode-only policy. For prefill, use FullAttentionPolicy.
"""
def __init__(
self,
config: QuestConfig,
metadata_manager: BlockMetadataManager,
):
# Quest is decode-only
supports_prefill = False
supports_decode = True
def __init__(self, config: QuestConfig):
"""
Initialize Quest policy.
Args:
config: QuestConfig with selection parameters
metadata_manager: BlockMetadataManager for min/max key storage
"""
self.config = config
self.metadata = metadata_manager
self.metadata: Optional[BlockMetadataManager] = None
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
device: torch.device = None,
) -> None:
"""Create BlockMetadataManager for storing min/max keys on GPU."""
self.metadata = BlockMetadataManager(
num_blocks=num_cpu_blocks,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
device=device,
)
def select_blocks(
self,
@@ -175,6 +199,12 @@ class QuestPolicy(SparsePolicy):
If query is not available (some prefill scenarios), falls back
to loading all blocks.
"""
if self.metadata is None:
raise RuntimeError(
"QuestPolicy not initialized. Call initialize() first or "
"let the framework call it during KV cache allocation."
)
n = len(available_blocks)
# If below threshold or no query, load all
@@ -185,15 +215,13 @@ class QuestPolicy(SparsePolicy):
# No query available - cannot compute scores
return available_blocks
# Get metadata for available blocks
# Get metadata for available blocks (already on GPU)
key_min, key_max = self.metadata.get_block_metadata(
available_blocks, ctx.layer_id
)
# Move to query device for computation
# Metadata is already on GPU, same device as query
device = ctx.query.device
key_min = key_min.to(device, non_blocking=True)
key_max = key_max.to(device, non_blocking=True)
# Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
@@ -261,19 +289,32 @@ class QuestPolicy(SparsePolicy):
return result
def on_block_offloaded(
def on_prefill_offload(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Update min/max key metadata when block is offloaded."""
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
"""Update min/max key metadata during prefill offload."""
if self.metadata is not None:
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
def on_decode_offload(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Update min/max key metadata during decode offload (for new blocks)."""
if self.metadata is not None:
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
def reset(self) -> None:
"""Reset metadata."""
self.metadata.reset()
if self.metadata is not None:
self.metadata.reset()
def __repr__(self) -> str:
return (

View File

@@ -1,84 +0,0 @@
"""
StreamingLLM sparse attention policy.
Only keeps sink tokens (beginning) + recent tokens (end).
Intermediate context is discarded. This enables infinite-length
generation but loses intermediate context.
Reference: StreamingLLM paper on attention sinks.
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class StreamingLLMConfig:
"""Configuration for StreamingLLMPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (attention sinks)."""
num_recent_blocks: int = 3
"""Number of most recent blocks to include (sliding window)."""
class StreamingLLMPolicy(SparsePolicy):
"""
StreamingLLM pattern: sink tokens + recent tokens only.
This is the most aggressive sparsity pattern - only keeps a small
fixed window of context. Suitable for:
- Very long streaming generation
- When intermediate context can be safely discarded
- Maximizing throughput over accuracy
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
× × × ↑ ↑ ↑
sink (discarded) recent window
```
Warning: This loses information from intermediate blocks!
Use only when this trade-off is acceptable.
"""
def __init__(self, config: StreamingLLMConfig = None):
self.config = config or StreamingLLMConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + recent blocks only.
Intermediate blocks are not loaded (effectively discarded).
"""
n = len(available_blocks)
# If total blocks fit in sink + recent, load all
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
if n <= total_keep:
return available_blocks
selected_indices = set()
# Sink blocks (first N)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Recent blocks (last M)
for i in range(max(0, n - self.config.num_recent_blocks), n):
selected_indices.add(i)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
f"recent={self.config.num_recent_blocks})"
)

View File

@@ -1,95 +0,0 @@
"""
Vertical-Slash sparse attention policy (MInference-style).
Selects sink blocks (beginning of sequence) + local window blocks
(near the current query position). This pattern captures:
- Important initial context (system prompt, instructions)
- Recent context (relevant for local dependencies)
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class VerticalSlashConfig:
"""Configuration for VerticalSlashPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (sink tokens)."""
local_window_blocks: int = 2
"""Number of blocks in the local window near current query position."""
threshold_blocks: int = 4
"""If total blocks <= threshold, load all (no sparsity applied)."""
class VerticalSlashPolicy(SparsePolicy):
"""
Vertical-Slash pattern: sink tokens + local window.
This pattern is inspired by MInference and observations that:
1. Initial tokens (sink) often receive high attention
2. Local context (recent tokens) is important for dependencies
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
↑ ↑ ↑ ↑
sink local window (for query at block 9)
```
For prefill chunk K, the local window is blocks [K-window, K-1].
For decode, the local window is the last N blocks.
"""
def __init__(self, config: VerticalSlashConfig = None):
self.config = config or VerticalSlashConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + local window blocks.
For prefill: local window is relative to current chunk position.
For decode: local window is the most recent blocks.
"""
n = len(available_blocks)
# If below threshold, load all
if n <= self.config.threshold_blocks:
return available_blocks
selected_indices = set()
# Sink blocks (first N blocks)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Local window
if ctx.is_prefill:
# For prefill chunk K, local window is blocks [K-window, K-1]
# (blocks before current chunk, not including current)
window_end = min(ctx.query_chunk_idx, n)
window_start = max(0, window_end - self.config.local_window_blocks)
for i in range(window_start, window_end):
selected_indices.add(i)
else:
# For decode, local window is the last M blocks
for i in range(max(0, n - self.config.local_window_blocks), n):
selected_indices.add(i)
# Return blocks in order (maintains sequential access pattern)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, "
f"window={self.config.local_window_blocks}, "
f"threshold={self.config.threshold_blocks})"
)

View File

@@ -2,8 +2,6 @@ import logging
import torch
import torch.cuda.nvtx
from torch import nn
import triton
import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
@@ -12,37 +10,59 @@ from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__)
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
def store_kvcache(
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
"""
Store key/value tensors into KV cache using slot mapping.
This is a pure PyTorch implementation replacing the previous Triton kernel.
Uses index_copy_ for efficient in-place scatter operation.
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
Args:
key: [N, num_kv_heads, head_dim]
value: [N, num_kv_heads, head_dim]
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar
v_cache: same shape as k_cache
slot_mapping: [N] with values as flat indices, -1 means skip
"""
is_capturing = torch.cuda.is_current_stream_capturing()
if is_capturing:
# During CUDA graph capture, assume all slots are valid.
# CUDA graphs don't support data-dependent operations like boolean indexing.
# This is safe because decode (captured) always has valid slots.
valid_slots = slot_mapping
valid_keys = key
valid_values = value
else:
# Normal execution: filter out invalid slots (slot == -1)
valid_mask = slot_mapping >= 0
if not valid_mask.any():
return
valid_slots = slot_mapping[valid_mask]
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
valid_values = value[valid_mask]
# Flatten cache and KV for scatter operation
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim
N, num_kv_heads, head_dim = key.shape
D = num_kv_heads * head_dim
total_slots = k_cache.numel() // D
k_cache_flat = k_cache.view(total_slots, D)
v_cache_flat = v_cache.view(total_slots, D)
valid_keys_flat = valid_keys.reshape(-1, D)
valid_values_flat = valid_values.reshape(-1, D)
# In-place scatter using index_copy_
# 即使 valid_slots 为空张量index_copy_ 也是安全的(不会修改数据)。
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
class Attention(nn.Module):
@@ -66,8 +86,49 @@ class Attention(nn.Module):
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Determine if we're in chunked offload mode
is_chunked_offload = (
context.is_chunked_prefill and
hasattr(context, 'kvcache_manager') and
context.kvcache_manager is not None and
hasattr(context.kvcache_manager, 'offload_engine')
)
#! Ensure synchronization before accessing k_cache/v_cache
# torch.cuda.synchronize()
#! =======================================================
if is_chunked_offload and context.is_prefill:
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
# This enables fully async offloads since each layer has its own buffer.
offload_engine = context.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
# Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
# k, v shape: [num_tokens, kv_heads, head_dim]
num_tokens = k.shape[0]
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
elif is_chunked_offload:
# Chunked decode mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
# slot_mapping is created with non_blocking=True on default stream, but we use it
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
else:
# Normal mode: store on default stream
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.is_chunked_prefill:
@@ -111,43 +172,44 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute attention with unified ring buffer for chunked prefill.
Compute attention with per-layer prefill buffer for async offload.
Ring buffer design:
- Current chunk's KV is written to ring_slot[chunk_idx % N]
- Previous chunks' KV are loaded from CPU using N-1 available slots
- Pipeline: pre-fill slots, then process with overlapped load/compute
Optimized design:
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
- Previous chunks' KV are loaded from CPU using GPU slots
- Each layer offloads from its own buffer - no waiting required!
For each layer:
1. Current chunk's KV is in k_batched, v_batched (just written by model)
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
2. Load previous chunks from CPU using available slots (pipeline)
3. Compute attention against previous KV (no causal mask)
4. Compute attention against current KV (causal)
4. Compute attention against current KV from prefill buffer (causal)
5. Merge all results using online softmax
6. Async offload prefill buffer to CPU (no waiting!)
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
# q, k, v shape: [total_tokens, num_heads, head_dim]
# Reshape for flash attention: [batch, seq, heads, dim]
# q shape: [total_tokens, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
num_tokens = k.shape[0]
o_acc = None
lse_acc = None
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled
if cpu_block_table and kvcache_manager.sparse_policy is not None:
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
sparse_policy = kvcache_manager.sparse_policy
if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
@@ -158,16 +220,13 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
# Get write slot for current chunk and available load slots
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
# Get available load slots (all slots can be used since we use prefill buffer)
load_slots = list(range(offload_engine.num_ring_slots))
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
@@ -182,45 +241,67 @@ class Attention(nn.Module):
current_chunk_idx
)
# Get compute stream for all attention operations
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
# Get KV from per-layer prefill buffer
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated
# Merge with accumulated (all on compute_stream for consistency)
if o_acc is None:
final_o = current_o
else:
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
# reading it on the default stream for the merge operation.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Per-layer offload: In new GPU cache architecture (no layer dimension),
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# Per-layer ASYNC offload: offload prefill buffer to CPU
# No waiting required! Each layer has its own buffer and stream.
if offload_engine is not None and seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
# Async offload - no waiting, fully parallel across layers
offload_engine.offload_prefill_buffer_async(
self.layer_id, cpu_block_id, num_tokens
)
# Sync default stream with compute_stream before returning
# This ensures the result is ready for the rest of the model (layernorm, MLP)
if compute_stream is not None:
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -318,6 +399,7 @@ class Attention(nn.Module):
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
@@ -364,6 +446,7 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
@@ -399,17 +482,15 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention using ring buffer pipeline (same as prefill).
Compute decode attention using cross-layer pipeline.
Uses the same loading mechanism as _chunked_prefill_attention:
- Load one block at a time from CPU to GPU slot
- Compute attention for each block
- Merge results using online softmax
- Finally merge with decode buffer (accumulated decode tokens)
Optimization: Uses double-buffered layer cache to overlap H2D transfer
with computation across layers:
- Layer N computes while Layer N+1's data is being loaded
- Each layer only waits for its own data, not all layers' data
This approach is simpler and proven correct (prefill tests pass).
The only difference from prefill is the additional decode buffer
that stores new tokens generated during decode.
This reduces effective latency from O(num_layers * transfer_time) to
O(transfer_time + num_layers * compute_time) when transfer < compute.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -426,16 +507,19 @@ class Attention(nn.Module):
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last block
# Note: For chunked prefill, each block is exactly block_size tokens
# The cpu_block_table only contains full prefill blocks
# Calculate valid tokens in the last CPU block
# CRITICAL: Use original prefill length, not current seq length!
# CPU blocks are fixed after prefill, their content doesn't change during decode.
block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table)
# All prefill blocks are full (block_size tokens each)
last_block_valid_tokens = block_size
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is not None:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
@@ -445,18 +529,25 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
offload_engine = kvcache_manager.offload_engine
load_slots = offload_engine.decode_load_slots # Available slots for loading
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Use cross-layer pipeline if active (initialized in model_runner)
if offload_engine.is_pipeline_active():
o_acc, lse_acc = self._decode_with_layer_pipeline(
q_batched, cpu_block_table, offload_engine,
block_size, last_block_valid_tokens
)
else:
# Fallback to original ring buffer pipeline
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Now attend to accumulated decode tokens from per-layer decode buffer
pos_in_block = context.decode_pos_in_block
@@ -569,3 +660,62 @@ class Attention(nn.Module):
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _decode_with_layer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
block_size: int,
last_block_valid_tokens: int,
):
"""
Decode using cross-layer pipeline for optimized H2D transfer.
This method uses pre-loaded layer buffers instead of loading
blocks one by one. The pipeline loads the next layer's data
while the current layer computes, achieving transfer/compute overlap.
The key insight is that each layer needs the SAME blocks but from
different layers of CPU cache. By double-buffering and pipelining
across layers, we reduce total latency.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
compute_stream = offload_engine.compute_stream
# Get KV from pre-loaded layer buffer (triggers next layer loading)
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
total_tokens = num_blocks * block_size
# Handle partial last block
if last_block_valid_tokens < block_size:
# Only use valid tokens from last block
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
# Flatten and truncate
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
else:
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
prev_k_batched = prev_k_flat.unsqueeze(0)
prev_v_batched = prev_v_flat.unsqueeze(0)
# Compute attention on all prefilled blocks at once
with torch.cuda.stream(compute_stream):
o_acc, lse_acc = flash_attn_with_lse(
q_batched, prev_k_batched, prev_v_batched,
softmax_scale=self.scale,
causal=False,
)
return o_acc, lse_acc

757
tests/modeling_qwen3.py Normal file
View File

@@ -0,0 +1,757 @@
"""
Custom Qwen3 implementation using only torch and transformers.
This file provides a clean reference implementation for understanding the model computation graph.
Computation Graph:
==================
Input: token_ids [batch, seq_len]
┌─────────────┐
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
└─────────────┘
hidden_states [batch, seq_len, hidden_size]
┌─────────────────────────────────────────────────────────┐
│ Decoder Layer (x N) │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Self Attention Block │ │
│ │ │ │
│ │ input_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3Attention │ │ │
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
│ │ │ V = v_proj(x) → reshape │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ attn_output = attention(Q, K, V) │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ output = o_proj(attn_output) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + attn_output │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ MLP Block │ │
│ │ │ │
│ │ post_attention_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3MLP │ │ │
│ │ │ gate = gate_proj(x) │ │ │
│ │ │ up = up_proj(x) │ │ │
│ │ │ output = down_proj(silu(gate) * up) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + mlp_output │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
┌─────────────┐
│ norm │ final RMSNorm
└─────────────┘
┌─────────────┐
│ lm_head │ [hidden_size, vocab_size]
└─────────────┘
logits [batch, seq_len, vocab_size]
"""
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Qwen3RMSNorm(nn.Module):
"""RMSNorm implementation."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
class Qwen3RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
position_ids: Position indices [batch, seq_len]
Returns:
cos, sin: [batch, seq_len, head_dim]
"""
# inv_freq: [dim/2]
# position_ids: [batch, seq_len]
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
# freqs: [batch, dim/2, seq_len]
freqs = inv_freq_expanded @ position_ids_expanded
# freqs: [batch, seq_len, dim/2]
freqs = freqs.transpose(1, 2)
# Duplicate for full head_dim: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(x.dtype)
sin = emb.sin().to(x.dtype)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to Q and K.
Args:
q: [batch, num_heads, seq_len, head_dim]
k: [batch, num_kv_heads, seq_len, head_dim]
cos: [batch, seq_len, head_dim]
sin: [batch, seq_len, head_dim]
Returns:
q_embed, k_embed with same shapes as inputs
"""
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3Attention(nn.Module):
"""
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
apply_rotary_pos_emb(Q, K)
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
attention_bias: bool = False,
rms_norm_eps: float = 1e-6,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = head_dim
self.num_kv_groups = num_attention_heads // num_key_value_heads
self.layer_idx = layer_idx
# Scaling factor
self.scaling = head_dim ** -0.5
# QKV projections
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
# QK normalization (Qwen3 specific)
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
# Rotary embeddings
self.rotary_emb = Qwen3RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
past_key_value: (k_cache, v_cache) from previous steps
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V tensors for debugging
Returns:
output: [batch, seq_len, hidden_size]
past_key_value: Updated cache (if use_cache=True)
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
"""
batch_size, seq_len, _ = hidden_states.shape
# === QKV Projections ===
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
# Reshape to [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# === QK Normalization (Qwen3 specific) ===
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to [batch, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# === Rotary Position Embeddings ===
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# === KV Cache Update ===
if past_key_value is not None:
k_cache, v_cache = past_key_value
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_past_key_value = (k, v) if use_cache else None
# === Grouped Query Attention (expand KV heads if needed) ===
if self.num_kv_groups > 1:
# Repeat KV for each query group
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
# === Attention Computation (using SDPA for memory efficiency) ===
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
# is_causal only works when q_len == kv_len (prefill), not during decode
q_len, kv_len = q.shape[2], k.shape[2]
is_causal = (q_len == kv_len) and (q_len > 1)
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=is_causal,
scale=self.scaling,
) # [batch, num_heads, seq_len, head_dim]
# === Output Projection ===
# Transpose back and reshape
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
output = self.o_proj(attn_output)
# Optional QKV output for debugging
qkv_dict = None
if output_qkv:
qkv_dict = {
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
}
return output, new_past_key_value, qkv_dict
class Qwen3MLP(nn.Module):
"""
Qwen3 MLP with SwiGLU activation.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
└──► up_proj ──► up [batch, seq_len, intermediate_size]
silu(gate) * up
down_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(F.silu(gate) * up)
class Qwen3DecoderLayer(nn.Module):
"""Single Qwen3 Decoder Layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
layer_idx: int = 0,
):
super().__init__()
self.layer_idx = layer_idx
# Pre-attention LayerNorm
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Self-attention
self.self_attn = Qwen3Attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
attention_bias=attention_bias,
rms_norm_eps=rms_norm_eps,
layer_idx=layer_idx,
)
# Post-attention LayerNorm
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: Causal attention mask
past_key_value: KV cache for this layer
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V for debugging
Returns:
hidden_states: [batch, seq_len, hidden_size]
past_key_value: Updated cache
qkv_dict: QKV tensors (if output_qkv=True)
"""
# === Self Attention Block ===
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, new_past_key_value, qkv_dict = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_qkv=output_qkv,
)
hidden_states = residual + attn_output
# === MLP Block ===
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, new_past_key_value, qkv_dict
class Qwen3Model(nn.Module):
"""Qwen3 Transformer Model (without LM head)."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers
# Token embeddings
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# Decoder layers
self.layers = nn.ModuleList([
Qwen3DecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
layer_idx=i,
)
for i in range(num_hidden_layers)
])
# Final LayerNorm
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
position_ids: [batch, seq_len]
attention_mask: [batch, seq_len] or pre-computed 4D mask
past_key_values: List of (k, v) tuples for each layer
use_cache: Whether to return new cache
output_qkv_layers: List of layer indices to output QKV for
Returns:
hidden_states: [batch, seq_len, hidden_size]
new_past_key_values: Updated cache
qkv_outputs: {layer_idx: qkv_dict}
"""
batch_size, seq_len = input_ids.shape
# Embedding
hidden_states = self.embed_tokens(input_ids)
# Position IDs
if position_ids is None:
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Attention mask (create causal mask if not provided)
if attention_mask is None or attention_mask.dim() == 2:
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
causal_mask = torch.triu(
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
diagonal=kv_seq_len - seq_len + 1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
# Initialize cache list
new_past_key_values = [] if use_cache else None
qkv_outputs = {} if output_qkv_layers else None
# Decoder layers
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values else None
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
hidden_states, new_kv, qkv_dict = layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_kv,
use_cache=use_cache,
output_qkv=output_qkv,
)
if use_cache:
new_past_key_values.append(new_kv)
if qkv_dict is not None:
qkv_outputs[i] = qkv_dict
# Final norm
hidden_states = self.norm(hidden_states)
return hidden_states, new_past_key_values, qkv_outputs
class Qwen3ForCausalLM(nn.Module):
"""Qwen3 Model with Language Modeling head."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
tie_word_embeddings: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
# Transformer model
self.model = Qwen3Model(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
)
# LM head
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
... (same as Qwen3Model)
Returns:
logits: [batch, seq_len, vocab_size]
past_key_values: Updated KV cache
qkv_outputs: QKV tensors for specified layers
"""
hidden_states, new_past_key_values, qkv_outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_qkv_layers=output_qkv_layers,
)
logits = self.lm_head(hidden_states)
return logits, new_past_key_values, qkv_outputs
@classmethod
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
"""
Load weights from a pretrained Qwen3 model.
Args:
model_path: Path to model directory containing config.json and model weights
dtype: Data type for model weights
Returns:
Initialized Qwen3ForCausalLM model
"""
import json
import os
from safetensors.torch import load_file
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
# Create model
model = cls(
vocab_size=config["vocab_size"],
hidden_size=config["hidden_size"],
intermediate_size=config["intermediate_size"],
num_hidden_layers=config["num_hidden_layers"],
num_attention_heads=config["num_attention_heads"],
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
max_position_embeddings=config.get("max_position_embeddings", 32768),
rope_theta=config.get("rope_theta", 10000.0),
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
attention_bias=config.get("attention_bias", False),
mlp_bias=config.get("mlp_bias", False),
tie_word_embeddings=config.get("tie_word_embeddings", True),
)
# Load weights
weight_files = sorted([
f for f in os.listdir(model_path)
if f.endswith(".safetensors")
])
state_dict = {}
for wf in weight_files:
state_dict.update(load_file(os.path.join(model_path, wf)))
# Load into model
model.load_state_dict(state_dict, strict=False)
# Tie lm_head weights to embed_tokens if configured
if model.tie_word_embeddings:
model.lm_head.weight = model.model.embed_tokens.weight
model = model.to(dtype)
return model
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 32,
temperature: float = 1.0,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""Simple autoregressive generation."""
device = input_ids.device
batch_size, seq_len = input_ids.shape
past_key_values = None
generated = input_ids.clone()
for _ in range(max_new_tokens):
if past_key_values is None:
current_input = generated
else:
current_input = generated[:, -1:]
logits, past_key_values, _ = self(
input_ids=current_input,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = logits[:, -1, :]
if temperature > 0 and do_sample:
next_token_logits = next_token_logits / temperature
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def print_computation_graph():
"""Print the computation graph for reference."""
print(__doc__)
if __name__ == "__main__":
print_computation_graph()

View File

@@ -1,23 +0,0 @@
cmake_minimum_required(VERSION 3.18)
project(sgdma_test CUDA CXX)
# Find CUDA
enable_language(CUDA)
find_package(CUDA REQUIRED)
# Set C++ standard
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
# CUDA flags
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 --use_fast_math")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
# Build test executable
add_executable(sgdma_test sgdma_test.cpp)
target_link_libraries(sgdma_test cudart)
# Set output directory
set_target_properties(sgdma_test PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
)

View File

@@ -1,326 +0,0 @@
#include <cuda_runtime.h>
#include <iostream>
#include <chrono>
#include <cstring>
#include <cstdlib>
#include <iomanip>
// CUDA error checking macro
#define CUDA_CHECK(call) do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
std::cerr << "CUDA Error in " << __FILE__ << " at line " << __LINE__ << ": " \
<< cudaGetErrorString(err) << std::endl; \
exit(EXIT_FAILURE); \
} \
} while (0)
// Configuration matching nano-vllm realistic parameters
struct Config {
int num_layers = 32;
int num_blocks = 10; // Reduced from 100 to avoid huge allocation
int block_size = 4096;
int num_kv_heads = 8;
int head_dim = 128;
int dtype_size = 2; // float16
// Derived parameters (use size_t to avoid overflow)
size_t features_per_block() const { return (size_t)block_size * num_kv_heads * head_dim; }
size_t bytes_per_block() const { return features_per_block() * dtype_size; }
int total_blocks_per_layer() const { return num_blocks; }
size_t bytes_per_layer() const { return (size_t)num_blocks * bytes_per_block(); }
size_t total_bytes() const { return (size_t)num_layers * bytes_per_layer(); }
};
// Timer utility
class Timer {
std::chrono::high_resolution_clock::time_point start_time;
public:
void start() { start_time = std::chrono::high_resolution_clock::now(); }
double elapsed_ms() {
auto end = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double, std::milli>(end - start_time).count();
}
};
// Initialize CPU memory with test pattern
void init_test_data(void* data, size_t bytes, int seed) {
uint16_t* ptr = static_cast<uint16_t*>(data);
size_t num_elements = bytes / sizeof(uint16_t);
for (size_t i = 0; i < num_elements; i++) {
ptr[i] = static_cast<uint16_t>((seed + i) % 65536);
}
}
// Verify data correctness
bool verify_data(const void* data1, const void* data2, size_t bytes) {
const uint16_t* p1 = static_cast<const uint16_t*>(data1);
const uint16_t* p2 = static_cast<const uint16_t*>(data2);
size_t num_elements = bytes / sizeof(uint16_t);
for (size_t i = 0; i < num_elements; i++) {
if (p1[i] != p2[i]) {
std::cerr << "Mismatch at element " << i << ": "
<< p1[i] << " != " << p2[i] << std::endl;
return false;
}
}
return true;
}
// ============================================================
// Test 1: Basic Functionality Test
// ============================================================
bool test_basic_functionality(const Config& cfg) {
std::cout << "\n[Test 1] Basic Functionality Test" << std::endl;
std::cout << " Testing cudaMemcpy2D correctness with strided layout" << std::endl;
// Allocate strided CPU memory (pinned)
// Layout: [num_layers, num_blocks, block_features]
size_t total_bytes = cfg.total_bytes();
std::cout << " Allocating " << total_bytes / 1024.0 / 1024.0 / 1024.0 << " GB pinned memory..." << std::endl;
void* cpu_strided = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes));
std::cout << " CPU strided memory allocated at: " << cpu_strided << std::endl;
// Allocate GPU memory for one block (all layers)
size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block();
void* gpu_data = nullptr;
CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes));
// Allocate CPU verify buffer
void* cpu_verify = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_verify, gpu_block_bytes));
// Initialize strided CPU memory
init_test_data(cpu_strided, total_bytes, 12345);
// Test: Copy block_id=5 from CPU to GPU using cudaMemcpy2D
int test_block_id = 5;
size_t spitch = cfg.bytes_per_layer(); // Source pitch (stride between layers)
size_t dpitch = cfg.bytes_per_block(); // Destination pitch (contiguous)
size_t width = cfg.bytes_per_block(); // Width to copy per row
size_t height = cfg.num_layers; // Number of rows (layers)
// Debug: print parameters
std::cout << " cudaMemcpy2D parameters:" << std::endl;
std::cout << " spitch: " << spitch << " bytes" << std::endl;
std::cout << " dpitch: " << dpitch << " bytes" << std::endl;
std::cout << " width: " << width << " bytes" << std::endl;
std::cout << " height: " << height << " rows" << std::endl;
std::cout << " dpitch >= width: " << (dpitch >= width ? "yes" : "no") << std::endl;
std::cout << " spitch >= width: " << (spitch >= width ? "yes" : "no") << std::endl;
// Calculate source pointer (first layer, block_id)
uint8_t* src_ptr = static_cast<uint8_t*>(cpu_strided) + test_block_id * cfg.bytes_per_block();
// H2D transfer
CUDA_CHECK(cudaMemcpy2D(
gpu_data, // dst
dpitch, // dpitch
src_ptr, // src
spitch, // spitch
width, // width
height, // height
cudaMemcpyHostToDevice
));
// D2H transfer back
CUDA_CHECK(cudaMemcpy2D(
cpu_verify, // dst
dpitch, // dpitch
gpu_data, // src
dpitch, // spitch
width, // width
height, // height
cudaMemcpyDeviceToHost
));
// Verify correctness
bool passed = true;
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* expected_ptr = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* actual_ptr = static_cast<uint8_t*>(cpu_verify) +
layer * cfg.bytes_per_block();
if (!verify_data(expected_ptr, actual_ptr, cfg.bytes_per_block())) {
std::cerr << " Verification failed at layer " << layer << std::endl;
passed = false;
break;
}
}
// Cleanup
CUDA_CHECK(cudaFreeHost(cpu_strided));
CUDA_CHECK(cudaFreeHost(cpu_verify));
CUDA_CHECK(cudaFree(gpu_data));
std::cout << " Result: " << (passed ? "PASSED ✓" : "FAILED ✗") << std::endl;
return passed;
}
// ============================================================
// Test 2: Performance Benchmark
// ============================================================
void test_performance_benchmark(const Config& cfg) {
std::cout << "\n[Test 2] Performance Benchmark" << std::endl;
std::cout << " Configuration:" << std::endl;
std::cout << " num_layers: " << cfg.num_layers << std::endl;
std::cout << " num_blocks: " << cfg.num_blocks << std::endl;
std::cout << " block_size: " << cfg.block_size << std::endl;
std::cout << " num_kv_heads: " << cfg.num_kv_heads << std::endl;
std::cout << " head_dim: " << cfg.head_dim << std::endl;
std::cout << " dtype_size: " << cfg.dtype_size << " bytes" << std::endl;
std::cout << " bytes_per_block: " << cfg.bytes_per_block() / 1024.0 << " KB" << std::endl;
std::cout << " total transfer size: " << cfg.num_layers * cfg.bytes_per_block() / 1024.0 / 1024.0 << " MB" << std::endl;
const int num_iterations = 100;
const int warmup = 10;
int test_block_id = 5;
// Allocate memory
size_t total_bytes = cfg.total_bytes();
void* cpu_strided = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes));
void* cpu_contiguous = nullptr;
size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block();
CUDA_CHECK(cudaMallocHost(&cpu_contiguous, gpu_block_bytes));
void* gpu_data = nullptr;
CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes));
init_test_data(cpu_strided, total_bytes, 12345);
init_test_data(cpu_contiguous, gpu_block_bytes, 12345);
Timer timer;
double elapsed;
double bandwidth;
// ========================================
// Method A: cudaMemcpy2D with strided layout
// ========================================
size_t spitch = cfg.bytes_per_layer();
size_t dpitch = cfg.bytes_per_block();
size_t width = cfg.bytes_per_block();
size_t height = cfg.num_layers;
uint8_t* src_ptr = static_cast<uint8_t*>(cpu_strided) + test_block_id * cfg.bytes_per_block();
// Warmup
for (int i = 0; i < warmup; i++) {
CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method A (cudaMemcpy2D strided):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_a_bw = bandwidth;
// ========================================
// Method B: cudaMemcpy with contiguous layout (baseline)
// ========================================
// Warmup
for (int i = 0; i < warmup; i++) {
CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method B (cudaMemcpy contiguous):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_b_bw = bandwidth;
// ========================================
// Method C: Layer-by-layer copy (simulate PyTorch non-contiguous)
// ========================================
// Warmup
for (int i = 0; i < warmup; i++) {
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* src_layer = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* dst_layer = static_cast<uint8_t*>(gpu_data) + layer * cfg.bytes_per_block();
CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice));
}
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* src_layer = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* dst_layer = static_cast<uint8_t*>(gpu_data) + layer * cfg.bytes_per_block();
CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice));
}
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method C (layer-by-layer copy):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_c_bw = bandwidth;
// Summary
std::cout << "\n ========================================" << std::endl;
std::cout << " Performance Summary:" << std::endl;
std::cout << " Method A vs Method B: " << std::setprecision(2) << (method_a_bw / method_b_bw * 100) << "%" << std::endl;
std::cout << " Method A vs Method C: " << std::setprecision(2) << (method_a_bw / method_c_bw) << "x speedup" << std::endl;
std::cout << " ========================================" << std::endl;
// Cleanup
CUDA_CHECK(cudaFreeHost(cpu_strided));
CUDA_CHECK(cudaFreeHost(cpu_contiguous));
CUDA_CHECK(cudaFree(gpu_data));
}
int main() {
std::cout << "=== cudaMemcpy2D Test ===" << std::endl;
// Print CUDA device info
int device;
CUDA_CHECK(cudaGetDevice(&device));
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
std::cout << "Using GPU: " << prop.name << std::endl;
std::cout << "Memory Clock Rate: " << prop.memoryClockRate / 1000 << " MHz" << std::endl;
std::cout << "Memory Bus Width: " << prop.memoryBusWidth << " bits" << std::endl;
std::cout << "Peak Memory Bandwidth: " <<
2.0 * prop.memoryClockRate * (prop.memoryBusWidth / 8) / 1.0e6 << " GB/s" << std::endl;
Config cfg;
// Run tests
bool test1_passed = test_basic_functionality(cfg);
test_performance_benchmark(cfg);
std::cout << "\n=== Test Complete ===" << std::endl;
std::cout << "All tests " << (test1_passed ? "PASSED ✓" : "FAILED ✗") << std::endl;
return test1_passed ? 0 : 1;
}

View File

@@ -1,297 +0,0 @@
"""
Test Attention layer with KV cache offload - N-way Pipeline.
This test demonstrates and verifies the N-way pipeline with:
- Per-slot transfer streams for parallel H2D
- Dedicated compute stream (avoids CUDA default stream implicit sync)
- Pre-load phase + main loop with immediate slot reuse
Key difference from previous test:
- We first pre-fill many chunks to CPU cache
- Then simulate processing a new chunk that loads ALL previous blocks
- This exercises the full N-way pipeline with many blocks in flight
"""
import torch
from nanovllm.layers.attention import Attention
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
from nanovllm.engine.sequence import Sequence
from nanovllm.utils.context import set_context, reset_context
# ============================================================
# Configuration
# ============================================================
NUM_LAYERS = 8
NUM_HEADS = 8
NUM_KV_HEADS = 8
HEAD_DIM = 64
BLOCK_SIZE = 1024
CHUNK_SIZE = 1024
NUM_GPU_SLOTS = 6 # N-way pipeline with 6 slots
NUM_CPU_BLOCKS = 16 # Many blocks to load from CPU
DTYPE = torch.bfloat16
DEVICE = "cuda"
# ============================================================
# Setup
# ============================================================
def create_manager():
manager = HybridKVCacheManager(
num_gpu_slots=NUM_GPU_SLOTS,
num_cpu_blocks=NUM_CPU_BLOCKS,
block_size=BLOCK_SIZE,
)
manager.allocate_cache(
num_layers=NUM_LAYERS,
num_kv_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
dtype=DTYPE,
)
return manager
def create_attention_layers(manager):
layers = []
for layer_id in range(NUM_LAYERS):
attn = Attention(
num_heads=NUM_HEADS,
head_dim=HEAD_DIM,
scale=HEAD_DIM ** -0.5,
num_kv_heads=NUM_KV_HEADS,
)
attn.layer_id = layer_id
k_cache, v_cache = manager.get_layer_cache(layer_id)
attn.k_cache = k_cache
attn.v_cache = v_cache
layers.append(attn.to(DEVICE))
return layers
# ============================================================
# Pre-fill CPU cache with random data
# ============================================================
def prefill_cpu_cache(manager, num_blocks):
"""
Fill CPU cache with random KV data for num_blocks blocks.
This simulates having already processed many chunks.
"""
offload_engine = manager.offload_engine
for block_id in range(num_blocks):
# Generate random KV data for all layers
for layer_id in range(NUM_LAYERS):
k_data = torch.randn(
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
dtype=DTYPE, device=DEVICE
)
v_data = torch.randn(
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
dtype=DTYPE, device=DEVICE
)
# Copy to CPU cache
offload_engine.k_cache_cpu[layer_id, block_id].copy_(k_data)
offload_engine.v_cache_cpu[layer_id, block_id].copy_(v_data)
return list(range(num_blocks))
# ============================================================
# Simulate N-way Pipeline (mirrors attention.py logic)
# ============================================================
def simulate_nway_pipeline(
layer_id: int,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
scale: float,
):
"""
Simulate N-way pipeline for a single layer.
This mirrors the logic in Attention._ring_buffer_pipeline_load().
"""
num_blocks = len(cpu_block_table)
num_slots = len(load_slots)
o_acc, lse_acc = None, None
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
torch.cuda.nvtx.range_push(f"Phase1_Preload: L{layer_id}")
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
torch.cuda.nvtx.range_pop()
# Phase 2: Main loop with compute_stream
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"Block: L{layer_id} B{block_idx}")
current_slot = load_slots[block_idx % num_slots]
# Wait for transfer
offload_engine.wait_slot_layer(current_slot, layer_id)
# Compute on dedicated stream
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{layer_id} B{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
offload_engine.record_slot_compute_done(current_slot, layer_id)
# Start next transfer (reuse current_slot)
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(
current_slot, layer_id, cpu_block_table[next_block_idx]
)
# Merge
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop()
return o_acc, lse_acc
def simulate_full_forward(layers, manager, cpu_block_table, chunk_size):
"""
Simulate forward pass through all layers, loading previous blocks from CPU.
This is the key test: many blocks loaded via N-way pipeline.
"""
offload_engine = manager.offload_engine
# Current chunk index (we're processing the "next" chunk after all prefilled ones)
current_chunk_idx = len(cpu_block_table)
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
# Random query for attention
q = torch.randn(1, chunk_size, NUM_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
outputs = []
for layer in layers:
torch.cuda.nvtx.range_push(f"Layer: {layer.layer_id}")
o_acc, lse_acc = simulate_nway_pipeline(
layer.layer_id,
q,
cpu_block_table,
load_slots,
offload_engine,
layer.scale,
)
outputs.append(o_acc)
torch.cuda.nvtx.range_pop()
return outputs
# ============================================================
# Main Test
# ============================================================
print("=" * 60)
print("Test: N-way Pipeline with CPU Offload")
print("=" * 60)
# 1. Setup
print("\n[1] Creating manager and attention layers...")
manager = create_manager()
layers = create_attention_layers(manager)
offload_engine = manager.offload_engine
print(f" - GPU slots: {NUM_GPU_SLOTS}")
print(f" - CPU blocks: {NUM_CPU_BLOCKS}")
print(f" - Per-slot streams: {len(offload_engine.slot_transfer_streams)}")
print(f" - Compute stream: {offload_engine.compute_stream}")
# 2. Pre-fill CPU cache
NUM_PREV_BLOCKS = 12 # Many blocks to load via N-way pipeline
print(f"\n[2] Pre-filling {NUM_PREV_BLOCKS} blocks to CPU cache...")
cpu_block_table = prefill_cpu_cache(manager, NUM_PREV_BLOCKS)
print(f" - CPU blocks filled: {cpu_block_table}")
# 3. Verify pipeline configuration
current_chunk_idx = NUM_PREV_BLOCKS
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
print(f"\n[3] Pipeline configuration for chunk {current_chunk_idx}:")
print(f" - Write slot: {write_slot}")
print(f" - Load slots: {load_slots}")
print(f" - Pipeline depth (N-way): {len(load_slots)}")
assert len(load_slots) == NUM_GPU_SLOTS - 1, f"Expected {NUM_GPU_SLOTS - 1} load slots"
# 4. Warmup
print("\n[4] Warmup (3 iterations)...")
for i in range(3):
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
torch.cuda.synchronize()
print(f" - Warmup {i+1}/3 done")
# 5. Benchmark
NUM_ITERS = 10
print(f"\n[5] Benchmark ({NUM_ITERS} iterations)...")
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"Iteration_{i}")
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
torch.cuda.nvtx.range_pop()
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
# Stats
total_blocks_loaded = NUM_PREV_BLOCKS * NUM_LAYERS * NUM_ITERS
blocks_per_sec = total_blocks_loaded / (elapsed_ms / 1000)
total_tokens = NUM_PREV_BLOCKS * BLOCK_SIZE * NUM_LAYERS * NUM_ITERS
tokens_per_sec = total_tokens / (elapsed_ms / 1000)
print(f"\n[6] Results:")
print(f" - Total time: {elapsed_ms:.2f} ms")
print(f" - Per iteration: {elapsed_ms / NUM_ITERS:.2f} ms")
print(f" - Blocks loaded: {total_blocks_loaded} ({blocks_per_sec:.0f} blocks/s)")
print(f" - Tokens processed: {total_tokens} ({tokens_per_sec:.0f} tok/s)")
# 7. Verification
print("\n[7] Verification:")
assert len(outputs) == NUM_LAYERS, f"Expected {NUM_LAYERS} outputs"
for i, o in enumerate(outputs):
assert o is not None, f"Layer {i} output is None"
assert o.shape == (1, CHUNK_SIZE, NUM_HEADS, HEAD_DIM), f"Layer {i} shape mismatch"
print(" - All layer outputs valid ✓")
print(" - N-way pipeline executed correctly ✓")
# Cleanup
reset_context()
print("\n" + "=" * 60)
print("test_attention_offload: PASSED")
print("=" * 60)

View File

@@ -1,169 +0,0 @@
"""
Test script for chunked attention correctness.
Validates that chunked prefill using flash_attn_with_lse + merge_attention_outputs
produces the same result as full flash_attn_varlen_func.
Scenario: Simulating chunked prefill where we process query chunk by chunk.
For each query chunk i:
- KV contains all tokens from chunk 0 to chunk i
- Previous KV chunks (0 to i-1): full attention (no causal mask)
- Current KV chunk (i): causal attention (diagonal block)
"""
import torch
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_func
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# ============================================================
# Utility Functions
# ============================================================
def compute_chunked_prefill_for_chunk(
q_chunk: torch.Tensor,
kv_chunks: list,
current_chunk_idx: int,
) -> torch.Tensor:
"""
Compute attention for a single query chunk against all KV chunks up to current.
This simulates chunked prefill for query chunk `current_chunk_idx`:
- KV chunks 0 to current_chunk_idx-1: full attention (all previous tokens visible)
- KV chunk current_chunk_idx: causal attention (diagonal block)
Args:
q_chunk: [batch, chunk_size, nheads, headdim] - current query chunk
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
current_chunk_idx: Index of the current chunk being processed
Returns:
out: [batch, chunk_size, nheads, headdim]
"""
accumulated_o = None
accumulated_lse = None
for i in range(current_chunk_idx + 1):
k_chunk, v_chunk = kv_chunks[i]
# Previous chunks: no causal mask (all tokens visible)
# Current chunk (diagonal): causal mask
is_diagonal = (i == current_chunk_idx)
chunk_o, chunk_lse = flash_attn_with_lse(
q_chunk, k_chunk, v_chunk, causal=is_diagonal
)
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse
)
return accumulated_o
def compute_reference_causal(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""
Compute reference causal attention using flash_attn_func.
Args:
q, k, v: [batch, seqlen, nheads, headdim]
Returns:
out: [batch, seqlen, nheads, headdim]
"""
return flash_attn_func(q, k, v, causal=True)
# ============================================================
# Main Test Script
# ============================================================
torch.manual_seed(42)
# Test configurations: (batch, num_chunks, chunk_size, nheads, headdim)
TEST_CASES = [
(1, 4, 256, 8, 128),
(1, 4, 512, 8, 128),
(1, 8, 512, 8, 128),
(1, 32, 1024, 8, 128),
(1, 32, 1024, 32, 128), # More heads
(1, 32, 256, 8, 64), # Smaller head dim
]
DTYPES = [torch.float16, torch.bfloat16]
print("=" * 80)
print("Test: Chunked Prefill Attention vs Reference (flash_attn_func causal)")
print("=" * 80)
print("Simulating chunked prefill: Q chunk attends to all KV chunks up to current")
print(" - Previous KV chunks: full attention (no causal mask)")
print(" - Current KV chunk (diagonal): causal attention")
print()
all_passed = True
for dtype in DTYPES:
print(f"--- dtype: {dtype} ---")
for batch, num_chunks, chunk_size, nheads, headdim in TEST_CASES:
seqlen = num_chunks * chunk_size
# Generate full Q, K, V
q_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
k_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
v_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
# Reference: full causal attention
out_ref = compute_reference_causal(q_full, k_full, v_full)
# Split into chunks
q_chunks = [q_full[:, i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
kv_chunks = [
(k_full[:, i*chunk_size:(i+1)*chunk_size],
v_full[:, i*chunk_size:(i+1)*chunk_size])
for i in range(num_chunks)
]
# Compute chunked prefill for each query chunk
out_chunks = []
for chunk_idx in range(num_chunks):
chunk_out = compute_chunked_prefill_for_chunk(
q_chunks[chunk_idx],
kv_chunks,
chunk_idx,
)
out_chunks.append(chunk_out)
# Concatenate chunked outputs
out_chunked = torch.cat(out_chunks, dim=1)
# Compare
diff = (out_ref - out_chunked).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
# Tolerance: fp16/bf16 have limited precision
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
print(
f"[{status}] seqlen={seqlen:5d} chunks={num_chunks} "
f"chunk_size={chunk_size:4d} heads={nheads:2d} dim={headdim:3d} "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}"
)
print()
print("=" * 80)
assert all_passed, "Some tests failed!"
print("test_chunked_attention: PASSED")

View File

@@ -1,391 +0,0 @@
"""
Correctness test for chunked decode attention.
Captures Q and output during inference, then computes reference using
CPU KV cache with standard flash attention.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from flash_attn.flash_attn_interface import flash_attn_func
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
NUM_DECODE_TOKENS = 5
BLOCK_SIZE = 1024
# State
prefill_captures: List[Dict] = []
decode_captures: List[Dict] = []
def make_ones_injection_hook():
"""Inject Q=K=V=1.0 for deterministic testing."""
def hook(module, inputs):
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q, K, V, output during inference."""
def hook(module, inputs, output):
ctx = get_context()
q, k, v = inputs
if ctx.is_prefill:
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
prefill_captures.append({
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
else:
decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
decode_captures.append({
'layer_id': layer_id,
'decode_step': decode_step,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
"""
Compute reference decode output using CPU KV cache and standard flash attention.
For decode, query attends to:
1. All prefill KV (from CPU cache)
2. All previous decode tokens (from captured decode k, v)
"""
# Get decode capture for this layer and step
decode_cap = None
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
decode_cap = c
break
if decode_cap is None:
return None
# Query: single decode token
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Collect all K, V: prefill chunks from captures + decode tokens from captures
# NOTE: We use prefill captures directly instead of CPU cache because
# the CPU block ID may not equal the chunk index.
all_k = []
all_v = []
# 1. Prefill chunks from captures (use captured K/V, not CPU cache)
for cidx in range(num_prefill_chunks):
prefill_cap = None
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
prefill_cap = c
break
if prefill_cap is not None:
# Use captured K/V directly (guaranteed to be correct layer data)
all_k.append(prefill_cap['k'].cuda())
all_v.append(prefill_cap['v'].cuda())
# 2. Decode tokens from captures (up to and including current step)
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
all_k.append(c['k'].cuda())
all_v.append(c['v'].cuda())
break
if not all_k:
return None
# Concatenate all K, V
full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
# Run flash attention (non-causal since we explicitly control what KV to include)
output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
# ============================================================
# Main
# ============================================================
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Calculate number of prefill chunks
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
# Debug: Compare decode_buffer with captured K/V
print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===")
decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu()
for step in range(NUM_DECODE_TOKENS):
for layer_id in [0, 17, 35]: # Sample a few layers
# Find captured K for this step and layer
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
captured_k = c['k'].squeeze(0) # [kv_heads, head_dim]
buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim]
diff = (captured_k - buffer_k).abs().max().item()
print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}")
break
# Debug: Verify that decode_buffer slices match concatenated captures
print("\n=== DEBUG: Verifying decode_buffer slices ===")
for layer_id in [0]:
for decode_step in [1, 2]: # Check steps that use multiple tokens
# Build expected slice from captures
expected_k_list = []
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim]
break
if expected_k_list:
expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim]
buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1]
diff = (expected_k - buffer_slice).abs().max().item()
print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}")
# Print first values
print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}")
if decode_step >= 1:
print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}")
# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading)
print("\n=== DEBUG: Expected K values for block 0, layer 0 ===")
for c in prefill_captures:
if c['layer_id'] == 0 and c['chunk_idx'] == 0:
print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}")
break
print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}")
# Debug: Compare CPU cache with captured prefill K/V
print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===")
for chunk_idx in [0, 7, 15]: # Sample a few chunks
for layer_id in [0, 17, 35]: # Sample a few layers
# Find captured K for this chunk and layer
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx:
captured_k = c['k'] # [seq_len, kv_heads, head_dim]
cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]]
diff = (captured_k - cpu_cache_k).abs().max().item()
print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}")
break
# Debug: Get cpu_block_table to check order
kvcache_manager = llm.model_runner.kvcache_manager
# Find the sequence (it should still exist)
from nanovllm.engine.sequence import Sequence
for attr_name in ['sequences', '_sequences', 'active_sequences']:
if hasattr(kvcache_manager, attr_name):
print(f"Found {attr_name}")
break
# Try to get cpu_block_table through a different way
print(f"\n=== DEBUG: CPU block order ===")
# For each prefill capture, check which CPU block it ended up in
for chunk_idx in range(num_prefill_chunks):
for c in prefill_captures:
if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx:
# Check if this chunk's K matches any CPU block
captured_k_first = c['k'][0, 0, 0].item()
for block_id in range(num_prefill_chunks):
cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item()
if abs(captured_k_first - cpu_k_first) < 1e-6:
print(f"Chunk {chunk_idx} -> CPU block {block_id}")
break
break
# Debug: Check reference vs actual for decode steps 0 and 1
# Also compute partial references (prefill only, decode only) to isolate the bug
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
for decode_step in [0, 1]:
print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===")
layer_id = 0
# Find the capture
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
q = c['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Build prefill K/V per-block for block-by-block reference
prefill_k_blocks = []
prefill_v_blocks = []
for cidx in range(num_prefill_chunks):
for pc in prefill_captures:
if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx:
prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim]
prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0))
break
# Build decode K/V
decode_k_list = []
decode_v_list = []
for step in range(decode_step + 1):
for dc in decode_captures:
if dc['layer_id'] == layer_id and dc['decode_step'] == step:
decode_k_list.append(dc['k'].cuda())
decode_v_list.append(dc['v'].cuda())
break
full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0)
full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0)
full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0)
full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0)
full_k = torch.cat([full_prefill_k, full_decode_k], dim=1)
full_v = torch.cat([full_prefill_v, full_decode_v], dim=1)
print(f"Q shape: {q_batched.shape}")
print(f"Prefill K shape: {full_prefill_k.shape}")
print(f"Decode K shape: {full_decode_k.shape}")
print(f"Full K shape: {full_k.shape}")
print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}")
# Reference output (single attention over all)
ref_output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
# Chunked reference: prefill attention + decode attention + merge
prefill_o, prefill_lse = flash_attn_with_lse(
q_batched, full_prefill_k, full_prefill_v,
softmax_scale=scale,
causal=False,
)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, full_decode_k, full_decode_v,
softmax_scale=scale,
causal=False,
)
chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse)
# Block-by-block reference (simulating ring buffer pipeline)
block_o_acc, block_lse_acc = None, None
for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)):
o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False)
if block_o_acc is None:
block_o_acc, block_lse_acc = o_blk, lse_blk
else:
block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk)
# Compare block-by-block vs single
block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item()
print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}")
# Compare full reference vs chunked reference
ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item()
print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}")
ref_output = ref_output.squeeze(0).squeeze(0).cpu()
chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu()
# Actual output
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff_ref = (actual_output - ref_output).abs()
diff_chunked = (actual_output - chunked_output_cpu).abs()
print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}")
print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}")
break
print()
# Verify decode outputs
all_passed = True
for c in decode_captures:
layer_id = c['layer_id']
decode_step = c['decode_step']
ref_output = compute_decode_reference(
layer_id, decode_step, scale,
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
)
if ref_output is None:
continue
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1
all_passed = all_passed and passed
if not passed:
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")

View File

@@ -1,196 +0,0 @@
"""
Correctness test for chunked prefill attention.
Captures Q and output during inference, then computes reference using
CPU KV cache with standard flash attention.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
BLOCK_SIZE = 1024
# State - capture Q and output for each (layer, chunk)
captures: List[Dict] = []
def make_ones_injection_hook():
"""Inject Q=K=V=1.0 for deterministic testing."""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return inputs
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q and output during prefill."""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
q, k, v = inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
captures.append({
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
def compute_reference(layer_id: int, chunk_idx: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int) -> torch.Tensor:
"""
Compute reference output using CPU KV cache and standard flash attention.
Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
then extracts output for the current chunk.
"""
# Get all captures for this layer up to chunk_idx
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
if not layer_captures:
return None
# Collect Q from captures, K/V from CPU cache
all_q = []
all_k = []
all_v = []
chunk_lengths = []
for c in layer_captures:
cidx = c['chunk_idx']
q = c['q'].cuda() # [seqlen, nheads, headdim]
all_q.append(q)
chunk_lengths.append(q.shape[0])
# Get K, V from CPU cache (already offloaded during prefill)
# CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
all_k.append(k)
all_v.append(v)
# Concatenate
full_q = torch.cat(all_q, dim=0)
full_k = torch.cat(all_k, dim=0)
full_v = torch.cat(all_v, dim=0)
total_len = full_q.shape[0]
# Run standard causal flash attention
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
full_o = flash_attn_varlen_func(
full_q, full_k, full_v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_len,
max_seqlen_k=total_len,
softmax_scale=scale,
causal=True,
)
# Extract output for current chunk
start_pos = sum(chunk_lengths[:-1])
end_pos = sum(chunk_lengths)
return full_o[start_pos:end_pos].cpu()
# ============================================================
# Main
# ============================================================
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Verify: compare actual output with reference computed from CPU cache
all_passed = True
num_chunks = INPUT_LEN // BLOCK_SIZE
for idx,c in enumerate(captures):
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
# Skip chunk 0 (no previous KV to load)
if chunk_idx == 0:
continue
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
if ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1 # float16 tolerance
all_passed = all_passed and passed
if not passed:
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
__import__('pdb').set_trace()
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")

View File

@@ -1,137 +0,0 @@
"""
Test KV cache offload correctness using debug hooks.
Injects distinctive K/V values, verifies loaded tensors match expected patterns.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import inspect
from random import randint, seed
from typing import Dict, List
import torch
from torch import Tensor
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024
BLOCK_SIZE = 1024
# State
load_log: List[Dict] = []
current_chunk: List[int] = [0]
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
"""Record loaded tensor values for layer 0."""
if layer_id != 0:
return
# Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots
frame = inspect.currentframe()
try:
caller_frame = frame.f_back
if caller_frame is not None:
local_vars = caller_frame.f_locals
if 'self' in local_vars:
self_obj = local_vars['self']
if hasattr(self_obj, 'k_cache_gpu'):
num_slots = self_obj.k_cache_gpu.shape[0]
vals = []
for i in range(num_slots):
v = self_obj.k_cache_gpu[i][0,0,0].item()
if i == slot_idx:
vals.append(f"[{v}]")
else:
vals.append(str(v))
print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]")
finally:
del frame
load_log.append({
"chunk_idx": current_chunk[0],
"cpu_block_id": cpu_block_id,
"k_value": k.float().mean().item(),
})
def make_pattern_injection_hook(layer_id):
"""Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0."""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill or layer_id != 0:
return inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
current_chunk[0] = chunk_idx
if len(inputs) >= 3:
q, k, v = inputs[0], inputs[1], inputs[2]
k_new = torch.full_like(k, float(chunk_idx + 1))
v_new = torch.full_like(v, float(-(chunk_idx + 1)))
return (q, k_new, v_new) + inputs[3:]
return inputs
return hook
def verify() -> bool:
"""Verify blocks loaded in correct order with correct K values."""
chunk_loads: Dict[int, List[tuple]] = {}
for log in load_log:
chunk = log["chunk_idx"]
if chunk not in chunk_loads:
chunk_loads[chunk] = []
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
for chunk, loads in chunk_loads.items():
expected_blocks = list(range(chunk))
actual_blocks = [b for b, _ in loads]
k_values = [k for _, k in loads]
expected_k = [float(b + 1) for b in expected_blocks]
if actual_blocks != expected_blocks:
return False
if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)):
return False
return True
# Main
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
offload_engine = llm.model_runner.kvcache_manager.offload_engine
offload_engine.enable_debug_mode()
offload_engine.register_debug_hook(debug_load_hook)
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook(
make_pattern_injection_hook(layer_idx)
))
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False)
for hook in hooks:
hook.remove()
offload_engine.remove_debug_hook(debug_load_hook)
offload_engine.disable_debug_mode()
# Verify
num_chunks = INPUT_LEN // BLOCK_SIZE
expected_loads = num_chunks * (num_chunks - 1) // 2
passed = len(load_log) == expected_loads and verify()
print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")

View File

@@ -1,276 +0,0 @@
"""
Test script for flash_attn_with_kvcache based chunked prefill.
Verifies that chunked prefill produces identical results to full attention.
"""
import torch
from flash_attn import flash_attn_func, flash_attn_with_kvcache
def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size):
"""
Chunked prefill using flash_attn_with_kvcache.
Args:
q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim]
k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim]
cache_seqlens: [batch] - current cache lengths
chunk_size: size of each chunk
Returns:
output: [batch, total_seq_len, heads, head_dim]
"""
total_len = q_full.shape[1]
outputs = []
for start in range(0, total_len, chunk_size):
end = min(start + chunk_size, total_len)
q_chunk = q_full[:, start:end]
k_chunk = k_full[:, start:end]
v_chunk = v_full[:, start:end]
out = flash_attn_with_kvcache(
q_chunk,
k_cache,
v_cache,
k=k_chunk,
v=v_chunk,
cache_seqlens=cache_seqlens,
causal=True,
)
outputs.append(out)
cache_seqlens += (end - start)
return torch.cat(outputs, dim=1)
def reference_attention(q, k, v):
"""Standard flash attention as reference."""
return flash_attn_func(q, k, v, causal=True)
def test_chunked_prefill_correctness():
"""Test that chunked prefill matches full attention."""
batch_size = 1
num_heads = 32
num_kv_heads = 8 # GQA
head_dim = 128
max_seq_len = 131072 # 128K
test_configs = [
(1024, 256), # 1K tokens, 256 chunk
(2048, 512), # 2K tokens, 512 chunk
(4096, 1024), # 4K tokens, 1K chunk
(4096, 2048), # 4K tokens, 2K chunk (2 chunks)
(8192, 2048), # 8K tokens, 2K chunk (4 chunks)
(16384, 4096), # 16K tokens, 4K chunk
(32768, 4096), # 32K tokens, 4K chunk
(65536, 8192), # 64K tokens, 8K chunk
(131072, 8192), # 128K tokens, 8K chunk (16 chunks)
]
for seq_len, chunk_size in test_configs:
print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...")
# Generate random input
torch.manual_seed(42)
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
# Expand K/V for non-GQA reference
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
# Reference: full attention
ref_out = reference_attention(q, k_expanded, v_expanded)
# Chunked prefill with KV cache
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
# Compare
max_diff = (ref_out - chunked_out).abs().max().item()
mean_diff = (ref_out - chunked_out).abs().mean().item()
# Verify cache was filled correctly
assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}"
# Check K/V cache content
k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item()
v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item()
print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}")
print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}")
# Tolerance for fp16
tolerance = 1e-2
if max_diff < tolerance:
print(f" PASSED")
else:
print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})")
return False
return True
def test_incremental_decode():
"""Test that decode after chunked prefill works correctly."""
batch_size = 1
num_heads = 32
num_kv_heads = 8
head_dim = 128
max_seq_len = 8192
prefill_len = 2048
chunk_size = 512
num_decode_steps = 10
print(f"\nTesting incremental decode after chunked prefill...")
print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}")
print(f" Decode: {num_decode_steps} steps")
torch.manual_seed(42)
# Prefill phase
q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
# Run chunked prefill
prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill,
k_cache, v_cache, cache_seqlens, chunk_size)
print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}")
# Decode phase - one token at a time
for step in range(num_decode_steps):
q_decode = torch.randn(batch_size, 1, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
decode_out = flash_attn_with_kvcache(
q_decode,
k_cache,
v_cache,
k=k_decode,
v=v_decode,
cache_seqlens=cache_seqlens,
causal=True,
)
cache_seqlens += 1
assert decode_out.shape == (batch_size, 1, num_heads, head_dim)
expected_len = prefill_len + num_decode_steps
actual_len = cache_seqlens[0].item()
print(f" After decode: cache_seqlens={actual_len}")
if actual_len == expected_len:
print(f" PASSED")
return True
else:
print(f" FAILED: expected {expected_len}, got {actual_len}")
return False
def test_batch_processing():
"""Test chunked prefill with batch > 1."""
batch_size = 4
num_heads = 32
num_kv_heads = 8
head_dim = 128
max_seq_len = 4096
seq_len = 2048
chunk_size = 512
print(f"\nTesting batch processing (batch_size={batch_size})...")
torch.manual_seed(42)
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
# Verify all batches have correct cache length
assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}"
assert out.shape == (batch_size, seq_len, num_heads, head_dim)
# Compare with reference for each batch item
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
ref_out = reference_attention(q, k_expanded, v_expanded)
max_diff = (ref_out - out).abs().max().item()
print(f" Output shape: {out.shape}")
print(f" Max diff vs reference: {max_diff:.6f}")
if max_diff < 1e-2:
print(f" PASSED")
return True
else:
print(f" FAILED")
return False
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Testing flash_attn_with_kvcache chunked prefill")
print("=" * 60)
all_passed = True
all_passed &= test_chunked_prefill_correctness()
all_passed &= test_incremental_decode()
all_passed &= test_batch_processing()
print("\n" + "=" * 60)
if all_passed:
print("test_flash_attn_kvcache: ALL TESTS PASSED")
else:
print("test_flash_attn_kvcache: SOME TESTS FAILED")
print("=" * 60)

View File

@@ -1,104 +0,0 @@
"""
Test FlashInfer chunked attention with CPU offload.
Uses single_prefill_with_kv_cache + merge_state for chunked KV processing.
"""
import torch
import flashinfer
# ============================================================
# Core Functions
# ============================================================
def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size):
"""
Chunked causal attention with KV on CPU.
q: [seq_q, num_heads, head_dim] on GPU
k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU
"""
seq_q = q.shape[0]
seq_kv = k_cpu.shape[0]
final_outputs = []
for q_start in range(0, seq_q, q_chunk_size):
q_end = min(q_start + q_chunk_size, seq_q)
q_chunk = q[q_start:q_end]
merged_output = None
merged_lse = None
for kv_start in range(0, seq_kv, kv_chunk_size):
kv_end = min(kv_start + kv_chunk_size, seq_kv)
if kv_start >= q_end:
continue
k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
causal = not (kv_end <= q_start)
partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache(
q_chunk, k_chunk, v_chunk,
causal=causal,
return_lse=True,
)
if merged_output is None:
merged_output, merged_lse = partial_out, partial_lse
else:
merged_output, merged_lse = flashinfer.merge_state(
merged_output, merged_lse,
partial_out, partial_lse,
)
final_outputs.append(merged_output)
return torch.cat(final_outputs, dim=0)
# ============================================================
# Main Test Script
# ============================================================
print("=" * 60)
print("Testing FlashInfer chunked attention with CPU offload")
print("=" * 60)
num_heads = 32
num_kv_heads = 8
head_dim = 128
test_configs = [
(32768, 8192, 8192), # 32K tokens
(65536, 8192, 8192), # 64K tokens
(131072, 16384, 16384), # 128K tokens
# (262144, 16384, 16384), # 256K tokens (slow)
# (524288, 16384, 16384), # 512K tokens (slow)
]
for seq_len, q_chunk, kv_chunk in test_configs:
torch.manual_seed(42)
q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda')
k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
# Chunked result
chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk)
# Reference
k_gpu = k_cpu.to('cuda')
v_gpu = v_cpu.to('cuda')
ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True)
max_diff = (ref_out - chunked_out).abs().max().item()
mean_diff = (ref_out - chunked_out).abs().mean().item()
num_chunks = (seq_len + q_chunk - 1) // q_chunk
assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}"
print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print("\ntest_flashinfer_merge: PASSED")

View File

@@ -8,155 +8,12 @@ sequences longer than ~200 tokens. Use --no-offload for correctness testing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question at the end
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
# Also try to find it as a standalone number
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
from nanovllm.config import SparsePolicyType
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
@@ -168,10 +25,14 @@ def run_needle_test(
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
enable_quest: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
verbose: bool = True,
) -> bool:
"""
@@ -182,15 +43,21 @@ def run_needle_test(
max_model_len: Maximum model context length
input_len: Target input sequence length
num_gpu_blocks: Number of GPU blocks for offload
block_size: KV cache block size
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test")
@@ -198,9 +65,12 @@ def run_needle_test(
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload:
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
print(f"{'='*60}\n")
# 1. Initialize LLM
@@ -209,9 +79,13 @@ def run_needle_test(
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
llm = LLM(model_path, **llm_kwargs)
@@ -263,7 +137,7 @@ if __name__ == "__main__":
parser.add_argument(
"--max-model-len",
type=int,
default=32 * 1024,
default=128 * 1024,
help="Maximum model context length"
)
parser.add_argument(
@@ -278,6 +152,12 @@ if __name__ == "__main__":
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--needle-position",
type=float,
@@ -301,6 +181,23 @@ if __name__ == "__main__":
action="store_true",
help="Enable CPU offload (has known bug for long sequences)"
)
parser.add_argument(
"--enable-quest",
action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)"
)
parser.add_argument(
"--sparse-topk",
type=int,
default=8,
help="Top-K blocks for Quest sparse attention"
)
parser.add_argument(
"--sparse-threshold",
type=int,
default=4,
help="Apply sparse only when blocks > threshold"
)
args = parser.parse_args()
passed = run_needle_test(
@@ -308,10 +205,14 @@ if __name__ == "__main__":
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
verbose=True,
)

View File

@@ -8,148 +8,9 @@ Uses standard HuggingFace inference (no custom KV cache, no offload).
import os
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
@@ -207,22 +68,19 @@ def run_needle_test(
# 3. Load model
print("[3/4] Loading model...")
torch_dtype = {
"auto": "auto",
"auto": torch.float16, # default to float16 for custom model
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}.get(dtype, "auto")
}.get(dtype, torch.float16)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map="auto",
trust_remote_code=True,
)
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# 4. Generate output
print("[4/4] Running inference...")
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
device = next(model.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
print(f" Input shape: {input_ids.shape}")
with torch.no_grad():

View File

@@ -1,695 +0,0 @@
"""
Test script to verify CPU offload correctness using distinctive KV patterns.
Strategy:
1. Hook into attention forward pass
2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx)
3. After offload to CPU, verify CPU cache contains correct patterns
4. On subsequent chunks, verify loaded KV from CPU has correct patterns
This catches bugs like:
- Wrong block being offloaded
- Wrong block being loaded
- Data corruption during transfer
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import torch
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 64 * 1024
NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks)
BLOCK_SIZE = 1024
# Test state
errors = []
chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern)
block_coverage = {} # chunk_idx -> set of blocks that were actually computed
load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples
current_chunk_for_load = [0] # Mutable container to track current chunk during loads
# ============================================================
# Pattern Helpers
# ============================================================
def get_expected_pattern(chunk_idx: int):
"""Get expected K/V pattern for a chunk."""
# Use float values that are easy to identify
k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ...
v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ...
return k_val, v_val
def fill_with_pattern(tensor: torch.Tensor, value: float):
"""Fill tensor with a constant value."""
tensor.fill_(value)
def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3):
"""Check if tensor contains expected pattern."""
actual_mean = tensor.float().mean().item()
if abs(actual_mean - expected) > tolerance:
return False, f"{name}: expected mean={expected}, got {actual_mean}"
return True, None
# ============================================================
# Load Verification Instrumentation
# ============================================================
_original_load_to_slot_layer = None
_offload_engine_ref = None
def make_verified_load_to_slot_layer(original_func, offload_engine):
"""
Create a wrapper around load_to_slot_layer that verifies each load operation.
After each H2D transfer, checks that the GPU slot contains the expected
pattern from the source CPU block.
"""
def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int):
# Call original load
original_func(slot_idx, layer_id, cpu_block_id)
# Only verify layer 0 to reduce overhead
if layer_id != 0:
return
# IMPORTANT: Synchronize CUDA to ensure async transfer is complete
# The transfer happens on a per-slot stream, and wait_slot_layer only
# makes compute_stream wait. We need full sync to read on default stream.
torch.cuda.synchronize()
# Get the expected pattern for this CPU block
# cpu_block_id == chunk_idx in our sequential test
expected_k, expected_v = get_expected_pattern(cpu_block_id)
# Read GPU slot data (GPU cache has no layer dimension)
gpu_k = offload_engine.k_cache_gpu[slot_idx]
gpu_v = offload_engine.v_cache_gpu[slot_idx]
actual_k = gpu_k.float().mean().item()
actual_v = gpu_v.float().mean().item()
k_ok = abs(actual_k - expected_k) < 1e-3
v_ok = abs(actual_v - expected_v) < 1e-3
chunk_idx = current_chunk_for_load[0]
load_operations.append({
'chunk_idx': chunk_idx,
'slot_idx': slot_idx,
'cpu_block_id': cpu_block_id,
'expected_k': expected_k,
'expected_v': expected_v,
'actual_k': actual_k,
'actual_v': actual_v,
'k_ok': k_ok,
'v_ok': v_ok,
})
if not (k_ok and v_ok):
errors.append(f"Load verification failed: chunk {chunk_idx}, "
f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: "
f"expected K={expected_k:.1f}/V={expected_v:.1f}, "
f"got K={actual_k:.4f}/V={actual_v:.4f}")
return verified_load
def install_load_verification(llm):
"""Install verification wrapper on load_to_slot_layer."""
global _original_load_to_slot_layer, _offload_engine_ref
oe = llm.model_runner.kvcache_manager.offload_engine
_offload_engine_ref = oe
_original_load_to_slot_layer = oe.load_to_slot_layer
oe.load_to_slot_layer = make_verified_load_to_slot_layer(
_original_load_to_slot_layer, oe
)
print("Installed load verification wrapper on load_to_slot_layer")
def uninstall_load_verification():
"""Restore original load_to_slot_layer."""
global _original_load_to_slot_layer, _offload_engine_ref
if _offload_engine_ref is not None and _original_load_to_slot_layer is not None:
_offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer
print("Restored original load_to_slot_layer")
_original_load_to_slot_layer = None
_offload_engine_ref = None
# ============================================================
# Attention Hook
# ============================================================
def make_kv_pattern_pre_hook(layer_id: int):
"""
Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE
they are stored to cache. This is called before attention.forward().
register_forward_pre_hook receives (module, inputs) and can modify inputs in-place.
"""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
q, k, v = inputs
k_pattern, v_pattern = get_expected_pattern(chunk_idx)
# === Overwrite current chunk's K/V with distinctive pattern ===
# This happens BEFORE forward(), so these values will be stored to cache
k.fill_(k_pattern)
v.fill_(v_pattern)
# Only print for first few and last few chunks to reduce noise
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}")
elif chunk_idx == 3:
print(f"... (chunks 3 to {num_chunks - 3} omitted) ...")
return hook
def make_block_coverage_pre_hook(layer_id: int):
"""
Create a PRE-forward hook to verify that all previous blocks are included
in the cpu_block_table for chunked prefill attention.
This catches bugs where:
- Some blocks are missing from the computation
- Sparse policy incorrectly filters out blocks (when not intended)
- Block table construction has off-by-one errors
"""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
# Update current chunk for load verification tracking
current_chunk_for_load[0] = chunk_idx
# No previous blocks for chunk 0
if chunk_idx == 0:
return
# Get the sequence and its block table (same logic as _chunked_prefill_attention)
seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None
if seq is None:
return
# Get the CPU block table that will be used for attention
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Expected blocks: 0 to chunk_idx-1 (all previous chunks)
expected_blocks = set(range(chunk_idx))
actual_blocks = set(cpu_block_table) if cpu_block_table else set()
# Store for later summary
block_coverage[chunk_idx] = {
'expected': expected_blocks,
'actual': actual_blocks,
}
# Check for missing blocks
missing_blocks = expected_blocks - actual_blocks
extra_blocks = actual_blocks - expected_blocks
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks:
if not missing_blocks and not extra_blocks:
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]")
else:
status_parts = []
if missing_blocks:
status_parts.append(f"MISSING {sorted(missing_blocks)}")
if extra_blocks:
status_parts.append(f"EXTRA {sorted(extra_blocks)}")
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]")
elif chunk_idx == 4:
# Indicate that middle chunks are being verified silently
print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...")
if missing_blocks:
errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}")
return hook
def make_gpu_write_verification_post_hook(layer_id: int):
"""
Create a POST-forward hook to verify the current chunk's KV was correctly
written to the GPU ring buffer write_slot.
This is a more reliable verification than checking load slots, because:
1. Post-hook runs AFTER forward() writes to GPU cache
2. write_slot mapping is deterministic: chunk_idx % num_ring_slots
3. We injected known patterns in pre-hook, now verify they're in GPU cache
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
oe = kvcache_manager.offload_engine
num_ring_slots = oe.num_ring_slots
write_slot = chunk_idx % num_ring_slots
# Get expected pattern for current chunk
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[write_slot]
gpu_v = oe.v_cache_gpu[write_slot]
actual_k_mean = gpu_k.float().mean().item()
actual_v_mean = gpu_v.float().mean().item()
k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}")
v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}")
num_chunks = INPUT_LEN // BLOCK_SIZE
# Print for first/last chunks, or if there's an error
if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok):
if k_ok and v_ok:
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]")
else:
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, "
f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]")
elif chunk_idx == 4:
print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...")
if not (k_ok and v_ok):
errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: "
f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}")
return hook
def make_kv_verification_post_hook(layer_id: int):
"""
Create a POST-forward hook to verify CPU cache contains correct patterns
from previously offloaded blocks.
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
# === Verify previously offloaded blocks in CPU cache ===
if chunk_idx >= 1:
oe = kvcache_manager.offload_engine
num_ok = 0
num_fail = 0
# Check all previously offloaded blocks
for prev_chunk in range(chunk_idx):
# CPU block ID = prev_chunk (in simple sequential case)
cpu_block_id = prev_chunk
# Get expected pattern for this block
expected_k, expected_v = get_expected_pattern(prev_chunk)
# Read from CPU cache (layer 0)
cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id]
cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id]
# Verify patterns
k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}")
v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}")
if k_ok and v_ok:
num_ok += 1
else:
num_fail += 1
if k_err:
errors.append(k_err)
if v_err:
errors.append(v_err)
# Only print summary for each chunk verification
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0:
status = "OK" if num_fail == 0 else f"FAIL({num_fail})"
print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]")
elif chunk_idx == 4:
print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...")
return hook
def make_post_chunk_verification_hook(layer_id: int):
"""
Post-forward hook to verify GPU ring buffer state after attention.
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill or layer_id != 0:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
oe = kvcache_manager.offload_engine
# After attention, the current chunk's KV should be in the GPU ring buffer
# Ring slot = chunk_idx % num_ring_slots
ring_slot = chunk_idx % oe.num_ring_slots
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check GPU ring buffer (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[ring_slot]
gpu_v = oe.v_cache_gpu[ring_slot]
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
if k_ok and v_ok:
print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}")
else:
if k_err:
print(f" [FAIL] {k_err}")
errors.append(k_err)
if v_err:
print(f" [FAIL] {v_err}")
errors.append(v_err)
return hook
def register_hooks(llm):
"""Register pre and post forward hooks."""
hooks = []
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
# PRE-forward hook 1: Verify all previous blocks are in cpu_block_table
coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx))
hooks.append(coverage_hook)
# PRE-forward hook 2: Inject K/V patterns before they're stored to cache
pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx))
hooks.append(pattern_hook)
# POST-forward hook 1: Verify GPU write_slot contains current chunk's data
gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx))
hooks.append(gpu_verify_hook)
# POST-forward hook 2: Verify CPU cache contains correct patterns after offload
cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx))
hooks.append(cpu_verify_hook)
return hooks
# ============================================================
# Final Verification
# ============================================================
def verify_final_cpu_state(llm, num_chunks: int):
"""Verify all CPU blocks have correct patterns after prefill completes."""
print("\n" + "=" * 60)
print("Final CPU Cache Verification")
print("=" * 60)
kvcache_manager = llm.model_runner.kvcache_manager
oe = kvcache_manager.offload_engine
num_ok = 0
num_fail = 0
fail_details = []
# After prefill, all chunks should be in CPU
for chunk_idx in range(num_chunks):
cpu_block_id = chunk_idx
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check layer 0
cpu_k = oe.k_cache_cpu[0, cpu_block_id]
cpu_v = oe.v_cache_cpu[0, cpu_block_id]
k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}")
v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}")
if k_ok and v_ok:
num_ok += 1
# Only print first few and last few
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
actual_k_mean = cpu_k.float().mean().item()
actual_v_mean = cpu_v.float().mean().item()
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]")
elif chunk_idx == 3:
print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...")
else:
num_fail += 1
actual_k_mean = cpu_k.float().mean().item()
actual_v_mean = cpu_v.float().mean().item()
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]")
if k_err:
errors.append(k_err)
if v_err:
errors.append(v_err)
print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks")
def verify_block_coverage_summary(num_chunks: int):
"""Verify that all chunks had complete block coverage during prefill."""
print("\n" + "=" * 60)
print("Block Coverage Verification Summary")
print("=" * 60)
num_ok = 0
num_fail = 0
total_blocks_expected = 0
total_blocks_computed = 0
for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous)
if chunk_idx not in block_coverage:
print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]")
errors.append(f"Chunk {chunk_idx} has no block coverage data")
num_fail += 1
continue
coverage = block_coverage[chunk_idx]
expected = coverage['expected']
actual = coverage['actual']
missing = expected - actual
total_blocks_expected += len(expected)
total_blocks_computed += len(actual)
if not missing:
num_ok += 1
else:
num_fail += 1
# Print summary
if num_fail == 0:
print(f" All {num_ok} chunks had complete block coverage [OK]")
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
else:
print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]")
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
# Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2
expected_total = num_chunks * (num_chunks - 1) // 2
if total_blocks_expected == expected_total:
print(f" Expected total blocks matches formula: {expected_total} [OK]")
else:
print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]")
errors.append(f"Block coverage total mismatch")
def verify_load_operations_summary(num_chunks: int):
"""Verify all H2D load operations transferred correct data."""
print("\n" + "=" * 60)
print("H2D Load Operations Verification Summary")
print("=" * 60)
if not load_operations:
print(" WARNING: No load operations recorded!")
print(" (This may indicate load verification was not installed)")
return
num_ok = 0
num_fail = 0
loads_per_chunk = {}
for op in load_operations:
chunk_idx = op['chunk_idx']
if chunk_idx not in loads_per_chunk:
loads_per_chunk[chunk_idx] = []
loads_per_chunk[chunk_idx].append(op)
if op['k_ok'] and op['v_ok']:
num_ok += 1
else:
num_fail += 1
# Print per-chunk summary for first/last chunks
for chunk_idx in sorted(loads_per_chunk.keys()):
ops = loads_per_chunk[chunk_idx]
chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok'])
chunk_fail = len(ops) - chunk_ok
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0:
# Show loaded block IDs in order
block_ids = [op['cpu_block_id'] for op in ops]
if chunk_fail == 0:
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]")
else:
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]")
for op in ops:
if not (op['k_ok'] and op['v_ok']):
print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: "
f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, "
f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}")
elif chunk_idx == 4:
print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...")
# Print overall summary
print(f"\n Total load operations: {len(load_operations)}")
print(f" Successful: {num_ok}, Failed: {num_fail}")
if num_fail == 0:
print(f" All H2D transfers verified correct [OK]")
else:
print(f" {num_fail} H2D transfers had incorrect data [FAIL]")
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Test: CPU Offload Correctness with Distinctive KV Patterns")
print("=" * 60)
print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks")
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)")
print()
# 1. Initialize LLM
print("Initializing LLM...")
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# 2. Register hooks
hooks = register_hooks(llm)
print(f"Registered {len(hooks)} hooks")
# 3. Install load verification (instrument load_to_slot_layer)
install_load_verification(llm)
# 4. Generate prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
num_chunks = INPUT_LEN // BLOCK_SIZE
# 5. Run prefill
print("\n" + "=" * 60)
print("Running Prefill with KV Pattern Injection...")
print("=" * 60)
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# 6. Remove hooks and uninstall load verification
for hook in hooks:
hook.remove()
uninstall_load_verification()
# 7. Final verification
verify_final_cpu_state(llm, num_chunks)
# 8. Block coverage summary
verify_block_coverage_summary(num_chunks)
# 9. H2D load operations summary
verify_load_operations_summary(num_chunks)
# 10. Report results
print("\n" + "=" * 60)
if errors:
print(f"test_offload_correctness: FAILED ({len(errors)} errors)")
for err in errors[:10]: # Show first 10 errors
print(f" - {err}")
exit(1)
else:
print("test_offload_correctness: PASSED")
print("=" * 60)

View File

@@ -1,119 +0,0 @@
"""
Test script for OffloadEngine - CPU-GPU KV cache transfer engine.
Demonstrates: ring buffer, H2D/D2H transfers, CUDA events, KV access.
"""
import torch
from nanovllm.kvcache.offload_engine import OffloadEngine
# ============================================================
# Utility Functions
# ============================================================
def verify(tensor: torch.Tensor, expected: float, name: str) -> None:
"""Verify tensor contains expected value."""
actual = tensor.mean().item()
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
# ============================================================
# Configuration
# ============================================================
NUM_LAYERS = 4
NUM_GPU_BLOCKS = 8
NUM_CPU_BLOCKS = 16
BLOCK_SIZE = 64
NUM_KV_HEADS = 4
HEAD_DIM = 32
# ============================================================
# Main Test Script
# ============================================================
# 1. Initialize
engine = OffloadEngine(
num_layers=NUM_LAYERS,
num_gpu_blocks=NUM_GPU_BLOCKS,
num_cpu_blocks=NUM_CPU_BLOCKS,
block_size=BLOCK_SIZE,
num_kv_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
dtype=torch.float16,
)
# 2. Ring buffer slot management
for chunk_idx in range(12):
write_slot = engine.get_write_slot_for_prefill(chunk_idx)
load_slots = engine.get_load_slots_for_prefill(write_slot)
print("chunk idx", chunk_idx, "write slots:", write_slot, "load slots:", load_slots)
assert write_slot == chunk_idx % engine.num_ring_slots
assert write_slot not in load_slots
assert engine.decode_slot == 0
assert engine.get_load_slots_for_decode() == list(range(1, NUM_GPU_BLOCKS))
# 3. Per-slot per-layer H2D transfer
engine.k_cache_cpu[0, 0].fill_(42.0)
engine.v_cache_cpu[0, 0].fill_(42.5)
engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=0)
engine.wait_slot_layer(slot_idx=1, layer_id=0)
verify(engine.k_cache_gpu[0, 1], 42.0, "H2D K")
verify(engine.v_cache_gpu[0, 1], 42.5, "H2D V")
# 4. Compute-done event (pipeline safety)
engine.record_slot_compute_done(slot_idx=1, layer_id=0)
engine.k_cache_cpu[0, 1].fill_(100.0)
engine.v_cache_cpu[0, 1].fill_(100.5)
engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=1)
engine.wait_slot_layer(slot_idx=1, layer_id=0)
verify(engine.k_cache_gpu[0, 1], 100.0, "Reuse K")
verify(engine.v_cache_gpu[0, 1], 100.5, "Reuse V")
# 5. D2H offload
engine.k_cache_gpu[1, 2].fill_(77.0)
engine.v_cache_gpu[1, 2].fill_(77.5)
engine.offload_slot_to_cpu(slot_idx=2, cpu_block_id=5)
engine.wait_slot_offload(slot_idx=2)
verify(engine.k_cache_cpu[1, 5], 77.0, "D2H K")
verify(engine.v_cache_cpu[1, 5], 77.5, "D2H V")
# 6. KV access methods
k, v = engine.get_kv_for_slot(slot_idx=1, layer_id=0)
assert k.shape == (1, BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM)
k, v = engine.get_kv_for_slots(layer_id=0, slot_indices=[0, 1, 2])
assert k.shape == (1, 3 * BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM)
engine.k_cache_gpu[0, engine.decode_slot].fill_(33.0)
k, v = engine.get_kv_for_decode_slot_accumulated(layer_id=0, num_tokens=10)
assert k.shape == (1, 10, NUM_KV_HEADS, HEAD_DIM)
verify(k, 33.0, "Decode slot K")
# 7. Batch transfer
cpu_blocks = [2, 3, 4]
gpu_slots = [3, 4, 5]
for cpu_id in cpu_blocks:
engine.k_cache_cpu[0, cpu_id].fill_(50.0 + cpu_id)
engine.load_cpu_blocks_to_gpu_slots(layer_id=0, cpu_block_ids=cpu_blocks, gpu_slot_ids=gpu_slots)
for cpu_id, gpu_slot in zip(cpu_blocks, gpu_slots):
verify(engine.k_cache_gpu[0, gpu_slot], 50.0 + cpu_id, f"Batch slot {gpu_slot}")
# 8. Gather indices (CUDA graph compatible)
engine.update_gather_indices(layer_id=0, mappings=[(0, 0), (1, 1), (2, 2)])
assert engine.gather_indices_gpu[0, :3].tolist() == [0, 1, 2]
engine.clear_gather_indices(layer_id=0)
assert engine.gather_indices_gpu[0, 0].item() == -1
print("test_offload_engine: PASSED")

View File

@@ -1,51 +0,0 @@
"""
Test script for chunked prefill with CPU offload.
Demonstrates: LLM initialization, prefill execution with CPU offload enabled.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
from random import randint, seed
from nanovllm import LLM, SamplingParams
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
# ============================================================
# Main Test Script
# ============================================================
# 1. Initialize LLM with CPU offload
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=1024,
num_gpu_blocks=NUM_GPU_BLOCKS,
)
# 2. Generate random prompt tokens
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
# 3. Run prefill (max_tokens=1 to focus on prefill only)
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# 4. Verify output
assert len(outputs) == 1
assert "token_ids" in outputs[0]
assert len(outputs[0]["token_ids"]) == 1
print("test_prefill: PASSED")

136
tests/test_quest_policy.py Normal file
View File

@@ -0,0 +1,136 @@
"""
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
Demonstrates the key limitation: scores are AVERAGED across heads,
so blocks strongly needed by one head but not others may be dropped.
This is the expected Quest behavior - not a bug.
"""
import torch
from nanovllm.kvcache.sparse import (
create_sparse_policy,
SparsePolicyType,
PolicyContext,
)
# ============================================================
# Test: Per-Head Score Averaging in GQA
# ============================================================
# Determine device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running test on device: {device}")
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
# topk=2 to make selection competitive
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
quest.initialize(
num_layers=1,
num_kv_heads=2,
head_dim=4,
num_cpu_blocks=6,
dtype=torch.float32,
device=device, # Metadata stored on GPU
)
metadata = quest.metadata
def set_key(block_id, head_id, values):
"""Set both key_min and key_max to same values for deterministic scoring."""
# Values need to be on the same device as metadata
tensor = torch.tensor(values, device=device)
metadata.key_min[block_id, 0, head_id, :] = tensor
metadata.key_max[block_id, 0, head_id, :] = tensor
# ============================================================
# Design: Different heads want different blocks
# ============================================================
#
# Query = [1,1,1,1] for all heads, so score = sum(key values)
#
# Block | Head 0 | Head 1 | Average | Result
# ------|--------|--------|---------|--------
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 2: Both heads want equally (highest average)
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 3: Both heads want moderately
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
# Block 4: Head 0 strongly wants, Head 1 neutral
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
# Block 5: Head 1 strongly wants, Head 0 neutral
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# ============================================================
# Run selection
# ============================================================
# Query on same device as metadata
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=0,
query=query,
is_prefill=False,
block_size=1024,
total_kv_len=6144,
)
available = list(range(6))
selected = quest.select_blocks(available, ctx)
# ============================================================
# Verify: Averaging behavior
# ============================================================
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
# Key insight: blocks 0 and 1 have score +4 for ONE head,
# but they cancel out due to averaging with the other head's -4
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
# But +2 < +3 (block 3), so they don't make the top-2
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
# Verify metadata is on correct device
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
print(f"✓ Metadata stored on {device.type.upper()}")
print("\ntest_quest_policy: PASSED")

199
tests/test_sequential.py Normal file
View File

@@ -0,0 +1,199 @@
"""
Sequential inference test for LLM.
Tests: After completing one prompt, the system can correctly handle
a second prompt with a clean state (first prompt's KV cache deallocated).
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from utils import generate_needle_prompt, check_needle_answer
def run_sequential_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
enable_cpu_offload: bool = False,
verbose: bool = True,
) -> bool:
"""
Run sequential inference test with two different prompts.
Each prompt has a different needle value. Both must be retrieved correctly.
"""
if verbose:
print(f"\n{'='*60}")
print(f"Sequential Inference Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"CPU offload: {enable_cpu_offload}")
print(f"{'='*60}\n")
# Initialize LLM once
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm = LLM(model_path, **llm_kwargs)
sampling_params = SamplingParams(
temperature=0.6,
max_tokens=32,
)
# ============================================================
# Test 1: First prompt with needle value "1234"
# ============================================================
needle_value_1 = "1234"
if verbose:
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
prompt_1, expected_1 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_1,
)
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
output_text_1 = outputs_1[0]["text"]
passed_1 = check_needle_answer(output_text_1, expected_1)
if verbose:
print(f" Expected: {expected_1}")
print(f" Output: {output_text_1[:100]}...")
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
# ============================================================
# Test 2: Second prompt with needle value "5678"
# ============================================================
needle_value_2 = "5678"
if verbose:
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
prompt_2, expected_2 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_2,
)
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
output_text_2 = outputs_2[0]["text"]
passed_2 = check_needle_answer(output_text_2, expected_2)
if verbose:
print(f" Expected: {expected_2}")
print(f" Output: {output_text_2[:100]}...")
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
# ============================================================
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
# ============================================================
needle_value_3 = "9999"
if verbose:
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
prompt_3, expected_3 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_3,
)
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
output_text_3 = outputs_3[0]["text"]
passed_3 = check_needle_answer(output_text_3, expected_3)
if verbose:
print(f" Expected: {expected_3}")
print(f" Output: {output_text_3[:100]}...")
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
# ============================================================
# Summary
# ============================================================
all_passed = passed_1 and passed_2 and passed_3
if verbose:
print(f"\n{'='*60}")
print(f"Summary")
print(f"{'='*60}")
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
print(f"{'='*60}\n")
return all_passed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential inference test")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=36 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload"
)
args = parser.parse_args()
passed = run_sequential_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
enable_cpu_offload=args.enable_offload,
verbose=True,
)
if passed:
print("test_sequential: PASSED")
else:
print("test_sequential: FAILED")
exit(1)

View File

@@ -1,176 +0,0 @@
"""
Tests for CUDA sgDMA (cudaMemcpy2D) extension.
Author: Zijie Tian
"""
import torch
import time
from nanovllm.comm import memcpy_2d
# ============================================================
# Configuration
# ============================================================
class Config:
num_layers = 32
num_blocks = 10
block_size = 4096
num_kv_heads = 8
head_dim = 128
dtype = torch.float16
@property
def features_per_block(self):
return self.block_size * self.num_kv_heads * self.head_dim
@property
def bytes_per_block(self):
return self.features_per_block * self.dtype.itemsize
@property
def bytes_per_layer(self):
return self.num_blocks * self.bytes_per_block
# ============================================================
# Performance Benchmark
# ============================================================
def benchmark_sgdma():
"""Benchmark cudaMemcpy2D vs standard PyTorch methods."""
print("\n=== Performance Benchmark ===")
cfg = Config()
print(f" Configuration:")
print(f" num_layers: {cfg.num_layers}")
print(f" num_blocks: {cfg.num_blocks}")
print(f" block_size: {cfg.block_size}")
print(f" dtype: {cfg.dtype}")
print(f" bytes_per_block: {cfg.bytes_per_block / 1024:.1f} KB")
print(f" total transfer size: {cfg.num_layers * cfg.bytes_per_block / 1024 / 1024:.1f} MB")
num_iterations = 10
warmup = 3
test_block_id = 5
# Allocate memory
cpu_strided = torch.randn(
cfg.num_layers,
cfg.num_blocks,
cfg.features_per_block,
dtype=cfg.dtype,
pin_memory=True
)
# ========================================
# Method A: cudaMemcpy2D with sgDMA
# ========================================
gpu_buffer_a = torch.empty(cfg.num_layers, cfg.features_per_block, dtype=cfg.dtype, device='cuda')
spitch = cfg.bytes_per_layer
dpitch = cfg.bytes_per_block
width = cfg.bytes_per_block
height = cfg.num_layers
src_view = cpu_strided[:, test_block_id, :]
# Warmup
for _ in range(warmup):
memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d")
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d")
torch.cuda.synchronize()
elapsed_a = time.perf_counter() - start
avg_time_a = elapsed_a / num_iterations * 1000 # ms
bandwidth_a = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_a
print(f"\n Method A (cudaMemcpy2D sgDMA):")
print(f" Avg time: {avg_time_a:.3f} ms")
print(f" Bandwidth: {bandwidth_a:.2f} GB/s")
# ========================================
# Method B: PyTorch .cuda() on strided view
# ========================================
# Warmup
for _ in range(warmup):
_ = cpu_strided[:, test_block_id, :].cuda()
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
_ = cpu_strided[:, test_block_id, :].cuda()
torch.cuda.synchronize()
elapsed_b = time.perf_counter() - start
avg_time_b = elapsed_b / num_iterations * 1000 # ms
bandwidth_b = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_b
print(f"\n Method B (PyTorch .cuda() on strided):")
print(f" Avg time: {avg_time_b:.3f} ms")
print(f" Bandwidth: {bandwidth_b:.2f} GB/s")
# ========================================
# Method C: PyTorch .cuda() on contiguous (pinned)
# ========================================
# Create contiguous version with pinned memory
cpu_contiguous = torch.empty(
cfg.num_layers,
cfg.features_per_block,
dtype=cfg.dtype,
pin_memory=True
)
cpu_contiguous.copy_(cpu_strided[:, test_block_id, :])
# Warmup
for _ in range(warmup):
_ = cpu_contiguous.cuda()
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
_ = cpu_contiguous.cuda()
torch.cuda.synchronize()
elapsed_c = time.perf_counter() - start
avg_time_c = elapsed_c / num_iterations * 1000 # ms
bandwidth_c = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_c
print(f"\n Method C (PyTorch .cuda() on contiguous):")
print(f" Avg time: {avg_time_c:.3f} ms")
print(f" Bandwidth: {bandwidth_c:.2f} GB/s")
# Summary
print(f"\n ========================================")
print(f" Performance Summary:")
print(f" Method A vs Method B: {bandwidth_a / bandwidth_b:.2f}x speedup")
print(f" Method A vs Method C: {bandwidth_a / bandwidth_c * 100:.2f}%")
print(f" ========================================")
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
print("=== CUDA sgDMA (cudaMemcpy2D) Benchmark ===")
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available. Skipping benchmark.")
exit(1)
# Print GPU info
print(f"Using GPU: {torch.cuda.get_device_name()}")
# Run benchmark
benchmark_sgdma()
print("\n=== Benchmark Complete ===")

181
tests/utils.py Normal file
View File

@@ -0,0 +1,181 @@
"""
Test utilities for nano-vllm.
"""
import re
from typing import Tuple
# ============================================================
# Needle-in-Haystack Test Utilities
# ============================================================
# Haystack filler paragraphs (various topics to create realistic context)
HAYSTACK_PARAGRAPHS = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
verbose: bool = True,
) -> Tuple[str, str]:
"""
Generate a needle-in-haystack prompt of exactly target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
verbose: Whether to print generation info
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question text
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
else:
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
def build_prompt(haystack_parts, needle_idx):
"""Build full prompt from haystack parts with needle inserted."""
parts = haystack_parts.copy()
parts.insert(needle_idx, needle)
full_text = "".join(parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
return full_text + question_text
def count_tokens(prompt):
return len(tokenizer.encode(prompt, add_special_tokens=False))
def get_needle_idx(parts):
idx = int(len(parts) * needle_position)
return max(0, min(idx, len(parts)))
# Pre-compute tokens per paragraph for efficiency (avoid O(n²) tokenization)
para_tokens = []
for para in HAYSTACK_PARAGRAPHS:
para_tokens.append(len(tokenizer.encode(para, add_special_tokens=False)))
avg_tokens_per_para = sum(para_tokens) / len(para_tokens)
# Estimate overhead (needle + question + chat template)
overhead_prompt = build_prompt([HAYSTACK_PARAGRAPHS[0]], 0)
overhead_tokens = count_tokens(overhead_prompt) - para_tokens[0]
# Phase 1: Estimate number of paragraphs needed
estimated_paras = int((target_length - overhead_tokens) / avg_tokens_per_para) + 1
# Build haystack with estimated paragraphs
haystack_parts = []
for i in range(estimated_paras):
haystack_parts.append(HAYSTACK_PARAGRAPHS[i % len(HAYSTACK_PARAGRAPHS)])
# Phase 2: Adjust to get closer to target
prompt = build_prompt(haystack_parts, get_needle_idx(haystack_parts))
current_tokens = count_tokens(prompt)
# Add more if under target
para_idx = len(haystack_parts)
while current_tokens < target_length and para_idx < 100000:
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
haystack_parts.append(para)
current_tokens += para_tokens[para_idx % len(HAYSTACK_PARAGRAPHS)]
para_idx += 1
# Remove if over target
while current_tokens > target_length + 100 and len(haystack_parts) > 1:
removed_para_idx = (len(haystack_parts) - 1) % len(HAYSTACK_PARAGRAPHS)
haystack_parts.pop()
current_tokens -= para_tokens[removed_para_idx]
# Final build
needle_idx = get_needle_idx(haystack_parts)
prompt = build_prompt(haystack_parts, needle_idx)
actual_tokens = count_tokens(prompt)
if verbose:
print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
# Also try to find it as a standalone number
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
def generate_random_token_ids(
length: int,
vocab_size: int = 10000,
seed: int = 42,
) -> list:
"""
Generate random token IDs for testing.
Args:
length: Number of tokens to generate
vocab_size: Maximum token ID (exclusive)
seed: Random seed for reproducibility
Returns:
List of random token IDs
"""
from random import randint, seed as set_seed
set_seed(seed)
return [randint(0, vocab_size - 1) for _ in range(length)]