[docs] Replace chunked prefill docs with layer-wise offload strategy

Remove all chunked prefill related documentation (ring buffer, sgDMA,
Triton merge kernels, known issues) and replace with layer-wise offload
system documentation including:
- Design philosophy and benefits
- Memory layout and per-layer KV size table
- Prefill and decode flow pseudocode
- Critical implementation details (sync offload, causal=False for decode)
- Helper methods in HybridKVCacheManager

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-08 05:39:26 +08:00
parent bbbfd1e7da
commit b5c0ef3b7a

377
CLAUDE.md
View File

@@ -110,10 +110,10 @@ Block C: both heads moderately need (+2, +2) → avg = +2 → selected
### Core Components
- **LLMEngine** (`llm_engine.py`): Main entry, runs prefill-decode loop
- **ModelRunner** (`model_runner.py`): Loads weights, allocates KV cache, CUDA graphs
- **ModelRunner** (`model_runner.py`): Loads weights, allocates KV cache, CUDA graphs, layer-wise offload
- **Scheduler** (`scheduler.py`): Two-phase scheduling (prefill → decode)
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
- **Attention** (`layers/attention.py`): FlashAttention for standard inference
## PyTorch Hooks for Debugging
@@ -183,248 +183,142 @@ Key files:
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
## CPU Offload System
## Layer-wise CPU Offload System
### Ring Buffer Design
### Design Philosophy
Unlike chunked prefill (which processes chunks across all layers), **layer-wise offload** processes the entire sequence through one layer at a time:
```
GPU Slots: [0] [1] [2] [3] ... (unified ring buffer)
Prefill: slot = chunk_idx % N
Decode: slot[0] = decode, slots[1:] = load previous chunks
Layer 0: [full sequence] → compute → offload K,V to CPU
Layer 1: [full sequence] → compute → offload K,V to CPU
...
Layer N: [full sequence] → compute → offload K,V to CPU
```
**Key Files**: `kvcache/offload_engine.py`, `kvcache/hybrid_manager.py`
**Memory Layout**:
- GPU: `[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]`
- CPU: `[num_layers, num_cpu_blocks, ...]` (pinned memory)
**Key Methods**:
- `load_to_slot_layer(slot, layer, cpu_block)`: Async H2D load
- `offload_slot_to_cpu(slot, cpu_block)`: Async D2H offload
- Per-slot per-layer CUDA events for fine-grained synchronization
**Pipeline**: N-way pipeline with dedicated streams for full compute-transfer overlap. Pipeline depth = N-1 (prefill), (N-1)/2 (decode).
### Stream Architecture
```
Transfer Streams: [slot_0_stream] [slot_1_stream] ... [slot_N_stream]
↓ ↓ ↓
GPU Slots: [slot_0] [slot_1] ... [slot_N]
↓ ↓ ↓
Compute Stream: ←←←←←←←←←←←← [dedicated compute stream] →→→→→→→→→→→→
```
**Key Design Decisions**:
- **Per-slot transfer streams**: Each GPU slot has its own CUDA stream for H2D transfers, enabling parallel loading
- **Dedicated compute stream**: Created with `torch.cuda.Stream()` (NOT `current_stream()`) to avoid implicit synchronization with default stream
- **CUDA Events**: `ring_slot_ready` (transfer complete), `ring_slot_compute_done` (safe to overwrite)
## Scatter-Gather DMA (sgDMA) - INTEGRATED ✓
### Problem & Solution
**Problem**: Strided CPU cache access `k_cache_cpu[:, block_id]` caused slow Device→Pageable transfers at ~1.4 GB/s instead of optimal ~24 GB/s pinned memory bandwidth.
**Solution**: Implemented `cudaMemcpy2D` via custom CUDA extension to handle strided layouts natively. **Integration complete** as of 2025-12-25.
### Quick Start
```python
from nanovllm.comm import memcpy_2d_async
# Transfer block_id across all layers
spitch = num_blocks * features * dtype_size # stride between layers
dpitch = features * dtype_size # contiguous destination
width = features * dtype_size # bytes per row
height = num_layers # number of rows
memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height, "h2d", stream)
```
### Benchmark Performance (Synthetic, 256MB)
| Method | Bandwidth | Speedup |
|--------|-----------|---------|
| **cudaMemcpy2D (sgDMA)** | **24.95 GB/s** | **Baseline** |
| PyTorch strided | 4.25 GB/s | **5.87x slower** |
| PyTorch contiguous | 24.92 GB/s | Same |
### Real-World Performance (A100, Attention Offload)
**Measured from `test_attention_offload.py` profiling**:
| Transfer Type | Count | Bandwidth | Previous | Speedup |
|---------------|-------|-----------|----------|---------|
| **Device→Pinned (D2H)** | 416 | **21.49 GB/s** | 1.40 GB/s | **15.35x** |
| **Pinned→Device (H2D)** | 24,960 | **23.39 GB/s** | N/A | N/A |
| Device→Pageable (D2H) | **0** | N/A | ~40 transfers | **Eliminated** |
**Verification**: All slow Device→Pageable transfers eliminated. System achieves near-optimal PCIe Gen3 x16 bandwidth.
**Build**: `python setup.py build_ext --inplace`
**Files**:
- `csrc/sgdma_kernel.cu`, `csrc/sgdma.cpp`: CUDA extension
- `nanovllm/comm/sgdma.py`: Python API
- `kvcache/offload_engine.py`: Integration (4 methods updated)
### Integration Details
**Modified methods in `offload_engine.py`**:
- `load_to_slot_all_layers()`: H2D ring buffer load
- `offload_slot_to_cpu()`: D2H ring buffer offload
- `offload_decode_slot()`: D2H decode slot offload
- `load_cpu_blocks_to_gpu_slots_all_layers()`: Batch H2D load
**Example replacement**:
```python
# Before (slow, Device→Pageable fallback)
self.k_cache_gpu[:, slot].copy_(self.k_cache_cpu[:, cpu_block], non_blocking=True)
# After (fast, Device→Pinned via sgDMA)
memcpy_2d_async(
self.k_cache_gpu[:, slot], self.k_cache_cpu[:, cpu_block],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
```
**Actual Impact**: 15.35x faster D2H transfers, eliminates memory transfer bottleneck. Expected 2-3x overall prefill throughput improvement.
## Online Softmax Merge - Triton Fused Kernel ✓
### Problem & Solution
**Problem**: Original PyTorch implementation of `merge_attention_outputs()` launches 7 separate kernels per merge operation:
1. `torch.maximum()` - max(lse1, lse2)
2. `torch.exp()` (2x) - exp(lse1-max), exp(lse2-max)
3. `transpose()` + `unsqueeze()` - reshape for broadcasting
4. Accumulation (6x) - weighted sum operations
5. Division - normalize output
6. `torch.log()` - merge LSE
7. `.to()` - type conversion
**Profiling revealed**: In ChunkedPrefill with 8 layers, these operations consumed **698 ms** GPU time (vs FlashAttention 603 ms), becoming a major bottleneck.
**Solution**: Implemented Triton fused kernels that combine all operations into 2 kernels. **Integration complete** as of 2025-12-25.
### Implementation
**File**: `nanovllm/kvcache/chunked_attention.py:278-408`
Two Triton kernels replace all PyTorch operations:
```python
@triton.jit
def _merge_lse_kernel(...):
"""Fused: max + exp + log"""
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
lse_merged = max_lse + tl.log(exp1 + exp2)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@triton.jit
def _merge_output_kernel(...):
"""Fused: broadcast + weighted sum + division"""
# Load LSE, compute scaling factors
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
sum_exp = exp1 + exp2
# Process headdim in chunks
for d_offset in range(0, headdim, BLOCK_SIZE):
o1_val = tl.load(o1_ptr + o_idx, mask=mask)
o2_val = tl.load(o2_ptr + o_idx, mask=mask)
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
```
### Performance Results
**From `test_attention_offload.py` profiling** (8 layers, 16K tokens, 16 chunks, 10 iterations):
| Metric | PyTorch (7 kernels) | Triton (2 kernels) | Speedup |
|--------|---------------------|---------------------|---------|
| **GPU time (8 layers)** | 698 ms | 160.7 ms | **4.3x** |
| **Per-layer time** | 87.3 ms | 20.1 ms | **4.3x** |
| **Avg per merge** | 56 µs | 12.9 µs | **4.3x** |
| **Kernel launches** | 10,920 | 3,120 | **71% reduction** |
**Breakdown** (per-layer, 1,560 merges):
- `_merge_output_kernel`: 126.9 ms / 8 = 15.9 ms/layer (avg 10.2 µs/call)
- `_merge_lse_kernel`: 33.8 ms / 8 = 4.2 ms/layer (avg 2.7 µs/call)
### Overall ChunkedPrefill Impact
**GPU time distribution** (test_attention_offload.py):
| Component | Time (ms) | Percentage |
|-----------|-----------|------------|
| FlashAttention | 603.2 | 74.8% |
| Triton Merge | 160.7 | 19.9% |
| Other | 42.1 | 5.3% |
| **Total** | **806.0** | **100%** |
**If using PyTorch merge** (estimated):
- Total GPU time: ~1,343 ms
- **Overall speedup with Triton**: 1.67x
**Benefits**:
- Supports MInference sparse attention (requires full KV access per layer)
- Simpler memory management (one layer's KV in GPU at a time)
- Peak GPU memory = one layer's KV cache + attention workspace
### Key Files
- `nanovllm/kvcache/chunked_attention.py`: Triton kernels + merge function
- `nanovllm/engine/model_runner.py`: Main implementation (`run_layerwise_offload_prefill`, `run_layerwise_offload_decode`)
- `nanovllm/kvcache/hybrid_manager.py`: CPU block management helpers
- `nanovllm/kvcache/offload_engine.py`: CPU/GPU cache storage
## Known Issues and Fixes
### Partial Last Block Bug (FIXED ✓)
**Problem**: When prefill token count is not an exact multiple of `block_size`, decode outputs garbage.
**Root Cause**: `_chunked_decode_attention` calculated `last_block_valid_tokens` using `len(seq) - 1`, which increases during decode. But CPU blocks are fixed after prefill!
### Memory Layout
**CPU Cache** (pinned memory):
```python
# BUG: len(seq) increases each decode step
total_prefill_tokens = len(seq) - 1 # Wrong!
last_block_valid_tokens = total_prefill_tokens % block_size # Reads garbage from CPU
k_cache_cpu: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
v_cache_cpu: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
```
**Fix**: Cache original prefill length in `HybridKVCacheManager.get_prefill_len()`:
**Per-layer KV size** (Qwen3-4B: 8 kv_heads × 128 head_dim × 2 bytes × 2 for K+V = 4KB/token):
| Context Length | KV per Layer |
|----------------|--------------|
| 128K tokens | 512 MB |
| 256K tokens | 1 GB |
| 512K tokens | 2 GB |
| 1M tokens | 4 GB |
### Prefill Flow
```python
# CORRECT: Use cached prefill length
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Fixed value
def run_layerwise_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
# 1. Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
# 2. Process each layer
for layer_id in range(num_layers):
# QKV projection + norms + RoPE
q = apply_rotary_pos_emb(q_proj(hidden_states), cos, sin)
k = apply_rotary_pos_emb(k_proj(hidden_states), cos, sin)
v = v_proj(hidden_states)
# Full FlashAttention (entire sequence)
attn_out = flash_attn_varlen_func(q, k, v, cu_seqlens, max_seqlen, causal=True)
# MLP
hidden_states = mlp(attn_out + residual)
# Synchronous offload to CPU (CRITICAL: must be sync to avoid memory reuse bugs)
self._offload_layer_kv_to_cpu_sync(layer_id, k, v, cpu_block_ids, total_tokens)
# 3. Final norm + sampling
return sampled_tokens
```
**Files Modified**:
- `nanovllm/kvcache/hybrid_manager.py`: Added `_prefill_len` dict and `get_prefill_len()` method
- `nanovllm/layers/attention.py`: Use `get_prefill_len()` instead of `len(seq) - 1`
### Decode Flow
### Block Size 4096 Race Condition (FIXED ✓)
**Problem**: `block_size=4096` with multiple chunks produced `index_copy_(): index out of bounds` CUDA error during Chunk 2 processing.
**Root Cause**: Race condition between default stream and compute stream. In `_prepare_chunked_offload_chunk()`, `slot_mapping` tensor was created with `non_blocking=True` H2D transfer on the default stream. However, `store_kvcache` runs on `compute_stream`. Without synchronization, `compute_stream` could use `slot_mapping` before its transfer completed, causing corrupted indices.
**Fix** (in `attention.py`):
```python
if is_chunked_offload:
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
def run_layerwise_offload_decode(self, seqs: list[Sequence]) -> list[int]:
# For each layer:
for layer_id in range(num_layers):
# 1. Load all prefilled KV from CPU
for block_idx, cpu_block_id in enumerate(cpu_block_table):
k_block = offload_engine.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens].to("cuda")
v_block = offload_engine.v_cache_cpu[layer_id, cpu_block_id, :valid_tokens].to("cuda")
# 2. Compute new Q,K,V for current token
q_new = apply_rotary_pos_emb(q_proj(hidden_states), cos, sin)
k_new = apply_rotary_pos_emb(k_proj(hidden_states), cos, sin)
v_new = v_proj(hidden_states)
# 3. Concatenate and compute attention
k_full = torch.cat([k_prefill, k_new], dim=0)
v_full = torch.cat([v_prefill, v_new], dim=0)
attn_out = flash_attn_varlen_func(q_new, k_full, v_full, ..., causal=False)
# Note: causal=False because single query token should attend to ALL keys
```
**Tested block sizes**: 512, 1024, 4096, 8192 - all pass.
### Critical Implementation Details
**1. Synchronous Offload Required**
Async offload with `non_blocking=True` causes memory reuse bugs:
```python
# BUG: PyTorch may reuse k,v GPU memory before async copy completes
offload_engine.k_cache_cpu[layer_id, block_id].copy_(k[start:end], non_blocking=True)
# CORRECT: Synchronous copy ensures data integrity
offload_engine.k_cache_cpu[layer_id, block_id, :size].copy_(k[start:end]) # sync
```
**2. Decode Attention: causal=False**
During decode, the single query token must attend to ALL keys (not just preceding ones):
```python
# Prefill: causal=True (each token only attends to previous tokens)
attn_out = flash_attn_varlen_func(..., causal=True)
# Decode: causal=False (query at position N attends to all N-1 prefill + itself)
attn_out = flash_attn_varlen_func(..., causal=False)
```
### Helper Methods in HybridKVCacheManager
```python
# Get all CPU blocks for a sequence
cpu_blocks = manager.get_all_cpu_blocks(seq) # List[int]
# Get only prefilled (offloaded) CPU blocks
prefilled_blocks = manager.get_prefilled_cpu_blocks(seq) # List[int]
# Get cached prefill length (doesn't change during decode)
prefill_len = manager.get_prefill_len(seq) # int
# Get decode start position
decode_pos = manager.get_decode_start_pos(seq) # int
```
## Configuration
| Parameter | Default | Notes |
|-----------|---------|-------|
| `kvcache_block_size` | 1024 | Tokens per block (4096 now works after race condition fix) |
| `kvcache_block_size` | 4096 | Tokens per block |
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
| `enable_cpu_offload` | False | Enable for long context |
@@ -447,53 +341,6 @@ if is_chunked_offload:
- CPU Offload (16K): ~14k tok/s (prefill)
- CPU Offload (32K): ~13k tok/s (prefill)
## Performance Summary
### Completed Optimizations ✓
1. **sgDMA Integration** (2025-12-25)
- Eliminated Device→Pageable transfers
- Achieved 21-23 GB/s bandwidth (near PCIe limit)
- 15.35x speedup on memory transfers
2. **Triton Fused Merge Kernel** (2025-12-25)
- Reduced 7 PyTorch kernels → 2 Triton kernels
- 4.3x speedup on merge operations
- 1.67x overall ChunkedPrefill speedup
3. **N-way Pipeline with Dedicated Streams** (2025-12-25)
- Per-slot transfer streams for parallel H2D across slots
- Dedicated compute stream (avoids CUDA default stream implicit sync)
- N-way pipeline using all available slots (not just 2-slot double buffering)
- **2.0x improvement**: 7.2k → 14.1k tok/s (16K tokens prefill)
### Current Performance Bottlenecks
**From profiling** (`test_attention_offload.py`, 8 layers, 16K tokens):
| Component | GPU Time | Percentage | Optimization Potential |
|-----------|----------|------------|------------------------|
| FlashAttention | 603 ms | 74.8% | ⚠️ Main bottleneck |
| Triton Merge | 161 ms | 19.9% | ✓ Optimized |
| Other | 42 ms | 5.3% | Minor |
### Future Optimization Directions
1. **FlashAttention Optimization** (highest priority)
- Current: 74.8% of GPU time
- Potential: Custom FlashAttention kernel for chunked case
- Expected: 1.5-2x additional speedup
2. ~~**Pipeline Optimization**~~ ✓ COMPLETED
- ~~Better overlap between compute and memory transfer~~
- ~~Multi-stream execution~~
- See: N-way Pipeline with Dedicated Streams above
3. **Alternative to sgDMA** (lower priority, PyTorch-only)
- Reorganize cache layout: `[num_cpu_blocks, num_layers, ...]` instead of `[num_layers, num_cpu_blocks, ...]`
- Trade-off: Extensive refactoring vs minimal sgDMA approach
- Same performance as sgDMA (~24 GB/s)
---
**Author**: Zijie Tian