Add comprehensive analysis document covering: - MLP activation memory bottlenecks with SwiGLU architecture - Chunked MLP strategy (98% memory reduction) - Chunked prefill for single layers (78% memory reduction) - Streaming Chunked Prefill (最优方案): GPU memory becomes constant - Memory formulas and implementation guidance - Theoretical maximum: 4M tokens on 24GB GPU (128× improvement) Co-Authored-By: Claude <noreply@anthropic.com>
1056 lines
33 KiB
Markdown
1056 lines
33 KiB
Markdown
# Chunked Prefill 与长序列推理内存分析
|
||
|
||
本文档详细分析了 nano-vllm 的 layerwise kvcache offload 策略的理论最大序列长度,以及通过 chunked prefill 降低 GPU 峰值内存的方案。
|
||
|
||
## 目录
|
||
|
||
1. [问题背景](#问题背景)
|
||
2. [理论最大序列长度分析](#理论最大序列长度分析)
|
||
3. [MLP Activation 内存瓶颈](#mlp-activation-内存瓶颈)
|
||
4. [Chunked MLP 方案](#chunked-mlp-方案)
|
||
5. [Chunked Prefill 方案](#chunked-prefill-方案)
|
||
6. [**Streaming Chunked Prefill(最优方案)**](#streaming-chunked-prefill最优方案)
|
||
7. [Layerwise Offload + Chunked Prefill 组合](#layerwise-offload--chunked-prefill-组合)
|
||
8. [实现指南](#实现指南)
|
||
|
||
---
|
||
|
||
## 问题背景
|
||
|
||
### 为什么需要长序列推理?
|
||
|
||
现代大语言模型的应用场景对上下文长度提出了越来越高的要求:
|
||
- **文档分析**:需要处理完整的书籍、报告
|
||
- **代码理解**:大型代码库的上下文
|
||
- **多轮对话**:保持长期对话历史
|
||
- **RAG 应用**:检索增强生成需要大量上下文
|
||
|
||
### GPU 内存限制
|
||
|
||
RTX 3090/4090 只有 24GB 显存,A100 有 40GB/80GB,而长序列推理的内存需求呈线性增长:
|
||
|
||
$$
|
||
\text{Memory} \propto \text{seq\_len}
|
||
$$
|
||
|
||
### Layerwise Offload 的局限性
|
||
|
||
虽然 layerwise offload 将 KV cache 移到 CPU,但仍有其他组件占用大量 GPU 内存:
|
||
- **模型权重**:固定 ~16GB(Llama-3.1-8B)
|
||
- **MLP 激活值**:与序列长度成正比
|
||
- **Attention 中间张量**:与序列长度成正比
|
||
|
||
---
|
||
|
||
## 理论最大序列长度分析
|
||
|
||
### 硬性限制排序
|
||
|
||
#### 1. RoPE float32 精度限制(最严格)
|
||
|
||
从 `rotary_embedding.py:30`:
|
||
```python
|
||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||
```
|
||
|
||
**问题**:`torch.arange` 使用 `float32`,精确整数表示上限为:
|
||
$$
|
||
2^{24} = 16,777,216 \approx 16.8\text{M tokens}
|
||
$$
|
||
|
||
**突破方法**:修改 `rotary_embedding.py` 使用 `float64` 或 `int64`
|
||
|
||
#### 2. GPU MLP gate_up 激活值(实际瓶颈)
|
||
|
||
**公式**:
|
||
$$
|
||
\text{MLP\_gate\_up} = \text{seq\_len} \times 2 \times \text{intermediate\_size} \times \text{dtype\_size}
|
||
$$
|
||
|
||
**Llama-3.1-8B 示例**:
|
||
```
|
||
intermediate_size = 14336
|
||
dtype_size = 2 (bfloat16)
|
||
|
||
MLP_gate_up = seq_len × 2 × 14336 × 2
|
||
= seq_len × 57344 bytes
|
||
= seq_len × 56 KB
|
||
```
|
||
|
||
| seq_len | MLP gate_up 内存 |
|
||
|---------|-----------------|
|
||
| 64K | 3.5 GB |
|
||
| 1M | 57 GB |
|
||
| 16M | 912 GB |
|
||
|
||
#### 3. GPU Ring Buffer 内存
|
||
|
||
**公式**:
|
||
$$
|
||
\text{Ring\_Buffer} = 2 \times \text{num\_kv\_buffers} \times \text{seq\_len} \times \text{kv\_dim} \times \text{dtype\_size}
|
||
$$
|
||
|
||
**优化**:设置 `num_kv_buffers=1` 可降低内存
|
||
|
||
#### 4. CPU KV Cache 内存
|
||
|
||
**公式**:
|
||
$$
|
||
\text{CPU\_KV} = \text{num\_layers} \times \text{seq\_len} \times \text{kv\_dim} \times 2 \times \text{dtype\_size}
|
||
$$
|
||
|
||
**Llama-3.1-8B** (32 layers, kv_dim=1024):
|
||
```
|
||
CPU_KV = 32 × seq_len × 1024 × 2 × 2
|
||
= seq_len × 131072 bytes
|
||
= seq_len × 128 KB
|
||
```
|
||
|
||
| seq_len | CPU 内存需求 |
|
||
|---------|-------------|
|
||
| 1M | 131 GB |
|
||
| 10M | 1.31 TB |
|
||
| 16M | 2.1 TB |
|
||
|
||
### 不同 GPU 配置的最大序列长度
|
||
|
||
**GPU 内存公式**:
|
||
$$
|
||
\text{GPU\_mem} = \text{model\_weights} + \text{MLP\_gate\_up} + \text{Ring\_Buffer} + \text{overhead}
|
||
$$
|
||
|
||
**Llama-3.1-8B 在不同 GPU 上**:
|
||
|
||
| GPU | 显存 | 模型权重 | 可用 | 最大序列(原始) | 最大序列(chunked) |
|
||
|-----|------|---------|------|----------------|-------------------|
|
||
| RTX 3090 | 24GB | 16 GB | 3 GB | ~32K | **~1M** |
|
||
| A100 40GB | 40GB | 16 GB | 19 GB | ~340K | **~2M** |
|
||
| A100 80GB | 80GB | 16 GB | 59 GB | ~1M | **~4M** |
|
||
|
||
### 最终理论最大值
|
||
|
||
**硬性上限**:**~16.8M tokens**(受 RoPE float32 精度限制)
|
||
|
||
**实用最大值**(单卡):**1-2M tokens**(受 MLP gate_up GPU 内存限制)
|
||
|
||
---
|
||
|
||
## MLP Activation 内存瓶颈
|
||
|
||
### MLP 计算流程
|
||
|
||
从 `llama.py:82-106`,LlamaMLP 使用 **SwiGLU** 激活:
|
||
|
||
```python
|
||
def forward(self, x):
|
||
# 输入: x [seq_len, hidden_size]
|
||
gate_up = self.gate_up_proj(x) # [seq_len, 2 * intermediate_size]
|
||
gate, up = gate_up.chunk(2, dim=-1) # 各 [seq_len, intermediate_size]
|
||
gate = F.silu(gate) # SiLU 激活
|
||
output = gate * up # 逐元素相乘
|
||
return self.down_proj(output) # [seq_len, hidden_size]
|
||
```
|
||
|
||
### 为什么 gate_up 这么大?
|
||
|
||
**Llama-3.1-8B 参数**:
|
||
```
|
||
hidden_size = 4096
|
||
intermediate_size = 14336 # 是 hidden_size 的 3.5 倍
|
||
```
|
||
|
||
**gate_up 张量大小**:
|
||
```
|
||
shape = [seq_len, 2 * intermediate_size]
|
||
= [seq_len, 28672]
|
||
|
||
内存 = seq_len × 28672 × 2 bytes
|
||
= seq_len × 57344 bytes
|
||
= seq_len × 56 KB
|
||
```
|
||
|
||
### SwiGLU vs 标准 MLP
|
||
|
||
**标准 Transformer MLP(如 GPT)**:
|
||
```python
|
||
hidden = gelu(x @ W_gate) # [seq_len, 4 * hidden_size]
|
||
output = hidden @ W_down # [seq_len, hidden_size]
|
||
```
|
||
|
||
**LLaMA 的 SwiGLU MLP**:
|
||
```python
|
||
gate_up = x @ W_gate_up # [seq_len, 2 * intermediate_size]
|
||
gate, up = gate_up.chunk(2)
|
||
output = silu(gate) * up # ← 逐元素相乘需要两个独立张量!
|
||
output = output @ W_down
|
||
```
|
||
|
||
**关键区别**:
|
||
- SwiGLU 需要 **gate 和 up 两个独立的分支**
|
||
- 两个分支必须 **同时存在内存中** 才能逐元素相乘
|
||
- 这就是内存开销的根本原因
|
||
|
||
### MLP Activation 内存公式
|
||
|
||
$$
|
||
\text{MLP\_activation} = \text{seq\_len} \times 2 \times \text{intermediate\_size} \times \text{dtype\_size}
|
||
$$
|
||
|
||
| 模型 | hidden_size | intermediate_size | 扩展比 | 64K 激活内存 |
|
||
|-----|-------------|-------------------|-------|-------------|
|
||
| Llama-3.1-8B | 4096 | 14336 | 3.5× | 3.5 GB |
|
||
| Llama-2-7B | 4096 | 11008 | 2.69× | 2.7 GB |
|
||
| Qwen3-4B | 2560 | 13696 | 5.35× | 3.5 GB |
|
||
| Qwen2-7B | 4096 | 13696 | 3.34× | 3.5 GB |
|
||
|
||
### 为什么 MLP Activation 不能 Offload?
|
||
|
||
与 KV cache 不同,MLP activation 是 **临时计算结果**:
|
||
|
||
1. **计算依赖**:`gate_up` → `chunk` → `silu` → `mul` → `down_proj`
|
||
2. **生命周期**:用完即弃,不需要长期保存
|
||
3. **传输开销**:CPU-GPU 传输会完全抵消 offload 的好处
|
||
|
||
如果 offload:
|
||
```python
|
||
gate_up = gate_up_proj(x) # GPU 计算
|
||
gate_up = gate_up.to("cpu") # 传输到 CPU ← 慢!
|
||
gate, up = gate_up.chunk(2) # CPU 操作
|
||
gate = silu(gate)
|
||
output = gate * up
|
||
output = output.to("cuda") # 传输回 GPU ← 慢!
|
||
output = down_proj(output) # GPU 计算
|
||
```
|
||
|
||
**结论**:MLP activation 的高内存开销是 Transformer SwiGLU 架构的固有问题。
|
||
|
||
---
|
||
|
||
## Chunked MLP 方案
|
||
|
||
### 核心思想
|
||
|
||
MLP 是 **逐位置独立** 的计算:
|
||
$$
|
||
\text{output}[i] = \text{MLP}(\text{input}[i])
|
||
$$
|
||
|
||
这意味着可以将序列分成多个 chunks,每个 chunk 独立计算。
|
||
|
||
### 实现方式
|
||
|
||
```python
|
||
def chunked_mlp_forward(mlp, hidden_states, chunk_size):
|
||
"""
|
||
Chunked MLP forward to reduce peak memory usage.
|
||
|
||
Args:
|
||
mlp: LlamaMLP module
|
||
hidden_states: [seq_len, hidden_size]
|
||
chunk_size: number of tokens to process per chunk
|
||
|
||
Returns:
|
||
output: [seq_len, hidden_size]
|
||
"""
|
||
if chunk_size <= 0 or len(hidden_states) <= chunk_size:
|
||
return mlp(hidden_states)
|
||
|
||
# 预分配输出 buffer
|
||
output = torch.empty_like(hidden_states)
|
||
|
||
# Chunked processing
|
||
for i in range(0, len(hidden_states), chunk_size):
|
||
chunk_end = min(i + chunk_size, len(hidden_states))
|
||
chunk = hidden_states[i:chunk_end]
|
||
output[i:chunk_end] = mlp(chunk)
|
||
|
||
return output
|
||
```
|
||
|
||
### 内存效果对比(Llama-3.1-8B, seq_len=1M)
|
||
|
||
| 组件 | 原始 | Chunked (chunk_size=16K) | 降低 |
|
||
|-----|------|-------------------------|------|
|
||
| gate_up | 56 GB | 875 MB | **98.4%** ✅ |
|
||
| activated | 28 GB | 437 MB | 98.4% ✅ |
|
||
| mlp_out | 8 GB | 125 MB | 98.4% ✅ |
|
||
|
||
### GPU 内存效果
|
||
|
||
| GPU | 显存 | 原始最大序列 | Chunked (16K) | 提升 |
|
||
|-----|------|------------|---------------|------|
|
||
| RTX 3090 | 24GB | ~32K | ~200K | **6.25×** |
|
||
| A100 40GB | 40GB | ~340K | ~1M | **3×** |
|
||
| A100 80GB | 80GB | ~1M | ~2M | **2×** |
|
||
|
||
### 性能权衡
|
||
|
||
| chunk_size | 峰值内存降低 | 预期性能开销 | 推荐场景 |
|
||
|-----------|-------------|-------------|---------|
|
||
| 16K | 75% | +5-10% | 64K-128K 序列 |
|
||
| 8K | 87% | +10-20% | 128K-256K 序列 |
|
||
| 4K | 94% | +20-50% | 256K-512K 序列 |
|
||
|
||
---
|
||
|
||
## Chunked Prefill 方案
|
||
|
||
### 核心思想
|
||
|
||
不仅 MLP 是逐位置独立的,**在保持因果关系的前提下,Attention 也可以分块处理**。
|
||
|
||
### 关键洞察
|
||
|
||
对于序列的 Chunk N(位置范围 `[start, end]`):
|
||
$$
|
||
\text{output}[i] = \text{Attention}(\text{query}[i], \text{keys}[0:i], \text{values}[0:i])
|
||
$$
|
||
|
||
通过累积 KV cache,可以让 Chunk N 看到所有之前的 tokens:
|
||
$$
|
||
\text{K}_{\text{full}} = \text{cat}([\text{K}_0, \text{K}_1, ..., \text{K}_N])
|
||
$$
|
||
|
||
### Chunk 独立性验证
|
||
|
||
**关键问题**:Chunk N 的输出是否是最终结果?
|
||
|
||
**答案**:✅ 是的!每个 chunk 计算完成后的输出就是该 chunk 的最终结果。
|
||
|
||
**原因**:
|
||
1. Chunk N 的 tokens 能看到所有之前的 tokens(通过累积 KV)
|
||
2. 因果关系通过 `causal=True` 保持
|
||
3. 不同 chunks 之间不需要额外的计算
|
||
|
||
**数学等价性**:
|
||
```
|
||
原始实现(一次性):
|
||
output[i] = Layer(attention(hidden[i], hidden[0:i]))
|
||
|
||
Chunked 实现:
|
||
output[i] = Layer(attention(hidden[i], cat([KV[0:N]])))
|
||
|
||
因为 cat([KV[0:N]]) == hidden[0:i](因果范围内)
|
||
所以两者完全相同!
|
||
```
|
||
|
||
### 实现伪代码
|
||
|
||
```python
|
||
def run_layerwise_offload_chunked_prefill(layer, hidden_states, chunk_size=16384):
|
||
"""
|
||
Chunked prefill for a single layer.
|
||
|
||
Args:
|
||
layer: Transformer layer
|
||
hidden_states: [seq_len, hidden_size]
|
||
chunk_size: chunk size for processing
|
||
|
||
Returns:
|
||
output: [seq_len, hidden_size]
|
||
"""
|
||
seq_len = len(hidden_states)
|
||
|
||
# 初始化累积 KV cache
|
||
k_accumulated = []
|
||
v_accumulated = []
|
||
|
||
# 预分配输出 buffer(避免 torch.cat)
|
||
output_buffer = torch.empty_like(hidden_states)
|
||
|
||
# Chunked prefill
|
||
for start in range(0, seq_len, chunk_size):
|
||
end = min(start + chunk_size, seq_len)
|
||
|
||
# 1. 获取当前 chunk
|
||
chunk_hidden = hidden_states[start:end]
|
||
chunk_residual = residual[start:end] if residual is not None else None
|
||
|
||
# 2. Input LayerNorm
|
||
chunk_hidden_ln, chunk_residual = layer.input_layernorm(chunk_hidden, chunk_residual)
|
||
|
||
# 3. QKV Projection(只对 chunk)
|
||
chunk_qkv = layer.qkv_proj(chunk_hidden_ln) # [chunk_size, 6144]
|
||
chunk_q, chunk_k, chunk_v = chunk_qkv.split([...], dim=-1)
|
||
|
||
# 4. RoPE(只对 chunk)
|
||
positions_chunk = positions[start:end]
|
||
chunk_q_rot, chunk_k_rot = layer.rotary_emb(positions_chunk, chunk_q, chunk_k)
|
||
|
||
# 5. 累积 KV cache
|
||
k_accumulated.append(chunk_k_rot)
|
||
v_accumulated.append(chunk_v)
|
||
k_full = torch.cat(k_accumulated, dim=0) # 累积的全部 K
|
||
v_full = torch.cat(v_accumulated, dim=0) # 累积的全部 V
|
||
|
||
# 6. FlashAttention(chunk Q vs 累积 K,V)
|
||
chunk_attn_out = flash_attn(
|
||
chunk_q_rot, # Query: 当前 chunk
|
||
k_full, # Key: 累积的全部
|
||
v_full, # Value: 累积的全部
|
||
causal=True
|
||
)
|
||
|
||
# 7. O Projection + Residual
|
||
chunk_o_proj = layer.o_proj(chunk_attn_out)
|
||
chunk_hidden = chunk_o_proj + chunk_residual
|
||
|
||
# 8. Post-Attention LayerNorm
|
||
chunk_hidden_ln, chunk_residual = layer.post_attention_layernorm(chunk_hidden)
|
||
|
||
# 9. MLP(只对 chunk)
|
||
chunk_output = layer.mlp(chunk_hidden_ln)
|
||
|
||
# 10. 直接写入输出 buffer
|
||
output_buffer[start:end] = chunk_output
|
||
|
||
# 11. Offload 完整的 K, V to CPU
|
||
k_full = torch.cat(k_accumulated, dim=0)
|
||
v_full = torch.cat(v_accumulated, dim=0)
|
||
offload_kv_to_cpu(layer.layer_id, k_full, v_full)
|
||
|
||
return output_buffer
|
||
```
|
||
|
||
### 内存效果对比(Llama-3.1-8B, seq_len=1M)
|
||
|
||
| 组件 | 原始 | Chunked (16K) | 降低 |
|
||
|-----|------|---------------|------|
|
||
| hidden_states | 8 GB | 8 GB | 0% (需要完整存储) |
|
||
| qkv | 12 GB | 192 MB | **98.4%** ✅ |
|
||
| q_rotated | 8 GB | 128 MB | **98.4%** ✅ |
|
||
| k_rotated | 2 GB | 32 MB | **98.4%** ✅ |
|
||
| attn_out | 8 GB | 128 MB | **98.4%** ✅ |
|
||
| gate_up | 56 GB | 875 MB | **98.4%** ✅ |
|
||
| KV 累积 | - | 4 GB | 必须的 |
|
||
| **总峰值** | **102 GB** | **21 GB** | **79%** ✅ |
|
||
|
||
### 不同序列长度的内存效果
|
||
|
||
| seq_len | Layerwise Offload | +Chunked MLP | **+Chunked Prefill** | 提升 vs Offload |
|
||
|---------|------------------|-------------|-------------------|-----------------|
|
||
| 32K | 20 GB | 19 GB | **17 GB** | 15% |
|
||
| 64K | 25 GB | 22 GB | **17 GB** | 32% |
|
||
| 128K | 32 GB | 27 GB | **18 GB** | 44% |
|
||
| 256K | 47 GB | 38 GB | **19 GB** | 60% |
|
||
| 512K | 77 GB | 56 GB | **22 GB** | 71% |
|
||
| **1M** | **137 GB** | **98 GB** | **25 GB** | **82%** ✅ |
|
||
|
||
### GPU 内存公式推导
|
||
|
||
**Layerwise Offload**:
|
||
$$
|
||
\text{GPU}_{\text{offload}} = \text{model\_weights} + \text{seq\_len} \times (6 \times H + Q + 2KV) + \text{seq\_len} \times 2I
|
||
$$
|
||
|
||
**Layerwise + Chunked MLP**:
|
||
$$
|
||
\text{GPU}_{\text{chunked\_mlp}} = \text{model\_weights} + \text{seq\_len} \times (6H + Q + 2KV) + \text{chunk\_size} \times 2I
|
||
$$
|
||
|
||
**Layerwise + Chunked Prefill**:
|
||
$$
|
||
\text{GPU}_{\text{chunked\_prefill}} = \text{model\_weights} + \text{seq\_len} \times 2H + \text{seq\_len} \times 2KV + \text{chunk\_size} \times (Q + KV + 2I)
|
||
$$
|
||
|
||
其中:
|
||
- $H = \text{hidden\_size}$ (4096)
|
||
- $I = \text{intermediate\_size}$ (14336)
|
||
- $Q = \text{q\_size}$ (4096)
|
||
- $KV = \text{kv\_size}$ (1024)
|
||
|
||
**Llama-3.1-8B 具体数值**:
|
||
$$
|
||
\text{GPU}_{\text{chunked\_prefill}} = 16\text{ GB} + \text{seq\_len} \times 12\text{ KB} + \text{chunk\_size} \times 30\text{ KB}
|
||
$$
|
||
|
||
### 24GB GPU 可运行的最大序列
|
||
|
||
| 方案 | 最大序列长度 | 提升倍数 |
|
||
|-----|-------------|---------|
|
||
| Layerwise Offload | ~32K | 1× |
|
||
| +Chunked MLP | ~100K | 3× |
|
||
| **+Chunked Prefill (chunk=16K)** | **~1M** | **32×** ✅ |
|
||
| **+Chunked Prefill (chunk=4K)** | **~2M** | **64×** ✅ |
|
||
|
||
---
|
||
|
||
## Streaming Chunked Prefill(最优方案)
|
||
|
||
### 关键洞察:层间独立性
|
||
|
||
**之前的误解**:认为 Layer N+1 需要 Layer N 的完整 `hidden_states` 输出。
|
||
|
||
**正确的理解**:如果每一层都是 chunked 的,那么 Layer N+1 只需要当前 chunk 的输入!
|
||
|
||
### 为什么 Streaming 方式可行?
|
||
|
||
**核心原因**:MLP 和 Attention 的计算都是**逐位置独立**的。
|
||
|
||
$$
|
||
\text{output}[i] = \text{Layer}(\text{input}[i])
|
||
$$
|
||
|
||
这意味着:
|
||
1. Chunk 0 在所有层的计算完成后,其结果就是最终结果
|
||
2. Chunk 1 在所有层的计算完成后,其结果也是最终结果
|
||
3. 不同 chunks 之间**完全独立**
|
||
|
||
### 两种 Chunked Prefill 对比
|
||
|
||
#### 原始理解(错误):"层内 Chunked"
|
||
|
||
```python
|
||
# 错误的方式:处理完一层所有 chunks 后再处理下一层
|
||
for layer_id in range(num_layers):
|
||
output_buffer = torch.empty([seq_len, 4096]) # ← 需要 8GB!
|
||
|
||
for chunk_idx in range(num_chunks):
|
||
chunk = input_hidden[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size]
|
||
output_buffer[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] = process_chunk(layer_id, chunk)
|
||
|
||
hidden_states = output_buffer # 传递给下一层
|
||
```
|
||
|
||
**问题**:
|
||
- ❌ 需要存储完整的 `output_buffer`(8GB @ 1M)
|
||
- ❌ GPU 内存仍与 `seq_len` 成正比
|
||
- ❌ 层与层之间需要传递完整的 hidden_states
|
||
|
||
#### Streaming Chunked Prefill(正确)
|
||
|
||
```python
|
||
# 正确的方式:每个 chunk 依次通过所有层
|
||
initial_hidden = embedding(input_ids) # [seq_len, 4096]
|
||
|
||
for chunk_idx in range(num_chunks):
|
||
# 获取当前 chunk
|
||
chunk = initial_hidden[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] # [chunk_size, 4096]
|
||
|
||
# 通过所有层处理这个 chunk
|
||
for layer_id in range(num_layers):
|
||
# 从 CPU 加载该层的累积 KV
|
||
k_accumulated, v_accumulated = load_kv_from_cpu(layer_id, end=chunk_idx * chunk_size)
|
||
|
||
# 处理当前 chunk
|
||
chunk = process_layer_chunk(
|
||
layer_id,
|
||
chunk,
|
||
k_accumulated,
|
||
v_accumulated,
|
||
start_pos=chunk_idx * chunk_size
|
||
)
|
||
|
||
# 将当前 chunk 的 KV offload 到 CPU
|
||
offload_chunk_kv_to_cpu(layer_id, chunk_k, chunk_v)
|
||
|
||
# chunk 已经通过所有层,可以直接输出或继续处理
|
||
# 不需要存储完整的 output_buffer!
|
||
```
|
||
|
||
**优势**:
|
||
- ✅ **不需要完整的 output_buffer**(节省 8GB)
|
||
- ✅ GPU 内存**独立于 seq_len**
|
||
- ✅ 只需要存储当前 chunk 的 hidden_states(125MB @ chunk_size=16K)
|
||
|
||
### 内存效果对比(Llama-3.1-8B, seq_len=1M)
|
||
|
||
| 组件 | 原始 Chunked Prefill | Streaming Chunked Prefill | 降低 |
|
||
|-----|---------------------|-------------------------|------|
|
||
| **hidden_states (输出)** | **8 GB** | **125 MB** | **98.4%** ✅ |
|
||
| qkv (chunked) | 192 MB | 192 MB | - |
|
||
| gate_up (chunked) | 875 MB | 875 MB | - |
|
||
| KV 累积(单层) | 4 GB | 4 GB | - |
|
||
| 模型权重 | 16 GB | 16 GB | - |
|
||
| **总峰值** | **~29 GB** | **~21 GB** | **28%** ✅ |
|
||
|
||
### GPU 内存公式
|
||
|
||
**原始 Chunked Prefill**:
|
||
$$
|
||
\text{GPU}_{\text{chunked}} = \text{model\_weights} + \text{seq\_len} \times 2H + \text{chunk\_size} \times (Q + KV + 2I)
|
||
$$
|
||
|
||
**Streaming Chunked Prefill**:
|
||
$$
|
||
\text{GPU}_{\text{streaming}} = \text{model\_weights} + \text{chunk\_size} \times 2H + \text{chunk\_size} \times (Q + KV + 2I)
|
||
$$
|
||
|
||
**关键区别**:
|
||
- 原始:`seq_len × 2H`(与序列长度成正比)
|
||
- Streaming:`chunk_size × 2H`(与 chunk size 成正比)
|
||
|
||
**Llama-3.1-8B 具体数值(chunk_size=16K)**:
|
||
$$
|
||
\text{GPU}_{\text{streaming}} = 16\text{ GB} + 16\text{K} \times 12\text{ KB} + 16\text{K} \times 30\text{ KB} = 16\text{ GB} + 192\text{ MB} + 480\text{ MB} \approx 17.5\text{ GB}
|
||
$$
|
||
|
||
### 不同序列长度的内存效果
|
||
|
||
| seq_len | Layerwise Offload | Chunked Prefill | **Streaming** | Streaming vs Offload |
|
||
|---------|------------------|----------------|---------------|---------------------|
|
||
| 32K | 20 GB | 17 GB | **17.5 GB** | -12% |
|
||
| 64K | 25 GB | 17 GB | **17.5 GB** | -30% |
|
||
| 128K | 32 GB | 18 GB | **17.5 GB** | -45% |
|
||
| 256K | 47 GB | 19 GB | **17.5 GB** | -63% |
|
||
| 512K | 77 GB | 22 GB | **17.5 GB** | -77% |
|
||
| **1M** | **137 GB** | **25 GB** | **17.5 GB** | **-87%** ✅ |
|
||
| **2M** | **265 GB** | **38 GB** | **17.5 GB** | **-93%** ✅ |
|
||
| **4M** | **521 GB** | **70 GB** | **17.5 GB** | **-97%** ✅ |
|
||
|
||
### 关键发现
|
||
|
||
**Streaming Chunked Prefill 使得 GPU 内存几乎独立于序列长度!**
|
||
|
||
- 32K: 17.5 GB
|
||
- 1M: 17.5 GB(**完全相同**)
|
||
- 4M: 17.5 GB(**完全相同**)
|
||
|
||
唯一增长的是 **CPU KV cache 内存**:
|
||
$$
|
||
\text{CPU\_KV} = \text{num\_layers} \times \text{seq\_len} \times \text{kv\_dim} \times 2 \times \text{dtype\_size}
|
||
$$
|
||
|
||
### 理论最大序列长度(更新)
|
||
|
||
#### 新的限制:CPU 内存
|
||
|
||
| 配置 | CPU 内存 | 最大序列长度 |
|
||
|-----|---------|-------------|
|
||
| 典型服务器 | 128 GB | ~1M |
|
||
| 大内存服务器 | 512 GB | ~4M |
|
||
| 超大内存服务器 | 1 TB | ~8M |
|
||
| 极限配置 | 2 TB | ~16M |
|
||
|
||
#### RoPE 硬性上限仍然是瓶颈
|
||
|
||
$$
|
||
2^{24} = 16,777,216 \approx 16.8\text{M tokens}
|
||
$$
|
||
|
||
**结论**:
|
||
- **GPU 限制**:已被 Streaming Chunked Prefill 完全消除(17.5 GB 恒定)
|
||
- **CPU 限制**:512 GB RAM → ~4M tokens
|
||
- **RoPE 限制**:~16.8M tokens(硬性上限)
|
||
|
||
### 24GB GPU 可运行的最大序列(更新)
|
||
|
||
| 方案 | 最大序列长度 | 提升倍数 | 瓶颈 |
|
||
|-----|-------------|---------|------|
|
||
| Layerwise Offload | ~32K | 1× | GPU 内存 |
|
||
| +Chunked MLP | ~100K | 3× | GPU 内存 |
|
||
| +Chunked Prefill | ~1M | 32× | GPU 内存 |
|
||
| **+Streaming** | **~4M** | **128×** | **CPU 内存** ✅ |
|
||
| **+Streaming (RoPE float64)** | **~16M** | **512×** | **CPU 内存** |
|
||
|
||
### Streaming 实现伪代码
|
||
|
||
```python
|
||
def run_streaming_chunked_prefill(model, input_ids, chunk_size=16384):
|
||
"""
|
||
Streaming chunked prefill: 每个 chunk 依次通过所有层。
|
||
|
||
Args:
|
||
model: Transformer model
|
||
input_ids: [seq_len]
|
||
chunk_size: chunk size for processing
|
||
|
||
Returns:
|
||
output: [seq_len, vocab_size] (logits)
|
||
"""
|
||
seq_len = len(input_ids)
|
||
num_chunks = (seq_len + chunk_size - 1) // chunk_size
|
||
|
||
# 1. Embedding(一次性)
|
||
initial_hidden = model.model.embed_tokens(input_ids) # [seq_len, 4096]
|
||
|
||
# 2. 预分配 CPU KV cache
|
||
cpu_kv_cache = [{
|
||
'k': torch.zeros(seq_len, kv_dim, dtype=dtype),
|
||
'v': torch.zeros(seq_len, kv_dim, dtype=dtype),
|
||
} for _ in range(num_layers)]
|
||
|
||
# 3. 预分配输出(可选,也可以直接写入文件)
|
||
# 对于超长序列,可能不需要存储完整输出
|
||
final_output = torch.zeros(seq_len, vocab_size, dtype=dtype)
|
||
|
||
# 4. Streaming: 每个 chunk 依次通过所有层
|
||
for chunk_idx in range(num_chunks):
|
||
start = chunk_idx * chunk_size
|
||
end = min(start + chunk_size, seq_len)
|
||
|
||
# 获取当前 chunk 的初始 hidden states
|
||
chunk_hidden = initial_hidden[start:end] # [chunk_size, 4096]
|
||
|
||
# 通过所有层处理
|
||
for layer_id in range(num_layers):
|
||
layer = model.model.layers[layer_id]
|
||
|
||
# 从 CPU 加载之前 chunks 的累积 KV
|
||
k_prev = cpu_kv_cache[layer_id]['k'][:start].to('cuda') # [start, kv_dim]
|
||
v_prev = cpu_kv_cache[layer_id]['v'][:start].to('cuda') # [start, kv_dim]
|
||
|
||
# 处理当前 chunk
|
||
chunk_hidden, chunk_k, chunk_v = process_layer_chunk(
|
||
layer,
|
||
chunk_hidden,
|
||
k_prev,
|
||
v_prev,
|
||
positions=torch.arange(start, end, device='cuda')
|
||
)
|
||
|
||
# 将当前 chunk 的 KV offload 到 CPU
|
||
cpu_kv_cache[layer_id]['k'][start:end] = chunk_k.cpu()
|
||
cpu_kv_cache[layer_id]['v'][start:end] = chunk_v.cpu()
|
||
|
||
# Chunk 已经通过所有层,计算 logits 并存储
|
||
chunk_logits = model.lm_head(chunk_hidden)
|
||
final_output[start:end] = chunk_logits
|
||
|
||
# 可选:释放当前 chunk 的 GPU 内存
|
||
del chunk_hidden, chunk_logits
|
||
|
||
return final_output
|
||
|
||
|
||
def process_layer_chunk(layer, chunk_hidden, k_prev, v_prev, positions):
|
||
"""
|
||
处理单个层的一个 chunk。
|
||
|
||
Args:
|
||
layer: Transformer layer
|
||
chunk_hidden: [chunk_size, hidden_size]
|
||
k_prev: 之前 chunks 的累积 K [prev_len, kv_dim]
|
||
v_prev: 之前 chunks 的累积 V [prev_len, kv_dim]
|
||
positions: 当前 chunk 的位置 [chunk_size]
|
||
|
||
Returns:
|
||
chunk_output: [chunk_size, hidden_size]
|
||
chunk_k: [chunk_size, kv_dim]
|
||
chunk_v: [chunk_size, kv_dim]
|
||
"""
|
||
# 1. Input LayerNorm
|
||
chunk_hidden_ln = layer.input_layernorm(chunk_hidden)
|
||
|
||
# 2. QKV Projection
|
||
chunk_qkv = layer.qkv_proj(chunk_hidden_ln)
|
||
chunk_q, chunk_k, chunk_v = chunk_qkv.split([q_size, kv_dim, kv_dim], dim=-1)
|
||
|
||
# 3. RoPE
|
||
chunk_q, chunk_k = layer.rotary_emb(positions, chunk_q, chunk_k)
|
||
|
||
# 4. 累积 KV
|
||
k_full = torch.cat([k_prev, chunk_k], dim=0) # [prev_len + chunk_size, kv_dim]
|
||
v_full = torch.cat([v_prev, chunk_v], dim=0)
|
||
|
||
# 5. FlashAttention(chunk Q vs 累积 K,V)
|
||
chunk_attn_out = flash_attn(
|
||
chunk_q, # [chunk_size, q_size]
|
||
k_full, # [prev_len + chunk_size, kv_dim]
|
||
v_full, # [prev_len + chunk_size, kv_dim]
|
||
causal=True
|
||
)
|
||
|
||
# 6. O Projection + Residual
|
||
chunk_hidden = layer.o_proj(chunk_attn_out) + chunk_hidden
|
||
|
||
# 7. Post-Attention LayerNorm
|
||
chunk_hidden_ln = layer.post_attention_layernorm(chunk_hidden)
|
||
|
||
# 8. MLP
|
||
chunk_output = layer.mlp(chunk_hidden_ln) + chunk_hidden
|
||
|
||
return chunk_output, chunk_k, chunk_v
|
||
```
|
||
|
||
### 性能权衡
|
||
|
||
| chunk_size | GPU 内存 | CPU-GPU 传输 | 预期性能 | 适用场景 |
|
||
|-----------|---------|-------------|---------|---------|
|
||
| 32K | 18.5 GB | 32K × 32 = 1M/层 | 最快 | ≤ 32K |
|
||
| 16K | 17.5 GB | 16K × 32 = 512K/层 | +10% | 32K - 1M |
|
||
| 8K | 17.0 GB | 8K × 32 = 256K/层 | +25% | 1M - 4M |
|
||
| 4K | 16.8 GB | 4K × 32 = 128K/层 | +50% | > 4M |
|
||
|
||
**CPU-GPU 传输开销**:
|
||
- 每个 chunk 需要从 CPU 加载之前的所有 KV
|
||
- 传输量:`chunk_idx × chunk_size × kv_dim × 2 × dtype_size`
|
||
- Chunk 0: 0 KB(没有之前的 KV)
|
||
- Chunk 1: 16K × 1K × 2 × 2 = 64 MB
|
||
- Chunk 63: 63 × 64 MB = 4 GB
|
||
|
||
---
|
||
|
||
## Layerwise Offload + Chunked Prefill 组合
|
||
|
||
> **注意**:本节描述的是"层内 Chunked Prefill"方案。如果采用 **Streaming Chunked Prefill**,请参考上一节。
|
||
|
||
### 为什么需要组合方案?
|
||
|
||
**Chunked MLP 的局限性**:
|
||
- ✅ 解决了 MLP 瓶颈(98% 降低)
|
||
- ❌ 但 Attention 仍是瓶颈(30GB @ 1M)
|
||
|
||
**Chunked Prefill 的优势**:
|
||
- ✅ 同时解决 MLP 和 Attention 瓶颈
|
||
- ✅ 让 24GB GPU 运行 1M 序列
|
||
|
||
### 组合方案架构
|
||
|
||
```
|
||
输入: input_ids [seq_len]
|
||
↓
|
||
Embedding: hidden_states [seq_len, hidden_size]
|
||
↓
|
||
┌─────────────────────────────────────────────────────────────┐
|
||
│ Layer 0 (Chunked Prefill) │
|
||
│ for chunk in chunks(hidden_states, chunk_size): │
|
||
│ - QKV Projection (chunk only) │
|
||
│ - FlashAttention (chunk Q vs accumulated K,V) │
|
||
│ - MLP (chunk only) │
|
||
│ Offload KV to CPU │
|
||
└─────────────────────────────────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────────┐
|
||
│ Layer 1 (Chunked Prefill) │
|
||
│ for chunk in chunks(hidden_states, chunk_size): │
|
||
│ - ... (same pattern) │
|
||
│ Offload KV to CPU │
|
||
└─────────────────────────────────────────────────────────────┘
|
||
↓
|
||
... (repeat for all layers)
|
||
↓
|
||
Final LayerNorm + LM Head
|
||
```
|
||
|
||
### 完整内存 Breakdown(seq_len=1M, chunk_size=16K)
|
||
|
||
| 组件 | 内存 | 占比 | 能否 chunk? |
|
||
|-----|------|------|------------|
|
||
| 模型权重 | 16 GB | 64% | ❌ 固定 |
|
||
| **最终输出 (hidden_states)** | **8 GB** | **32%** | ❌ **层间数据流** |
|
||
| KV 累积 | 4 GB | 16% | ❌ 必须累积 |
|
||
| qkv (chunked) | 192 MB | 0.8% | ✅ 已 chunked |
|
||
| gate_up (chunked) | 875 MB | 3.5% | ✅ 已 chunked |
|
||
| 其他临时张量 | 169 MB | 0.7% | ✅ 已 chunked |
|
||
| **总计** | **~25 GB** | 100% | |
|
||
|
||
### 性能权衡分析
|
||
|
||
**Kernel 调用次数**:
|
||
```
|
||
原始: num_layers × 1 = 32
|
||
Chunked (16K): num_layers × (seq_len / chunk_size) = 32 × 64 = 2048
|
||
增加: +6300%
|
||
```
|
||
|
||
**预期性能开销**:
|
||
|
||
| chunk_size | 内存节省 | 预期时间开销 | 适用场景 |
|
||
|-----------|---------|------------|---------|
|
||
| 32K | 15-24% | +5-10% | 64K-128K |
|
||
| 16K | 24-46% | +10-20% | 128K-512K |
|
||
| 8K | 34-57% | +20-40% | 512K-1M |
|
||
| 4K | 46-78% | +40-80% | >1M |
|
||
|
||
### 为什么最终输出不能 Chunk?
|
||
|
||
**原因**:`hidden_states` 是层与层之间的数据流
|
||
|
||
```python
|
||
# Layer N 的输出
|
||
hidden_states: [seq_len, hidden_size]
|
||
|
||
# Layer N+1 的输入
|
||
# 需要 Layer N 的完整输出!
|
||
qkv = qkv_proj(hidden_states) # 需要完整的 hidden_states
|
||
```
|
||
|
||
**突破方法**:
|
||
1. **模型并行**:分布 hidden_states 到多个 GPU
|
||
2. **流水线并行**:不同层在不同 GPU
|
||
3. **降低 hidden_size**:使用更小的模型
|
||
|
||
---
|
||
|
||
## 实现指南
|
||
|
||
### 配置参数
|
||
|
||
```python
|
||
@dataclass
|
||
class Config:
|
||
# Chunked prefill 配置
|
||
prefill_chunk_size: int = 0 # 0 = 不 chunk, >0 = chunk 大小
|
||
|
||
# 自动选择策略
|
||
def get_chunk_size(self, seq_len: int) -> int:
|
||
if self.prefill_chunk_size > 0:
|
||
return self.prefill_chunk_size
|
||
|
||
# 自动选择
|
||
if seq_len < 32000:
|
||
return 0 # 不需要 chunk
|
||
elif seq_len < 128000:
|
||
return 16384
|
||
elif seq_len < 512000:
|
||
return 8192
|
||
else:
|
||
return 4096
|
||
```
|
||
|
||
### 实现步骤
|
||
|
||
**第一步:验证 Chunked MLP**(~50 行代码)
|
||
|
||
```python
|
||
# 在 model_runner.py 的 run_layerwise_offload_prefill 中
|
||
# 将 MLP 调用改为 chunked
|
||
|
||
chunk_size = self.config.get_chunk_size(len(hidden_states))
|
||
if chunk_size > 0:
|
||
hidden_states = chunked_mlp_forward(layer.mlp, hidden_states, chunk_size)
|
||
else:
|
||
hidden_states = layer.mlp(hidden_states)
|
||
```
|
||
|
||
**第二步:实现 Chunked Prefill**(~150 行代码)
|
||
|
||
```python
|
||
def run_layerwise_offload_chunked_prefill(self, seqs):
|
||
chunk_size = self.config.get_chunk_size(total_tokens)
|
||
|
||
hidden_states = self.model.model.embed_tokens(input_ids)
|
||
|
||
for layer_id in range(num_layers):
|
||
if chunk_size == 0:
|
||
# 原始实现
|
||
hidden_states = self._process_layer_full(layer, hidden_states)
|
||
else:
|
||
# Chunked 实现
|
||
hidden_states = self._process_layer_chunked(
|
||
layer, hidden_states, chunk_size
|
||
)
|
||
|
||
# Offload KV to CPU
|
||
...
|
||
|
||
return hidden_states
|
||
```
|
||
|
||
### 优化建议
|
||
|
||
1. **预分配输出 buffer**:
|
||
```python
|
||
output = torch.empty_like(hidden_states)
|
||
for chunk in chunks:
|
||
output[start:end] = process_chunk(chunk)
|
||
```
|
||
|
||
2. **复用 KV 累积 buffer**:
|
||
```python
|
||
# 预分配最大可能的 KV buffer
|
||
k_buffer = torch.zeros(seq_len, kv_dim, ...)
|
||
v_buffer = torch.zeros(seq_len, kv_dim, ...)
|
||
|
||
# 每次 copy 到对应位置
|
||
k_buffer[current_pos:end] = chunk_k
|
||
```
|
||
|
||
3. **异步 H2D 传输**:
|
||
```python
|
||
# 在处理当前 chunk 时,异步加载下一层需要的 KV
|
||
# (需要流式处理支持)
|
||
```
|
||
|
||
### 测试验证
|
||
|
||
```python
|
||
# 验证正确性
|
||
def test_chunked_prefill_correctness():
|
||
model = LLM("path/to/model")
|
||
|
||
# 原始实现
|
||
output_original = model.generate(prompt, enable_cpu_offload=True, prefill_chunk_size=0)
|
||
|
||
# Chunked 实现
|
||
output_chunked = model.generate(prompt, enable_cpu_offload=True, prefill_chunk_size=16384)
|
||
|
||
# 验证结果相同
|
||
assert torch.allclose(output_original, output_chunked, atol=1e-3)
|
||
|
||
# 验证内存
|
||
def test_chunked_prefill_memory():
|
||
torch.cuda.reset_peak_memory_stats()
|
||
|
||
model = LLM("path/to/model", enable_cpu_offload=True, prefill_chunk_size=16384)
|
||
model.generate(prompt_1M)
|
||
|
||
peak = torch.cuda.max_memory_allocated()
|
||
assert peak < 24 * 1024**3 # 24GB
|
||
```
|
||
|
||
---
|
||
|
||
## 总结
|
||
|
||
### 关键发现
|
||
|
||
1. **理论最大序列**:~16.8M tokens(RoPE float32 限制)
|
||
|
||
2. **主要瓶颈演进**(按优化顺序):
|
||
- RoPE 精度(最严格,硬性上限)
|
||
- MLP gate_up 激活值(原始瓶颈)
|
||
- Attention 中间张量(第二瓶颈)
|
||
- 最终输出 hidden_states(第三瓶颈,已被 Streaming 消除)
|
||
|
||
3. **Streaming Chunked Prefill(最终方案)效果**:
|
||
- **GPU 内存独立于 seq_len**:17.5 GB 恒定
|
||
- 内存降低:87%+(相比 layerwise offload)
|
||
- 序列扩展:128×(32K → 4M,受 CPU 内存限制)
|
||
- 性能开销:10-25%(取决于 chunk_size)
|
||
|
||
### 方案对比总结
|
||
|
||
| 方案 | 最大序列(24GB GPU) | GPU 内存 | 瓶颈 | 推荐度 |
|
||
|-----|---------------------|---------|------|-------|
|
||
| Layerwise Offload | ~32K | 20-137 GB | GPU 内存 | ⭐⭐⭐⭐ |
|
||
| +Chunked MLP | ~100K | 19-98 GB | GPU 内存 | ⭐⭐⭐ |
|
||
| +Chunked Prefill(层内) | ~1M | 17-25 GB | GPU 内存 | ⭐⭐⭐⭐ |
|
||
| **+Streaming(最终)** | **~4M** | **17.5 GB(恒定)** | **CPU 内存** | **⭐⭐⭐⭐⭐** |
|
||
|
||
### 实现优先级(更新)
|
||
|
||
| 阶段 | 方案 | 效果 | 优先级 |
|
||
|-----|------|------|-------|
|
||
| 已实现 | Layerwise Offload | 32K | P0 ✅ |
|
||
| 第一步 | +Chunked MLP | ~100K | P2(可跳过) |
|
||
| 第二步 | +Chunked Prefill(层内) | ~1M | P1 |
|
||
| **最终目标** | **+Streaming** | **~4M** | **P0** ✅ |
|
||
|
||
### Streaming Chunked Prefill 关键优势
|
||
|
||
- ✅ **GPU 内存恒定**:17.5 GB,与 seq_len 无关
|
||
- ✅ **突破 GPU 限制**:最大序列由 CPU 内存决定
|
||
- ✅ **实现复杂度可控**:~200 行代码
|
||
- ✅ **与 layerwise offload 完美集成**
|
||
- ⚠️ CPU-GPU 传输增加:每个 chunk 需加载之前所有 KV
|
||
- ⚠️ 性能开销:10-25%(可接受的权衡)
|
||
|
||
### 最终推荐
|
||
|
||
**对于超长序列推理(> 128K)**:
|
||
1. ✅ 使用 **Streaming Chunked Prefill**(最优方案)
|
||
2. ✅ 设置 `prefill_chunk_size = 16384`(自动选择)
|
||
3. ✅ 设置 `num_kv_buffers = 1`(降低 ring buffer 内存)
|
||
4. ✅ 确保充足 CPU 内存(512GB → 4M tokens)
|
||
|
||
**效果**:
|
||
- RTX 3090 (24GB): 32K → **~4M** (128× 提升) ✅
|
||
- A100 40GB: 340K → **~4M** (12× 提升) ✅
|
||
- A100 80GB: 1M → **~4M** (4× 提升) ✅
|
||
|
||
**唯一限制**:
|
||
- CPU 内存:512 GB → ~4M tokens
|
||
- RoPE 硬性上限:~16.8M tokens(需改 float64)
|