Combines two performance optimization features: - perf_opt-1: Cross-layer pipeline for decode (double-buffered layer cache) - perf_opt-2: Per-layer prefill buffer for async offload Both features are complementary and improve CPU offload performance. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
19 KiB
CLAUDE.md
This file provides guidance to Claude Code when working with this repository.
Overview
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
GPU Mutex for Multi-Instance Debugging
IMPORTANT: When running multiple Claude instances for parallel debugging, only one GPU (cuda:0) is available. Before executing ANY command that uses the GPU (python scripts, benchmarks, tests), Claude MUST:
-
Check GPU availability by running:
nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader -
If processes are running on GPU:
- Wait and retry every 10 seconds until GPU is free
- Use this polling loop:
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do echo "GPU busy, waiting 10s..." sleep 10 done
-
Only proceed when
nvidia-smi --query-compute-apps=pid --format=csv,noheaderreturns empty output
Example workflow:
# First check if GPU is in use
nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader
# If output is empty, proceed with your command
python bench_offload.py
# If output shows processes, wait until they finish
Note: This applies to ALL GPU operations including:
- Running tests (
python tests/test_*.py) - Running benchmarks (
python bench*.py) - Running examples (
python example.py) - Any script that imports torch/cuda
Local Package Installation for Multi-Instance
CRITICAL: After ANY code modification in the nanovllm/ directory, you MUST reinstall the package before running tests or benchmarks:
pip install -e . --prefix=./.local --no-deps
Then run with PYTHONPATH:
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
IMPORTANT: When running multiple Claude instances on different worktrees, do NOT use pip install -e . globally as it will affect other instances. Instead, use local installation:
-
Install to worktree-local directory:
pip install -e . --prefix=./.local --no-deps -
Set PYTHONPATH before running any Python command:
export PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH -
Combined example:
# One-liner for running tests with local package PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python tests/test_needle.py
Note: The Python version in the path (python3.10) should match your environment.
CRITICAL: After making code changes to nanovllm/ source files, you MUST reinstall the package for changes to take effect:
pip install -e . --prefix=./.local --no-deps
Without reinstallation, Python will use the old cached version and your changes will NOT be reflected!
Sparse Attention
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to docs/sparse_attention_guide.md.
Quest Sparse Policy
Files: nanovllm/kvcache/sparse/quest.py, nanovllm/kvcache/sparse/policy.py
Quest policy selects Top-K blocks based on query-key similarity bounds using min/max key metadata.
Scoring Mechanism:
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
score_max = torch.einsum('hd,bhd->bh', q, key_max) # [num_blocks, kv_heads]
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] ← averaged!
Critical Limitation - No Per-Head Scheduling:
The .mean(dim=-1) averages scores across all heads, making a unified block selection for all heads:
Block A: head0 needs (+4), head1 doesn't (-4) → avg = 0 → NOT selected
Block B: head0 doesn't (-4), head1 needs (+4) → avg = 0 → NOT selected
Block C: both heads moderately need (+2, +2) → avg = +2 → selected
Why Per-Head Scheduling is Infeasible:
- Memory Layout: GPU cache stores all heads together
[block_size, kv_heads, head_dim] - FlashAttention: Requires complete heads - partial heads cause dimension mismatch
- Block Granularity: If any head needs a block, the entire block (all heads) must be loaded
Policy Types:
FullAttentionPolicy:supports_prefill=True, supports_decode=True- loads all blocksQuestPolicy:supports_prefill=False, supports_decode=True- decode-only Top-K selection
Architecture
Core Components
- LLMEngine (
llm_engine.py): Main entry, runs prefill-decode loop - ModelRunner (
model_runner.py): Loads weights, allocates KV cache, CUDA graphs - Scheduler (
scheduler.py): Two-phase scheduling (prefill → decode) - BlockManager (
block_manager.py): Paged attention with prefix caching (xxhash), default block size 4096 - Attention (
layers/attention.py): FlashAttention with chunked methods for CPU offload
PyTorch Hooks for Debugging
Hook Positions in Qwen3
decoder_layer
├── input_layernorm (RMSNorm)
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
│ ├── q_proj → q_norm → RoPE
│ ├── k_proj → k_norm → RoPE
│ ├── v_proj
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
│ │ └── FlashAttention / SDPA
│ └── o_proj
├── post_attention_layernorm (RMSNorm)
└── mlp (Qwen3MLP)
Hook Types & Data Shapes
| Hook Position | Type | Captured Data |
|---|---|---|
self_attn |
post | [batch, seq_len, hidden_size] - after o_proj |
self_attn.attn |
pre | Q,K,V: [seq_len, num_heads, head_dim] - after RoPE |
self_attn.attn |
post | [seq_len, num_heads, head_dim] - before o_proj |
Example: Capture Attention Outputs
storage = {}
def make_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
if isinstance(output, tuple):
attn_output = output[0]
else:
attn_output = output
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
# Register hooks
hooks = []
for layer_idx, layer in enumerate(model.model.layers):
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
# Run inference...
# Cleanup
for hook in hooks:
hook.remove()
Reference Implementation
Key files:
tests/modeling_qwen3.py: Reference Qwen3 implementation (torch + transformers only)tests/test_needle_ref.py: Reference needle test using custom Qwen3tests/test_needle.py: Needle-in-haystack test for nanovllm
Common Pitfalls
- Shape mismatch: nanovllm uses
[num_tokens, ...]while torch uses[batch, seq_len, ...] - Hook position:
self_attncaptures after o_proj,self_attn.attncaptures before o_proj - Output format: nanovllm returns tuple
(attn_output, None), handle withoutput[0]
CPU Offload System
Ring Buffer Design
GPU Slots: [0] [1] [2] [3] ... (unified ring buffer)
Prefill: slot = chunk_idx % N
Decode: slot[0] = decode, slots[1:] = load previous chunks
Key Files: kvcache/offload_engine.py, kvcache/hybrid_manager.py
Memory Layout:
- GPU:
[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] - CPU:
[num_layers, num_cpu_blocks, ...](pinned memory)
Key Methods:
load_to_slot_layer(slot, layer, cpu_block): Async H2D loadoffload_slot_to_cpu(slot, cpu_block): Async D2H offload- Per-slot per-layer CUDA events for fine-grained synchronization
Pipeline: N-way pipeline with dedicated streams for full compute-transfer overlap. Pipeline depth = N-1 (prefill), (N-1)/2 (decode).
Stream Architecture
Transfer Streams: [slot_0_stream] [slot_1_stream] ... [slot_N_stream]
↓ ↓ ↓
GPU Slots: [slot_0] [slot_1] ... [slot_N]
↓ ↓ ↓
Compute Stream: ←←←←←←←←←←←← [dedicated compute stream] →→→→→→→→→→→→
Key Design Decisions:
- Per-slot transfer streams: Each GPU slot has its own CUDA stream for H2D transfers, enabling parallel loading
- Dedicated compute stream: Created with
torch.cuda.Stream()(NOTcurrent_stream()) to avoid implicit synchronization with default stream - CUDA Events:
ring_slot_ready(transfer complete),ring_slot_compute_done(safe to overwrite)
Scatter-Gather DMA (sgDMA) - INTEGRATED ✓
Problem & Solution
Problem: Strided CPU cache access k_cache_cpu[:, block_id] caused slow Device→Pageable transfers at ~1.4 GB/s instead of optimal ~24 GB/s pinned memory bandwidth.
Solution: Implemented cudaMemcpy2D via custom CUDA extension to handle strided layouts natively. Integration complete as of 2025-12-25.
Quick Start
from nanovllm.comm import memcpy_2d_async
# Transfer block_id across all layers
spitch = num_blocks * features * dtype_size # stride between layers
dpitch = features * dtype_size # contiguous destination
width = features * dtype_size # bytes per row
height = num_layers # number of rows
memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height, "h2d", stream)
Benchmark Performance (Synthetic, 256MB)
| Method | Bandwidth | Speedup |
|---|---|---|
| cudaMemcpy2D (sgDMA) | 24.95 GB/s | Baseline |
| PyTorch strided | 4.25 GB/s | 5.87x slower |
| PyTorch contiguous | 24.92 GB/s | Same |
Real-World Performance (A100, Attention Offload)
Measured from test_attention_offload.py profiling:
| Transfer Type | Count | Bandwidth | Previous | Speedup |
|---|---|---|---|---|
| Device→Pinned (D2H) | 416 | 21.49 GB/s | 1.40 GB/s | 15.35x |
| Pinned→Device (H2D) | 24,960 | 23.39 GB/s | N/A | N/A |
| Device→Pageable (D2H) | 0 | N/A | ~40 transfers | Eliminated |
Verification: All slow Device→Pageable transfers eliminated. System achieves near-optimal PCIe Gen3 x16 bandwidth.
Build: python setup.py build_ext --inplace
Files:
csrc/sgdma_kernel.cu,csrc/sgdma.cpp: CUDA extensionnanovllm/comm/sgdma.py: Python APIkvcache/offload_engine.py: Integration (4 methods updated)
Integration Details
Modified methods in offload_engine.py:
load_to_slot_all_layers(): H2D ring buffer loadoffload_slot_to_cpu(): D2H ring buffer offloadoffload_decode_slot(): D2H decode slot offloadload_cpu_blocks_to_gpu_slots_all_layers(): Batch H2D load
Example replacement:
# Before (slow, Device→Pageable fallback)
self.k_cache_gpu[:, slot].copy_(self.k_cache_cpu[:, cpu_block], non_blocking=True)
# After (fast, Device→Pinned via sgDMA)
memcpy_2d_async(
self.k_cache_gpu[:, slot], self.k_cache_cpu[:, cpu_block],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
Actual Impact: 15.35x faster D2H transfers, eliminates memory transfer bottleneck. Expected 2-3x overall prefill throughput improvement.
Online Softmax Merge - Triton Fused Kernel ✓
Problem & Solution
Problem: Original PyTorch implementation of merge_attention_outputs() launches 7 separate kernels per merge operation:
torch.maximum()- max(lse1, lse2)torch.exp()(2x) - exp(lse1-max), exp(lse2-max)transpose()+unsqueeze()- reshape for broadcasting- Accumulation (6x) - weighted sum operations
- Division - normalize output
torch.log()- merge LSE.to()- type conversion
Profiling revealed: In ChunkedPrefill with 8 layers, these operations consumed 698 ms GPU time (vs FlashAttention 603 ms), becoming a major bottleneck.
Solution: Implemented Triton fused kernels that combine all operations into 2 kernels. Integration complete as of 2025-12-25.
Implementation
File: nanovllm/kvcache/chunked_attention.py:278-408
Two Triton kernels replace all PyTorch operations:
@triton.jit
def _merge_lse_kernel(...):
"""Fused: max + exp + log"""
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
lse_merged = max_lse + tl.log(exp1 + exp2)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@triton.jit
def _merge_output_kernel(...):
"""Fused: broadcast + weighted sum + division"""
# Load LSE, compute scaling factors
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
sum_exp = exp1 + exp2
# Process headdim in chunks
for d_offset in range(0, headdim, BLOCK_SIZE):
o1_val = tl.load(o1_ptr + o_idx, mask=mask)
o2_val = tl.load(o2_ptr + o_idx, mask=mask)
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
Performance Results
From test_attention_offload.py profiling (8 layers, 16K tokens, 16 chunks, 10 iterations):
| Metric | PyTorch (7 kernels) | Triton (2 kernels) | Speedup |
|---|---|---|---|
| GPU time (8 layers) | 698 ms | 160.7 ms | 4.3x |
| Per-layer time | 87.3 ms | 20.1 ms | 4.3x |
| Avg per merge | 56 µs | 12.9 µs | 4.3x |
| Kernel launches | 10,920 | 3,120 | 71% reduction |
Breakdown (per-layer, 1,560 merges):
_merge_output_kernel: 126.9 ms / 8 = 15.9 ms/layer (avg 10.2 µs/call)_merge_lse_kernel: 33.8 ms / 8 = 4.2 ms/layer (avg 2.7 µs/call)
Overall ChunkedPrefill Impact
GPU time distribution (test_attention_offload.py):
| Component | Time (ms) | Percentage |
|---|---|---|
| FlashAttention | 603.2 | 74.8% |
| Triton Merge | 160.7 | 19.9% |
| Other | 42.1 | 5.3% |
| Total | 806.0 | 100% |
If using PyTorch merge (estimated):
- Total GPU time: ~1,343 ms
- Overall speedup with Triton: 1.67x
Key Files
nanovllm/kvcache/chunked_attention.py: Triton kernels + merge function
Known Issues and Fixes
Partial Last Block Bug (FIXED ✓)
Problem: When prefill token count is not an exact multiple of block_size, decode outputs garbage.
Root Cause: _chunked_decode_attention calculated last_block_valid_tokens using len(seq) - 1, which increases during decode. But CPU blocks are fixed after prefill!
# BUG: len(seq) increases each decode step
total_prefill_tokens = len(seq) - 1 # Wrong!
last_block_valid_tokens = total_prefill_tokens % block_size # Reads garbage from CPU
Fix: Cache original prefill length in HybridKVCacheManager.get_prefill_len():
# CORRECT: Use cached prefill length
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Fixed value
Files Modified:
nanovllm/kvcache/hybrid_manager.py: Added_prefill_lendict andget_prefill_len()methodnanovllm/layers/attention.py: Useget_prefill_len()instead oflen(seq) - 1
Block Size 4096 Race Condition (FIXED ✓)
Problem: block_size=4096 with multiple chunks produced index_copy_(): index out of bounds CUDA error during Chunk 2 processing.
Root Cause: Race condition between default stream and compute stream. In _prepare_chunked_offload_chunk(), slot_mapping tensor was created with non_blocking=True H2D transfer on the default stream. However, store_kvcache runs on compute_stream. Without synchronization, compute_stream could use slot_mapping before its transfer completed, causing corrupted indices.
Fix (in attention.py):
if is_chunked_offload:
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
Tested block sizes: 512, 1024, 4096, 8192 - all pass.
Configuration
| Parameter | Default | Notes |
|---|---|---|
kvcache_block_size |
1024 | Tokens per block (4096 now works after race condition fix) |
max_num_batched_tokens |
16384 | Set = max_model_len for long context |
gpu_memory_utilization |
0.9 | GPU memory fraction |
enable_cpu_offload |
False | Enable for long context |
Benchmarking
Files: bench.py (GPU), bench_offload.py (CPU offload), bench_vllm.py (comparison)
Common Issues:
max_num_batched_tokens < max_model_len: Set equal for long context- CUDA graph dimension mismatch: Ensure
input_len + output_len <= max_model_len - RoPE out of bounds: Check model's
max_position_embeddingsin config.json
Model Limits:
- Qwen3-0.6B/4B: 40960 tokens
- Qwen2.5-7B-Instruct-1M: 1048576 tokens
Performance (Qwen3-0.6B):
- GPU: ~18k tok/s (prefill), ~100 tok/s (decode)
- CPU Offload (16K): ~14k tok/s (prefill)
- CPU Offload (32K): ~13k tok/s (prefill)
Performance Summary
Completed Optimizations ✓
-
sgDMA Integration (2025-12-25)
- Eliminated Device→Pageable transfers
- Achieved 21-23 GB/s bandwidth (near PCIe limit)
- 15.35x speedup on memory transfers
-
Triton Fused Merge Kernel (2025-12-25)
- Reduced 7 PyTorch kernels → 2 Triton kernels
- 4.3x speedup on merge operations
- 1.67x overall ChunkedPrefill speedup
-
N-way Pipeline with Dedicated Streams (2025-12-25)
- Per-slot transfer streams for parallel H2D across slots
- Dedicated compute stream (avoids CUDA default stream implicit sync)
- N-way pipeline using all available slots (not just 2-slot double buffering)
- 2.0x improvement: 7.2k → 14.1k tok/s (16K tokens prefill)
Current Performance Bottlenecks
From profiling (test_attention_offload.py, 8 layers, 16K tokens):
| Component | GPU Time | Percentage | Optimization Potential |
|---|---|---|---|
| FlashAttention | 603 ms | 74.8% | ⚠️ Main bottleneck |
| Triton Merge | 161 ms | 19.9% | ✓ Optimized |
| Other | 42 ms | 5.3% | Minor |
Future Optimization Directions
-
FlashAttention Optimization (highest priority)
- Current: 74.8% of GPU time
- Potential: Custom FlashAttention kernel for chunked case
- Expected: 1.5-2x additional speedup
-
Pipeline Optimization✓ COMPLETEDBetter overlap between compute and memory transferMulti-stream execution- See: N-way Pipeline with Dedicated Streams above
-
Alternative to sgDMA (lower priority, PyTorch-only)
- Reorganize cache layout:
[num_cpu_blocks, num_layers, ...]instead of[num_layers, num_cpu_blocks, ...] - Trade-off: Extensive refactoring vs minimal sgDMA approach
- Same performance as sgDMA (~24 GB/s)
- Reorganize cache layout:
Author: Zijie Tian