[claudesquad] update from 'lw-offload-2' on 08 Jan 26 20:53 CST

This commit is contained in:
Zijie Tian
2026-01-08 20:53:08 +08:00
parent 85bcca3d17
commit a8c9f0d837
9 changed files with 894 additions and 1704 deletions

View File

@@ -0,0 +1,409 @@
# Layer-wise Offload Memory Analysis
This document provides a detailed analysis of memory allocations in the layer-wise CPU offload system, distinguishing between pre-allocated (managed) memory and temporary (non-pre-allocated) memory.
## Variable Notation
| Symbol | Description | Example (Qwen3-4B) |
|--------|-------------|-------------------|
| `seq_len` | Input sequence length | 131072 (128k) |
| `hidden_size` | Model hidden dimension | 2560 |
| `num_heads` | Number of attention heads | 20 |
| `num_kv_heads` | Number of KV heads (GQA) | 8 |
| `head_dim` | Dimension per head | 128 |
| `intermediate_size` | MLP intermediate dimension | 13696 |
| `num_layers` | Number of transformer layers | 36 |
| `block_size` | KV cache block size | 1024 |
| `num_kv_buffers` | Ring buffer count | 4 |
| `num_cpu_blocks` | Number of CPU cache blocks | 128 |
| `vocab_size` | Vocabulary size | 151936 |
| `dtype_size` | Bytes per element (fp16/bf16) | 2 |
Derived values:
- `kv_dim = num_kv_heads × head_dim`
- `q_size = num_heads × head_dim`
- `kv_size = num_kv_heads × head_dim`
- `qkv_size = q_size + 2 × kv_size`
---
## 1. Pre-allocated Memory (Managed by nanovllm)
These tensors are allocated once during initialization and reused throughout inference.
### 1.1 OffloadEngine Managed Memory
| Tensor | Shape | Size Formula | Location |
|--------|-------|--------------|----------|
| `layer_k_cache` | `[num_kv_buffers, seq_len, num_kv_heads, head_dim]` | `num_kv_buffers × seq_len × kv_dim × dtype_size` | GPU |
| `layer_v_cache` | `[num_kv_buffers, seq_len, num_kv_heads, head_dim]` | `num_kv_buffers × seq_len × kv_dim × dtype_size` | GPU |
| `decode_k_buffer` | `[num_layers, block_size, num_kv_heads, head_dim]` | `num_layers × block_size × kv_dim × dtype_size` | GPU |
| `decode_v_buffer` | `[num_layers, block_size, num_kv_heads, head_dim]` | `num_layers × block_size × kv_dim × dtype_size` | GPU |
| `k_cache_cpu` | `[num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]` | `num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size` | CPU (pinned) |
| `v_cache_cpu` | `[num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]` | `num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size` | CPU (pinned) |
**Total GPU (OffloadEngine)**: `2 × (num_kv_buffers × seq_len + num_layers × block_size) × kv_dim × dtype_size`
**Total CPU (OffloadEngine)**: `2 × num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size`
### 1.2 Model Weights
| Component | Approximate Size |
|-----------|-----------------|
| Embedding | `vocab_size × hidden_size × dtype_size` |
| Per-layer QKV proj | `hidden_size × qkv_size × dtype_size` |
| Per-layer O proj | `q_size × hidden_size × dtype_size` |
| Per-layer MLP | `hidden_size × 2 × intermediate_size × dtype_size + intermediate_size × hidden_size × dtype_size` |
| Per-layer LayerNorm | `2 × hidden_size × dtype_size` |
| LM Head | `hidden_size × vocab_size × dtype_size` |
### 1.3 RoPE Cache
| Tensor | Shape | Size |
|--------|-------|------|
| `cos_sin_cache` | `[max_position, 1, head_dim]` | `max_position × head_dim × 4` (float32) |
---
## 2. Non-Pre-allocated Memory: Prefill Phase
Location: `model_runner.py:run_layerwise_offload_prefill()`
### 2.1 Persistent Tensors (Live Throughout Prefill)
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `input_ids` | 488 | `[seq_len]` | `seq_len × 8` | int64 |
| `positions` | 489 | `[seq_len]` | `seq_len × 8` | int64 |
| `cu_seqlens` | 493 | `[2]` | negligible | int32 |
| `hidden_states` | 497 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | Embedding output |
| `residual` | 506 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | Residual connection |
### 2.2 Per-Layer Temporary Tensors
These are allocated and deallocated within each layer iteration.
#### 2.2.1 LayerNorm
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `hidden_ln` | 506-508 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | Input layernorm output |
**Inside RMSNorm** (`layernorm.py:add_rms_forward`):
| Variable | Shape | Size | Notes |
|----------|-------|------|-------|
| `x.float()` | `[seq_len, hidden_size]` | `seq_len × hidden_size × 4` | Upcasted to float32 |
| `var` | `[seq_len, 1]` | `seq_len × 4` | Variance |
#### 2.2.2 QKV Projection
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `qkv` | 512 | `[seq_len, q_size + 2 × kv_size]` | `seq_len × qkv_size × dtype_size` | Merged QKV output |
| `q` | 513-519 | `[seq_len, num_heads, head_dim]` | 0 (view) | View of qkv |
| `k` | 513-520 | `[seq_len, num_kv_heads, head_dim]` | 0 (view) | View of qkv |
| `v` | 513-521 | `[seq_len, num_kv_heads, head_dim]` | 0 (view) | View of qkv |
#### 2.2.3 Q/K Norms (Qwen3 specific)
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `q.reshape()` | 526 | `[seq_len × num_heads, head_dim]` | 0 (view) | Reshape for norm |
| `k.reshape()` | 528 | `[seq_len × num_kv_heads, head_dim]` | 0 (view) | Reshape for norm |
| RMSNorm intermediates | - | see above | `seq_len × num_heads × head_dim × 4` | Float32 upcasting |
#### 2.2.4 RoPE (Rotary Position Embedding)
Location: `rotary_embedding.py:apply_rotary_emb()`
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `cos_sin` | 44 | `[seq_len, 1, head_dim]` | 0 (view) | View of cached cos_sin |
| `cos` | 45 | `[seq_len, 1, head_dim/2]` | 0 (view) | Chunk view |
| `sin` | 45 | `[seq_len, 1, head_dim/2]` | 0 (view) | Chunk view |
**Inside `apply_rotary_emb` for Q** (`rotary_embedding.py:6-14`):
| Variable | Shape | Size | Notes |
|----------|-------|------|-------|
| `x.float()` | `[seq_len, num_heads, head_dim]` | `seq_len × num_heads × head_dim × 4` | Upcast to float32 |
| `x1` | `[seq_len, num_heads, head_dim/2]` | 0 (view) | Chunk view |
| `x2` | `[seq_len, num_heads, head_dim/2]` | 0 (view) | Chunk view |
| `y1 = x1*cos - x2*sin` | `[seq_len, num_heads, head_dim/2]` | `seq_len × num_heads × head_dim/2 × 4` | New tensor |
| `y2 = x2*cos + x1*sin` | `[seq_len, num_heads, head_dim/2]` | `seq_len × num_heads × head_dim/2 × 4` | New tensor |
| `torch.cat((y1, y2))` | `[seq_len, num_heads, head_dim]` | `seq_len × num_heads × head_dim × 4` | New tensor |
| `.to(x.dtype)` | `[seq_len, num_heads, head_dim]` | `seq_len × num_heads × head_dim × dtype_size` | Downcast |
**Inside `apply_rotary_emb` for K**:
| Variable | Shape | Size | Notes |
|----------|-------|------|-------|
| Same pattern as Q | `[seq_len, num_kv_heads, head_dim]` | Similar, with `num_kv_heads` | |
**Total RoPE temporary for Q+K**: ~`seq_len × (num_heads + num_kv_heads) × head_dim × 4 × 3` (float32 intermediates)
#### 2.2.5 FlashAttention
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `attn_output` | 535 | `[seq_len, num_heads, head_dim]` | `seq_len × num_heads × head_dim × dtype_size` | Attention output |
| Internal workspace | - | O(seq_len) | Variable | FlashAttention internal |
#### 2.2.6 Output Projection
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `attn_output.view()` | 546 | `[seq_len, q_size]` | 0 (view) | Reshape for o_proj |
| `o_proj(attn_output)` | 547 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | O projection output |
#### 2.2.7 Post-Attention LayerNorm
Same as input layernorm (2.2.1).
#### 2.2.8 MLP
Location: `qwen3.py:Qwen3MLP.forward()`
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `gate_up` | 117 | `[seq_len, 2 × intermediate_size]` | `seq_len × 2 × intermediate_size × dtype_size` | **LARGEST TEMPORARY!** |
| `x, y = chunk()` | activation.py:13 | `[seq_len, intermediate_size]` × 2 | 0 (views) | Chunk views |
| `F.silu(x)` | activation.py:14 | `[seq_len, intermediate_size]` | `seq_len × intermediate_size × dtype_size` | SiLU activation |
| `silu(x) * y` | activation.py:14 | `[seq_len, intermediate_size]` | `seq_len × intermediate_size × dtype_size` | Gated output |
| `down_proj()` | 119 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | MLP output |
### 2.3 Prefill Memory Summary
**Peak per-layer temporary memory**:
```
= qkv + RoPE_temps + attn_output + o_proj + layernorm + MLP_gate_up + MLP_activation
≈ seq_len × (qkv_size + (num_heads + num_kv_heads) × head_dim × 4 × 3
+ num_heads × head_dim + hidden_size × 2 + 2 × intermediate_size + intermediate_size) × dtype_size
```
**Dominant term**: `seq_len × 2 × intermediate_size × dtype_size` (MLP gate_up)
---
## 3. Non-Pre-allocated Memory: Decode Phase
Location: `model_runner.py:run_layerwise_offload_decode()`
### 3.1 Persistent Tensors
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `input_ids` | 604 | `[1]` | 8 bytes | Single token |
| `positions` | 605 | `[1]` | 8 bytes | Single position |
| `cu_seqlens_q` | 631 | `[2]` | 8 bytes | Fixed |
| `valid_tokens_per_block` | 613-622 | Python list | negligible | |
### 3.2 Per-Layer Temporary Tensors
#### 3.2.1 Views (Zero Additional Memory)
| Variable | Line | Shape | Notes |
|----------|------|-------|-------|
| `k_prefill` | 682 | `[prefill_len, num_kv_heads, head_dim]` | View of ring buffer |
| `v_prefill` | 682 | `[prefill_len, num_kv_heads, head_dim]` | View of ring buffer |
| `k_decode_prev` | 686-687 | `[num_decode_tokens-1, num_kv_heads, head_dim]` | View of decode buffer |
| `v_decode_prev` | 686-688 | `[num_decode_tokens-1, num_kv_heads, head_dim]` | View of decode buffer |
#### 3.2.2 New Allocations
| Variable | Line | Shape | Size | Notes |
|----------|------|-------|------|-------|
| `hidden_ln` | 654-657 | `[1, hidden_size]` | `hidden_size × dtype_size` | Tiny |
| `qkv` | 660 | `[1, qkv_size]` | `qkv_size × dtype_size` | Tiny |
| `q` | 667 | `[1, num_heads, head_dim]` | 0 (view) | |
| `k_new` | 668 | `[1, num_kv_heads, head_dim]` | 0 (view) | |
| `v_new` | 669 | `[1, num_kv_heads, head_dim]` | 0 (view) | |
| **`k_full`** | 689/692 | `[prefill_len + num_decode_tokens, num_kv_heads, head_dim]` | `(prefill_len + num_decode_tokens) × kv_dim × dtype_size` | **torch.cat - NEW ALLOCATION** |
| **`v_full`** | 690/693 | `[prefill_len + num_decode_tokens, num_kv_heads, head_dim]` | `(prefill_len + num_decode_tokens) × kv_dim × dtype_size` | **torch.cat - NEW ALLOCATION** |
| `cu_seqlens_k` | 710 | `[2]` | 8 bytes | Created per layer |
| `attn_output` | 712 | `[1, num_heads, head_dim]` | `num_heads × head_dim × dtype_size` | Tiny |
| MLP temps | 728 | `[1, ...]` | negligible | Single token |
### 3.3 Decode Memory Summary
**Peak per-layer temporary memory**:
```
= k_full + v_full + small_tensors
≈ 2 × (prefill_len + num_decode_tokens) × num_kv_heads × head_dim × dtype_size
≈ 2 × seq_len × kv_dim × dtype_size
```
**Dominant term**: `k_full` and `v_full` from `torch.cat()`
---
## 4. Memory Comparison Table
For Qwen3-4B with 128k context:
| Category | Memory | Notes |
|----------|--------|-------|
| **Pre-allocated GPU** | ~2.2 GB | Ring buffer + decode buffer |
| **Pre-allocated CPU** | ~18.4 GB | Pinned memory |
| **Model Weights** | ~8 GB | |
| **Prefill Peak Temp** | ~10-12 GB | MLP gate_up dominant |
| **Decode Peak Temp** | ~512 MB | k_full + v_full |
---
## 5. Optimization Opportunities
### 5.1 Decode: Pre-allocate k_full/v_full
**Current** (L689-693):
```python
k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0) # New allocation each layer
v_full = torch.cat([v_prefill, v_decode_prev, v_new], dim=0) # New allocation each layer
```
**Optimized**:
```python
# Pre-allocate in OffloadEngine.__init__():
self.k_full_buffer = torch.zeros(max_seq_len + block_size, num_kv_heads, head_dim, ...)
self.v_full_buffer = torch.zeros(max_seq_len + block_size, num_kv_heads, head_dim, ...)
# In decode loop:
total_len = prefill_len + num_decode_tokens
k_full = self.k_full_buffer[:total_len]
k_full[:prefill_len].copy_(k_prefill)
k_full[prefill_len:prefill_len+num_decode_prev].copy_(k_decode_prev)
k_full[-1:].copy_(k_new)
```
**Savings**: ~512 MB per decode step (for 128k)
### 5.2 Decode: Reuse cu_seqlens_k
**Current** (L710):
```python
cu_seqlens_k = torch.tensor([0, total_kv_tokens], dtype=torch.int32, device="cuda")
```
**Optimized**:
```python
# Pre-allocate once:
self.cu_seqlens_k = torch.zeros(2, dtype=torch.int32, device="cuda")
# In decode loop:
self.cu_seqlens_k[1] = total_kv_tokens
```
**Savings**: Negligible memory, but reduces allocation overhead.
### 5.3 RoPE: In-place or Pre-allocated Buffers
The RoPE implementation creates multiple float32 intermediate tensors. Options:
1. Pre-allocate buffers for Q and K rotary outputs
2. Use in-place operations where possible
3. Use fused RoPE kernel (e.g., from FlashAttention)
**Potential savings**: ~1.5 GB during prefill per layer
### 5.4 MLP: Cannot Optimize Easily
The MLP `gate_up` tensor is inherently required for the gated activation:
```python
gate_up = gate_up_proj(x) # [seq_len, 2 × intermediate_size]
x, y = gate_up.chunk(2, -1)
output = silu(x) * y
```
This is a fundamental computation pattern. Potential optimizations:
- Chunked MLP computation (process seq_len in chunks)
- Fused kernels that avoid materializing full gate_up
---
## 6. Memory Flow Diagram
### Prefill (per layer):
```
hidden_states ──┬──► LayerNorm ──► hidden_ln
residual ◄──────┘
hidden_ln ──► QKV_proj ──► qkv ──┬──► q ──► Q_norm ──► RoPE ──► q_rotated
├──► k ──► K_norm ──► RoPE ──► k_rotated
└──► v
q_rotated, k_rotated, v ──► FlashAttention ──► attn_output
attn_output ──► O_proj ──► hidden_states'
hidden_states', residual ──► LayerNorm ──► hidden_ln', residual'
hidden_ln' ──► MLP_gate_up ──► gate_up ──► SiLU×gate ──► MLP_down ──► hidden_states''
k_rotated, v ──► CPU_offload (sync copy)
```
### Decode (per layer):
```
[CPU] k_cache_cpu, v_cache_cpu
▼ (H2D async to ring buffer)
[GPU] layer_k_cache[buffer_idx], layer_v_cache[buffer_idx]
▼ (view)
k_prefill, v_prefill
├──► torch.cat([k_prefill, k_decode_prev, k_new]) ──► k_full ⚠️ NEW ALLOC
└──► torch.cat([v_prefill, v_decode_prev, v_new]) ──► v_full ⚠️ NEW ALLOC
q_new, k_full, v_full ──► FlashAttention ──► attn_output
k_new, v_new ──► decode_k_buffer, decode_v_buffer (in-place store)
```
---
## 7. Appendix: Size Calculations
### Qwen3-4B Example (128k context)
```python
# Model config
seq_len = 131072
hidden_size = 2560
num_heads = 20
num_kv_heads = 8
head_dim = 128
intermediate_size = 13696
num_layers = 36
block_size = 1024
num_kv_buffers = 4
num_cpu_blocks = 128
dtype_size = 2 # fp16/bf16
# Derived
kv_dim = num_kv_heads * head_dim # 1024
q_size = num_heads * head_dim # 2560
qkv_size = q_size + 2 * kv_dim # 4608
# Pre-allocated GPU (OffloadEngine)
ring_buffer = 2 * num_kv_buffers * seq_len * kv_dim * dtype_size
# = 2 * 4 * 131072 * 1024 * 2 = 2,147,483,648 bytes = 2048 MB
decode_buffer = 2 * num_layers * block_size * kv_dim * dtype_size
# = 2 * 36 * 1024 * 1024 * 2 = 150,994,944 bytes = 144 MB
# Pre-allocated CPU
cpu_cache = 2 * num_layers * num_cpu_blocks * block_size * kv_dim * dtype_size
# = 2 * 36 * 128 * 1024 * 1024 * 2 = 19,327,352,832 bytes = 18432 MB
# Prefill temporaries (per layer peak)
mlp_gate_up = seq_len * 2 * intermediate_size * dtype_size
# = 131072 * 2 * 13696 * 2 = 7,180,648,448 bytes = 6848 MB
# Decode temporaries (per layer)
k_full = seq_len * kv_dim * dtype_size
# = 131072 * 1024 * 2 = 268,435,456 bytes = 256 MB
v_full = k_full # = 256 MB
# Total: 512 MB
```

