diff --git a/CLAUDE.md b/CLAUDE.md index b0b6b52..73b6c9b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -64,6 +64,7 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py | [`docs/xattention_integration.md`](docs/xattention_integration.md) | XAttention integration guide: algorithm, implementation, design decisions, and testing | | [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design | | [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work | +| [`docs/chunked_prefill_analysis.md`](docs/chunked_prefill_analysis.md) | **NEW**: Chunked prefill for ultra-long sequences (1M+), memory analysis, MLP activation breakdown, implementation guide | ## Configuration diff --git a/docs/chunked_prefill_analysis.md b/docs/chunked_prefill_analysis.md new file mode 100644 index 0000000..8356a60 --- /dev/null +++ b/docs/chunked_prefill_analysis.md @@ -0,0 +1,1055 @@ +# Chunked Prefill 与长序列推理内存分析 + +本文档详细分析了 nano-vllm 的 layerwise kvcache offload 策略的理论最大序列长度,以及通过 chunked prefill 降低 GPU 峰值内存的方案。 + +## 目录 + +1. [问题背景](#问题背景) +2. [理论最大序列长度分析](#理论最大序列长度分析) +3. [MLP Activation 内存瓶颈](#mlp-activation-内存瓶颈) +4. [Chunked MLP 方案](#chunked-mlp-方案) +5. [Chunked Prefill 方案](#chunked-prefill-方案) +6. [**Streaming Chunked Prefill(最优方案)**](#streaming-chunked-prefill最优方案) +7. [Layerwise Offload + Chunked Prefill 组合](#layerwise-offload--chunked-prefill-组合) +8. [实现指南](#实现指南) + +--- + +## 问题背景 + +### 为什么需要长序列推理? + +现代大语言模型的应用场景对上下文长度提出了越来越高的要求: +- **文档分析**:需要处理完整的书籍、报告 +- **代码理解**:大型代码库的上下文 +- **多轮对话**:保持长期对话历史 +- **RAG 应用**:检索增强生成需要大量上下文 + +### GPU 内存限制 + +RTX 3090/4090 只有 24GB 显存,A100 有 40GB/80GB,而长序列推理的内存需求呈线性增长: + +$$ +\text{Memory} \propto \text{seq\_len} +$$ + +### Layerwise Offload 的局限性 + +虽然 layerwise offload 将 KV cache 移到 CPU,但仍有其他组件占用大量 GPU 内存: +- **模型权重**:固定 ~16GB(Llama-3.1-8B) +- **MLP 激活值**:与序列长度成正比 +- **Attention 中间张量**:与序列长度成正比 + +--- + +## 理论最大序列长度分析 + +### 硬性限制排序 + +#### 1. RoPE float32 精度限制(最严格) + +从 `rotary_embedding.py:30`: +```python +t = torch.arange(max_position_embeddings, dtype=torch.float) +``` + +**问题**:`torch.arange` 使用 `float32`,精确整数表示上限为: +$$ +2^{24} = 16,777,216 \approx 16.8\text{M tokens} +$$ + +**突破方法**:修改 `rotary_embedding.py` 使用 `float64` 或 `int64` + +#### 2. GPU MLP gate_up 激活值(实际瓶颈) + +**公式**: +$$ +\text{MLP\_gate\_up} = \text{seq\_len} \times 2 \times \text{intermediate\_size} \times \text{dtype\_size} +$$ + +**Llama-3.1-8B 示例**: +``` +intermediate_size = 14336 +dtype_size = 2 (bfloat16) + +MLP_gate_up = seq_len × 2 × 14336 × 2 + = seq_len × 57344 bytes + = seq_len × 56 KB +``` + +| seq_len | MLP gate_up 内存 | +|---------|-----------------| +| 64K | 3.5 GB | +| 1M | 57 GB | +| 16M | 912 GB | + +#### 3. GPU Ring Buffer 内存 + +**公式**: +$$ +\text{Ring\_Buffer} = 2 \times \text{num\_kv\_buffers} \times \text{seq\_len} \times \text{kv\_dim} \times \text{dtype\_size} +$$ + +**优化**:设置 `num_kv_buffers=1` 可降低内存 + +#### 4. CPU KV Cache 内存 + +**公式**: +$$ +\text{CPU\_KV} = \text{num\_layers} \times \text{seq\_len} \times \text{kv\_dim} \times 2 \times \text{dtype\_size} +$$ + +**Llama-3.1-8B** (32 layers, kv_dim=1024): +``` +CPU_KV = 32 × seq_len × 1024 × 2 × 2 + = seq_len × 131072 bytes + = seq_len × 128 KB +``` + +| seq_len | CPU 内存需求 | +|---------|-------------| +| 1M | 131 GB | +| 10M | 1.31 TB | +| 16M | 2.1 TB | + +### 不同 GPU 配置的最大序列长度 + +**GPU 内存公式**: +$$ +\text{GPU\_mem} = \text{model\_weights} + \text{MLP\_gate\_up} + \text{Ring\_Buffer} + \text{overhead} +$$ + +**Llama-3.1-8B 在不同 GPU 上**: + +| GPU | 显存 | 模型权重 | 可用 | 最大序列(原始) | 最大序列(chunked) | +|-----|------|---------|------|----------------|-------------------| +| RTX 3090 | 24GB | 16 GB | 3 GB | ~32K | **~1M** | +| A100 40GB | 40GB | 16 GB | 19 GB | ~340K | **~2M** | +| A100 80GB | 80GB | 16 GB | 59 GB | ~1M | **~4M** | + +### 最终理论最大值 + +**硬性上限**:**~16.8M tokens**(受 RoPE float32 精度限制) + +**实用最大值**(单卡):**1-2M tokens**(受 MLP gate_up GPU 内存限制) + +--- + +## MLP Activation 内存瓶颈 + +### MLP 计算流程 + +从 `llama.py:82-106`,LlamaMLP 使用 **SwiGLU** 激活: + +```python +def forward(self, x): + # 输入: x [seq_len, hidden_size] + gate_up = self.gate_up_proj(x) # [seq_len, 2 * intermediate_size] + gate, up = gate_up.chunk(2, dim=-1) # 各 [seq_len, intermediate_size] + gate = F.silu(gate) # SiLU 激活 + output = gate * up # 逐元素相乘 + return self.down_proj(output) # [seq_len, hidden_size] +``` + +### 为什么 gate_up 这么大? + +**Llama-3.1-8B 参数**: +``` +hidden_size = 4096 +intermediate_size = 14336 # 是 hidden_size 的 3.5 倍 +``` + +**gate_up 张量大小**: +``` +shape = [seq_len, 2 * intermediate_size] + = [seq_len, 28672] + +内存 = seq_len × 28672 × 2 bytes + = seq_len × 57344 bytes + = seq_len × 56 KB +``` + +### SwiGLU vs 标准 MLP + +**标准 Transformer MLP(如 GPT)**: +```python +hidden = gelu(x @ W_gate) # [seq_len, 4 * hidden_size] +output = hidden @ W_down # [seq_len, hidden_size] +``` + +**LLaMA 的 SwiGLU MLP**: +```python +gate_up = x @ W_gate_up # [seq_len, 2 * intermediate_size] +gate, up = gate_up.chunk(2) +output = silu(gate) * up # ← 逐元素相乘需要两个独立张量! +output = output @ W_down +``` + +**关键区别**: +- SwiGLU 需要 **gate 和 up 两个独立的分支** +- 两个分支必须 **同时存在内存中** 才能逐元素相乘 +- 这就是内存开销的根本原因 + +### MLP Activation 内存公式 + +$$ +\text{MLP\_activation} = \text{seq\_len} \times 2 \times \text{intermediate\_size} \times \text{dtype\_size} +$$ + +| 模型 | hidden_size | intermediate_size | 扩展比 | 64K 激活内存 | +|-----|-------------|-------------------|-------|-------------| +| Llama-3.1-8B | 4096 | 14336 | 3.5× | 3.5 GB | +| Llama-2-7B | 4096 | 11008 | 2.69× | 2.7 GB | +| Qwen3-4B | 2560 | 13696 | 5.35× | 3.5 GB | +| Qwen2-7B | 4096 | 13696 | 3.34× | 3.5 GB | + +### 为什么 MLP Activation 不能 Offload? + +与 KV cache 不同,MLP activation 是 **临时计算结果**: + +1. **计算依赖**:`gate_up` → `chunk` → `silu` → `mul` → `down_proj` +2. **生命周期**:用完即弃,不需要长期保存 +3. **传输开销**:CPU-GPU 传输会完全抵消 offload 的好处 + +如果 offload: +```python +gate_up = gate_up_proj(x) # GPU 计算 +gate_up = gate_up.to("cpu") # 传输到 CPU ← 慢! +gate, up = gate_up.chunk(2) # CPU 操作 +gate = silu(gate) +output = gate * up +output = output.to("cuda") # 传输回 GPU ← 慢! +output = down_proj(output) # GPU 计算 +``` + +**结论**:MLP activation 的高内存开销是 Transformer SwiGLU 架构的固有问题。 + +--- + +## Chunked MLP 方案 + +### 核心思想 + +MLP 是 **逐位置独立** 的计算: +$$ +\text{output}[i] = \text{MLP}(\text{input}[i]) +$$ + +这意味着可以将序列分成多个 chunks,每个 chunk 独立计算。 + +### 实现方式 + +```python +def chunked_mlp_forward(mlp, hidden_states, chunk_size): + """ + Chunked MLP forward to reduce peak memory usage. + + Args: + mlp: LlamaMLP module + hidden_states: [seq_len, hidden_size] + chunk_size: number of tokens to process per chunk + + Returns: + output: [seq_len, hidden_size] + """ + if chunk_size <= 0 or len(hidden_states) <= chunk_size: + return mlp(hidden_states) + + # 预分配输出 buffer + output = torch.empty_like(hidden_states) + + # Chunked processing + for i in range(0, len(hidden_states), chunk_size): + chunk_end = min(i + chunk_size, len(hidden_states)) + chunk = hidden_states[i:chunk_end] + output[i:chunk_end] = mlp(chunk) + + return output +``` + +### 内存效果对比(Llama-3.1-8B, seq_len=1M) + +| 组件 | 原始 | Chunked (chunk_size=16K) | 降低 | +|-----|------|-------------------------|------| +| gate_up | 56 GB | 875 MB | **98.4%** ✅ | +| activated | 28 GB | 437 MB | 98.4% ✅ | +| mlp_out | 8 GB | 125 MB | 98.4% ✅ | + +### GPU 内存效果 + +| GPU | 显存 | 原始最大序列 | Chunked (16K) | 提升 | +|-----|------|------------|---------------|------| +| RTX 3090 | 24GB | ~32K | ~200K | **6.25×** | +| A100 40GB | 40GB | ~340K | ~1M | **3×** | +| A100 80GB | 80GB | ~1M | ~2M | **2×** | + +### 性能权衡 + +| chunk_size | 峰值内存降低 | 预期性能开销 | 推荐场景 | +|-----------|-------------|-------------|---------| +| 16K | 75% | +5-10% | 64K-128K 序列 | +| 8K | 87% | +10-20% | 128K-256K 序列 | +| 4K | 94% | +20-50% | 256K-512K 序列 | + +--- + +## Chunked Prefill 方案 + +### 核心思想 + +不仅 MLP 是逐位置独立的,**在保持因果关系的前提下,Attention 也可以分块处理**。 + +### 关键洞察 + +对于序列的 Chunk N(位置范围 `[start, end]`): +$$ +\text{output}[i] = \text{Attention}(\text{query}[i], \text{keys}[0:i], \text{values}[0:i]) +$$ + +通过累积 KV cache,可以让 Chunk N 看到所有之前的 tokens: +$$ +\text{K}_{\text{full}} = \text{cat}([\text{K}_0, \text{K}_1, ..., \text{K}_N]) +$$ + +### Chunk 独立性验证 + +**关键问题**:Chunk N 的输出是否是最终结果? + +**答案**:✅ 是的!每个 chunk 计算完成后的输出就是该 chunk 的最终结果。 + +**原因**: +1. Chunk N 的 tokens 能看到所有之前的 tokens(通过累积 KV) +2. 因果关系通过 `causal=True` 保持 +3. 不同 chunks 之间不需要额外的计算 + +**数学等价性**: +``` +原始实现(一次性): +output[i] = Layer(attention(hidden[i], hidden[0:i])) + +Chunked 实现: +output[i] = Layer(attention(hidden[i], cat([KV[0:N]]))) + +因为 cat([KV[0:N]]) == hidden[0:i](因果范围内) +所以两者完全相同! +``` + +### 实现伪代码 + +```python +def run_layerwise_offload_chunked_prefill(layer, hidden_states, chunk_size=16384): + """ + Chunked prefill for a single layer. + + Args: + layer: Transformer layer + hidden_states: [seq_len, hidden_size] + chunk_size: chunk size for processing + + Returns: + output: [seq_len, hidden_size] + """ + seq_len = len(hidden_states) + + # 初始化累积 KV cache + k_accumulated = [] + v_accumulated = [] + + # 预分配输出 buffer(避免 torch.cat) + output_buffer = torch.empty_like(hidden_states) + + # Chunked prefill + for start in range(0, seq_len, chunk_size): + end = min(start + chunk_size, seq_len) + + # 1. 获取当前 chunk + chunk_hidden = hidden_states[start:end] + chunk_residual = residual[start:end] if residual is not None else None + + # 2. Input LayerNorm + chunk_hidden_ln, chunk_residual = layer.input_layernorm(chunk_hidden, chunk_residual) + + # 3. QKV Projection(只对 chunk) + chunk_qkv = layer.qkv_proj(chunk_hidden_ln) # [chunk_size, 6144] + chunk_q, chunk_k, chunk_v = chunk_qkv.split([...], dim=-1) + + # 4. RoPE(只对 chunk) + positions_chunk = positions[start:end] + chunk_q_rot, chunk_k_rot = layer.rotary_emb(positions_chunk, chunk_q, chunk_k) + + # 5. 累积 KV cache + k_accumulated.append(chunk_k_rot) + v_accumulated.append(chunk_v) + k_full = torch.cat(k_accumulated, dim=0) # 累积的全部 K + v_full = torch.cat(v_accumulated, dim=0) # 累积的全部 V + + # 6. FlashAttention(chunk Q vs 累积 K,V) + chunk_attn_out = flash_attn( + chunk_q_rot, # Query: 当前 chunk + k_full, # Key: 累积的全部 + v_full, # Value: 累积的全部 + causal=True + ) + + # 7. O Projection + Residual + chunk_o_proj = layer.o_proj(chunk_attn_out) + chunk_hidden = chunk_o_proj + chunk_residual + + # 8. Post-Attention LayerNorm + chunk_hidden_ln, chunk_residual = layer.post_attention_layernorm(chunk_hidden) + + # 9. MLP(只对 chunk) + chunk_output = layer.mlp(chunk_hidden_ln) + + # 10. 直接写入输出 buffer + output_buffer[start:end] = chunk_output + + # 11. Offload 完整的 K, V to CPU + k_full = torch.cat(k_accumulated, dim=0) + v_full = torch.cat(v_accumulated, dim=0) + offload_kv_to_cpu(layer.layer_id, k_full, v_full) + + return output_buffer +``` + +### 内存效果对比(Llama-3.1-8B, seq_len=1M) + +| 组件 | 原始 | Chunked (16K) | 降低 | +|-----|------|---------------|------| +| hidden_states | 8 GB | 8 GB | 0% (需要完整存储) | +| qkv | 12 GB | 192 MB | **98.4%** ✅ | +| q_rotated | 8 GB | 128 MB | **98.4%** ✅ | +| k_rotated | 2 GB | 32 MB | **98.4%** ✅ | +| attn_out | 8 GB | 128 MB | **98.4%** ✅ | +| gate_up | 56 GB | 875 MB | **98.4%** ✅ | +| KV 累积 | - | 4 GB | 必须的 | +| **总峰值** | **102 GB** | **21 GB** | **79%** ✅ | + +### 不同序列长度的内存效果 + +| seq_len | Layerwise Offload | +Chunked MLP | **+Chunked Prefill** | 提升 vs Offload | +|---------|------------------|-------------|-------------------|-----------------| +| 32K | 20 GB | 19 GB | **17 GB** | 15% | +| 64K | 25 GB | 22 GB | **17 GB** | 32% | +| 128K | 32 GB | 27 GB | **18 GB** | 44% | +| 256K | 47 GB | 38 GB | **19 GB** | 60% | +| 512K | 77 GB | 56 GB | **22 GB** | 71% | +| **1M** | **137 GB** | **98 GB** | **25 GB** | **82%** ✅ | + +### GPU 内存公式推导 + +**Layerwise Offload**: +$$ +\text{GPU}_{\text{offload}} = \text{model\_weights} + \text{seq\_len} \times (6 \times H + Q + 2KV) + \text{seq\_len} \times 2I +$$ + +**Layerwise + Chunked MLP**: +$$ +\text{GPU}_{\text{chunked\_mlp}} = \text{model\_weights} + \text{seq\_len} \times (6H + Q + 2KV) + \text{chunk\_size} \times 2I +$$ + +**Layerwise + Chunked Prefill**: +$$ +\text{GPU}_{\text{chunked\_prefill}} = \text{model\_weights} + \text{seq\_len} \times 2H + \text{seq\_len} \times 2KV + \text{chunk\_size} \times (Q + KV + 2I) +$$ + +其中: +- $H = \text{hidden\_size}$ (4096) +- $I = \text{intermediate\_size}$ (14336) +- $Q = \text{q\_size}$ (4096) +- $KV = \text{kv\_size}$ (1024) + +**Llama-3.1-8B 具体数值**: +$$ +\text{GPU}_{\text{chunked\_prefill}} = 16\text{ GB} + \text{seq\_len} \times 12\text{ KB} + \text{chunk\_size} \times 30\text{ KB} +$$ + +### 24GB GPU 可运行的最大序列 + +| 方案 | 最大序列长度 | 提升倍数 | +|-----|-------------|---------| +| Layerwise Offload | ~32K | 1× | +| +Chunked MLP | ~100K | 3× | +| **+Chunked Prefill (chunk=16K)** | **~1M** | **32×** ✅ | +| **+Chunked Prefill (chunk=4K)** | **~2M** | **64×** ✅ | + +--- + +## Streaming Chunked Prefill(最优方案) + +### 关键洞察:层间独立性 + +**之前的误解**:认为 Layer N+1 需要 Layer N 的完整 `hidden_states` 输出。 + +**正确的理解**:如果每一层都是 chunked 的,那么 Layer N+1 只需要当前 chunk 的输入! + +### 为什么 Streaming 方式可行? + +**核心原因**:MLP 和 Attention 的计算都是**逐位置独立**的。 + +$$ +\text{output}[i] = \text{Layer}(\text{input}[i]) +$$ + +这意味着: +1. Chunk 0 在所有层的计算完成后,其结果就是最终结果 +2. Chunk 1 在所有层的计算完成后,其结果也是最终结果 +3. 不同 chunks 之间**完全独立** + +### 两种 Chunked Prefill 对比 + +#### 原始理解(错误):"层内 Chunked" + +```python +# 错误的方式:处理完一层所有 chunks 后再处理下一层 +for layer_id in range(num_layers): + output_buffer = torch.empty([seq_len, 4096]) # ← 需要 8GB! + + for chunk_idx in range(num_chunks): + chunk = input_hidden[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] + output_buffer[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] = process_chunk(layer_id, chunk) + + hidden_states = output_buffer # 传递给下一层 +``` + +**问题**: +- ❌ 需要存储完整的 `output_buffer`(8GB @ 1M) +- ❌ GPU 内存仍与 `seq_len` 成正比 +- ❌ 层与层之间需要传递完整的 hidden_states + +#### Streaming Chunked Prefill(正确) + +```python +# 正确的方式:每个 chunk 依次通过所有层 +initial_hidden = embedding(input_ids) # [seq_len, 4096] + +for chunk_idx in range(num_chunks): + # 获取当前 chunk + chunk = initial_hidden[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] # [chunk_size, 4096] + + # 通过所有层处理这个 chunk + for layer_id in range(num_layers): + # 从 CPU 加载该层的累积 KV + k_accumulated, v_accumulated = load_kv_from_cpu(layer_id, end=chunk_idx * chunk_size) + + # 处理当前 chunk + chunk = process_layer_chunk( + layer_id, + chunk, + k_accumulated, + v_accumulated, + start_pos=chunk_idx * chunk_size + ) + + # 将当前 chunk 的 KV offload 到 CPU + offload_chunk_kv_to_cpu(layer_id, chunk_k, chunk_v) + + # chunk 已经通过所有层,可以直接输出或继续处理 + # 不需要存储完整的 output_buffer! +``` + +**优势**: +- ✅ **不需要完整的 output_buffer**(节省 8GB) +- ✅ GPU 内存**独立于 seq_len** +- ✅ 只需要存储当前 chunk 的 hidden_states(125MB @ chunk_size=16K) + +### 内存效果对比(Llama-3.1-8B, seq_len=1M) + +| 组件 | 原始 Chunked Prefill | Streaming Chunked Prefill | 降低 | +|-----|---------------------|-------------------------|------| +| **hidden_states (输出)** | **8 GB** | **125 MB** | **98.4%** ✅ | +| qkv (chunked) | 192 MB | 192 MB | - | +| gate_up (chunked) | 875 MB | 875 MB | - | +| KV 累积(单层) | 4 GB | 4 GB | - | +| 模型权重 | 16 GB | 16 GB | - | +| **总峰值** | **~29 GB** | **~21 GB** | **28%** ✅ | + +### GPU 内存公式 + +**原始 Chunked Prefill**: +$$ +\text{GPU}_{\text{chunked}} = \text{model\_weights} + \text{seq\_len} \times 2H + \text{chunk\_size} \times (Q + KV + 2I) +$$ + +**Streaming Chunked Prefill**: +$$ +\text{GPU}_{\text{streaming}} = \text{model\_weights} + \text{chunk\_size} \times 2H + \text{chunk\_size} \times (Q + KV + 2I) +$$ + +**关键区别**: +- 原始:`seq_len × 2H`(与序列长度成正比) +- Streaming:`chunk_size × 2H`(与 chunk size 成正比) + +**Llama-3.1-8B 具体数值(chunk_size=16K)**: +$$ +\text{GPU}_{\text{streaming}} = 16\text{ GB} + 16\text{K} \times 12\text{ KB} + 16\text{K} \times 30\text{ KB} = 16\text{ GB} + 192\text{ MB} + 480\text{ MB} \approx 17.5\text{ GB} +$$ + +### 不同序列长度的内存效果 + +| seq_len | Layerwise Offload | Chunked Prefill | **Streaming** | Streaming vs Offload | +|---------|------------------|----------------|---------------|---------------------| +| 32K | 20 GB | 17 GB | **17.5 GB** | -12% | +| 64K | 25 GB | 17 GB | **17.5 GB** | -30% | +| 128K | 32 GB | 18 GB | **17.5 GB** | -45% | +| 256K | 47 GB | 19 GB | **17.5 GB** | -63% | +| 512K | 77 GB | 22 GB | **17.5 GB** | -77% | +| **1M** | **137 GB** | **25 GB** | **17.5 GB** | **-87%** ✅ | +| **2M** | **265 GB** | **38 GB** | **17.5 GB** | **-93%** ✅ | +| **4M** | **521 GB** | **70 GB** | **17.5 GB** | **-97%** ✅ | + +### 关键发现 + +**Streaming Chunked Prefill 使得 GPU 内存几乎独立于序列长度!** + +- 32K: 17.5 GB +- 1M: 17.5 GB(**完全相同**) +- 4M: 17.5 GB(**完全相同**) + +唯一增长的是 **CPU KV cache 内存**: +$$ +\text{CPU\_KV} = \text{num\_layers} \times \text{seq\_len} \times \text{kv\_dim} \times 2 \times \text{dtype\_size} +$$ + +### 理论最大序列长度(更新) + +#### 新的限制:CPU 内存 + +| 配置 | CPU 内存 | 最大序列长度 | +|-----|---------|-------------| +| 典型服务器 | 128 GB | ~1M | +| 大内存服务器 | 512 GB | ~4M | +| 超大内存服务器 | 1 TB | ~8M | +| 极限配置 | 2 TB | ~16M | + +#### RoPE 硬性上限仍然是瓶颈 + +$$ +2^{24} = 16,777,216 \approx 16.8\text{M tokens} +$$ + +**结论**: +- **GPU 限制**:已被 Streaming Chunked Prefill 完全消除(17.5 GB 恒定) +- **CPU 限制**:512 GB RAM → ~4M tokens +- **RoPE 限制**:~16.8M tokens(硬性上限) + +### 24GB GPU 可运行的最大序列(更新) + +| 方案 | 最大序列长度 | 提升倍数 | 瓶颈 | +|-----|-------------|---------|------| +| Layerwise Offload | ~32K | 1× | GPU 内存 | +| +Chunked MLP | ~100K | 3× | GPU 内存 | +| +Chunked Prefill | ~1M | 32× | GPU 内存 | +| **+Streaming** | **~4M** | **128×** | **CPU 内存** ✅ | +| **+Streaming (RoPE float64)** | **~16M** | **512×** | **CPU 内存** | + +### Streaming 实现伪代码 + +```python +def run_streaming_chunked_prefill(model, input_ids, chunk_size=16384): + """ + Streaming chunked prefill: 每个 chunk 依次通过所有层。 + + Args: + model: Transformer model + input_ids: [seq_len] + chunk_size: chunk size for processing + + Returns: + output: [seq_len, vocab_size] (logits) + """ + seq_len = len(input_ids) + num_chunks = (seq_len + chunk_size - 1) // chunk_size + + # 1. Embedding(一次性) + initial_hidden = model.model.embed_tokens(input_ids) # [seq_len, 4096] + + # 2. 预分配 CPU KV cache + cpu_kv_cache = [{ + 'k': torch.zeros(seq_len, kv_dim, dtype=dtype), + 'v': torch.zeros(seq_len, kv_dim, dtype=dtype), + } for _ in range(num_layers)] + + # 3. 预分配输出(可选,也可以直接写入文件) + # 对于超长序列,可能不需要存储完整输出 + final_output = torch.zeros(seq_len, vocab_size, dtype=dtype) + + # 4. Streaming: 每个 chunk 依次通过所有层 + for chunk_idx in range(num_chunks): + start = chunk_idx * chunk_size + end = min(start + chunk_size, seq_len) + + # 获取当前 chunk 的初始 hidden states + chunk_hidden = initial_hidden[start:end] # [chunk_size, 4096] + + # 通过所有层处理 + for layer_id in range(num_layers): + layer = model.model.layers[layer_id] + + # 从 CPU 加载之前 chunks 的累积 KV + k_prev = cpu_kv_cache[layer_id]['k'][:start].to('cuda') # [start, kv_dim] + v_prev = cpu_kv_cache[layer_id]['v'][:start].to('cuda') # [start, kv_dim] + + # 处理当前 chunk + chunk_hidden, chunk_k, chunk_v = process_layer_chunk( + layer, + chunk_hidden, + k_prev, + v_prev, + positions=torch.arange(start, end, device='cuda') + ) + + # 将当前 chunk 的 KV offload 到 CPU + cpu_kv_cache[layer_id]['k'][start:end] = chunk_k.cpu() + cpu_kv_cache[layer_id]['v'][start:end] = chunk_v.cpu() + + # Chunk 已经通过所有层,计算 logits 并存储 + chunk_logits = model.lm_head(chunk_hidden) + final_output[start:end] = chunk_logits + + # 可选:释放当前 chunk 的 GPU 内存 + del chunk_hidden, chunk_logits + + return final_output + + +def process_layer_chunk(layer, chunk_hidden, k_prev, v_prev, positions): + """ + 处理单个层的一个 chunk。 + + Args: + layer: Transformer layer + chunk_hidden: [chunk_size, hidden_size] + k_prev: 之前 chunks 的累积 K [prev_len, kv_dim] + v_prev: 之前 chunks 的累积 V [prev_len, kv_dim] + positions: 当前 chunk 的位置 [chunk_size] + + Returns: + chunk_output: [chunk_size, hidden_size] + chunk_k: [chunk_size, kv_dim] + chunk_v: [chunk_size, kv_dim] + """ + # 1. Input LayerNorm + chunk_hidden_ln = layer.input_layernorm(chunk_hidden) + + # 2. QKV Projection + chunk_qkv = layer.qkv_proj(chunk_hidden_ln) + chunk_q, chunk_k, chunk_v = chunk_qkv.split([q_size, kv_dim, kv_dim], dim=-1) + + # 3. RoPE + chunk_q, chunk_k = layer.rotary_emb(positions, chunk_q, chunk_k) + + # 4. 累积 KV + k_full = torch.cat([k_prev, chunk_k], dim=0) # [prev_len + chunk_size, kv_dim] + v_full = torch.cat([v_prev, chunk_v], dim=0) + + # 5. FlashAttention(chunk Q vs 累积 K,V) + chunk_attn_out = flash_attn( + chunk_q, # [chunk_size, q_size] + k_full, # [prev_len + chunk_size, kv_dim] + v_full, # [prev_len + chunk_size, kv_dim] + causal=True + ) + + # 6. O Projection + Residual + chunk_hidden = layer.o_proj(chunk_attn_out) + chunk_hidden + + # 7. Post-Attention LayerNorm + chunk_hidden_ln = layer.post_attention_layernorm(chunk_hidden) + + # 8. MLP + chunk_output = layer.mlp(chunk_hidden_ln) + chunk_hidden + + return chunk_output, chunk_k, chunk_v +``` + +### 性能权衡 + +| chunk_size | GPU 内存 | CPU-GPU 传输 | 预期性能 | 适用场景 | +|-----------|---------|-------------|---------|---------| +| 32K | 18.5 GB | 32K × 32 = 1M/层 | 最快 | ≤ 32K | +| 16K | 17.5 GB | 16K × 32 = 512K/层 | +10% | 32K - 1M | +| 8K | 17.0 GB | 8K × 32 = 256K/层 | +25% | 1M - 4M | +| 4K | 16.8 GB | 4K × 32 = 128K/层 | +50% | > 4M | + +**CPU-GPU 传输开销**: +- 每个 chunk 需要从 CPU 加载之前的所有 KV +- 传输量:`chunk_idx × chunk_size × kv_dim × 2 × dtype_size` +- Chunk 0: 0 KB(没有之前的 KV) +- Chunk 1: 16K × 1K × 2 × 2 = 64 MB +- Chunk 63: 63 × 64 MB = 4 GB + +--- + +## Layerwise Offload + Chunked Prefill 组合 + +> **注意**:本节描述的是"层内 Chunked Prefill"方案。如果采用 **Streaming Chunked Prefill**,请参考上一节。 + +### 为什么需要组合方案? + +**Chunked MLP 的局限性**: +- ✅ 解决了 MLP 瓶颈(98% 降低) +- ❌ 但 Attention 仍是瓶颈(30GB @ 1M) + +**Chunked Prefill 的优势**: +- ✅ 同时解决 MLP 和 Attention 瓶颈 +- ✅ 让 24GB GPU 运行 1M 序列 + +### 组合方案架构 + +``` +输入: input_ids [seq_len] + ↓ +Embedding: hidden_states [seq_len, hidden_size] + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ Layer 0 (Chunked Prefill) │ +│ for chunk in chunks(hidden_states, chunk_size): │ +│ - QKV Projection (chunk only) │ +│ - FlashAttention (chunk Q vs accumulated K,V) │ +│ - MLP (chunk only) │ +│ Offload KV to CPU │ +└─────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ Layer 1 (Chunked Prefill) │ +│ for chunk in chunks(hidden_states, chunk_size): │ +│ - ... (same pattern) │ +│ Offload KV to CPU │ +└─────────────────────────────────────────────────────────────┘ + ↓ +... (repeat for all layers) + ↓ +Final LayerNorm + LM Head +``` + +### 完整内存 Breakdown(seq_len=1M, chunk_size=16K) + +| 组件 | 内存 | 占比 | 能否 chunk? | +|-----|------|------|------------| +| 模型权重 | 16 GB | 64% | ❌ 固定 | +| **最终输出 (hidden_states)** | **8 GB** | **32%** | ❌ **层间数据流** | +| KV 累积 | 4 GB | 16% | ❌ 必须累积 | +| qkv (chunked) | 192 MB | 0.8% | ✅ 已 chunked | +| gate_up (chunked) | 875 MB | 3.5% | ✅ 已 chunked | +| 其他临时张量 | 169 MB | 0.7% | ✅ 已 chunked | +| **总计** | **~25 GB** | 100% | | + +### 性能权衡分析 + +**Kernel 调用次数**: +``` +原始: num_layers × 1 = 32 +Chunked (16K): num_layers × (seq_len / chunk_size) = 32 × 64 = 2048 +增加: +6300% +``` + +**预期性能开销**: + +| chunk_size | 内存节省 | 预期时间开销 | 适用场景 | +|-----------|---------|------------|---------| +| 32K | 15-24% | +5-10% | 64K-128K | +| 16K | 24-46% | +10-20% | 128K-512K | +| 8K | 34-57% | +20-40% | 512K-1M | +| 4K | 46-78% | +40-80% | >1M | + +### 为什么最终输出不能 Chunk? + +**原因**:`hidden_states` 是层与层之间的数据流 + +```python +# Layer N 的输出 +hidden_states: [seq_len, hidden_size] + +# Layer N+1 的输入 +# 需要 Layer N 的完整输出! +qkv = qkv_proj(hidden_states) # 需要完整的 hidden_states +``` + +**突破方法**: +1. **模型并行**:分布 hidden_states 到多个 GPU +2. **流水线并行**:不同层在不同 GPU +3. **降低 hidden_size**:使用更小的模型 + +--- + +## 实现指南 + +### 配置参数 + +```python +@dataclass +class Config: + # Chunked prefill 配置 + prefill_chunk_size: int = 0 # 0 = 不 chunk, >0 = chunk 大小 + + # 自动选择策略 + def get_chunk_size(self, seq_len: int) -> int: + if self.prefill_chunk_size > 0: + return self.prefill_chunk_size + + # 自动选择 + if seq_len < 32000: + return 0 # 不需要 chunk + elif seq_len < 128000: + return 16384 + elif seq_len < 512000: + return 8192 + else: + return 4096 +``` + +### 实现步骤 + +**第一步:验证 Chunked MLP**(~50 行代码) + +```python +# 在 model_runner.py 的 run_layerwise_offload_prefill 中 +# 将 MLP 调用改为 chunked + +chunk_size = self.config.get_chunk_size(len(hidden_states)) +if chunk_size > 0: + hidden_states = chunked_mlp_forward(layer.mlp, hidden_states, chunk_size) +else: + hidden_states = layer.mlp(hidden_states) +``` + +**第二步:实现 Chunked Prefill**(~150 行代码) + +```python +def run_layerwise_offload_chunked_prefill(self, seqs): + chunk_size = self.config.get_chunk_size(total_tokens) + + hidden_states = self.model.model.embed_tokens(input_ids) + + for layer_id in range(num_layers): + if chunk_size == 0: + # 原始实现 + hidden_states = self._process_layer_full(layer, hidden_states) + else: + # Chunked 实现 + hidden_states = self._process_layer_chunked( + layer, hidden_states, chunk_size + ) + + # Offload KV to CPU + ... + + return hidden_states +``` + +### 优化建议 + +1. **预分配输出 buffer**: + ```python + output = torch.empty_like(hidden_states) + for chunk in chunks: + output[start:end] = process_chunk(chunk) + ``` + +2. **复用 KV 累积 buffer**: + ```python + # 预分配最大可能的 KV buffer + k_buffer = torch.zeros(seq_len, kv_dim, ...) + v_buffer = torch.zeros(seq_len, kv_dim, ...) + + # 每次 copy 到对应位置 + k_buffer[current_pos:end] = chunk_k + ``` + +3. **异步 H2D 传输**: + ```python + # 在处理当前 chunk 时,异步加载下一层需要的 KV + # (需要流式处理支持) + ``` + +### 测试验证 + +```python +# 验证正确性 +def test_chunked_prefill_correctness(): + model = LLM("path/to/model") + + # 原始实现 + output_original = model.generate(prompt, enable_cpu_offload=True, prefill_chunk_size=0) + + # Chunked 实现 + output_chunked = model.generate(prompt, enable_cpu_offload=True, prefill_chunk_size=16384) + + # 验证结果相同 + assert torch.allclose(output_original, output_chunked, atol=1e-3) + +# 验证内存 +def test_chunked_prefill_memory(): + torch.cuda.reset_peak_memory_stats() + + model = LLM("path/to/model", enable_cpu_offload=True, prefill_chunk_size=16384) + model.generate(prompt_1M) + + peak = torch.cuda.max_memory_allocated() + assert peak < 24 * 1024**3 # 24GB +``` + +--- + +## 总结 + +### 关键发现 + +1. **理论最大序列**:~16.8M tokens(RoPE float32 限制) + +2. **主要瓶颈演进**(按优化顺序): + - RoPE 精度(最严格,硬性上限) + - MLP gate_up 激活值(原始瓶颈) + - Attention 中间张量(第二瓶颈) + - 最终输出 hidden_states(第三瓶颈,已被 Streaming 消除) + +3. **Streaming Chunked Prefill(最终方案)效果**: + - **GPU 内存独立于 seq_len**:17.5 GB 恒定 + - 内存降低:87%+(相比 layerwise offload) + - 序列扩展:128×(32K → 4M,受 CPU 内存限制) + - 性能开销:10-25%(取决于 chunk_size) + +### 方案对比总结 + +| 方案 | 最大序列(24GB GPU) | GPU 内存 | 瓶颈 | 推荐度 | +|-----|---------------------|---------|------|-------| +| Layerwise Offload | ~32K | 20-137 GB | GPU 内存 | ⭐⭐⭐⭐ | +| +Chunked MLP | ~100K | 19-98 GB | GPU 内存 | ⭐⭐⭐ | +| +Chunked Prefill(层内) | ~1M | 17-25 GB | GPU 内存 | ⭐⭐⭐⭐ | +| **+Streaming(最终)** | **~4M** | **17.5 GB(恒定)** | **CPU 内存** | **⭐⭐⭐⭐⭐** | + +### 实现优先级(更新) + +| 阶段 | 方案 | 效果 | 优先级 | +|-----|------|------|-------| +| 已实现 | Layerwise Offload | 32K | P0 ✅ | +| 第一步 | +Chunked MLP | ~100K | P2(可跳过) | +| 第二步 | +Chunked Prefill(层内) | ~1M | P1 | +| **最终目标** | **+Streaming** | **~4M** | **P0** ✅ | + +### Streaming Chunked Prefill 关键优势 + +- ✅ **GPU 内存恒定**:17.5 GB,与 seq_len 无关 +- ✅ **突破 GPU 限制**:最大序列由 CPU 内存决定 +- ✅ **实现复杂度可控**:~200 行代码 +- ✅ **与 layerwise offload 完美集成** +- ⚠️ CPU-GPU 传输增加:每个 chunk 需加载之前所有 KV +- ⚠️ 性能开销:10-25%(可接受的权衡) + +### 最终推荐 + +**对于超长序列推理(> 128K)**: +1. ✅ 使用 **Streaming Chunked Prefill**(最优方案) +2. ✅ 设置 `prefill_chunk_size = 16384`(自动选择) +3. ✅ 设置 `num_kv_buffers = 1`(降低 ring buffer 内存) +4. ✅ 确保充足 CPU 内存(512GB → 4M tokens) + +**效果**: +- RTX 3090 (24GB): 32K → **~4M** (128× 提升) ✅ +- A100 40GB: 340K → **~4M** (12× 提升) ✅ +- A100 80GB: 1M → **~4M** (4× 提升) ✅ + +**唯一限制**: +- CPU 内存:512 GB → ~4M tokens +- RoPE 硬性上限:~16.8M tokens(需改 float64)