Files
nano-vllm/docs/layerwise_offload_memory_analysis.md

19 KiB
Raw Blame History

Layer-wise Offload Memory Analysis

This document provides a detailed analysis of memory allocations in the layer-wise CPU offload system, distinguishing between pre-allocated (managed) memory and temporary (non-pre-allocated) memory.

Variable Notation

Symbol Description Example (Qwen3-4B)
seq_len Input sequence length 131072 (128k)
hidden_size Model hidden dimension 2560
num_heads Number of attention heads 20
num_kv_heads Number of KV heads (GQA) 8
head_dim Dimension per head 128
intermediate_size MLP intermediate dimension 13696
num_layers Number of transformer layers 36
block_size KV cache block size 1024
num_kv_buffers Ring buffer count 4
num_cpu_blocks Number of CPU cache blocks 128
vocab_size Vocabulary size 151936
dtype_size Bytes per element (fp16/bf16) 2

Derived values:

  • kv_dim = num_kv_heads × head_dim
  • q_size = num_heads × head_dim
  • kv_size = num_kv_heads × head_dim
  • qkv_size = q_size + 2 × kv_size

1. Pre-allocated Memory (Managed by nanovllm)

These tensors are allocated once during initialization and reused throughout inference.

1.1 OffloadEngine Managed Memory

Tensor Shape Size Formula Location
layer_k_cache [num_kv_buffers, seq_len, num_kv_heads, head_dim] num_kv_buffers × seq_len × kv_dim × dtype_size GPU
layer_v_cache [num_kv_buffers, seq_len, num_kv_heads, head_dim] num_kv_buffers × seq_len × kv_dim × dtype_size GPU
decode_k_buffer [num_layers, block_size, num_kv_heads, head_dim] num_layers × block_size × kv_dim × dtype_size GPU
decode_v_buffer [num_layers, block_size, num_kv_heads, head_dim] num_layers × block_size × kv_dim × dtype_size GPU
k_cache_cpu [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim] num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size CPU (pinned)
v_cache_cpu [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim] num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size CPU (pinned)

Total GPU (OffloadEngine): 2 × (num_kv_buffers × seq_len + num_layers × block_size) × kv_dim × dtype_size

Total CPU (OffloadEngine): 2 × num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size

1.2 Model Weights

Component Approximate Size
Embedding vocab_size × hidden_size × dtype_size
Per-layer QKV proj hidden_size × qkv_size × dtype_size
Per-layer O proj q_size × hidden_size × dtype_size
Per-layer MLP hidden_size × 2 × intermediate_size × dtype_size + intermediate_size × hidden_size × dtype_size
Per-layer LayerNorm 2 × hidden_size × dtype_size
LM Head hidden_size × vocab_size × dtype_size

1.3 RoPE Cache

Tensor Shape Size
cos_sin_cache [max_position, 1, head_dim] max_position × head_dim × 4 (float32)

2. Non-Pre-allocated Memory: Prefill Phase

Location: model_runner.py:run_layerwise_offload_prefill()

2.1 Persistent Tensors (Live Throughout Prefill)

Variable Line Shape Size Notes
input_ids 488 [seq_len] seq_len × 8 int64
positions 489 [seq_len] seq_len × 8 int64
cu_seqlens 493 [2] negligible int32
hidden_states 497 [seq_len, hidden_size] seq_len × hidden_size × dtype_size Embedding output
residual 506 [seq_len, hidden_size] seq_len × hidden_size × dtype_size Residual connection

2.2 Per-Layer Temporary Tensors

These are allocated and deallocated within each layer iteration.

2.2.1 LayerNorm

Variable Line Shape Size Notes
hidden_ln 506-508 [seq_len, hidden_size] seq_len × hidden_size × dtype_size Input layernorm output

Inside RMSNorm (layernorm.py:add_rms_forward):

Variable Shape Size Notes
x.float() [seq_len, hidden_size] seq_len × hidden_size × 4 Upcasted to float32
var [seq_len, 1] seq_len × 4 Variance

2.2.2 QKV Projection

