# Chunked Prefill 与长序列推理内存分析 本文档详细分析了 nano-vllm 的 layerwise kvcache offload 策略的理论最大序列长度,以及通过 chunked prefill 降低 GPU 峰值内存的方案。 ## 目录 1. [问题背景](#问题背景) 2. [理论最大序列长度分析](#理论最大序列长度分析) 3. [MLP Activation 内存瓶颈](#mlp-activation-内存瓶颈) 4. [Chunked MLP 方案](#chunked-mlp-方案) 5. [Chunked Prefill 方案](#chunked-prefill-方案) 6. [**Streaming Chunked Prefill(最优方案)**](#streaming-chunked-prefill最优方案) 7. [Layerwise Offload + Chunked Prefill 组合](#layerwise-offload--chunked-prefill-组合) 8. [实现指南](#实现指南) --- ## 问题背景 ### 为什么需要长序列推理? 现代大语言模型的应用场景对上下文长度提出了越来越高的要求: - **文档分析**:需要处理完整的书籍、报告 - **代码理解**:大型代码库的上下文 - **多轮对话**:保持长期对话历史 - **RAG 应用**:检索增强生成需要大量上下文 ### GPU 内存限制 RTX 3090/4090 只有 24GB 显存,A100 有 40GB/80GB,而长序列推理的内存需求呈线性增长: $$ \text{Memory} \propto \text{seq\_len} $$ ### Layerwise Offload 的局限性 虽然 layerwise offload 将 KV cache 移到 CPU,但仍有其他组件占用大量 GPU 内存: - **模型权重**:固定 ~16GB(Llama-3.1-8B) - **MLP 激活值**:与序列长度成正比 - **Attention 中间张量**:与序列长度成正比 --- ## 理论最大序列长度分析 ### 硬性限制排序 #### 1. RoPE float32 精度限制(最严格) 从 `rotary_embedding.py:30`: ```python t = torch.arange(max_position_embeddings, dtype=torch.float) ``` **问题**:`torch.arange` 使用 `float32`,精确整数表示上限为: $$ 2^{24} = 16,777,216 \approx 16.8\text{M tokens} $$ **突破方法**:修改 `rotary_embedding.py` 使用 `float64` 或 `int64` #### 2. GPU MLP gate_up 激活值(实际瓶颈) **公式**: $$ \text{MLP\_gate\_up} = \text{seq\_len} \times 2 \times \text{intermediate\_size} \times \text{dtype\_size} $$ **Llama-3.1-8B 示例**: ``` intermediate_size = 14336 dtype_size = 2 (bfloat16) MLP_gate_up = seq_len × 2 × 14336 × 2 = seq_len × 57344 bytes = seq_len × 56 KB ``` | seq_len | MLP gate_up 内存 | |---------|-----------------| | 64K | 3.5 GB | | 1M | 57 GB | | 16M | 912 GB | #### 3. GPU Ring Buffer 内存 **公式**: $$ \text{Ring\_Buffer} = 2 \times \text{num\_kv\_buffers} \times \text{seq\_len} \times \text{kv\_dim} \times \text{dtype\_size} $$ **优化**:设置 `num_kv_buffers=1` 可降低内存 #### 4. CPU KV Cache 内存 **公式**: $$ \text{CPU\_KV} = \text{num\_layers} \times \text{seq\_len} \times \text{kv\_dim} \times 2 \times \text{dtype\_size} $$ **Llama-3.1-8B** (32 layers, kv_dim=1024): ``` CPU_KV = 32 × seq_len × 1024 × 2 × 2 = seq_len × 131072 bytes = seq_len × 128 KB ``` | seq_len | CPU 内存需求 | |---------|-------------| | 1M | 131 GB | | 10M | 1.31 TB | | 16M | 2.1 TB | ### 不同 GPU 配置的最大序列长度 **GPU 内存公式**: $$ \text{GPU\_mem} = \text{model\_weights} + \text{MLP\_gate\_up} + \text{Ring\_Buffer} + \text{overhead} $$ **Llama-3.1-8B 在不同 GPU 上**: | GPU | 显存 | 模型权重 | 可用 | 最大序列(原始) | 最大序列(chunked) | |-----|------|---------|------|----------------|-------------------| | RTX 3090 | 24GB | 16 GB | 3 GB | ~32K | **~1M** | | A100 40GB | 40GB | 16 GB | 19 GB | ~340K | **~2M** | | A100 80GB | 80GB | 16 GB | 59 GB | ~1M | **~4M** | ### 最终理论最大值 **硬性上限**:**~16.8M tokens**(受 RoPE float32 精度限制) **实用最大值**(单卡):**1-2M tokens**(受 MLP gate_up GPU 内存限制) --- ## MLP Activation 内存瓶颈 ### MLP 计算流程 从 `llama.py:82-106`,LlamaMLP 使用 **SwiGLU** 激活: ```python def forward(self, x): # 输入: x [seq_len, hidden_size] gate_up = self.gate_up_proj(x) # [seq_len, 2 * intermediate_size] gate, up = gate_up.chunk(2, dim=-1) # 各 [seq_len, intermediate_size] gate = F.silu(gate) # SiLU 激活 output = gate * up # 逐元素相乘 return self.down_proj(output) # [seq_len, hidden_size] ``` ### 为什么 gate_up 这么大? **Llama-3.1-8B 参数**: ``` hidden_size = 4096 intermediate_size = 14336 # 是 hidden_size 的 3.5 倍 ``` **gate_up 张量大小**: ``` shape = [seq_len, 2 * intermediate_size] = [seq_len, 28672] 内存 = seq_len × 28672 × 2 bytes = seq_len × 57344 bytes = seq_len × 56 KB ``` ### SwiGLU vs 标准 MLP **标准 Transformer MLP(如 GPT)**: ```python hidden = gelu(x @ W_gate) # [seq_len, 4 * hidden_size] output = hidden @ W_down # [seq_len, hidden_size] ``` **LLaMA 的 SwiGLU MLP**: ```python gate_up = x @ W_gate_up # [seq_len, 2 * intermediate_size] gate, up = gate_up.chunk(2) output = silu(gate) * up # ← 逐元素相乘需要两个独立张量! output = output @ W_down ``` **关键区别**: - SwiGLU 需要 **gate 和 up 两个独立的分支** - 两个分支必须 **同时存在内存中** 才能逐元素相乘 - 这就是内存开销的根本原因 ### MLP Activation 内存公式 $$ \text{MLP\_activation} = \text{seq\_len} \times 2 \times \text{intermediate\_size} \times \text{dtype\_size} $$ | 模型 | hidden_size | intermediate_size | 扩展比 | 64K 激活内存 | |-----|-------------|-------------------|-------|-------------| | Llama-3.1-8B | 4096 | 14336 | 3.5× | 3.5 GB | | Llama-2-7B | 4096 | 11008 | 2.69× | 2.7 GB | | Qwen3-4B | 2560 | 13696 | 5.35× | 3.5 GB | | Qwen2-7B | 4096 | 13696 | 3.34× | 3.5 GB | ### 为什么 MLP Activation 不能 Offload? 与 KV cache 不同,MLP activation 是 **临时计算结果**: 1. **计算依赖**:`gate_up` → `chunk` → `silu` → `mul` → `down_proj` 2. **生命周期**:用完即弃,不需要长期保存 3. **传输开销**:CPU-GPU 传输会完全抵消 offload 的好处 如果 offload: ```python gate_up = gate_up_proj(x) # GPU 计算 gate_up = gate_up.to("cpu") # 传输到 CPU ← 慢! gate, up = gate_up.chunk(2) # CPU 操作 gate = silu(gate) output = gate * up output = output.to("cuda") # 传输回 GPU ← 慢! output = down_proj(output) # GPU 计算 ``` **结论**:MLP activation 的高内存开销是 Transformer SwiGLU 架构的固有问题。 --- ## Chunked MLP 方案 ### 核心思想 MLP 是 **逐位置独立** 的计算: $$ \text{output}[i] = \text{MLP}(\text{input}[i]) $$ 这意味着可以将序列分成多个 chunks,每个 chunk 独立计算。 ### 实现方式 ```python def chunked_mlp_forward(mlp, hidden_states, chunk_size): """ Chunked MLP forward to reduce peak memory usage. Args: mlp: LlamaMLP module hidden_states: [seq_len, hidden_size] chunk_size: number of tokens to process per chunk Returns: output: [seq_len, hidden_size] """ if chunk_size <= 0 or len(hidden_states) <= chunk_size: return mlp(hidden_states) # 预分配输出 buffer output = torch.empty_like(hidden_states) # Chunked processing for i in range(0, len(hidden_states), chunk_size): chunk_end = min(i + chunk_size, len(hidden_states)) chunk = hidden_states[i:chunk_end] output[i:chunk_end] = mlp(chunk) return output ``` ### 内存效果对比(Llama-3.1-8B, seq_len=1M) | 组件 | 原始 | Chunked (chunk_size=16K) | 降低 | |-----|------|-------------------------|------| | gate_up | 56 GB | 875 MB | **98.4%** ✅ | | activated | 28 GB | 437 MB | 98.4% ✅ | | mlp_out | 8 GB | 125 MB | 98.4% ✅ | ### GPU 内存效果 | GPU | 显存 | 原始最大序列 | Chunked (16K) | 提升 | |-----|------|------------|---------------|------| | RTX 3090 | 24GB | ~32K | ~200K | **6.25×** | | A100 40GB | 40GB | ~340K | ~1M | **3×** | | A100 80GB | 80GB | ~1M | ~2M | **2×** | ### 性能权衡 | chunk_size | 峰值内存降低 | 预期性能开销 | 推荐场景 | |-----------|-------------|-------------|---------| | 16K | 75% | +5-10% | 64K-128K 序列 | | 8K | 87% | +10-20% | 128K-256K 序列 | | 4K | 94% | +20-50% | 256K-512K 序列 | --- ## Chunked Prefill 方案 ### 核心思想 不仅 MLP 是逐位置独立的,**在保持因果关系的前提下,Attention 也可以分块处理**。 ### 关键洞察 对于序列的 Chunk N(位置范围 `[start, end]`): $$ \text{output}[i] = \text{Attention}(\text{query}[i], \text{keys}[0:i], \text{values}[0:i]) $$ 通过累积 KV cache,可以让 Chunk N 看到所有之前的 tokens: $$ \text{K}_{\text{full}} = \text{cat}([\text{K}_0, \text{K}_1, ..., \text{K}_N]) $$ ### Chunk 独立性验证 **关键问题**:Chunk N 的输出是否是最终结果? **答案**:✅ 是的!每个 chunk 计算完成后的输出就是该 chunk 的最终结果。 **原因**: 1. Chunk N 的 tokens 能看到所有之前的 tokens(通过累积 KV) 2. 因果关系通过 `causal=True` 保持 3. 不同 chunks 之间不需要额外的计算 **数学等价性**: ``` 原始实现(一次性): output[i] = Layer(attention(hidden[i], hidden[0:i])) Chunked 实现: output[i] = Layer(attention(hidden[i], cat([KV[0:N]]))) 因为 cat([KV[0:N]]) == hidden[0:i](因果范围内) 所以两者完全相同! ``` ### 实现伪代码 ```python def run_layerwise_offload_chunked_prefill(layer, hidden_states, chunk_size=16384): """ Chunked prefill for a single layer. Args: layer: Transformer layer hidden_states: [seq_len, hidden_size] chunk_size: chunk size for processing Returns: output: [seq_len, hidden_size] """ seq_len = len(hidden_states) # 初始化累积 KV cache k_accumulated = [] v_accumulated = [] # 预分配输出 buffer(避免 torch.cat) output_buffer = torch.empty_like(hidden_states) # Chunked prefill for start in range(0, seq_len, chunk_size): end = min(start + chunk_size, seq_len) # 1. 获取当前 chunk chunk_hidden = hidden_states[start:end] chunk_residual = residual[start:end] if residual is not None else None # 2. Input LayerNorm chunk_hidden_ln, chunk_residual = layer.input_layernorm(chunk_hidden, chunk_residual) # 3. QKV Projection(只对 chunk) chunk_qkv = layer.qkv_proj(chunk_hidden_ln) # [chunk_size, 6144] chunk_q, chunk_k, chunk_v = chunk_qkv.split([...], dim=-1) # 4. RoPE(只对 chunk) positions_chunk = positions[start:end] chunk_q_rot, chunk_k_rot = layer.rotary_emb(positions_chunk, chunk_q, chunk_k) # 5. 累积 KV cache k_accumulated.append(chunk_k_rot) v_accumulated.append(chunk_v) k_full = torch.cat(k_accumulated, dim=0) # 累积的全部 K v_full = torch.cat(v_accumulated, dim=0) # 累积的全部 V # 6. FlashAttention(chunk Q vs 累积 K,V) chunk_attn_out = flash_attn( chunk_q_rot, # Query: 当前 chunk k_full, # Key: 累积的全部 v_full, # Value: 累积的全部 causal=True ) # 7. O Projection + Residual chunk_o_proj = layer.o_proj(chunk_attn_out) chunk_hidden = chunk_o_proj + chunk_residual # 8. Post-Attention LayerNorm chunk_hidden_ln, chunk_residual = layer.post_attention_layernorm(chunk_hidden) # 9. MLP(只对 chunk) chunk_output = layer.mlp(chunk_hidden_ln) # 10. 直接写入输出 buffer output_buffer[start:end] = chunk_output # 11. Offload 完整的 K, V to CPU k_full = torch.cat(k_accumulated, dim=0) v_full = torch.cat(v_accumulated, dim=0) offload_kv_to_cpu(layer.layer_id, k_full, v_full) return output_buffer ``` ### 内存效果对比(Llama-3.1-8B, seq_len=1M) | 组件 | 原始 | Chunked (16K) | 降低 | |-----|------|---------------|------| | hidden_states | 8 GB | 8 GB | 0% (需要完整存储) | | qkv | 12 GB | 192 MB | **98.4%** ✅ | | q_rotated | 8 GB | 128 MB | **98.4%** ✅ | | k_rotated | 2 GB | 32 MB | **98.4%** ✅ | | attn_out | 8 GB | 128 MB | **98.4%** ✅ | | gate_up | 56 GB | 875 MB | **98.4%** ✅ | | KV 累积 | - | 4 GB | 必须的 | | **总峰值** | **102 GB** | **21 GB** | **79%** ✅ | ### 不同序列长度的内存效果 | seq_len | Layerwise Offload | +Chunked MLP | **+Chunked Prefill** | 提升 vs Offload | |---------|------------------|-------------|-------------------|-----------------| | 32K | 20 GB | 19 GB | **17 GB** | 15% | | 64K | 25 GB | 22 GB | **17 GB** | 32% | | 128K | 32 GB | 27 GB | **18 GB** | 44% | | 256K | 47 GB | 38 GB | **19 GB** | 60% | | 512K | 77 GB | 56 GB | **22 GB** | 71% | | **1M** | **137 GB** | **98 GB** | **25 GB** | **82%** ✅ | ### GPU 内存公式推导 **Layerwise Offload**: $$ \text{GPU}_{\text{offload}} = \text{model\_weights} + \text{seq\_len} \times (6 \times H + Q + 2KV) + \text{seq\_len} \times 2I $$ **Layerwise + Chunked MLP**: $$ \text{GPU}_{\text{chunked\_mlp}} = \text{model\_weights} + \text{seq\_len} \times (6H + Q + 2KV) + \text{chunk\_size} \times 2I $$ **Layerwise + Chunked Prefill**: $$ \text{GPU}_{\text{chunked\_prefill}} = \text{model\_weights} + \text{seq\_len} \times 2H + \text{seq\_len} \times 2KV + \text{chunk\_size} \times (Q + KV + 2I) $$ 其中: - $H = \text{hidden\_size}$ (4096) - $I = \text{intermediate\_size}$ (14336) - $Q = \text{q\_size}$ (4096) - $KV = \text{kv\_size}$ (1024) **Llama-3.1-8B 具体数值**: $$ \text{GPU}_{\text{chunked\_prefill}} = 16\text{ GB} + \text{seq\_len} \times 12\text{ KB} + \text{chunk\_size} \times 30\text{ KB} $$ ### 24GB GPU 可运行的最大序列 | 方案 | 最大序列长度 | 提升倍数 | |-----|-------------|---------| | Layerwise Offload | ~32K | 1× | | +Chunked MLP | ~100K | 3× | | **+Chunked Prefill (chunk=16K)** | **~1M** | **32×** ✅ | | **+Chunked Prefill (chunk=4K)** | **~2M** | **64×** ✅ | --- ## Streaming Chunked Prefill(最优方案) ### 关键洞察:层间独立性 **之前的误解**:认为 Layer N+1 需要 Layer N 的完整 `hidden_states` 输出。 **正确的理解**:如果每一层都是 chunked 的,那么 Layer N+1 只需要当前 chunk 的输入! ### 为什么 Streaming 方式可行? **核心原因**:MLP 和 Attention 的计算都是**逐位置独立**的。 $$ \text{output}[i] = \text{Layer}(\text{input}[i]) $$ 这意味着: 1. Chunk 0 在所有层的计算完成后,其结果就是最终结果 2. Chunk 1 在所有层的计算完成后,其结果也是最终结果 3. 不同 chunks 之间**完全独立** ### 两种 Chunked Prefill 对比 #### 原始理解(错误):"层内 Chunked" ```python # 错误的方式:处理完一层所有 chunks 后再处理下一层 for layer_id in range(num_layers): output_buffer = torch.empty([seq_len, 4096]) # ← 需要 8GB! for chunk_idx in range(num_chunks): chunk = input_hidden[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] output_buffer[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] = process_chunk(layer_id, chunk) hidden_states = output_buffer # 传递给下一层 ``` **问题**: - ❌ 需要存储完整的 `output_buffer`(8GB @ 1M) - ❌ GPU 内存仍与 `seq_len` 成正比 - ❌ 层与层之间需要传递完整的 hidden_states #### Streaming Chunked Prefill(正确) ```python # 正确的方式:每个 chunk 依次通过所有层 initial_hidden = embedding(input_ids) # [seq_len, 4096] for chunk_idx in range(num_chunks): # 获取当前 chunk chunk = initial_hidden[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] # [chunk_size, 4096] # 通过所有层处理这个 chunk for layer_id in range(num_layers): # 从 CPU 加载该层的累积 KV k_accumulated, v_accumulated = load_kv_from_cpu(layer_id, end=chunk_idx * chunk_size) # 处理当前 chunk chunk = process_layer_chunk( layer_id, chunk, k_accumulated, v_accumulated, start_pos=chunk_idx * chunk_size ) # 将当前 chunk 的 KV offload 到 CPU offload_chunk_kv_to_cpu(layer_id, chunk_k, chunk_v) # chunk 已经通过所有层,可以直接输出或继续处理 # 不需要存储完整的 output_buffer! ``` **优势**: - ✅ **不需要完整的 output_buffer**(节省 8GB) - ✅ GPU 内存**独立于 seq_len** - ✅ 只需要存储当前 chunk 的 hidden_states(125MB @ chunk_size=16K) ### 内存效果对比(Llama-3.1-8B, seq_len=1M) | 组件 | 原始 Chunked Prefill | Streaming Chunked Prefill | 降低 | |-----|---------------------|-------------------------|------| | **hidden_states (输出)** | **8 GB** | **125 MB** | **98.4%** ✅ | | qkv (chunked) | 192 MB | 192 MB | - | | gate_up (chunked) | 875 MB | 875 MB | - | | KV 累积(单层) | 4 GB | 4 GB | - | | 模型权重 | 16 GB | 16 GB | - | | **总峰值** | **~29 GB** | **~21 GB** | **28%** ✅ | ### GPU 内存公式 **原始 Chunked Prefill**: $$ \text{GPU}_{\text{chunked}} = \text{model\_weights} + \text{seq\_len} \times 2H + \text{chunk\_size} \times (Q + KV + 2I) $$ **Streaming Chunked Prefill**: $$ \text{GPU}_{\text{streaming}} = \text{model\_weights} + \text{chunk\_size} \times 2H + \text{chunk\_size} \times (Q + KV + 2I) $$ **关键区别**: - 原始:`seq_len × 2H`(与序列长度成正比) - Streaming:`chunk_size × 2H`(与 chunk size 成正比) **Llama-3.1-8B 具体数值(chunk_size=16K)**: $$ \text{GPU}_{\text{streaming}} = 16\text{ GB} + 16\text{K} \times 12\text{ KB} + 16\text{K} \times 30\text{ KB} = 16\text{ GB} + 192\text{ MB} + 480\text{ MB} \approx 17.5\text{ GB} $$ ### 不同序列长度的内存效果 | seq_len | Layerwise Offload | Chunked Prefill | **Streaming** | Streaming vs Offload | |---------|------------------|----------------|---------------|---------------------| | 32K | 20 GB | 17 GB | **17.5 GB** | -12% | | 64K | 25 GB | 17 GB | **17.5 GB** | -30% | | 128K | 32 GB | 18 GB | **17.5 GB** | -45% | | 256K | 47 GB | 19 GB | **17.5 GB** | -63% | | 512K | 77 GB | 22 GB | **17.5 GB** | -77% | | **1M** | **137 GB** | **25 GB** | **17.5 GB** | **-87%** ✅ | | **2M** | **265 GB** | **38 GB** | **17.5 GB** | **-93%** ✅ | | **4M** | **521 GB** | **70 GB** | **17.5 GB** | **-97%** ✅ | ### 关键发现 **Streaming Chunked Prefill 使得 GPU 内存几乎独立于序列长度!** - 32K: 17.5 GB - 1M: 17.5 GB(**完全相同**) - 4M: 17.5 GB(**完全相同**) 唯一增长的是 **CPU KV cache 内存**: $$ \text{CPU\_KV} = \text{num\_layers} \times \text{seq\_len} \times \text{kv\_dim} \times 2 \times \text{dtype\_size} $$ ### 理论最大序列长度(更新) #### 新的限制:CPU 内存 | 配置 | CPU 内存 | 最大序列长度 | |-----|---------|-------------| | 典型服务器 | 128 GB | ~1M | | 大内存服务器 | 512 GB | ~4M | | 超大内存服务器 | 1 TB | ~8M | | 极限配置 | 2 TB | ~16M | #### RoPE 硬性上限仍然是瓶颈 $$ 2^{24} = 16,777,216 \approx 16.8\text{M tokens} $$ **结论**: - **GPU 限制**:已被 Streaming Chunked Prefill 完全消除(17.5 GB 恒定) - **CPU 限制**:512 GB RAM → ~4M tokens - **RoPE 限制**:~16.8M tokens(硬性上限) ### 24GB GPU 可运行的最大序列(更新) | 方案 | 最大序列长度 | 提升倍数 | 瓶颈 | |-----|-------------|---------|------| | Layerwise Offload | ~32K | 1× | GPU 内存 | | +Chunked MLP | ~100K | 3× | GPU 内存 | | +Chunked Prefill | ~1M | 32× | GPU 内存 | | **+Streaming** | **~4M** | **128×** | **CPU 内存** ✅ | | **+Streaming (RoPE float64)** | **~16M** | **512×** | **CPU 内存** | ### Streaming 实现伪代码 ```python def run_streaming_chunked_prefill(model, input_ids, chunk_size=16384): """ Streaming chunked prefill: 每个 chunk 依次通过所有层。 Args: model: Transformer model input_ids: [seq_len] chunk_size: chunk size for processing Returns: output: [seq_len, vocab_size] (logits) """ seq_len = len(input_ids) num_chunks = (seq_len + chunk_size - 1) // chunk_size # 1. Embedding(一次性) initial_hidden = model.model.embed_tokens(input_ids) # [seq_len, 4096] # 2. 预分配 CPU KV cache cpu_kv_cache = [{ 'k': torch.zeros(seq_len, kv_dim, dtype=dtype), 'v': torch.zeros(seq_len, kv_dim, dtype=dtype), } for _ in range(num_layers)] # 3. 预分配输出(可选,也可以直接写入文件) # 对于超长序列,可能不需要存储完整输出 final_output = torch.zeros(seq_len, vocab_size, dtype=dtype) # 4. Streaming: 每个 chunk 依次通过所有层 for chunk_idx in range(num_chunks): start = chunk_idx * chunk_size end = min(start + chunk_size, seq_len) # 获取当前 chunk 的初始 hidden states chunk_hidden = initial_hidden[start:end] # [chunk_size, 4096] # 通过所有层处理 for layer_id in range(num_layers): layer = model.model.layers[layer_id] # 从 CPU 加载之前 chunks 的累积 KV k_prev = cpu_kv_cache[layer_id]['k'][:start].to('cuda') # [start, kv_dim] v_prev = cpu_kv_cache[layer_id]['v'][:start].to('cuda') # [start, kv_dim] # 处理当前 chunk chunk_hidden, chunk_k, chunk_v = process_layer_chunk( layer, chunk_hidden, k_prev, v_prev, positions=torch.arange(start, end, device='cuda') ) # 将当前 chunk 的 KV offload 到 CPU cpu_kv_cache[layer_id]['k'][start:end] = chunk_k.cpu() cpu_kv_cache[layer_id]['v'][start:end] = chunk_v.cpu() # Chunk 已经通过所有层,计算 logits 并存储 chunk_logits = model.lm_head(chunk_hidden) final_output[start:end] = chunk_logits # 可选:释放当前 chunk 的 GPU 内存 del chunk_hidden, chunk_logits return final_output def process_layer_chunk(layer, chunk_hidden, k_prev, v_prev, positions): """ 处理单个层的一个 chunk。 Args: layer: Transformer layer chunk_hidden: [chunk_size, hidden_size] k_prev: 之前 chunks 的累积 K [prev_len, kv_dim] v_prev: 之前 chunks 的累积 V [prev_len, kv_dim] positions: 当前 chunk 的位置 [chunk_size] Returns: chunk_output: [chunk_size, hidden_size] chunk_k: [chunk_size, kv_dim] chunk_v: [chunk_size, kv_dim] """ # 1. Input LayerNorm chunk_hidden_ln = layer.input_layernorm(chunk_hidden) # 2. QKV Projection chunk_qkv = layer.qkv_proj(chunk_hidden_ln) chunk_q, chunk_k, chunk_v = chunk_qkv.split([q_size, kv_dim, kv_dim], dim=-1) # 3. RoPE chunk_q, chunk_k = layer.rotary_emb(positions, chunk_q, chunk_k) # 4. 累积 KV k_full = torch.cat([k_prev, chunk_k], dim=0) # [prev_len + chunk_size, kv_dim] v_full = torch.cat([v_prev, chunk_v], dim=0) # 5. FlashAttention(chunk Q vs 累积 K,V) chunk_attn_out = flash_attn( chunk_q, # [chunk_size, q_size] k_full, # [prev_len + chunk_size, kv_dim] v_full, # [prev_len + chunk_size, kv_dim] causal=True ) # 6. O Projection + Residual chunk_hidden = layer.o_proj(chunk_attn_out) + chunk_hidden # 7. Post-Attention LayerNorm chunk_hidden_ln = layer.post_attention_layernorm(chunk_hidden) # 8. MLP chunk_output = layer.mlp(chunk_hidden_ln) + chunk_hidden return chunk_output, chunk_k, chunk_v ``` ### 性能权衡 | chunk_size | GPU 内存 | CPU-GPU 传输 | 预期性能 | 适用场景 | |-----------|---------|-------------|---------|---------| | 32K | 18.5 GB | 32K × 32 = 1M/层 | 最快 | ≤ 32K | | 16K | 17.5 GB | 16K × 32 = 512K/层 | +10% | 32K - 1M | | 8K | 17.0 GB | 8K × 32 = 256K/层 | +25% | 1M - 4M | | 4K | 16.8 GB | 4K × 32 = 128K/层 | +50% | > 4M | **CPU-GPU 传输开销**: - 每个 chunk 需要从 CPU 加载之前的所有 KV - 传输量:`chunk_idx × chunk_size × kv_dim × 2 × dtype_size` - Chunk 0: 0 KB(没有之前的 KV) - Chunk 1: 16K × 1K × 2 × 2 = 64 MB - Chunk 63: 63 × 64 MB = 4 GB --- ## Layerwise Offload + Chunked Prefill 组合 > **注意**:本节描述的是"层内 Chunked Prefill"方案。如果采用 **Streaming Chunked Prefill**,请参考上一节。 ### 为什么需要组合方案? **Chunked MLP 的局限性**: - ✅ 解决了 MLP 瓶颈(98% 降低) - ❌ 但 Attention 仍是瓶颈(30GB @ 1M) **Chunked Prefill 的优势**: - ✅ 同时解决 MLP 和 Attention 瓶颈 - ✅ 让 24GB GPU 运行 1M 序列 ### 组合方案架构 ``` 输入: input_ids [seq_len] ↓ Embedding: hidden_states [seq_len, hidden_size] ↓ ┌─────────────────────────────────────────────────────────────┐ │ Layer 0 (Chunked Prefill) │ │ for chunk in chunks(hidden_states, chunk_size): │ │ - QKV Projection (chunk only) │ │ - FlashAttention (chunk Q vs accumulated K,V) │ │ - MLP (chunk only) │ │ Offload KV to CPU │ └─────────────────────────────────────────────────────────────┘ ↓ ┌─────────────────────────────────────────────────────────────┐ │ Layer 1 (Chunked Prefill) │ │ for chunk in chunks(hidden_states, chunk_size): │ │ - ... (same pattern) │ │ Offload KV to CPU │ └─────────────────────────────────────────────────────────────┘ ↓ ... (repeat for all layers) ↓ Final LayerNorm + LM Head ``` ### 完整内存 Breakdown(seq_len=1M, chunk_size=16K) | 组件 | 内存 | 占比 | 能否 chunk? | |-----|------|------|------------| | 模型权重 | 16 GB | 64% | ❌ 固定 | | **最终输出 (hidden_states)** | **8 GB** | **32%** | ❌ **层间数据流** | | KV 累积 | 4 GB | 16% | ❌ 必须累积 | | qkv (chunked) | 192 MB | 0.8% | ✅ 已 chunked | | gate_up (chunked) | 875 MB | 3.5% | ✅ 已 chunked | | 其他临时张量 | 169 MB | 0.7% | ✅ 已 chunked | | **总计** | **~25 GB** | 100% | | ### 性能权衡分析 **Kernel 调用次数**: ``` 原始: num_layers × 1 = 32 Chunked (16K): num_layers × (seq_len / chunk_size) = 32 × 64 = 2048 增加: +6300% ``` **预期性能开销**: | chunk_size | 内存节省 | 预期时间开销 | 适用场景 | |-----------|---------|------------|---------| | 32K | 15-24% | +5-10% | 64K-128K | | 16K | 24-46% | +10-20% | 128K-512K | | 8K | 34-57% | +20-40% | 512K-1M | | 4K | 46-78% | +40-80% | >1M | ### 为什么最终输出不能 Chunk? **原因**:`hidden_states` 是层与层之间的数据流 ```python # Layer N 的输出 hidden_states: [seq_len, hidden_size] # Layer N+1 的输入 # 需要 Layer N 的完整输出! qkv = qkv_proj(hidden_states) # 需要完整的 hidden_states ``` **突破方法**: 1. **模型并行**:分布 hidden_states 到多个 GPU 2. **流水线并行**:不同层在不同 GPU 3. **降低 hidden_size**:使用更小的模型 --- ## 实现指南 ### 配置参数 ```python @dataclass class Config: # Chunked prefill 配置 prefill_chunk_size: int = 0 # 0 = 不 chunk, >0 = chunk 大小 # 自动选择策略 def get_chunk_size(self, seq_len: int) -> int: if self.prefill_chunk_size > 0: return self.prefill_chunk_size # 自动选择 if seq_len < 32000: return 0 # 不需要 chunk elif seq_len < 128000: return 16384 elif seq_len < 512000: return 8192 else: return 4096 ``` ### 实现步骤 **第一步:验证 Chunked MLP**(~50 行代码) ```python # 在 model_runner.py 的 run_layerwise_offload_prefill 中 # 将 MLP 调用改为 chunked chunk_size = self.config.get_chunk_size(len(hidden_states)) if chunk_size > 0: hidden_states = chunked_mlp_forward(layer.mlp, hidden_states, chunk_size) else: hidden_states = layer.mlp(hidden_states) ``` **第二步:实现 Chunked Prefill**(~150 行代码) ```python def run_layerwise_offload_chunked_prefill(self, seqs): chunk_size = self.config.get_chunk_size(total_tokens) hidden_states = self.model.model.embed_tokens(input_ids) for layer_id in range(num_layers): if chunk_size == 0: # 原始实现 hidden_states = self._process_layer_full(layer, hidden_states) else: # Chunked 实现 hidden_states = self._process_layer_chunked( layer, hidden_states, chunk_size ) # Offload KV to CPU ... return hidden_states ``` ### 优化建议 1. **预分配输出 buffer**: ```python output = torch.empty_like(hidden_states) for chunk in chunks: output[start:end] = process_chunk(chunk) ``` 2. **复用 KV 累积 buffer**: ```python # 预分配最大可能的 KV buffer k_buffer = torch.zeros(seq_len, kv_dim, ...) v_buffer = torch.zeros(seq_len, kv_dim, ...) # 每次 copy 到对应位置 k_buffer[current_pos:end] = chunk_k ``` 3. **异步 H2D 传输**: ```python # 在处理当前 chunk 时,异步加载下一层需要的 KV # (需要流式处理支持) ``` ### 测试验证 ```python # 验证正确性 def test_chunked_prefill_correctness(): model = LLM("path/to/model") # 原始实现 output_original = model.generate(prompt, enable_cpu_offload=True, prefill_chunk_size=0) # Chunked 实现 output_chunked = model.generate(prompt, enable_cpu_offload=True, prefill_chunk_size=16384) # 验证结果相同 assert torch.allclose(output_original, output_chunked, atol=1e-3) # 验证内存 def test_chunked_prefill_memory(): torch.cuda.reset_peak_memory_stats() model = LLM("path/to/model", enable_cpu_offload=True, prefill_chunk_size=16384) model.generate(prompt_1M) peak = torch.cuda.max_memory_allocated() assert peak < 24 * 1024**3 # 24GB ``` --- ## 总结 ### 关键发现 1. **理论最大序列**:~16.8M tokens(RoPE float32 限制) 2. **主要瓶颈演进**(按优化顺序): - RoPE 精度(最严格,硬性上限) - MLP gate_up 激活值(原始瓶颈) - Attention 中间张量(第二瓶颈) - 最终输出 hidden_states(第三瓶颈,已被 Streaming 消除) 3. **Streaming Chunked Prefill(最终方案)效果**: - **GPU 内存独立于 seq_len**:17.5 GB 恒定 - 内存降低:87%+(相比 layerwise offload) - 序列扩展:128×(32K → 4M,受 CPU 内存限制) - 性能开销:10-25%(取决于 chunk_size) ### 方案对比总结 | 方案 | 最大序列(24GB GPU) | GPU 内存 | 瓶颈 | 推荐度 | |-----|---------------------|---------|------|-------| | Layerwise Offload | ~32K | 20-137 GB | GPU 内存 | ⭐⭐⭐⭐ | | +Chunked MLP | ~100K | 19-98 GB | GPU 内存 | ⭐⭐⭐ | | +Chunked Prefill(层内) | ~1M | 17-25 GB | GPU 内存 | ⭐⭐⭐⭐ | | **+Streaming(最终)** | **~4M** | **17.5 GB(恒定)** | **CPU 内存** | **⭐⭐⭐⭐⭐** | ### 实现优先级(更新) | 阶段 | 方案 | 效果 | 优先级 | |-----|------|------|-------| | 已实现 | Layerwise Offload | 32K | P0 ✅ | | 第一步 | +Chunked MLP | ~100K | P2(可跳过) | | 第二步 | +Chunked Prefill(层内) | ~1M | P1 | | **最终目标** | **+Streaming** | **~4M** | **P0** ✅ | ### Streaming Chunked Prefill 关键优势 - ✅ **GPU 内存恒定**:17.5 GB,与 seq_len 无关 - ✅ **突破 GPU 限制**:最大序列由 CPU 内存决定 - ✅ **实现复杂度可控**:~200 行代码 - ✅ **与 layerwise offload 完美集成** - ⚠️ CPU-GPU 传输增加:每个 chunk 需加载之前所有 KV - ⚠️ 性能开销:10-25%(可接受的权衡) ### 最终推荐 **对于超长序列推理(> 128K)**: 1. ✅ 使用 **Streaming Chunked Prefill**(最优方案) 2. ✅ 设置 `prefill_chunk_size = 16384`(自动选择) 3. ✅ 设置 `num_kv_buffers = 1`(降低 ring buffer 内存) 4. ✅ 确保充足 CPU 内存(512GB → 4M tokens) **效果**: - RTX 3090 (24GB): 32K → **~4M** (128× 提升) ✅ - A100 40GB: 340K → **~4M** (12× 提升) ✅ - A100 80GB: 1M → **~4M** (4× 提升) ✅ **唯一限制**: - CPU 内存:512 GB → ~4M tokens - RoPE 硬性上限:~16.8M tokens(需改 float64)