diff --git a/CLAUDE.md b/CLAUDE.md index a2d5ea8..b3514a6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -110,10 +110,10 @@ Block C: both heads moderately need (+2, +2) → avg = +2 → selected ### Core Components - **LLMEngine** (`llm_engine.py`): Main entry, runs prefill-decode loop -- **ModelRunner** (`model_runner.py`): Loads weights, allocates KV cache, CUDA graphs +- **ModelRunner** (`model_runner.py`): Loads weights, allocates KV cache, CUDA graphs, layer-wise offload - **Scheduler** (`scheduler.py`): Two-phase scheduling (prefill → decode) - **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096 -- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload +- **Attention** (`layers/attention.py`): FlashAttention for standard inference ## PyTorch Hooks for Debugging @@ -183,248 +183,142 @@ Key files: 2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj 3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]` -## CPU Offload System +## Layer-wise CPU Offload System -### Ring Buffer Design +### Design Philosophy + +Unlike chunked prefill (which processes chunks across all layers), **layer-wise offload** processes the entire sequence through one layer at a time: ``` -GPU Slots: [0] [1] [2] [3] ... (unified ring buffer) -Prefill: slot = chunk_idx % N -Decode: slot[0] = decode, slots[1:] = load previous chunks +Layer 0: [full sequence] → compute → offload K,V to CPU +Layer 1: [full sequence] → compute → offload K,V to CPU +... +Layer N: [full sequence] → compute → offload K,V to CPU ``` -**Key Files**: `kvcache/offload_engine.py`, `kvcache/hybrid_manager.py` - -**Memory Layout**: -- GPU: `[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]` -- CPU: `[num_layers, num_cpu_blocks, ...]` (pinned memory) - -**Key Methods**: -- `load_to_slot_layer(slot, layer, cpu_block)`: Async H2D load -- `offload_slot_to_cpu(slot, cpu_block)`: Async D2H offload -- Per-slot per-layer CUDA events for fine-grained synchronization - -**Pipeline**: N-way pipeline with dedicated streams for full compute-transfer overlap. Pipeline depth = N-1 (prefill), (N-1)/2 (decode). - -### Stream Architecture - -``` -Transfer Streams: [slot_0_stream] [slot_1_stream] ... [slot_N_stream] - ↓ ↓ ↓ -GPU Slots: [slot_0] [slot_1] ... [slot_N] - ↓ ↓ ↓ -Compute Stream: ←←←←←←←←←←←← [dedicated compute stream] →→→→→→→→→→→→ -``` - -**Key Design Decisions**: -- **Per-slot transfer streams**: Each GPU slot has its own CUDA stream for H2D transfers, enabling parallel loading -- **Dedicated compute stream**: Created with `torch.cuda.Stream()` (NOT `current_stream()`) to avoid implicit synchronization with default stream -- **CUDA Events**: `ring_slot_ready` (transfer complete), `ring_slot_compute_done` (safe to overwrite) - -## Scatter-Gather DMA (sgDMA) - INTEGRATED ✓ - -### Problem & Solution - -**Problem**: Strided CPU cache access `k_cache_cpu[:, block_id]` caused slow Device→Pageable transfers at ~1.4 GB/s instead of optimal ~24 GB/s pinned memory bandwidth. - -**Solution**: Implemented `cudaMemcpy2D` via custom CUDA extension to handle strided layouts natively. **Integration complete** as of 2025-12-25. - -### Quick Start - -```python -from nanovllm.comm import memcpy_2d_async - -# Transfer block_id across all layers -spitch = num_blocks * features * dtype_size # stride between layers -dpitch = features * dtype_size # contiguous destination -width = features * dtype_size # bytes per row -height = num_layers # number of rows - -memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height, "h2d", stream) -``` - -### Benchmark Performance (Synthetic, 256MB) - -| Method | Bandwidth | Speedup | -|--------|-----------|---------| -| **cudaMemcpy2D (sgDMA)** | **24.95 GB/s** | **Baseline** | -| PyTorch strided | 4.25 GB/s | **5.87x slower** | -| PyTorch contiguous | 24.92 GB/s | Same | - -### Real-World Performance (A100, Attention Offload) - -**Measured from `test_attention_offload.py` profiling**: - -| Transfer Type | Count | Bandwidth | Previous | Speedup | -|---------------|-------|-----------|----------|---------| -| **Device→Pinned (D2H)** | 416 | **21.49 GB/s** | 1.40 GB/s | **15.35x** | -| **Pinned→Device (H2D)** | 24,960 | **23.39 GB/s** | N/A | N/A | -| Device→Pageable (D2H) | **0** | N/A | ~40 transfers | **Eliminated** | - -**Verification**: All slow Device→Pageable transfers eliminated. System achieves near-optimal PCIe Gen3 x16 bandwidth. - -**Build**: `python setup.py build_ext --inplace` - -**Files**: -- `csrc/sgdma_kernel.cu`, `csrc/sgdma.cpp`: CUDA extension -- `nanovllm/comm/sgdma.py`: Python API -- `kvcache/offload_engine.py`: Integration (4 methods updated) - -### Integration Details - -**Modified methods in `offload_engine.py`**: -- `load_to_slot_all_layers()`: H2D ring buffer load -- `offload_slot_to_cpu()`: D2H ring buffer offload -- `offload_decode_slot()`: D2H decode slot offload -- `load_cpu_blocks_to_gpu_slots_all_layers()`: Batch H2D load - -**Example replacement**: -```python -# Before (slow, Device→Pageable fallback) -self.k_cache_gpu[:, slot].copy_(self.k_cache_cpu[:, cpu_block], non_blocking=True) - -# After (fast, Device→Pinned via sgDMA) -memcpy_2d_async( - self.k_cache_gpu[:, slot], self.k_cache_cpu[:, cpu_block], - self.gpu_pitch, self.cpu_pitch, self.width, self.height, - "h2d", stream=self.transfer_stream_main -) -``` - -**Actual Impact**: 15.35x faster D2H transfers, eliminates memory transfer bottleneck. Expected 2-3x overall prefill throughput improvement. - -## Online Softmax Merge - Triton Fused Kernel ✓ - -### Problem & Solution - -**Problem**: Original PyTorch implementation of `merge_attention_outputs()` launches 7 separate kernels per merge operation: -1. `torch.maximum()` - max(lse1, lse2) -2. `torch.exp()` (2x) - exp(lse1-max), exp(lse2-max) -3. `transpose()` + `unsqueeze()` - reshape for broadcasting -4. Accumulation (6x) - weighted sum operations -5. Division - normalize output -6. `torch.log()` - merge LSE -7. `.to()` - type conversion - -**Profiling revealed**: In ChunkedPrefill with 8 layers, these operations consumed **698 ms** GPU time (vs FlashAttention 603 ms), becoming a major bottleneck. - -**Solution**: Implemented Triton fused kernels that combine all operations into 2 kernels. **Integration complete** as of 2025-12-25. - -### Implementation - -**File**: `nanovllm/kvcache/chunked_attention.py:278-408` - -Two Triton kernels replace all PyTorch operations: - -```python -@triton.jit -def _merge_lse_kernel(...): - """Fused: max + exp + log""" - max_lse = tl.maximum(lse1, lse2) - exp1 = tl.exp(lse1 - max_lse) - exp2 = tl.exp(lse2 - max_lse) - lse_merged = max_lse + tl.log(exp1 + exp2) - tl.store(lse_out_ptr + offsets, lse_merged, mask=mask) - -@triton.jit -def _merge_output_kernel(...): - """Fused: broadcast + weighted sum + division""" - # Load LSE, compute scaling factors - exp1 = tl.exp(lse1 - max_lse) - exp2 = tl.exp(lse2 - max_lse) - sum_exp = exp1 + exp2 - - # Process headdim in chunks - for d_offset in range(0, headdim, BLOCK_SIZE): - o1_val = tl.load(o1_ptr + o_idx, mask=mask) - o2_val = tl.load(o2_ptr + o_idx, mask=mask) - o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp - tl.store(o_out_ptr + o_idx, o_merged, mask=mask) -``` - -### Performance Results - -**From `test_attention_offload.py` profiling** (8 layers, 16K tokens, 16 chunks, 10 iterations): - -| Metric | PyTorch (7 kernels) | Triton (2 kernels) | Speedup | -|--------|---------------------|---------------------|---------| -| **GPU time (8 layers)** | 698 ms | 160.7 ms | **4.3x** | -| **Per-layer time** | 87.3 ms | 20.1 ms | **4.3x** | -| **Avg per merge** | 56 µs | 12.9 µs | **4.3x** | -| **Kernel launches** | 10,920 | 3,120 | **71% reduction** | - -**Breakdown** (per-layer, 1,560 merges): -- `_merge_output_kernel`: 126.9 ms / 8 = 15.9 ms/layer (avg 10.2 µs/call) -- `_merge_lse_kernel`: 33.8 ms / 8 = 4.2 ms/layer (avg 2.7 µs/call) - -### Overall ChunkedPrefill Impact - -**GPU time distribution** (test_attention_offload.py): - -| Component | Time (ms) | Percentage | -|-----------|-----------|------------| -| FlashAttention | 603.2 | 74.8% | -| Triton Merge | 160.7 | 19.9% | -| Other | 42.1 | 5.3% | -| **Total** | **806.0** | **100%** | - -**If using PyTorch merge** (estimated): -- Total GPU time: ~1,343 ms -- **Overall speedup with Triton**: 1.67x +**Benefits**: +- Supports MInference sparse attention (requires full KV access per layer) +- Simpler memory management (one layer's KV in GPU at a time) +- Peak GPU memory = one layer's KV cache + attention workspace ### Key Files -- `nanovllm/kvcache/chunked_attention.py`: Triton kernels + merge function +- `nanovllm/engine/model_runner.py`: Main implementation (`run_layerwise_offload_prefill`, `run_layerwise_offload_decode`) +- `nanovllm/kvcache/hybrid_manager.py`: CPU block management helpers +- `nanovllm/kvcache/offload_engine.py`: CPU/GPU cache storage -## Known Issues and Fixes - -### Partial Last Block Bug (FIXED ✓) - -**Problem**: When prefill token count is not an exact multiple of `block_size`, decode outputs garbage. - -**Root Cause**: `_chunked_decode_attention` calculated `last_block_valid_tokens` using `len(seq) - 1`, which increases during decode. But CPU blocks are fixed after prefill! +### Memory Layout +**CPU Cache** (pinned memory): ```python -# BUG: len(seq) increases each decode step -total_prefill_tokens = len(seq) - 1 # Wrong! -last_block_valid_tokens = total_prefill_tokens % block_size # Reads garbage from CPU +k_cache_cpu: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] +v_cache_cpu: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] ``` -**Fix**: Cache original prefill length in `HybridKVCacheManager.get_prefill_len()`: +**Per-layer KV size** (Qwen3-4B: 8 kv_heads × 128 head_dim × 2 bytes × 2 for K+V = 4KB/token): + +| Context Length | KV per Layer | +|----------------|--------------| +| 128K tokens | 512 MB | +| 256K tokens | 1 GB | +| 512K tokens | 2 GB | +| 1M tokens | 4 GB | + +### Prefill Flow ```python -# CORRECT: Use cached prefill length -total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Fixed value +def run_layerwise_offload_prefill(self, seqs: list[Sequence]) -> list[int]: + # 1. Embedding + hidden_states = self.model.model.embed_tokens(input_ids) + + # 2. Process each layer + for layer_id in range(num_layers): + # QKV projection + norms + RoPE + q = apply_rotary_pos_emb(q_proj(hidden_states), cos, sin) + k = apply_rotary_pos_emb(k_proj(hidden_states), cos, sin) + v = v_proj(hidden_states) + + # Full FlashAttention (entire sequence) + attn_out = flash_attn_varlen_func(q, k, v, cu_seqlens, max_seqlen, causal=True) + + # MLP + hidden_states = mlp(attn_out + residual) + + # Synchronous offload to CPU (CRITICAL: must be sync to avoid memory reuse bugs) + self._offload_layer_kv_to_cpu_sync(layer_id, k, v, cpu_block_ids, total_tokens) + + # 3. Final norm + sampling + return sampled_tokens ``` -**Files Modified**: -- `nanovllm/kvcache/hybrid_manager.py`: Added `_prefill_len` dict and `get_prefill_len()` method -- `nanovllm/layers/attention.py`: Use `get_prefill_len()` instead of `len(seq) - 1` +### Decode Flow -### Block Size 4096 Race Condition (FIXED ✓) - -**Problem**: `block_size=4096` with multiple chunks produced `index_copy_(): index out of bounds` CUDA error during Chunk 2 processing. - -**Root Cause**: Race condition between default stream and compute stream. In `_prepare_chunked_offload_chunk()`, `slot_mapping` tensor was created with `non_blocking=True` H2D transfer on the default stream. However, `store_kvcache` runs on `compute_stream`. Without synchronization, `compute_stream` could use `slot_mapping` before its transfer completed, causing corrupted indices. - -**Fix** (in `attention.py`): ```python -if is_chunked_offload: - compute_stream = context.kvcache_manager.offload_engine.compute_stream - if k_cache.numel() and v_cache.numel(): - # CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete - compute_stream.wait_stream(torch.cuda.default_stream()) - with torch.cuda.stream(compute_stream): - store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) +def run_layerwise_offload_decode(self, seqs: list[Sequence]) -> list[int]: + # For each layer: + for layer_id in range(num_layers): + # 1. Load all prefilled KV from CPU + for block_idx, cpu_block_id in enumerate(cpu_block_table): + k_block = offload_engine.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens].to("cuda") + v_block = offload_engine.v_cache_cpu[layer_id, cpu_block_id, :valid_tokens].to("cuda") + + # 2. Compute new Q,K,V for current token + q_new = apply_rotary_pos_emb(q_proj(hidden_states), cos, sin) + k_new = apply_rotary_pos_emb(k_proj(hidden_states), cos, sin) + v_new = v_proj(hidden_states) + + # 3. Concatenate and compute attention + k_full = torch.cat([k_prefill, k_new], dim=0) + v_full = torch.cat([v_prefill, v_new], dim=0) + attn_out = flash_attn_varlen_func(q_new, k_full, v_full, ..., causal=False) + # Note: causal=False because single query token should attend to ALL keys ``` -**Tested block sizes**: 512, 1024, 4096, 8192 - all pass. +### Critical Implementation Details + +**1. Synchronous Offload Required** + +Async offload with `non_blocking=True` causes memory reuse bugs: +```python +# BUG: PyTorch may reuse k,v GPU memory before async copy completes +offload_engine.k_cache_cpu[layer_id, block_id].copy_(k[start:end], non_blocking=True) + +# CORRECT: Synchronous copy ensures data integrity +offload_engine.k_cache_cpu[layer_id, block_id, :size].copy_(k[start:end]) # sync +``` + +**2. Decode Attention: causal=False** + +During decode, the single query token must attend to ALL keys (not just preceding ones): +```python +# Prefill: causal=True (each token only attends to previous tokens) +attn_out = flash_attn_varlen_func(..., causal=True) + +# Decode: causal=False (query at position N attends to all N-1 prefill + itself) +attn_out = flash_attn_varlen_func(..., causal=False) +``` + +### Helper Methods in HybridKVCacheManager + +```python +# Get all CPU blocks for a sequence +cpu_blocks = manager.get_all_cpu_blocks(seq) # List[int] + +# Get only prefilled (offloaded) CPU blocks +prefilled_blocks = manager.get_prefilled_cpu_blocks(seq) # List[int] + +# Get cached prefill length (doesn't change during decode) +prefill_len = manager.get_prefill_len(seq) # int + +# Get decode start position +decode_pos = manager.get_decode_start_pos(seq) # int +``` ## Configuration | Parameter | Default | Notes | |-----------|---------|-------| -| `kvcache_block_size` | 1024 | Tokens per block (4096 now works after race condition fix) | +| `kvcache_block_size` | 4096 | Tokens per block | | `max_num_batched_tokens` | 16384 | Set = max_model_len for long context | | `gpu_memory_utilization` | 0.9 | GPU memory fraction | | `enable_cpu_offload` | False | Enable for long context | @@ -447,53 +341,6 @@ if is_chunked_offload: - CPU Offload (16K): ~14k tok/s (prefill) - CPU Offload (32K): ~13k tok/s (prefill) -## Performance Summary - -### Completed Optimizations ✓ - -1. **sgDMA Integration** (2025-12-25) - - Eliminated Device→Pageable transfers - - Achieved 21-23 GB/s bandwidth (near PCIe limit) - - 15.35x speedup on memory transfers - -2. **Triton Fused Merge Kernel** (2025-12-25) - - Reduced 7 PyTorch kernels → 2 Triton kernels - - 4.3x speedup on merge operations - - 1.67x overall ChunkedPrefill speedup - -3. **N-way Pipeline with Dedicated Streams** (2025-12-25) - - Per-slot transfer streams for parallel H2D across slots - - Dedicated compute stream (avoids CUDA default stream implicit sync) - - N-way pipeline using all available slots (not just 2-slot double buffering) - - **2.0x improvement**: 7.2k → 14.1k tok/s (16K tokens prefill) - -### Current Performance Bottlenecks - -**From profiling** (`test_attention_offload.py`, 8 layers, 16K tokens): - -| Component | GPU Time | Percentage | Optimization Potential | -|-----------|----------|------------|------------------------| -| FlashAttention | 603 ms | 74.8% | ⚠️ Main bottleneck | -| Triton Merge | 161 ms | 19.9% | ✓ Optimized | -| Other | 42 ms | 5.3% | Minor | - -### Future Optimization Directions - -1. **FlashAttention Optimization** (highest priority) - - Current: 74.8% of GPU time - - Potential: Custom FlashAttention kernel for chunked case - - Expected: 1.5-2x additional speedup - -2. ~~**Pipeline Optimization**~~ ✓ COMPLETED - - ~~Better overlap between compute and memory transfer~~ - - ~~Multi-stream execution~~ - - See: N-way Pipeline with Dedicated Streams above - -3. **Alternative to sgDMA** (lower priority, PyTorch-only) - - Reorganize cache layout: `[num_cpu_blocks, num_layers, ...]` instead of `[num_layers, num_cpu_blocks, ...]` - - Trade-off: Extensive refactoring vs minimal sgDMA approach - - Same performance as sgDMA (~24 GB/s) - --- **Author**: Zijie Tian