Variable Line Shape Size Notes
qkv 512 [seq_len, q_size + 2 × kv_size] seq_len × qkv_size × dtype_size Merged QKV output
q 513-519 [seq_len, num_heads, head_dim] 0 (view) View of qkv
k 513-520 [seq_len, num_kv_heads, head_dim] 0 (view) View of qkv
v 513-521 [seq_len, num_kv_heads, head_dim] 0 (view) View of qkv

2.2.3 Q/K Norms (Qwen3 specific)

Variable Line Shape Size Notes
q.reshape() 526 [seq_len × num_heads, head_dim] 0 (view) Reshape for norm
k.reshape() 528 [seq_len × num_kv_heads, head_dim] 0 (view) Reshape for norm
RMSNorm intermediates - see above seq_len × num_heads × head_dim × 4 Float32 upcasting

2.2.4 RoPE (Rotary Position Embedding)

Location: rotary_embedding.py:apply_rotary_emb()

Variable Line Shape Size Notes
cos_sin 44 [seq_len, 1, head_dim] 0 (view) View of cached cos_sin
cos 45 [seq_len, 1, head_dim/2] 0 (view) Chunk view
sin 45 [seq_len, 1, head_dim/2] 0 (view) Chunk view

Inside apply_rotary_emb for Q (rotary_embedding.py:6-14):

Variable Shape Size Notes
x.float() [seq_len, num_heads, head_dim] seq_len × num_heads × head_dim × 4 Upcast to float32
x1 [seq_len, num_heads, head_dim/2] 0 (view) Chunk view
x2 [seq_len, num_heads, head_dim/2] 0 (view) Chunk view
y1 = x1*cos - x2*sin [seq_len, num_heads, head_dim/2] seq_len × num_heads × head_dim/2 × 4 New tensor
y2 = x2*cos + x1*sin [seq_len, num_heads, head_dim/2] seq_len × num_heads × head_dim/2 × 4 New tensor
torch.cat((y1, y2)) [seq_len, num_heads, head_dim] seq_len × num_heads × head_dim × 4 New tensor
.to(x.dtype) [seq_len, num_heads, head_dim] seq_len × num_heads × head_dim × dtype_size Downcast

Inside apply_rotary_emb for K:

Variable Shape Size Notes
Same pattern as Q [seq_len, num_kv_heads, head_dim] Similar, with num_kv_heads

Total RoPE temporary for Q+K: ~seq_len × (num_heads + num_kv_heads) × head_dim × 4 × 3 (float32 intermediates)

2.2.5 FlashAttention

Variable Line Shape Size Notes
attn_output 535 [seq_len, num_heads, head_dim] seq_len × num_heads × head_dim × dtype_size Attention output
Internal workspace - O(seq_len) Variable FlashAttention internal

2.2.6 Output Projection

Variable Line Shape Size Notes
attn_output.view() 546 [seq_len, q_size] 0 (view) Reshape for o_proj
o_proj(attn_output) 547 [seq_len, hidden_size] seq_len × hidden_size × dtype_size O projection output

2.2.7 Post-Attention LayerNorm

Same as input layernorm (2.2.1).

2.2.8 MLP

Location: qwen3.py:Qwen3MLP.forward()

Variable Line Shape Size Notes
gate_up 117 [seq_len, 2 × intermediate_size] seq_len × 2 × intermediate_size × dtype_size LARGEST TEMPORARY!
x, y = chunk() activation.py:13 [seq_len, intermediate_size] × 2 0 (views) Chunk views
F.silu(x) activation.py:14 [seq_len, intermediate_size] seq_len × intermediate_size × dtype_size SiLU activation
silu(x) * y activation.py:14 [seq_len, intermediate_size] seq_len × intermediate_size × dtype_size Gated output
down_proj() 119 [seq_len, hidden_size] seq_len × hidden_size × dtype_size MLP output

2.3 Prefill Memory Summary

Peak per-layer temporary memory:

