224 lines
8.9 KiB
Markdown
224 lines
8.9 KiB
Markdown
# Task Plan: Enable CUDA Graphs for CPU Offload Mode
|
|
|
|
## Current Status
|
|
|
|
### Completed: Refactor Offload Decode to Use Standard Attention Path
|
|
|
|
**Problem solved**: The original offload decode (`run_layerwise_offload_decode`) bypassed `Attention.forward()` by manually calling attention components. This was inconsistent with the standard execution path.
|
|
|
|
**Solution implemented**: Refactored to use `layer.forward()` which goes through:
|
|
```
|
|
Qwen3DecoderLayer.forward()
|
|
→ Qwen3Attention.forward()
|
|
→ Attention.forward() ← Now properly used!
|
|
```
|
|
|
|
### Code Changes Made
|
|
|
|
**File**: `nanovllm/engine/model_runner.py`
|
|
|
|
1. **`run_layerwise_offload_decode()` (line 841-991)** - Completely refactored:
|
|
|
|
Before (bypassed Attention):
|
|
```python
|
|
qkv = layer.self_attn.qkv_proj(hidden_ln)
|
|
q, k_new, v_new = qkv.split(...)
|
|
q = layer.self_attn.q_norm(...)
|
|
k = layer.self_attn.k_norm(...)
|
|
q, k = layer.self_attn.rotary_emb(...)
|
|
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...) # Direct call!
|
|
hidden_states = layer.self_attn.o_proj(attn_output)
|
|
```
|
|
|
|
After (uses standard path):
|
|
```python
|
|
# Set up Attention module's cache to ring buffer
|
|
attn_module.k_cache = offload_engine.layer_k_cache[buffer_idx:buffer_idx+1]
|
|
attn_module.v_cache = offload_engine.layer_v_cache[buffer_idx:buffer_idx+1]
|
|
|
|
# Set context for contiguous mode
|
|
set_context(is_prefill=False, slot_mapping=..., context_lens=..., block_tables=None)
|
|
|
|
# Standard layer forward - goes through Attention.forward()!
|
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
|
```
|
|
|
|
2. **`ModelRunner.__init__()` (line 46-57)** - Conditional CUDA graph capture:
|
|
```python
|
|
if not self.enforce_eager:
|
|
if config.enable_cpu_offload:
|
|
# TODO: Implement capture_offload_cudagraph()
|
|
pass # Temporarily use eager execution
|
|
else:
|
|
self.capture_cudagraph()
|
|
```
|
|
|
|
### Test Results
|
|
|
|
| Test | Mode | Status |
|
|
|------|------|--------|
|
|
| `test_needle.py --input-len 4096` | GPU-only | PASSED |
|
|
| `test_needle.py --input-len 4096 --enable-offload` | CPU offload | PASSED |
|
|
|
|
## Remaining Work: Implement Offload CUDA Graph
|
|
|
|
### Why Standard `capture_cudagraph()` Cannot Be Used
|
|
|
|
The standard capture function captures the PagedAttention decode path:
|
|
```python
|
|
# capture_cudagraph() sets up:
|
|
k_cache: [num_blocks, block_size, kv_heads, head_dim] # PagedAttention format
|
|
block_tables: [...] # Block indices for paged indexing
|
|
```
|
|
|
|
But offload mode uses contiguous ring buffer:
|
|
```python
|
|
# Offload decode sets up:
|
|
k_cache: [1, max_seq_len, kv_heads, head_dim] # Contiguous format
|
|
block_tables: None # No paging
|
|
```
|
|
|
|
### Implementation Plan for `capture_offload_cudagraph()`
|
|
|
|
#### Phase 1: Prepare Fixed-Address Tensors
|
|
|
|
```python
|
|
@torch.inference_mode()
|
|
def capture_offload_cudagraph(self):
|
|
"""Capture CUDA graphs for offload decode using ring buffer."""
|
|
offload_engine = self.kvcache_manager.offload_engine
|
|
num_buffers = offload_engine.num_kv_buffers
|
|
|
|
# Fixed-address tensors for graph capture
|
|
input_ids = torch.zeros(1, dtype=torch.int64, device="cuda")
|
|
positions = torch.zeros(1, dtype=torch.int64, device="cuda")
|
|
slot_mapping = torch.zeros(1, dtype=torch.int32, device="cuda")
|
|
context_lens = torch.zeros(1, dtype=torch.int32, device="cuda")
|
|
|
|
self.offload_graphs = {}
|
|
self.offload_graph_pool = None
|
|
```
|
|
|
|
#### Phase 2: Capture Per-Buffer Graphs
|
|
|
|
Since layer processing rotates through ring buffers (`layer_id % num_buffers`), we need graphs for each buffer slot:
|
|
|
|
```python
|
|
for buffer_idx in range(num_buffers):
|
|
graph = torch.cuda.CUDAGraph()
|
|
|
|
# Set Attention cache to this buffer slot (fixed address)
|
|
for layer in self.model.model.layers:
|
|
layer.self_attn.attn.k_cache = offload_engine.layer_k_cache[buffer_idx:buffer_idx+1]
|
|
layer.self_attn.attn.v_cache = offload_engine.layer_v_cache[buffer_idx:buffer_idx+1]
|
|
|
|
# Set context
|
|
set_context(is_prefill=False, slot_mapping=slot_mapping,
|
|
context_lens=context_lens, block_tables=None)
|
|
|
|
# Warmup
|
|
hidden = self.model.model.embed_tokens(input_ids)
|
|
residual = None
|
|
for layer_id, layer in enumerate(self.model.model.layers):
|
|
if layer_id % num_buffers == buffer_idx:
|
|
hidden, residual = layer(positions, hidden, residual)
|
|
|
|
# Capture
|
|
with torch.cuda.graph(graph, self.offload_graph_pool):
|
|
# Same operations
|
|
...
|
|
|
|
self.offload_graphs[buffer_idx] = graph
|
|
```
|
|
|
|
#### Phase 3: Use Graphs in Decode
|
|
|
|
Modify `run_layerwise_offload_decode()` to replay graphs:
|
|
|
|
```python
|
|
for layer_id in range(num_layers):
|
|
current_buffer = layer_id % num_buffers
|
|
|
|
# Wait for H2D load
|
|
offload_engine.wait_buffer_load(current_buffer)
|
|
|
|
# Copy decode buffer to ring buffer (same as current)
|
|
...
|
|
|
|
# Update graph variables
|
|
self.offload_graph_vars["positions"][0] = positions[0]
|
|
self.offload_graph_vars["slot_mapping"][0] = context_len
|
|
self.offload_graph_vars["context_lens"][0] = context_len + 1
|
|
|
|
# Replay graph instead of eager forward
|
|
self.offload_graphs[current_buffer].replay()
|
|
|
|
# Copy new KV to decode buffer (same as current)
|
|
...
|
|
```
|
|
|
|
### Challenges and Considerations
|
|
|
|
| Challenge | Solution |
|
|
|-----------|----------|
|
|
| H2D transfers interleaved with compute | H2D happens outside graph, only compute is captured |
|
|
| Different layers use different buffers | Capture per-buffer graphs, replay correct one |
|
|
| Variable context length | Use `cache_seqlens` parameter (fixed address, variable value) |
|
|
| Per-layer buffer rotation | Graph captures single-layer forward, loop in Python |
|
|
|
|
### Alternative: Full-Decode Graph (More Complex)
|
|
|
|
Instead of per-layer graphs, capture entire decode step:
|
|
1. Complete all H2D loads before graph
|
|
2. Single graph covers all layers
|
|
3. Better kernel fusion, less CPU overhead
|
|
4. More complex to implement (need to handle buffer rotation inside graph)
|
|
|
|
## Implementation Phases
|
|
|
|
| Phase | Description | Status |
|
|
|-------|-------------|--------|
|
|
| Phase 0 | Refactor offload decode to use Attention.forward() | ✅ Completed |
|
|
| Phase 1 | Implement `capture_offload_cudagraph()` with per-buffer graphs | ⬜ Pending |
|
|
| Phase 2 | Modify `run_layerwise_offload_decode()` to use graphs | ⬜ Pending |
|
|
| Phase 3 | Test and benchmark | ⬜ Pending |
|
|
| Phase 4 | (Optional) Optimize to full-decode graph | ⬜ Future |
|
|
|
|
## Architecture After Refactoring
|
|
|
|
```
|
|
┌─────────────────────────────────────────────────────────────────────────────┐
|
|
│ Offload Decode Flow (After Refactoring) │
|
|
├─────────────────────────────────────────────────────────────────────────────┤
|
|
│ │
|
|
│ For each layer: │
|
|
│ 1. Wait for H2D load (ring buffer has prefill KV) │
|
|
│ 2. Copy decode buffer → ring buffer (at prefill_len offset) │
|
|
│ 3. Set Attention.k_cache = ring_buffer[buffer_idx] │
|
|
│ 4. Set context (slot_mapping, context_lens, block_tables=None) │
|
|
│ 5. layer.forward() → Qwen3Attention.forward() → Attention.forward() │
|
|
│ └── store_kvcache() stores new token to ring buffer │
|
|
│ └── flash_attn_with_kvcache() computes attention │
|
|
│ 6. Copy new token KV: ring buffer → decode buffer │
|
|
│ 7. Start next layer H2D load │
|
|
│ │
|
|
│ Key insight: Now uses standard Attention path, just with ring buffer │
|
|
│ as k_cache/v_cache in contiguous format (block_tables=None) │
|
|
│ │
|
|
└─────────────────────────────────────────────────────────────────────────────┘
|
|
```
|
|
|
|
## Files Modified
|
|
|
|
| File | Changes |
|
|
|------|---------|
|
|
| `model_runner.py:46-57` | Conditional CUDA graph capture (skip for offload) |
|
|
| `model_runner.py:841-991` | Refactored `run_layerwise_offload_decode()` to use standard `layer.forward()` |
|
|
|
|
## Next Steps
|
|
|
|
1. Implement `capture_offload_cudagraph()` method
|
|
2. Modify `run_layerwise_offload_decode()` to optionally use captured graphs
|
|
3. Benchmark performance improvement from CUDA graphs
|
|
4. Consider full-decode graph optimization for maximum performance
|