View File

@@ -32,6 +32,7 @@ class Config:
offload_policy: str = "lru" # "lru", "fifo", or full class path
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
num_kv_buffers: int = 4 # Ring buffer size for layer-wise offload (decode H2D pipeline)
# Computed fields for offload (set in __post_init__ or by ModelRunner)
num_gpu_kvcache_blocks: int = -1

View File

@@ -400,10 +400,8 @@ class ModelRunner:
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
context = get_context()
# Use eager mode for: prefill, enforce_eager, large batch, or chunked attention
# Chunked attention requires dynamic KV loading that can't be captured in CUDA Graph
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
# Use eager mode for: prefill, enforce_eager, large batch
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512
if use_eager:
return self.model.compute_logits(self.model(input_ids, positions))
else:
@@ -462,13 +460,13 @@ class ModelRunner:
@torch.inference_mode()
def run_layerwise_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with layer-wise processing and CPU offload.
Run prefill with layer-wise processing and async CPU offload.
Key design:
- Process one layer at a time (not one chunk at a time)
- Each layer: full forward pass → offload KV to CPU
- Full KV stays on GPU during each layer's computation
- After layer completes, KV is offloaded to CPU
- Each layer: compute → async offload KV to CPU
- Offload of layer N overlaps with compute of layer N+1
- Uses OffloadEngine's async API with stream events
This enables future sparse attention methods (like MInference)
that need full KV context per layer for pattern estimation.
@@ -477,6 +475,7 @@ class ModelRunner:
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
num_layers = len(self.model.model.layers)
total_tokens = len(seq)
@@ -489,81 +488,91 @@ class ModelRunner:
input_ids = torch.tensor(seq[:], dtype=torch.int64, device="cuda")
positions = torch.arange(total_tokens, dtype=torch.int64, device="cuda")
# Step 1: Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
# Import FlashAttention once
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens = torch.tensor([0, total_tokens], dtype=torch.int32, device="cuda")
# Step 2: Layer-by-layer processing
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# Step 1: Embedding (on compute stream)
with torch.cuda.stream(compute_stream):
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
# 2a. Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# Step 2: Layer-by-layer processing
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# 2b. Self-attention (full sequence)
# QKV projection
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k, v = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
# 2a. Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
q = q.view(total_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v = v.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# 2b. Self-attention (full sequence)
# QKV projection
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k, v = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
# Q/K norms (Qwen3 specific)
if not layer.self_attn.qkv_bias:
num_tokens = q.shape[0]
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(num_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = layer.self_attn.k_norm(k.reshape(-1, layer.self_attn.head_dim))
k = k.view(num_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
q = q.view(total_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v = v.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Q/K norms (Qwen3 specific)
if not layer.self_attn.qkv_bias:
num_tokens = q.shape[0]
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(num_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = layer.self_attn.k_norm(k.reshape(-1, layer.self_attn.head_dim))
k = k.view(num_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Full attention using FlashAttention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens = torch.tensor([0, total_tokens], dtype=torch.int32, device="cuda")
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# O projection
attn_output = attn_output.view(total_tokens, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# 2c. Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# O projection
attn_output = attn_output.view(total_tokens, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# 2d. Offload KV to CPU (synchronous for correctness)
# Use synchronous copy to ensure data is fully copied before moving to next layer
self._offload_layer_kv_to_cpu_sync(layer_id, k, v, cpu_block_ids, total_tokens)
# 2c. Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# 2d. Offload KV to CPU (synchronous to avoid race condition)
# NOTE: Async offload has race condition where k,v memory gets reused
# before D2H copy completes. Use sync copy for correctness.
block_size = offload_engine.block_size
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
# Step 3: Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Step 4: Compute logits for last token
logits = self.model.compute_logits(hidden_states[-1:])
# Note: Using sync offload, no wait needed
# Mark all blocks as prefilled
for logical_id in logical_ids:
self.kvcache_manager.prefilled_blocks.add(logical_id)
# Sync offload completes within loop, no explicit wait needed
# Step 3: Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Step 4: Compute logits for last token
logits = self.model.compute_logits(hidden_states[-1:])
# Step 5: Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
@@ -572,236 +581,164 @@ class ModelRunner:
return token_ids
def _offload_layer_kv_to_cpu(
self,
layer_id: int,
k: torch.Tensor,
v: torch.Tensor,
cpu_block_ids: list[int],
total_tokens: int,
):
"""
Offload a layer's KV cache to CPU in blocks (async version).
Args:
layer_id: Layer index
k: Key tensor [seq_len, kv_heads, head_dim]
v: Value tensor [seq_len, kv_heads, head_dim]
cpu_block_ids: List of CPU block IDs to offload to
total_tokens: Total number of tokens
"""
offload_engine = self.kvcache_manager.offload_engine
block_size = offload_engine.block_size
stream = offload_engine.prefill_offload_streams[layer_id]
with torch.cuda.stream(stream):
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
# Copy K and V to CPU cache
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(
k[start:end], non_blocking=True
)
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(
v[start:end], non_blocking=True
)
# Record completion event
offload_engine.prefill_offload_events[layer_id].record(stream)
def _offload_layer_kv_to_cpu_sync(
self,
layer_id: int,
k: torch.Tensor,
v: torch.Tensor,
cpu_block_ids: list[int],
total_tokens: int,
):
"""
Offload a layer's KV cache to CPU in blocks (synchronous version).
This version uses synchronous copy to ensure correctness.
It's slower than async but guarantees data integrity.
"""
offload_engine = self.kvcache_manager.offload_engine
block_size = offload_engine.block_size
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
# Synchronous copy to CPU
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
@torch.inference_mode()
def run_layerwise_offload_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with layer-wise KV loading from CPU.
Run decode with ring-buffered layer-wise KV loading from CPU.
Key design:
- For each layer: load all prefilled KV from CPU
- Compute attention with loaded KV + new token's KV
- Store new token's KV for offload when block is full
- Ring buffer pipeline: load layer N+k while computing layer N
- Per-layer decode buffer for accumulating new tokens
- Async block offload when decode buffer is full
- Uses OffloadEngine's ring buffer API for H2D pipeline
"""
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
num_layers = len(self.model.model.layers)
num_buffers = offload_engine.num_kv_buffers
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, device="cuda")
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, device="cuda")
# Get prefilled CPU blocks
# Get prefilled CPU blocks and compute valid tokens per block
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
num_prefill_blocks = len(cpu_block_table)
total_prefill_tokens = self.kvcache_manager.get_prefill_len(seq)
# Calculate valid tokens in last prefill block
last_block_valid_tokens = total_prefill_tokens % self.block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = self.block_size
# Calculate valid tokens per block
valid_tokens_per_block = []
for block_idx in range(num_prefill_blocks):
if block_idx == num_prefill_blocks - 1:
# Last block may be partial
last_block_tokens = total_prefill_tokens % self.block_size
if last_block_tokens == 0 and total_prefill_tokens > 0:
last_block_tokens = self.block_size
valid_tokens_per_block.append(last_block_tokens)
else:
valid_tokens_per_block.append(self.block_size)
# Current decode position info
pos_in_block = (len(seq) - 1) % self.block_size
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
num_decode_tokens = pos_in_block - decode_start_pos + 1
# Step 1: Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
# Import FlashAttention once
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
# Allocate buffers for new decode token's KV (per layer)
# These will be accumulated and offloaded when block is full
decode_k_cache = []
decode_v_cache = []
# Step 2: Layer-by-layer processing
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# 2a. Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# 2b. QKV projection for new token
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k_new, v_new = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms
if not layer.self_attn.qkv_bias:
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim))
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k_new = layer.self_attn.rotary_emb(positions, q, k_new)
# Store new KV for later offload
decode_k_cache.append(k_new.clone())
decode_v_cache.append(v_new.clone())
# 2c. Load prefilled KV from CPU
k_prefill_list = []
v_prefill_list = []
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Determine valid tokens in this block
if block_idx == num_prefill_blocks - 1:
valid_tokens = last_block_valid_tokens
else:
valid_tokens = self.block_size
k_block = offload_engine.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens].to("cuda", non_blocking=True)
v_block = offload_engine.v_cache_cpu[layer_id, cpu_block_id, :valid_tokens].to("cuda", non_blocking=True)
k_prefill_list.append(k_block)
v_prefill_list.append(v_block)
# Concatenate prefilled KV
if k_prefill_list:
k_prefill = torch.cat(k_prefill_list, dim=0) # [prefill_tokens, kv_heads, head_dim]
v_prefill = torch.cat(v_prefill_list, dim=0)
else:
k_prefill = torch.empty(0, layer.self_attn.num_kv_heads, layer.self_attn.head_dim, device="cuda")
v_prefill = torch.empty(0, layer.self_attn.num_kv_heads, layer.self_attn.head_dim, device="cuda")
# 2d. Get accumulated decode KV from decode buffer (if any previous decode tokens)
if num_decode_tokens > 1:
# Load previous decode tokens for this layer from decode buffer
k_decode_prev = offload_engine.decode_k_buffer[layer_id, decode_start_pos:pos_in_block]
v_decode_prev = offload_engine.decode_v_buffer[layer_id, decode_start_pos:pos_in_block]
k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0)
v_full = torch.cat([v_prefill, v_decode_prev, v_new], dim=0)
else:
k_full = torch.cat([k_prefill, k_new], dim=0)
v_full = torch.cat([v_prefill, v_new], dim=0)
# Store new KV to decode buffer for future decode steps
offload_engine.decode_k_buffer[layer_id, pos_in_block].copy_(k_new.squeeze(0))
offload_engine.decode_v_buffer[layer_id, pos_in_block].copy_(v_new.squeeze(0))
# 2e. Compute attention
# For decode: query is at the last position, should attend to ALL previous keys
# Use causal=False because the single query token is conceptually at position N
# and should attend to all K tokens at positions 0 to N-1
from flash_attn.flash_attn_interface import flash_attn_varlen_func
total_kv_tokens = k_full.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
cu_seqlens_k = torch.tensor([0, total_kv_tokens], dtype=torch.int32, device="cuda")
attn_output = flash_attn_varlen_func(
q, k_full, v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=total_kv_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=False,
# Phase 1: Preload first N layers to ring buffer (fill pipeline)
num_preload = min(num_buffers, num_layers)
for i in range(num_preload):
offload_engine.load_layer_kv_to_buffer(
i, i, cpu_block_table, valid_tokens_per_block
)
# O projection
attn_output = attn_output.view(1, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# Step 1: Embedding (on compute stream)
with torch.cuda.stream(compute_stream):
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
# 2f. Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Phase 2: Layer-by-layer processing with ring buffer pipeline
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
current_buffer = layer_id % num_buffers
# Step 3: Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# 2a. Wait for current buffer's load to complete
offload_engine.wait_buffer_load(current_buffer)
# Step 4: Compute logits
logits = self.model.compute_logits(hidden_states)
# 2c. Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# Step 5: Handle block-full offload
# 2d. QKV projection for new token
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k_new, v_new = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms
if not layer.self_attn.qkv_bias:
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim))
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k_new = layer.self_attn.rotary_emb(positions, q, k_new)
# 2e. Get prefilled KV from ring buffer
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
# 2f. Get accumulated decode KV from decode buffer (if any previous decode tokens)
if num_decode_tokens > 1:
k_decode_prev, v_decode_prev = offload_engine.get_decode_kv(
layer_id, decode_start_pos, pos_in_block
)
k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0)
v_full = torch.cat([v_prefill, v_decode_prev, v_new], dim=0)
else:
k_full = torch.cat([k_prefill, k_new], dim=0)
v_full = torch.cat([v_prefill, v_new], dim=0)
# 2g. Store new KV to decode buffer for future decode steps
offload_engine.store_decode_kv(layer_id, pos_in_block, k_new, v_new)
# 2h. Mark buffer compute done (allows next load to reuse this buffer)
offload_engine.record_buffer_compute_done(current_buffer)
# 2i. Start loading next layer to same buffer (after compute done)
next_layer_to_load = layer_id + num_buffers
if next_layer_to_load < num_layers:
offload_engine.load_layer_kv_to_buffer(
current_buffer, next_layer_to_load, cpu_block_table, valid_tokens_per_block
)
# 2j. Compute attention
total_kv_tokens = k_full.shape[0]
cu_seqlens_k = torch.tensor([0, total_kv_tokens], dtype=torch.int32, device="cuda")
attn_output = flash_attn_varlen_func(
q, k_full, v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=total_kv_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=False,
)
# O projection
attn_output = attn_output.view(1, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# 2k. Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Step 3: Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Step 4: Compute logits
logits = self.model.compute_logits(hidden_states)
# Step 5: Handle block-full offload (async)
if pos_in_block == self.block_size - 1:
# Block is full, offload decode buffer to CPU
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0:
for layer_id in range(num_layers):
offload_engine.k_cache_cpu[layer_id, last_cpu_block].copy_(
offload_engine.decode_k_buffer[layer_id], non_blocking=True
)
offload_engine.v_cache_cpu[layer_id, last_cpu_block].copy_(
offload_engine.decode_v_buffer[layer_id], non_blocking=True
)
torch.cuda.synchronize()
# Async offload decode buffer to CPU
offload_engine.offload_decode_buffer_async(last_cpu_block)
# Mark as prefilled for future decode steps
logical_id = seq.block_table[-1]

View File

@@ -76,6 +76,8 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
block_size=config.kvcache_block_size,
policy=eviction_policy,
sparse_policy=sparse_policy,
num_kv_buffers=getattr(config, 'num_kv_buffers', 4),
max_seq_len=config.max_model_len,
)

View File

@@ -65,23 +65,22 @@ class LogicalBlock:
class HybridKVCacheManager(KVCacheManager):
"""
Hybrid CPU-GPU KV cache manager with ring buffer design.
Hybrid CPU-GPU KV cache manager with layer-wise offload design.
Architecture (CPU-primary mode):
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
- GPU buffer: Ring buffer for computation only (num_gpu_slots)
- Logical blocks: What sequences reference (num_cpu_blocks)
- GPU ring buffer: For decode H2D pipeline (num_kv_buffers)
- Decode buffer: Per-layer accumulation of decode tokens (block_size)
Design:
- All KV cache is stored on CPU as primary storage
- GPU is used as a ring buffer for computation only (no persistent data)
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
- During decode: Previous KV is loaded from CPU to GPU for attention
- Ring buffer enables pipelined H2D transfers overlapped with computation
- GPU ring buffer enables pipelined H2D transfers during decode
- During prefill: KV is computed and offloaded layer-by-layer to CPU
- During decode: Previous KV is loaded from CPU via ring buffer pipeline
Note:
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
- GPU slots are transient compute buffers, not tracked in logical blocks
- GPU ring buffer is for decode pipeline, not persistent storage
"""
def __init__(
@@ -91,25 +90,31 @@ class HybridKVCacheManager(KVCacheManager):
block_size: int,
policy: Optional[EvictionPolicy] = None,
sparse_policy: "SparsePolicy" = None,
num_kv_buffers: int = 4,
max_seq_len: int = 131072,
):
"""
Initialize hybrid manager with CPU-primary ring buffer design.
Initialize hybrid manager with layer-wise offload design.
All KV cache is stored on CPU as primary storage. GPU slots are used
as a ring buffer for computation only.
All KV cache is stored on CPU as primary storage. GPU ring buffer is used
for decode H2D pipeline.
Args:
num_gpu_slots: Number of GPU buffer slots (ring buffer for computation)
num_gpu_slots: Number of GPU buffer slots (kept for backward compat, not used)
num_cpu_blocks: Number of CPU pool blocks (primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU, used for prefix cache management)
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
num_kv_buffers: Ring buffer size for decode H2D pipeline
max_seq_len: Maximum sequence length for GPU buffer allocation
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.num_kv_buffers = num_kv_buffers
self.max_seq_len = max_seq_len
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
# GPU slots are transient compute buffers, not tracked as logical blocks
# GPU ring buffer is for decode pipeline, not persistent storage
self.total_blocks = num_cpu_blocks
# Eviction policy
@@ -147,7 +152,7 @@ class HybridKVCacheManager(KVCacheManager):
# Track blocks pending GPU load (for decode graph)
self.pending_gpu_loads: Set[int] = set() # logical_ids
# Track blocks that have been prefilled (KV written) for chunked prefill
# Track blocks that have been prefilled (KV offloaded to CPU)
self.prefilled_blocks: Set[int] = set() # logical_ids
# Track decode starting position within block (for batched offload optimization)
@@ -182,13 +187,21 @@ class HybridKVCacheManager(KVCacheManager):
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
num_kv_buffers=self.num_kv_buffers,
max_seq_len=self.max_seq_len,
sparse_policy=self.sparse_policy,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""Get GPU K/V cache tensors for a layer."""
"""
Get GPU K/V cache tensors for a layer.
Note: In layer-wise offload mode, this returns empty tensors as KV
is managed directly by the offload engine's ring buffer.
"""
assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id)
# Return empty tensors - actual KV is in offload_engine's ring buffer
return torch.empty(0), torch.empty(0)
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence."""
@@ -279,8 +292,8 @@ class HybridKVCacheManager(KVCacheManager):
"""
Prepare KV cache for attention computation.
In ring buffer mode, this is a no-op because chunked offload
paths handle H2D transfers directly in the attention layer.
In layer-wise offload mode, this is a no-op because KV transfers
are handled directly in model_runner's layer-by-layer methods.
"""
pass
@@ -291,12 +304,12 @@ class HybridKVCacheManager(KVCacheManager):
"""
Get GPU slot tables for sequences.
In ring buffer mode, all blocks are on CPU, so this raises an error
if called. Use run_chunked_offload_* methods instead.
In layer-wise offload mode, all blocks are on CPU, so this raises an error
if called. Use run_layerwise_offload_* methods instead.
"""
raise RuntimeError(
"get_gpu_block_tables should not be called in ring buffer mode. "
"Use run_chunked_offload_prefill/decode instead."
"get_gpu_block_tables should not be called in layer-wise offload mode. "
"Use run_layerwise_offload_prefill/decode instead."
)
def post_attention_cleanup(
@@ -307,18 +320,18 @@ class HybridKVCacheManager(KVCacheManager):
"""
Cleanup after attention.
In ring buffer mode, this is a no-op because offload is handled
directly in the chunked prefill/decode paths.
In layer-wise offload mode, this is a no-op because offload is handled
directly in model_runner's layer-by-layer methods.
"""
pass
# ========== Ring Buffer CPU-primary Chunked Prefill Support ==========
# ========== Layer-wise Offload Support ==========
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
"""
Get list of CPU block IDs for blocks that have been prefilled.
Used for loading previous KV during chunked prefill.
Used for loading prefilled KV during decode.
Returns:
List of CPU block IDs in sequence order
@@ -335,11 +348,11 @@ class HybridKVCacheManager(KVCacheManager):
# )
return cpu_blocks
# ========== Ring Buffer CPU-primary support ==========
# ========== CPU Block Allocation ==========
def allocate_cpu_only(self, seq: Sequence) -> None:
"""
Allocate CPU blocks for sequence (for ring buffer mode).
Allocate CPU blocks for sequence (for layer-wise offload mode).
Unlike allocate(), here all blocks are allocated to CPU,
GPU is only used as ring buffer for computation.
@@ -468,20 +481,6 @@ class HybridKVCacheManager(KVCacheManager):
return block.cpu_block_id
return -1
def get_write_slot_for_chunked_offload(self, seq: Sequence) -> int:
"""
Get GPU slot for writing new KV during chunked offload decode.
In ring buffer design, always use decode_slot (slot[0]) to write new KV.
This avoids conflicts with loading operations which use slots[1:].
Args:
seq: Sequence
Returns:
GPU slot ID (always decode_slot = 0)
"""
return self.offload_engine.decode_slot
def get_decode_start_pos(self, seq: Sequence) -> int:
"""

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,8 @@
import logging
import torch
import torch.cuda.nvtx
from torch import nn
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__)
def store_kvcache(
@@ -60,12 +55,17 @@ def store_kvcache(
valid_values_flat = valid_values.reshape(-1, D)
# In-place scatter using index_copy_
# 即使 valid_slots 为空张量index_copy_ 也是安全的(不会修改数据)。
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
class Attention(nn.Module):
"""
Attention layer for GPU-only mode.
For CPU offload mode, attention is computed directly in model_runner's
run_layerwise_offload_prefill/decode methods using FlashAttention.
"""
def __init__(
self,
@@ -87,54 +87,12 @@ class Attention(nn.Module):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
# Determine if we're in chunked offload mode
is_chunked_offload = (
context.is_chunked_prefill and
hasattr(context, 'kvcache_manager') and
context.kvcache_manager is not None and
hasattr(context.kvcache_manager, 'offload_engine')
)
#! Ensure synchronization before accessing k_cache/v_cache
# torch.cuda.synchronize()
#! =======================================================
if is_chunked_offload and context.is_prefill:
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
# This enables fully async offloads since each layer has its own buffer.
offload_engine = context.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
# Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
# k, v shape: [num_tokens, kv_heads, head_dim]
num_tokens = k.shape[0]
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
elif is_chunked_offload:
# Chunked decode mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
# slot_mapping is created with non_blocking=True on default stream, but we use it
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
else:
# Normal mode: store on default stream
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Store KV to cache (for GPU-only mode)
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV
o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
if context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
@@ -151,576 +109,7 @@ class Attention(nn.Module):
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
# Store current decode token to per-layer decode buffer
# This is needed because GPU cache has no layer dimension,
# so all layers would overwrite each other in decode_slot.
kvcache_manager = context.kvcache_manager
offload_engine = kvcache_manager.offload_engine
pos_in_block = context.decode_pos_in_block
# k, v shape: [1, kv_heads, head_dim]
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o
def _chunked_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute attention with per-layer prefill buffer for async offload.
Optimized design:
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
- Previous chunks' KV are loaded from CPU using GPU slots
- Each layer offloads from its own buffer - no waiting required!
For each layer:
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
2. Load previous chunks from CPU using available slots (pipeline)
3. Compute attention against previous KV (no causal mask)
4. Compute attention against current KV from prefill buffer (causal)
5. Merge all results using online softmax
6. Async offload prefill buffer to CPU (no waiting!)
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
# q shape: [total_tokens, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
num_tokens = k.shape[0]
o_acc = None
lse_acc = None
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
sparse_policy = kvcache_manager.sparse_policy
if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
if cpu_block_table:
# Get available load slots (all slots can be used since we use prefill buffer)
load_slots = list(range(offload_engine.num_ring_slots))
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
# Only 1 slot total, cannot pipeline - use sync loading
o_acc, lse_acc = self._sync_load_previous_chunks(
q_batched, cpu_block_table, offload_engine
)
else:
# Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine,
current_chunk_idx
)
# Get compute stream for all attention operations
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
# Get KV from per-layer prefill buffer
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated (all on compute_stream for consistency)
if o_acc is None:
final_o = current_o
else:
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Per-layer ASYNC offload: offload prefill buffer to CPU
# No waiting required! Each layer has its own buffer and stream.
if offload_engine is not None and seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
# Async offload - no waiting, fully parallel across layers
offload_engine.offload_prefill_buffer_async(
self.layer_id, cpu_block_id, num_tokens
)
# Sync default stream with compute_stream before returning
# This ensures the result is ready for the rest of the model (layernorm, MLP)
if compute_stream is not None:
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
def _sync_load_previous_chunks(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
):
"""Synchronous loading fallback when pipeline_depth=0."""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
compute_stream = offload_engine.compute_stream
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0)
# IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _ring_buffer_pipeline_load(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
current_chunk_idx: int = -1,
):
"""
Ring buffer async pipeline loading with double buffering.
Uses compute_done events to ensure safe buffer reuse:
- Before loading to slot X, wait for previous compute on slot X to finish
- Before computing on slot X, wait for load to slot X to finish
Timeline with 2 slots (A, B):
┌──────────────┐
│ Load B0→A │
└──────────────┘
┌──────────────┐ ┌──────────────┐
│ Load B1→B │ │ Load B2→A │ ...
└──────────────┘ └──────────────┘
↘ ↘
┌──────────────┐ ┌──────────────┐
│ Compute(A) │ │ Compute(B) │ ...
└──────────────┘ └──────────────┘
The load_to_slot_layer internally waits for compute_done[slot] before
starting the transfer, ensuring no data race.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
return None, None
o_acc, lse_acc = None, None
if pipeline_depth == 1:
# Only 1 slot available, cannot pipeline - use synchronous mode
# IMPORTANT: Must use compute_stream to match synchronization in
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
slot = load_slots[0]
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
# N-way pipeline: use ALL available slots for maximum overlap
# Pipeline depth = num_slots - 1 (num_slots blocks in flight)
num_slots = len(load_slots)
# Phase 1: Pre-load up to num_slots blocks to fill the pipeline
# This starts all transfers in parallel, utilizing full PCIe bandwidth
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
# Phase 2: Main loop - compute and immediately reuse slot for next transfer
# Use dedicated compute_stream (not default stream) to enable overlap with transfers
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
# Cycle through slots: slot[block_idx % num_slots]
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete (on compute_stream)
offload_engine.wait_slot_layer(current_slot)
# Compute attention on current slot's data
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next transfer to safely overwrite this slot
offload_engine.record_slot_compute_done(current_slot)
# Immediately start loading the NEXT block into this slot (if more blocks remain)
# Key insight: reuse current_slot immediately after compute is done!
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
# Merge with accumulated (also on compute_stream for consistency)
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop() # PipelineBlock
return o_acc, lse_acc
def _chunked_decode_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute decode attention using cross-layer pipeline.
Optimization: Uses double-buffered layer cache to overlap H2D transfer
with computation across layers:
- Layer N computes while Layer N+1's data is being loaded
- Each layer only waits for its own data, not all layers' data
This reduces effective latency from O(num_layers * transfer_time) to
O(transfer_time + num_layers * compute_time) when transfer < compute.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
# Get only PREFILLED CPU blocks (exclude the current decode block)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last CPU block
# CRITICAL: Use original prefill length, not current seq length!
# CPU blocks are fixed after prefill, their content doesn't change during decode.
block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table)
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is not None:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q_batched,
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
offload_engine = kvcache_manager.offload_engine
# Use cross-layer pipeline if active (initialized in model_runner)
if offload_engine.is_pipeline_active():
o_acc, lse_acc = self._decode_with_layer_pipeline(
q_batched, cpu_block_table, offload_engine,
block_size, last_block_valid_tokens
)
else:
# Fallback to original ring buffer pipeline
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Now attend to accumulated decode tokens from per-layer decode buffer
pos_in_block = context.decode_pos_in_block
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
# Sync compute_stream with default stream before reading decode_buffer
compute_stream = offload_engine.compute_stream
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
if num_accumulated > 0:
# Read from per-layer decode buffer
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Sync back to default stream before returning
torch.cuda.default_stream().wait_stream(compute_stream)
return o_acc
def _decode_ring_buffer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
block_size: int,
last_block_valid_tokens: int,
):
"""
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
Loads one block at a time, computes attention, and merges results.
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
methods as prefill for proven correctness.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
if not load_slots:
return None, None
o_acc, lse_acc = None, None
num_slots = len(load_slots)
compute_stream = offload_engine.compute_stream
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
# Phase 2: Process blocks with pipeline
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete
offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream):
# Get KV from slot
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
# Handle partial last block
is_last_block = (block_idx == num_blocks - 1)
if is_last_block and last_block_valid_tokens < block_size:
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
# Compute attention
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done for slot reuse
offload_engine.record_slot_compute_done(current_slot)
# Start loading next block (pipeline)
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
# Merge with accumulated
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _decode_with_layer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
block_size: int,
last_block_valid_tokens: int,
):
"""
Decode using cross-layer pipeline for optimized H2D transfer.
This method uses pre-loaded layer buffers instead of loading
blocks one by one. The pipeline loads the next layer's data
while the current layer computes, achieving transfer/compute overlap.
The key insight is that each layer needs the SAME blocks but from
different layers of CPU cache. By double-buffering and pipelining
across layers, we reduce total latency.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
compute_stream = offload_engine.compute_stream
# Get KV from pre-loaded layer buffer (triggers next layer loading)
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
total_tokens = num_blocks * block_size
# Handle partial last block
if last_block_valid_tokens < block_size:
# Only use valid tokens from last block
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
# Flatten and truncate
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
else:
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
prev_k_batched = prev_k_flat.unsqueeze(0)
prev_v_batched = prev_v_flat.unsqueeze(0)
# Compute attention on all prefilled blocks at once
with torch.cuda.stream(compute_stream):
o_acc, lse_acc = flash_attn_with_lse(
q_batched, prev_k_batched, prev_v_batched,
softmax_scale=self.scale,
causal=False,
)
return o_acc, lse_acc

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Any
from dataclasses import dataclass
from typing import Any
import torch
@@ -14,27 +14,6 @@ class Context:
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
# Chunked prefill support
is_chunked_prefill: bool = False
# Previous KV chunks info: List of (start_pos, end_pos) for blocks on CPU
prev_kv_ranges: List[Tuple[int, int]] = field(default_factory=list)
# Current chunk's position offset (for causal mask)
chunk_offset: int = 0
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
kvcache_manager: Any = None
# Current layer's previous K/V chunks (loaded from CPU)
# Set by model_runner before each layer's forward
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
# Current sequence being processed (for chunked prefill to load KV)
chunked_seq: Any = None
# Position within block for decode (used for reading from Decode region)
decode_pos_in_block: int = 0
# Starting position within block where decode tokens began (for accumulated token tracking)
# Used when batching decode offloads - we need to attend to all accumulated tokens
decode_start_pos_in_block: int = 0
# Current chunk index for ring buffer pipeline (prefill only)
current_chunk_idx: int = 0
# Sparse prefill attention support (GPU-only path)
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
@@ -56,14 +35,6 @@ def set_context(
slot_mapping=None,
context_lens=None,
block_tables=None,
is_chunked_prefill=False,
prev_kv_ranges=None,
chunk_offset=0,
kvcache_manager=None,
chunked_seq=None,
decode_pos_in_block=0,
decode_start_pos_in_block=0,
current_chunk_idx=0,
sparse_prefill_policy=None,
):
global _CONTEXT
@@ -76,14 +47,6 @@ def set_context(
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
is_chunked_prefill=is_chunked_prefill,
prev_kv_ranges=prev_kv_ranges or [],
chunk_offset=chunk_offset,
kvcache_manager=kvcache_manager,
chunked_seq=chunked_seq,
decode_pos_in_block=decode_pos_in_block,
decode_start_pos_in_block=decode_start_pos_in_block,
current_chunk_idx=current_chunk_idx,
sparse_prefill_policy=sparse_prefill_policy,
)

View File

@@ -4,12 +4,12 @@
Refactor layerwise offload to use proper OffloadEngine API, pre-allocate buffers, remove chunked prefill code, and pass needle test.
## Phases
- [ ] Phase 1: Add layerwise API to OffloadEngine
- [ ] Phase 2: Pre-allocate buffers in ModelRunner
- [ ] Phase 3: Refactor run_layerwise_offload_prefill()
- [ ] Phase 4: Refactor run_layerwise_offload_decode()
- [ ] Phase 5: Remove chunked prefill code
- [ ] Phase 6: Verify with needle test
- [x] Phase 1: Add layerwise API to OffloadEngine
- [x] Phase 2: Pre-allocate buffers in ModelRunner (skipped - handled by ring buffer)
- [x] Phase 3: Refactor run_layerwise_offload_prefill()
- [x] Phase 4: Refactor run_layerwise_offload_decode()
- [x] Phase 5: Remove chunked prefill code
- [x] Phase 6: Verify with needle test
## Key Questions
1. Should we keep chunked_attention.py for MInference use?
@@ -29,7 +29,7 @@ Refactor layerwise offload to use proper OffloadEngine API, pre-allocate buffers
(none yet)
## Status
**Currently in Phase 0** - Planning complete, awaiting user approval
**COMPLETE** - All phases implemented and needle test passes
---