= qkv + RoPE_temps + attn_output + o_proj + layernorm + MLP_gate_up + MLP_activation
≈ seq_len × (qkv_size + (num_heads + num_kv_heads) × head_dim × 4 × 3
           + num_heads × head_dim + hidden_size × 2 + 2 × intermediate_size + intermediate_size) × dtype_size

Dominant term: seq_len × 2 × intermediate_size × dtype_size (MLP gate_up)


3. Non-Pre-allocated Memory: Decode Phase

Location: model_runner.py:run_layerwise_offload_decode()

3.1 Persistent Tensors

Variable Line Shape Size Notes
input_ids 604 [1] 8 bytes Single token
positions 605 [1] 8 bytes Single position
cu_seqlens_q 631 [2] 8 bytes Fixed
valid_tokens_per_block 613-622 Python list negligible

3.2 Per-Layer Temporary Tensors

3.2.1 Views (Zero Additional Memory)

Variable Line Shape Notes
k_prefill 682 [prefill_len, num_kv_heads, head_dim] View of ring buffer
v_prefill 682 [prefill_len, num_kv_heads, head_dim] View of ring buffer
k_decode_prev 686-687 [num_decode_tokens-1, num_kv_heads, head_dim] View of decode buffer
v_decode_prev 686-688 [num_decode_tokens-1, num_kv_heads, head_dim] View of decode buffer

3.2.2 New Allocations

Variable Line Shape Size Notes
hidden_ln 654-657 [1, hidden_size] hidden_size × dtype_size Tiny
qkv 660 [1, qkv_size] qkv_size × dtype_size Tiny
q 667 [1, num_heads, head_dim] 0 (view)
k_new 668 [1, num_kv_heads, head_dim] 0 (view)
v_new 669 [1, num_kv_heads, head_dim] 0 (view)
k_full 689/692 [prefill_len + num_decode_tokens, num_kv_heads, head_dim] (prefill_len + num_decode_tokens) × kv_dim × dtype_size torch.cat - NEW ALLOCATION
v_full 690/693 [prefill_len + num_decode_tokens, num_kv_heads, head_dim] (prefill_len + num_decode_tokens) × kv_dim × dtype_size torch.cat - NEW ALLOCATION
cu_seqlens_k 710 [2] 8 bytes Created per layer
attn_output 712 [1, num_heads, head_dim] num_heads × head_dim × dtype_size Tiny
MLP temps 728 [1, ...] negligible Single token

3.3 Decode Memory Summary

Peak per-layer temporary memory:

= k_full + v_full + small_tensors
≈ 2 × (prefill_len + num_decode_tokens) × num_kv_heads × head_dim × dtype_size
≈ 2 × seq_len × kv_dim × dtype_size

Dominant term: k_full and v_full from torch.cat()


4. Memory Comparison Table

For Qwen3-4B with 128k context:

Category Memory Notes
Pre-allocated GPU ~2.2 GB Ring buffer + decode buffer
Pre-allocated CPU ~18.4 GB Pinned memory
Model Weights ~8 GB
Prefill Peak Temp ~10-12 GB MLP gate_up dominant
Decode Peak Temp ~512 MB k_full + v_full

5. Optimization Opportunities

5.1 Decode: Pre-allocate k_full/v_full

Current (L689-693):

k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0)  # New allocation each layer
v_full = torch.cat([v_prefill, v_decode_prev, v_new], dim=0)  # New allocation each layer

Optimized:

# Pre-allocate in OffloadEngine.__init__():
self.k_full_buffer = torch.zeros(max_seq_len + block_size, num_kv_heads, head_dim, ...)
self.v_full_buffer = torch.zeros(max_seq_len + block_size, num_kv_heads, head_dim, ...)

# In decode loop:
total_len = prefill_len + num_decode_tokens
k_full = self.k_full_buffer[:total_len]
k_full[:prefill_len].copy_(k_prefill)
k_full[prefill_len:prefill_len+num_decode_prev].copy_(k_decode_prev)
k_full[-1:].copy_(k_new)

Savings: ~512 MB per decode step (for 128k)

5.2 Decode: Reuse cu_seqlens_k

Current (L710):

