11 KiB
CLAUDE.md
This file provides guidance to Claude Code when working with this repository.
Overview
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
Architecture
Core Components
- LLMEngine (
llm_engine.py): Main entry, runs prefill-decode loop - ModelRunner (
model_runner.py): Loads weights, allocates KV cache, CUDA graphs - Scheduler (
scheduler.py): Two-phase scheduling (prefill → decode) - BlockManager (
block_manager.py): Paged attention with prefix caching (xxhash), default block size 4096 - Attention (
layers/attention.py): FlashAttention with chunked methods for CPU offload
CPU Offload System
Ring Buffer Design
GPU Slots: [0] [1] [2] [3] ... (unified ring buffer)
Prefill: slot = chunk_idx % N
Decode: slot[0] = decode, slots[1:] = load previous chunks
Key Files: kvcache/offload_engine.py, kvcache/hybrid_manager.py
Memory Layout:
- GPU:
[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] - CPU:
[num_layers, num_cpu_blocks, ...](pinned memory)
Key Methods:
load_to_slot_layer(slot, layer, cpu_block): Async H2D loadoffload_slot_to_cpu(slot, cpu_block): Async D2H offload- Per-slot per-layer CUDA events for fine-grained synchronization
Pipeline: N-way pipeline with dedicated streams for full compute-transfer overlap. Pipeline depth = N-1 (prefill), (N-1)/2 (decode).
Stream Architecture
Transfer Streams: [slot_0_stream] [slot_1_stream] ... [slot_N_stream]
↓ ↓ ↓
GPU Slots: [slot_0] [slot_1] ... [slot_N]
↓ ↓ ↓
Compute Stream: ←←←←←←←←←←←← [dedicated compute stream] →→→→→→→→→→→→
Key Design Decisions:
- Per-slot transfer streams: Each GPU slot has its own CUDA stream for H2D transfers, enabling parallel loading
- Dedicated compute stream: Created with
torch.cuda.Stream()(NOTcurrent_stream()) to avoid implicit synchronization with default stream - CUDA Events:
ring_slot_ready(transfer complete),ring_slot_compute_done(safe to overwrite)
Scatter-Gather DMA (sgDMA) - INTEGRATED ✓
Problem & Solution
Problem: Strided CPU cache access k_cache_cpu[:, block_id] caused slow Device→Pageable transfers at ~1.4 GB/s instead of optimal ~24 GB/s pinned memory bandwidth.
Solution: Implemented cudaMemcpy2D via custom CUDA extension to handle strided layouts natively. Integration complete as of 2025-12-25.
Quick Start
from nanovllm.comm import memcpy_2d_async
# Transfer block_id across all layers
spitch = num_blocks * features * dtype_size # stride between layers
dpitch = features * dtype_size # contiguous destination
width = features * dtype_size # bytes per row
height = num_layers # number of rows
memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height, "h2d", stream)
Benchmark Performance (Synthetic, 256MB)
| Method | Bandwidth | Speedup |
|---|---|---|
| cudaMemcpy2D (sgDMA) | 24.95 GB/s | Baseline |
| PyTorch strided | 4.25 GB/s | 5.87x slower |
| PyTorch contiguous | 24.92 GB/s | Same |
Real-World Performance (A100, Attention Offload)
Measured from test_attention_offload.py profiling:
| Transfer Type | Count | Bandwidth | Previous | Speedup |
|---|---|---|---|---|
| Device→Pinned (D2H) | 416 | 21.49 GB/s | 1.40 GB/s | 15.35x |
| Pinned→Device (H2D) | 24,960 | 23.39 GB/s | N/A | N/A |
| Device→Pageable (D2H) | 0 | N/A | ~40 transfers | Eliminated |
Verification: All slow Device→Pageable transfers eliminated. System achieves near-optimal PCIe Gen3 x16 bandwidth.
Build: python setup.py build_ext --inplace
Files:
csrc/sgdma_kernel.cu,csrc/sgdma.cpp: CUDA extensionnanovllm/comm/sgdma.py: Python APItests/test_sgdma.py: Standalone benchmarkkvcache/offload_engine.py: Integration (4 methods updated)
Integration Details
Modified methods in offload_engine.py:
load_to_slot_all_layers(): H2D ring buffer loadoffload_slot_to_cpu(): D2H ring buffer offloadoffload_decode_slot(): D2H decode slot offloadload_cpu_blocks_to_gpu_slots_all_layers(): Batch H2D load
Example replacement:
# Before (slow, Device→Pageable fallback)
self.k_cache_gpu[:, slot].copy_(self.k_cache_cpu[:, cpu_block], non_blocking=True)
# After (fast, Device→Pinned via sgDMA)
memcpy_2d_async(
self.k_cache_gpu[:, slot], self.k_cache_cpu[:, cpu_block],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
Actual Impact: 15.35x faster D2H transfers, eliminates memory transfer bottleneck. Expected 2-3x overall prefill throughput improvement.
Online Softmax Merge - Triton Fused Kernel ✓
Problem & Solution
Problem: Original PyTorch implementation of merge_attention_outputs() launches 7 separate kernels per merge operation:
torch.maximum()- max(lse1, lse2)torch.exp()(2x) - exp(lse1-max), exp(lse2-max)transpose()+unsqueeze()- reshape for broadcasting- Accumulation (6x) - weighted sum operations
- Division - normalize output
torch.log()- merge LSE.to()- type conversion
Profiling revealed: In ChunkedPrefill with 8 layers, these operations consumed 698 ms GPU time (vs FlashAttention 603 ms), becoming a major bottleneck.
Solution: Implemented Triton fused kernels that combine all operations into 2 kernels. Integration complete as of 2025-12-25.
Implementation
File: nanovllm/kvcache/chunked_attention.py:278-408
Two Triton kernels replace all PyTorch operations:
@triton.jit
def _merge_lse_kernel(...):
"""Fused: max + exp + log"""
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
lse_merged = max_lse + tl.log(exp1 + exp2)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@triton.jit
def _merge_output_kernel(...):
"""Fused: broadcast + weighted sum + division"""
# Load LSE, compute scaling factors
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
sum_exp = exp1 + exp2
# Process headdim in chunks
for d_offset in range(0, headdim, BLOCK_SIZE):
o1_val = tl.load(o1_ptr + o_idx, mask=mask)
o2_val = tl.load(o2_ptr + o_idx, mask=mask)
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
Performance Results
From test_attention_offload.py profiling (8 layers, 16K tokens, 16 chunks, 10 iterations):
| Metric | PyTorch (7 kernels) | Triton (2 kernels) | Speedup |
|---|---|---|---|
| GPU time (8 layers) | 698 ms | 160.7 ms | 4.3x |
| Per-layer time | 87.3 ms | 20.1 ms | 4.3x |
| Avg per merge | 56 µs | 12.9 µs | 4.3x |
| Kernel launches | 10,920 | 3,120 | 71% reduction |
Breakdown (per-layer, 1,560 merges):
_merge_output_kernel: 126.9 ms / 8 = 15.9 ms/layer (avg 10.2 µs/call)_merge_lse_kernel: 33.8 ms / 8 = 4.2 ms/layer (avg 2.7 µs/call)
Overall ChunkedPrefill Impact
GPU time distribution (test_attention_offload.py):
| Component | Time (ms) | Percentage |
|---|---|---|
| FlashAttention | 603.2 | 74.8% |
| Triton Merge | 160.7 | 19.9% |
| Other | 42.1 | 5.3% |
| Total | 806.0 | 100% |
If using PyTorch merge (estimated):
- Total GPU time: ~1,343 ms
- Overall speedup with Triton: 1.67x
Correctness Verification
Test: tests/test_chunked_attention.py
- 12 test cases (6 configs × 2 dtypes)
- All tests PASS with max error < 0.01
- float16: max_diff=0.000488, mean_diff~0.00001
- bfloat16: max_diff=0.003906, mean_diff~0.0001
Key Files
nanovllm/kvcache/chunked_attention.py: Triton kernels + merge functiontests/test_chunked_attention.py: Correctness teststests/test_attention_offload.py: Performance profiling
Configuration
| Parameter | Default | Notes |
|---|---|---|
kvcache_block_size |
4096 | Tokens per block |
max_num_batched_tokens |
16384 | Set = max_model_len for long context |
gpu_memory_utilization |
0.9 | GPU memory fraction |
enable_cpu_offload |
False | Enable for long context |
Benchmarking
Files: bench.py (GPU), bench_offload.py (CPU offload), bench_vllm.py (comparison)
Common Issues:
max_num_batched_tokens < max_model_len: Set equal for long context- CUDA graph dimension mismatch: Ensure
input_len + output_len <= max_model_len - RoPE out of bounds: Check model's
max_position_embeddingsin config.json
Model Limits:
- Qwen3-0.6B/4B: 40960 tokens
- Qwen2.5-7B-Instruct-1M: 1048576 tokens
Performance (Qwen3-0.6B):
- GPU: ~18k tok/s (prefill), ~100 tok/s (decode)
- CPU Offload (16K): ~14k tok/s (prefill)
- CPU Offload (32K): ~13k tok/s (prefill)
Performance Summary
Completed Optimizations ✓
-
sgDMA Integration (2025-12-25)
- Eliminated Device→Pageable transfers
- Achieved 21-23 GB/s bandwidth (near PCIe limit)
- 15.35x speedup on memory transfers
-
Triton Fused Merge Kernel (2025-12-25)
- Reduced 7 PyTorch kernels → 2 Triton kernels
- 4.3x speedup on merge operations
- 1.67x overall ChunkedPrefill speedup
-
N-way Pipeline with Dedicated Streams (2025-12-25)
- Per-slot transfer streams for parallel H2D across slots
- Dedicated compute stream (avoids CUDA default stream implicit sync)
- N-way pipeline using all available slots (not just 2-slot double buffering)
- 2.0x improvement: 7.2k → 14.1k tok/s (16K tokens prefill)
Current Performance Bottlenecks
From profiling (test_attention_offload.py, 8 layers, 16K tokens):
| Component | GPU Time | Percentage | Optimization Potential |
|---|---|---|---|
| FlashAttention | 603 ms | 74.8% | ⚠️ Main bottleneck |
| Triton Merge | 161 ms | 19.9% | ✓ Optimized |
| Other | 42 ms | 5.3% | Minor |
Future Optimization Directions
-
FlashAttention Optimization (highest priority)
- Current: 74.8% of GPU time
- Potential: Custom FlashAttention kernel for chunked case
- Expected: 1.5-2x additional speedup
-
Pipeline Optimization✓ COMPLETEDBetter overlap between compute and memory transferMulti-stream execution- See: N-way Pipeline with Dedicated Streams above
-
Alternative to sgDMA (lower priority, PyTorch-only)
- Reorganize cache layout:
[num_cpu_blocks, num_layers, ...]instead of[num_layers, num_cpu_blocks, ...] - Trade-off: Extensive refactoring vs minimal sgDMA approach
- Same performance as sgDMA (~24 GB/s)
- Reorganize cache layout:
Author: Zijie Tian