cu_seqlens_k = torch.tensor([0, total_kv_tokens], dtype=torch.int32, device="cuda")

Optimized:

# Pre-allocate once:
self.cu_seqlens_k = torch.zeros(2, dtype=torch.int32, device="cuda")

# In decode loop:
self.cu_seqlens_k[1] = total_kv_tokens

Savings: Negligible memory, but reduces allocation overhead.

5.3 RoPE: In-place or Pre-allocated Buffers

The RoPE implementation creates multiple float32 intermediate tensors. Options:

  1. Pre-allocate buffers for Q and K rotary outputs
  2. Use in-place operations where possible
  3. Use fused RoPE kernel (e.g., from FlashAttention)

Potential savings: ~1.5 GB during prefill per layer

5.4 MLP: Cannot Optimize Easily

The MLP gate_up tensor is inherently required for the gated activation:

gate_up = gate_up_proj(x)  # [seq_len, 2 × intermediate_size]
x, y = gate_up.chunk(2, -1)
output = silu(x) * y

This is a fundamental computation pattern. Potential optimizations:

  • Chunked MLP computation (process seq_len in chunks)
  • Fused kernels that avoid materializing full gate_up

6. Memory Flow Diagram

Prefill (per layer):

hidden_states ──┬──► LayerNorm ──► hidden_ln
                │
residual ◄──────┘

hidden_ln ──► QKV_proj ──► qkv ──┬──► q ──► Q_norm ──► RoPE ──► q_rotated
                                 ├──► k ──► K_norm ──► RoPE ──► k_rotated
                                 └──► v

q_rotated, k_rotated, v ──► FlashAttention ──► attn_output

attn_output ──► O_proj ──► hidden_states'

hidden_states', residual ──► LayerNorm ──► hidden_ln', residual'

hidden_ln' ──► MLP_gate_up ──► gate_up ──► SiLU×gate ──► MLP_down ──► hidden_states''

k_rotated, v ──► CPU_offload (sync copy)

Decode (per layer):

[CPU] k_cache_cpu, v_cache_cpu
           │
           ▼ (H2D async to ring buffer)
[GPU] layer_k_cache[buffer_idx], layer_v_cache[buffer_idx]
           │
           ▼ (view)
      k_prefill, v_prefill
           │
           ├──► torch.cat([k_prefill, k_decode_prev, k_new]) ──► k_full  ⚠️ NEW ALLOC
           │
           └──► torch.cat([v_prefill, v_decode_prev, v_new]) ──► v_full  ⚠️ NEW ALLOC

q_new, k_full, v_full ──► FlashAttention ──► attn_output

k_new, v_new ──► decode_k_buffer, decode_v_buffer (in-place store)

7. Appendix: Size Calculations

Qwen3-4B Example (128k context)

# Model config
seq_len = 131072
hidden_size = 2560
num_heads = 20
num_kv_heads = 8
head_dim = 128
intermediate_size = 13696
num_layers = 36
block_size = 1024
num_kv_buffers = 4
num_cpu_blocks = 128
dtype_size = 2  # fp16/bf16

# Derived
kv_dim = num_kv_heads * head_dim  # 1024
q_size = num_heads * head_dim     # 2560
qkv_size = q_size + 2 * kv_dim    # 4608

# Pre-allocated GPU (OffloadEngine)
ring_buffer = 2 * num_kv_buffers * seq_len * kv_dim * dtype_size
# = 2 * 4 * 131072 * 1024 * 2 = 2,147,483,648 bytes = 2048 MB

decode_buffer = 2 * num_layers * block_size * kv_dim * dtype_size
# = 2 * 36 * 1024 * 1024 * 2 = 150,994,944 bytes = 144 MB

# Pre-allocated CPU
cpu_cache = 2 * num_layers * num_cpu_blocks * block_size * kv_dim * dtype_size
# = 2 * 36 * 128 * 1024 * 1024 * 2 = 19,327,352,832 bytes = 18432 MB

# Prefill temporaries (per layer peak)
mlp_gate_up = seq_len * 2 * intermediate_size * dtype_size
# = 131072 * 2 * 13696 * 2 = 7,180,648,448 bytes = 6848 MB

# Decode temporaries (per layer)
k_full = seq_len * kv_dim * dtype_size
# = 131072 * 1024 * 2 = 268,435,456 bytes = 256 MB
v_full = k_full  # = 256 MB
# Total: 512 MB

8. Empirical Validation

This section validates the theoretical memory analysis against actual measurements.

8.1 Test Configuration

python tests/test_needle.py --enable-offload --input-len 100000 --block-size 1024

Parameters:

  • Model: Qwen3-4B-Instruct
  • seq_len = 100000 (actual tokens: 99925)
  • block_size = 1024
  • max_model_len = 131072
  • num_kv_buffers = 4

8.2 Theoretical Peak Memory Calculation

Step 1: Model Load Memory

Component Formula Size
Model weights ~4B params × 2 bytes ~8 GB
Ring buffer 2 × 4 × 131072 × 1024 × 2 2048 MB
Decode buffer 2 × 36 × 1024 × 1024 × 2 144 MB
Subtotal ~10.2 GB

Step 2: Prefill Activation Peak (per-layer)

Component Formula Size
hidden_states 100000 × 2560 × 2 512 MB
residual 100000 × 2560 × 2 512 MB
MLP gate_up 100000 × 27392 × 2 5478 MB
MLP silu×gate 100000 × 13696 × 2 2739 MB
Other intermediates (qkv, RoPE, attn) ~1-2 GB ~1500 MB
Subtotal ~10 GB

Step 3: Total Peak

Total Peak = Model Load + Activation Peak
           = 10.2 GB + 10 GB
           = ~20.2 GB

8.3 Actual Measurement Results

import torch
torch.cuda.reset_peak_memory_stats()
# ... run inference ...
peak = torch.cuda.max_memory_allocated()
Metric Value
After model load 9.82 GB
Peak during inference 20.02 GB
Activation peak (delta) 10.20 GB

8.4 Comparison: Theory vs Actual

Component Theoretical Actual Error
Model load memory ~10.2 GB 9.82 GB -3.7%
Activation peak ~10 GB 10.20 GB +2.0%
Total peak ~20.2 GB 20.02 GB < 1%

8.5 Key Findings

  1. Theoretical model is accurate: < 5% error in all components.

  2. MLP gate_up is the dominant temporary:

    • Size: 5.35 GB (for 100k tokens)
    • Accounts for ~50% of activation peak
    • Formula: seq_len × 2 × intermediate_size × dtype_size
  3. Memory scaling with sequence length:

    seq_len Model Load Activation Peak Total Peak
    8k ~10 GB ~0.8 GB ~11 GB
    32k ~10 GB ~3.2 GB ~13 GB
    64k ~10 GB ~6.4 GB ~16 GB
    100k ~10 GB ~10 GB ~20 GB
    128k ~10 GB ~13 GB ~23 GB
  4. Decode memory is much smaller:

    • Per-step: ~512 MB for k_full + v_full (at 100k context)
    • Does not grow with decode steps (constant per layer)

8.6 Memory Profiling Script

To reproduce the measurement:

import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"

import torch
from nanovllm import LLM, SamplingParams
from tests.utils import generate_needle_prompt

# Reset memory stats
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

# Initialize LLM
llm = LLM(
    "path/to/model",
    enforce_eager=True,
    max_model_len=131072,
    max_num_batched_tokens=131072,
    enable_cpu_offload=True,
    kvcache_block_size=1024,
    num_gpu_blocks=2,
)

after_load = torch.cuda.memory_allocated()
print(f"After model load: {after_load / 1024**3:.2f} GB")

# Generate prompt and run inference
prompt, expected = generate_needle_prompt(
    tokenizer=llm.tokenizer,
    target_length=100000,
    needle_position=0.5,
)

torch.cuda.reset_peak_memory_stats()
outputs = llm.generate([prompt], SamplingParams(max_tokens=32))

peak = torch.cuda.max_memory_allocated()
print(f"Peak during inference: {peak / 1024**3:.2f} GB")