Add comprehensive analysis document covering: - MLP activation memory bottlenecks with SwiGLU architecture - Chunked MLP strategy (98% memory reduction) - Chunked prefill for single layers (78% memory reduction) - Streaming Chunked Prefill (最优方案): GPU memory becomes constant - Memory formulas and implementation guidance - Theoretical maximum: 4M tokens on 24GB GPU (128× improvement) Co-Authored-By: Claude <noreply@anthropic.com>
33 KiB
Chunked Prefill 与长序列推理内存分析
本文档详细分析了 nano-vllm 的 layerwise kvcache offload 策略的理论最大序列长度,以及通过 chunked prefill 降低 GPU 峰值内存的方案。
目录
- 问题背景
- 理论最大序列长度分析
- MLP Activation 内存瓶颈
- Chunked MLP 方案
- Chunked Prefill 方案
- Streaming Chunked Prefill(最优方案)
- Layerwise Offload + Chunked Prefill 组合
- 实现指南
问题背景
为什么需要长序列推理?
现代大语言模型的应用场景对上下文长度提出了越来越高的要求:
- 文档分析:需要处理完整的书籍、报告
- 代码理解:大型代码库的上下文
- 多轮对话:保持长期对话历史
- RAG 应用:检索增强生成需要大量上下文
GPU 内存限制
RTX 3090/4090 只有 24GB 显存,A100 有 40GB/80GB,而长序列推理的内存需求呈线性增长:
\text{Memory} \propto \text{seq\_len}
Layerwise Offload 的局限性
虽然 layerwise offload 将 KV cache 移到 CPU,但仍有其他组件占用大量 GPU 内存:
- 模型权重:固定 ~16GB(Llama-3.1-8B)
- MLP 激活值:与序列长度成正比
- Attention 中间张量:与序列长度成正比
理论最大序列长度分析
硬性限制排序
1. RoPE float32 精度限制(最严格)
从 rotary_embedding.py:30:
t = torch.arange(max_position_embeddings, dtype=torch.float)
问题:torch.arange 使用 float32,精确整数表示上限为:
2^{24} = 16,777,216 \approx 16.8\text{M tokens}
突破方法:修改 rotary_embedding.py 使用 float64 或 int64
2. GPU MLP gate_up 激活值(实际瓶颈)
公式:
\text{MLP\_gate\_up} = \text{seq\_len} \times 2 \times \text{intermediate\_size} \times \text{dtype\_size}
Llama-3.1-8B 示例:
intermediate_size = 14336
dtype_size = 2 (bfloat16)
MLP_gate_up = seq_len × 2 × 14336 × 2
= seq_len × 57344 bytes
= seq_len × 56 KB
| seq_len | MLP gate_up 内存 |
|---|---|
| 64K | 3.5 GB |
| 1M | 57 GB |
| 16M | 912 GB |
3. GPU Ring Buffer 内存
公式:
\text{Ring\_Buffer} = 2 \times \text{num\_kv\_buffers} \times \text{seq\_len} \times \text{kv\_dim} \times \text{dtype\_size}
优化:设置 num_kv_buffers=1 可降低内存
4. CPU KV Cache 内存
公式:
\text{CPU\_KV} = \text{num\_layers} \times \text{seq\_len} \times \text{kv\_dim} \times 2 \times \text{dtype\_size}
Llama-3.1-8B (32 layers, kv_dim=1024):
CPU_KV = 32 × seq_len × 1024 × 2 × 2
= seq_len × 131072 bytes
= seq_len × 128 KB
| seq_len | CPU 内存需求 |
|---|---|
| 1M | 131 GB |
| 10M | 1.31 TB |
| 16M | 2.1 TB |
不同 GPU 配置的最大序列长度
GPU 内存公式:
\text{GPU\_mem} = \text{model\_weights} + \text{MLP\_gate\_up} + \text{Ring\_Buffer} + \text{overhead}
Llama-3.1-8B 在不同 GPU 上:
| GPU | 显存 | 模型权重 | 可用 | 最大序列(原始) | 最大序列(chunked) |
|---|---|---|---|---|---|
| RTX 3090 | 24GB | 16 GB | 3 GB | ~32K | ~1M |
| A100 40GB | 40GB | 16 GB | 19 GB | ~340K | ~2M |
| A100 80GB | 80GB | 16 GB | 59 GB | ~1M | ~4M |
最终理论最大值
硬性上限:~16.8M tokens(受 RoPE float32 精度限制)
实用最大值(单卡):1-2M tokens(受 MLP gate_up GPU 内存限制)
MLP Activation 内存瓶颈
MLP 计算流程
从 llama.py:82-106,LlamaMLP 使用 SwiGLU 激活:
def forward(self, x):
# 输入: x [seq_len, hidden_size]
gate_up = self.gate_up_proj(x) # [seq_len, 2 * intermediate_size]
gate, up = gate_up.chunk(2, dim=-1) # 各 [seq_len, intermediate_size]
gate = F.silu(gate) # SiLU 激活
output = gate * up # 逐元素相乘
return self.down_proj(output) # [seq_len, hidden_size]
为什么 gate_up 这么大?
Llama-3.1-8B 参数:
hidden_size = 4096
intermediate_size = 14336 # 是 hidden_size 的 3.5 倍
gate_up 张量大小:
shape = [seq_len, 2 * intermediate_size]
= [seq_len, 28672]
内存 = seq_len × 28672 × 2 bytes
= seq_len × 57344 bytes
= seq_len × 56 KB
SwiGLU vs 标准 MLP
标准 Transformer MLP(如 GPT):
hidden = gelu(x @ W_gate) # [seq_len, 4 * hidden_size]
output = hidden @ W_down # [seq_len, hidden_size]
LLaMA 的 SwiGLU MLP:
gate_up = x @ W_gate_up # [seq_len, 2 * intermediate_size]
gate, up = gate_up.chunk(2)
output = silu(gate) * up # ← 逐元素相乘需要两个独立张量!
output = output @ W_down
关键区别:
- SwiGLU 需要 gate 和 up 两个独立的分支
- 两个分支必须 同时存在内存中 才能逐元素相乘
- 这就是内存开销的根本原因
MLP Activation 内存公式
\text{MLP\_activation} = \text{seq\_len} \times 2 \times \text{intermediate\_size} \times \text{dtype\_size}
| 模型 | hidden_size | intermediate_size | 扩展比 | 64K 激活内存 |
|---|---|---|---|---|
| Llama-3.1-8B | 4096 | 14336 | 3.5× | 3.5 GB |
| Llama-2-7B | 4096 | 11008 | 2.69× | 2.7 GB |
| Qwen3-4B | 2560 | 13696 | 5.35× | 3.5 GB |
| Qwen2-7B | 4096 | 13696 | 3.34× | 3.5 GB |
为什么 MLP Activation 不能 Offload?
与 KV cache 不同,MLP activation 是 临时计算结果:
- 计算依赖:
gate_up→chunk→silu→mul→down_proj - 生命周期:用完即弃,不需要长期保存
- 传输开销:CPU-GPU 传输会完全抵消 offload 的好处
如果 offload:
gate_up = gate_up_proj(x) # GPU 计算
gate_up = gate_up.to("cpu") # 传输到 CPU ← 慢!
gate, up = gate_up.chunk(2) # CPU 操作
gate = silu(gate)
output = gate * up
output = output.to("cuda") # 传输回 GPU ← 慢!
output = down_proj(output) # GPU 计算
结论:MLP activation 的高内存开销是 Transformer SwiGLU 架构的固有问题。
Chunked MLP 方案
核心思想
MLP 是 逐位置独立 的计算:
\text{output}[i] = \text{MLP}(\text{input}[i])
这意味着可以将序列分成多个 chunks,每个 chunk 独立计算。
实现方式
def chunked_mlp_forward(mlp, hidden_states, chunk_size):
"""
Chunked MLP forward to reduce peak memory usage.
Args:
mlp: LlamaMLP module
hidden_states: [seq_len, hidden_size]
chunk_size: number of tokens to process per chunk
Returns:
output: [seq_len, hidden_size]
"""
if chunk_size <= 0 or len(hidden_states) <= chunk_size:
return mlp(hidden_states)
# 预分配输出 buffer
output = torch.empty_like(hidden_states)
# Chunked processing
for i in range(0, len(hidden_states), chunk_size):
chunk_end = min(i + chunk_size, len(hidden_states))
chunk = hidden_states[i:chunk_end]
output[i:chunk_end] = mlp(chunk)
return output
内存效果对比(Llama-3.1-8B, seq_len=1M)
| 组件 | 原始 | Chunked (chunk_size=16K) | 降低 |
|---|---|---|---|
| gate_up | 56 GB | 875 MB | 98.4% ✅ |
| activated | 28 GB | 437 MB | 98.4% ✅ |
| mlp_out | 8 GB | 125 MB | 98.4% ✅ |
GPU 内存效果
| GPU | 显存 | 原始最大序列 | Chunked (16K) | 提升 |
|---|---|---|---|---|
| RTX 3090 | 24GB | ~32K | ~200K | 6.25× |
| A100 40GB | 40GB | ~340K | ~1M | 3× |
| A100 80GB | 80GB | ~1M | ~2M | 2× |
性能权衡
| chunk_size | 峰值内存降低 | 预期性能开销 | 推荐场景 |
|---|---|---|---|
| 16K | 75% | +5-10% | 64K-128K 序列 |
| 8K | 87% | +10-20% | 128K-256K 序列 |
| 4K | 94% | +20-50% | 256K-512K 序列 |
Chunked Prefill 方案
核心思想
不仅 MLP 是逐位置独立的,在保持因果关系的前提下,Attention 也可以分块处理。
关键洞察
对于序列的 Chunk N(位置范围 [start, end]):
\text{output}[i] = \text{Attention}(\text{query}[i], \text{keys}[0:i], \text{values}[0:i])
通过累积 KV cache,可以让 Chunk N 看到所有之前的 tokens:
\text{K}_{\text{full}} = \text{cat}([\text{K}_0, \text{K}_1, ..., \text{K}_N])
Chunk 独立性验证
关键问题:Chunk N 的输出是否是最终结果?
答案:✅ 是的!每个 chunk 计算完成后的输出就是该 chunk 的最终结果。
原因:
- Chunk N 的 tokens 能看到所有之前的 tokens(通过累积 KV)
- 因果关系通过
causal=True保持 - 不同 chunks 之间不需要额外的计算
数学等价性:
原始实现(一次性):
output[i] = Layer(attention(hidden[i], hidden[0:i]))
Chunked 实现:
output[i] = Layer(attention(hidden[i], cat([KV[0:N]])))
因为 cat([KV[0:N]]) == hidden[0:i](因果范围内)
所以两者完全相同!
实现伪代码
def run_layerwise_offload_chunked_prefill(layer, hidden_states, chunk_size=16384):
"""
Chunked prefill for a single layer.
Args:
layer: Transformer layer
hidden_states: [seq_len, hidden_size]
chunk_size: chunk size for processing
Returns:
output: [seq_len, hidden_size]
"""
seq_len = len(hidden_states)
# 初始化累积 KV cache
k_accumulated = []
v_accumulated = []
# 预分配输出 buffer(避免 torch.cat)
output_buffer = torch.empty_like(hidden_states)
# Chunked prefill
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
# 1. 获取当前 chunk
chunk_hidden = hidden_states[start:end]
chunk_residual = residual[start:end] if residual is not None else None
# 2. Input LayerNorm
chunk_hidden_ln, chunk_residual = layer.input_layernorm(chunk_hidden, chunk_residual)
# 3. QKV Projection(只对 chunk)
chunk_qkv = layer.qkv_proj(chunk_hidden_ln) # [chunk_size, 6144]
chunk_q, chunk_k, chunk_v = chunk_qkv.split([...], dim=-1)
# 4. RoPE(只对 chunk)
positions_chunk = positions[start:end]
chunk_q_rot, chunk_k_rot = layer.rotary_emb(positions_chunk, chunk_q, chunk_k)
# 5. 累积 KV cache
k_accumulated.append(chunk_k_rot)
v_accumulated.append(chunk_v)
k_full = torch.cat(k_accumulated, dim=0) # 累积的全部 K
v_full = torch.cat(v_accumulated, dim=0) # 累积的全部 V
# 6. FlashAttention(chunk Q vs 累积 K,V)
chunk_attn_out = flash_attn(
chunk_q_rot, # Query: 当前 chunk
k_full, # Key: 累积的全部
v_full, # Value: 累积的全部
causal=True
)
# 7. O Projection + Residual
chunk_o_proj = layer.o_proj(chunk_attn_out)
chunk_hidden = chunk_o_proj + chunk_residual
# 8. Post-Attention LayerNorm
chunk_hidden_ln, chunk_residual = layer.post_attention_layernorm(chunk_hidden)
# 9. MLP(只对 chunk)
chunk_output = layer.mlp(chunk_hidden_ln)
# 10. 直接写入输出 buffer
output_buffer[start:end] = chunk_output
# 11. Offload 完整的 K, V to CPU
k_full = torch.cat(k_accumulated, dim=0)
v_full = torch.cat(v_accumulated, dim=0)
offload_kv_to_cpu(layer.layer_id, k_full, v_full)
return output_buffer
内存效果对比(Llama-3.1-8B, seq_len=1M)
| 组件 | 原始 | Chunked (16K) | 降低 |
|---|---|---|---|
| hidden_states | 8 GB | 8 GB | 0% (需要完整存储) |
| qkv | 12 GB | 192 MB | 98.4% ✅ |
| q_rotated | 8 GB | 128 MB | 98.4% ✅ |
| k_rotated | 2 GB | 32 MB | 98.4% ✅ |
| attn_out | 8 GB | 128 MB | 98.4% ✅ |
| gate_up | 56 GB | 875 MB | 98.4% ✅ |
| KV 累积 | - | 4 GB | 必须的 |
| 总峰值 | 102 GB | 21 GB | 79% ✅ |
不同序列长度的内存效果
| seq_len | Layerwise Offload | +Chunked MLP | +Chunked Prefill | 提升 vs Offload |
|---|---|---|---|---|
| 32K | 20 GB | 19 GB | 17 GB | 15% |
| 64K | 25 GB | 22 GB | 17 GB | 32% |
| 128K | 32 GB | 27 GB | 18 GB | 44% |
| 256K | 47 GB | 38 GB | 19 GB | 60% |
| 512K | 77 GB | 56 GB | 22 GB | 71% |
| 1M | 137 GB | 98 GB | 25 GB | 82% ✅ |
GPU 内存公式推导
Layerwise Offload:
\text{GPU}_{\text{offload}} = \text{model\_weights} + \text{seq\_len} \times (6 \times H + Q + 2KV) + \text{seq\_len} \times 2I
Layerwise + Chunked MLP:
\text{GPU}_{\text{chunked\_mlp}} = \text{model\_weights} + \text{seq\_len} \times (6H + Q + 2KV) + \text{chunk\_size} \times 2I
Layerwise + Chunked Prefill:
\text{GPU}_{\text{chunked\_prefill}} = \text{model\_weights} + \text{seq\_len} \times 2H + \text{seq\_len} \times 2KV + \text{chunk\_size} \times (Q + KV + 2I)
其中:
H = \text{hidden\_size}(4096)I = \text{intermediate\_size}(14336)Q = \text{q\_size}(4096)KV = \text{kv\_size}(1024)
Llama-3.1-8B 具体数值:
\text{GPU}_{\text{chunked\_prefill}} = 16\text{ GB} + \text{seq\_len} \times 12\text{ KB} + \text{chunk\_size} \times 30\text{ KB}
24GB GPU 可运行的最大序列
| 方案 | 最大序列长度 | 提升倍数 |
|---|---|---|
| Layerwise Offload | ~32K | 1× |
| +Chunked MLP | ~100K | 3× |
| +Chunked Prefill (chunk=16K) | ~1M | 32× ✅ |
| +Chunked Prefill (chunk=4K) | ~2M | 64× ✅ |
Streaming Chunked Prefill(最优方案)
关键洞察:层间独立性
之前的误解:认为 Layer N+1 需要 Layer N 的完整 hidden_states 输出。
正确的理解:如果每一层都是 chunked 的,那么 Layer N+1 只需要当前 chunk 的输入!
为什么 Streaming 方式可行?
核心原因:MLP 和 Attention 的计算都是逐位置独立的。
\text{output}[i] = \text{Layer}(\text{input}[i])
这意味着:
- Chunk 0 在所有层的计算完成后,其结果就是最终结果
- Chunk 1 在所有层的计算完成后,其结果也是最终结果
- 不同 chunks 之间完全独立
两种 Chunked Prefill 对比
原始理解(错误):"层内 Chunked"
# 错误的方式:处理完一层所有 chunks 后再处理下一层
for layer_id in range(num_layers):
output_buffer = torch.empty([seq_len, 4096]) # ← 需要 8GB!
for chunk_idx in range(num_chunks):
chunk = input_hidden[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size]
output_buffer[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] = process_chunk(layer_id, chunk)
hidden_states = output_buffer # 传递给下一层
问题:
- ❌ 需要存储完整的
output_buffer(8GB @ 1M) - ❌ GPU 内存仍与
seq_len成正比 - ❌ 层与层之间需要传递完整的 hidden_states
Streaming Chunked Prefill(正确)
# 正确的方式:每个 chunk 依次通过所有层
initial_hidden = embedding(input_ids) # [seq_len, 4096]
for chunk_idx in range(num_chunks):
# 获取当前 chunk
chunk = initial_hidden[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] # [chunk_size, 4096]
# 通过所有层处理这个 chunk
for layer_id in range(num_layers):
# 从 CPU 加载该层的累积 KV
k_accumulated, v_accumulated = load_kv_from_cpu(layer_id, end=chunk_idx * chunk_size)
# 处理当前 chunk
chunk = process_layer_chunk(
layer_id,
chunk,
k_accumulated,
v_accumulated,
start_pos=chunk_idx * chunk_size
)
# 将当前 chunk 的 KV offload 到 CPU
offload_chunk_kv_to_cpu(layer_id, chunk_k, chunk_v)
# chunk 已经通过所有层,可以直接输出或继续处理
# 不需要存储完整的 output_buffer!
优势:
- ✅ 不需要完整的 output_buffer(节省 8GB)
- ✅ GPU 内存独立于 seq_len
- ✅ 只需要存储当前 chunk 的 hidden_states(125MB @ chunk_size=16K)
内存效果对比(Llama-3.1-8B, seq_len=1M)
| 组件 | 原始 Chunked Prefill | Streaming Chunked Prefill | 降低 |
|---|---|---|---|
| hidden_states (输出) | 8 GB | 125 MB | 98.4% ✅ |
| qkv (chunked) | 192 MB | 192 MB | - |
| gate_up (chunked) | 875 MB | 875 MB | - |
| KV 累积(单层) | 4 GB | 4 GB | - |
| 模型权重 | 16 GB | 16 GB | - |
| 总峰值 | ~29 GB | ~21 GB | 28% ✅ |
GPU 内存公式
原始 Chunked Prefill:
\text{GPU}_{\text{chunked}} = \text{model\_weights} + \text{seq\_len} \times 2H + \text{chunk\_size} \times (Q + KV + 2I)
Streaming Chunked Prefill:
\text{GPU}_{\text{streaming}} = \text{model\_weights} + \text{chunk\_size} \times 2H + \text{chunk\_size} \times (Q + KV + 2I)
关键区别:
- 原始:
seq_len × 2H(与序列长度成正比) - Streaming:
chunk_size × 2H(与 chunk size 成正比)
Llama-3.1-8B 具体数值(chunk_size=16K):
\text{GPU}_{\text{streaming}} = 16\text{ GB} + 16\text{K} \times 12\text{ KB} + 16\text{K} \times 30\text{ KB} = 16\text{ GB} + 192\text{ MB} + 480\text{ MB} \approx 17.5\text{ GB}
不同序列长度的内存效果
| seq_len | Layerwise Offload | Chunked Prefill | Streaming | Streaming vs Offload |
|---|---|---|---|---|
| 32K | 20 GB | 17 GB | 17.5 GB | -12% |
| 64K | 25 GB | 17 GB | 17.5 GB | -30% |
| 128K | 32 GB | 18 GB | 17.5 GB | -45% |
| 256K | 47 GB | 19 GB | 17.5 GB | -63% |
| 512K | 77 GB | 22 GB | 17.5 GB | -77% |
| 1M | 137 GB | 25 GB | 17.5 GB | -87% ✅ |
| 2M | 265 GB | 38 GB | 17.5 GB | -93% ✅ |
| 4M | 521 GB | 70 GB | 17.5 GB | -97% ✅ |
关键发现
Streaming Chunked Prefill 使得 GPU 内存几乎独立于序列长度!
- 32K: 17.5 GB
- 1M: 17.5 GB(完全相同)
- 4M: 17.5 GB(完全相同)
唯一增长的是 CPU KV cache 内存:
\text{CPU\_KV} = \text{num\_layers} \times \text{seq\_len} \times \text{kv\_dim} \times 2 \times \text{dtype\_size}
理论最大序列长度(更新)
新的限制:CPU 内存
| 配置 | CPU 内存 | 最大序列长度 |
|---|---|---|
| 典型服务器 | 128 GB | ~1M |
| 大内存服务器 | 512 GB | ~4M |
| 超大内存服务器 | 1 TB | ~8M |
| 极限配置 | 2 TB | ~16M |
RoPE 硬性上限仍然是瓶颈
2^{24} = 16,777,216 \approx 16.8\text{M tokens}
结论:
- GPU 限制:已被 Streaming Chunked Prefill 完全消除(17.5 GB 恒定)
- CPU 限制:512 GB RAM → ~4M tokens
- RoPE 限制:~16.8M tokens(硬性上限)
24GB GPU 可运行的最大序列(更新)
| 方案 | 最大序列长度 | 提升倍数 | 瓶颈 |
|---|---|---|---|
| Layerwise Offload | ~32K | 1× | GPU 内存 |
| +Chunked MLP | ~100K | 3× | GPU 内存 |
| +Chunked Prefill | ~1M | 32× | GPU 内存 |
| +Streaming | ~4M | 128× | CPU 内存 ✅ |
| +Streaming (RoPE float64) | ~16M | 512× | CPU 内存 |
Streaming 实现伪代码
def run_streaming_chunked_prefill(model, input_ids, chunk_size=16384):
"""
Streaming chunked prefill: 每个 chunk 依次通过所有层。
Args:
model: Transformer model
input_ids: [seq_len]
chunk_size: chunk size for processing
Returns:
output: [seq_len, vocab_size] (logits)
"""
seq_len = len(input_ids)
num_chunks = (seq_len + chunk_size - 1) // chunk_size
# 1. Embedding(一次性)
initial_hidden = model.model.embed_tokens(input_ids) # [seq_len, 4096]
# 2. 预分配 CPU KV cache
cpu_kv_cache = [{
'k': torch.zeros(seq_len, kv_dim, dtype=dtype),
'v': torch.zeros(seq_len, kv_dim, dtype=dtype),
} for _ in range(num_layers)]
# 3. 预分配输出(可选,也可以直接写入文件)
# 对于超长序列,可能不需要存储完整输出
final_output = torch.zeros(seq_len, vocab_size, dtype=dtype)
# 4. Streaming: 每个 chunk 依次通过所有层
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, seq_len)
# 获取当前 chunk 的初始 hidden states
chunk_hidden = initial_hidden[start:end] # [chunk_size, 4096]
# 通过所有层处理
for layer_id in range(num_layers):
layer = model.model.layers[layer_id]
# 从 CPU 加载之前 chunks 的累积 KV
k_prev = cpu_kv_cache[layer_id]['k'][:start].to('cuda') # [start, kv_dim]
v_prev = cpu_kv_cache[layer_id]['v'][:start].to('cuda') # [start, kv_dim]
# 处理当前 chunk
chunk_hidden, chunk_k, chunk_v = process_layer_chunk(
layer,
chunk_hidden,
k_prev,
v_prev,
positions=torch.arange(start, end, device='cuda')
)
# 将当前 chunk 的 KV offload 到 CPU
cpu_kv_cache[layer_id]['k'][start:end] = chunk_k.cpu()
cpu_kv_cache[layer_id]['v'][start:end] = chunk_v.cpu()
# Chunk 已经通过所有层,计算 logits 并存储
chunk_logits = model.lm_head(chunk_hidden)
final_output[start:end] = chunk_logits
# 可选:释放当前 chunk 的 GPU 内存
del chunk_hidden, chunk_logits
return final_output
def process_layer_chunk(layer, chunk_hidden, k_prev, v_prev, positions):
"""
处理单个层的一个 chunk。
Args:
layer: Transformer layer
chunk_hidden: [chunk_size, hidden_size]
k_prev: 之前 chunks 的累积 K [prev_len, kv_dim]
v_prev: 之前 chunks 的累积 V [prev_len, kv_dim]
positions: 当前 chunk 的位置 [chunk_size]
Returns:
chunk_output: [chunk_size, hidden_size]
chunk_k: [chunk_size, kv_dim]
chunk_v: [chunk_size, kv_dim]
"""
# 1. Input LayerNorm
chunk_hidden_ln = layer.input_layernorm(chunk_hidden)
# 2. QKV Projection
chunk_qkv = layer.qkv_proj(chunk_hidden_ln)
chunk_q, chunk_k, chunk_v = chunk_qkv.split([q_size, kv_dim, kv_dim], dim=-1)
# 3. RoPE
chunk_q, chunk_k = layer.rotary_emb(positions, chunk_q, chunk_k)
# 4. 累积 KV
k_full = torch.cat([k_prev, chunk_k], dim=0) # [prev_len + chunk_size, kv_dim]
v_full = torch.cat([v_prev, chunk_v], dim=0)
# 5. FlashAttention(chunk Q vs 累积 K,V)
chunk_attn_out = flash_attn(
chunk_q, # [chunk_size, q_size]
k_full, # [prev_len + chunk_size, kv_dim]
v_full, # [prev_len + chunk_size, kv_dim]
causal=True
)
# 6. O Projection + Residual
chunk_hidden = layer.o_proj(chunk_attn_out) + chunk_hidden
# 7. Post-Attention LayerNorm
chunk_hidden_ln = layer.post_attention_layernorm(chunk_hidden)
# 8. MLP
chunk_output = layer.mlp(chunk_hidden_ln) + chunk_hidden
return chunk_output, chunk_k, chunk_v
性能权衡
| chunk_size | GPU 内存 | CPU-GPU 传输 | 预期性能 | 适用场景 |
|---|---|---|---|---|
| 32K | 18.5 GB | 32K × 32 = 1M/层 | 最快 | ≤ 32K |
| 16K | 17.5 GB | 16K × 32 = 512K/层 | +10% | 32K - 1M |
| 8K | 17.0 GB | 8K × 32 = 256K/层 | +25% | 1M - 4M |
| 4K | 16.8 GB | 4K × 32 = 128K/层 | +50% | > 4M |
CPU-GPU 传输开销:
- 每个 chunk 需要从 CPU 加载之前的所有 KV
- 传输量:
chunk_idx × chunk_size × kv_dim × 2 × dtype_size - Chunk 0: 0 KB(没有之前的 KV)
- Chunk 1: 16K × 1K × 2 × 2 = 64 MB
- Chunk 63: 63 × 64 MB = 4 GB
Layerwise Offload + Chunked Prefill 组合
注意:本节描述的是"层内 Chunked Prefill"方案。如果采用 Streaming Chunked Prefill,请参考上一节。
为什么需要组合方案?
Chunked MLP 的局限性:
- ✅ 解决了 MLP 瓶颈(98% 降低)
- ❌ 但 Attention 仍是瓶颈(30GB @ 1M)
Chunked Prefill 的优势:
- ✅ 同时解决 MLP 和 Attention 瓶颈
- ✅ 让 24GB GPU 运行 1M 序列
组合方案架构
输入: input_ids [seq_len]
↓
Embedding: hidden_states [seq_len, hidden_size]
↓
┌─────────────────────────────────────────────────────────────┐
│ Layer 0 (Chunked Prefill) │
│ for chunk in chunks(hidden_states, chunk_size): │
│ - QKV Projection (chunk only) │
│ - FlashAttention (chunk Q vs accumulated K,V) │
│ - MLP (chunk only) │
│ Offload KV to CPU │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Layer 1 (Chunked Prefill) │
│ for chunk in chunks(hidden_states, chunk_size): │
│ - ... (same pattern) │
│ Offload KV to CPU │
└─────────────────────────────────────────────────────────────┘
↓
... (repeat for all layers)
↓
Final LayerNorm + LM Head
完整内存 Breakdown(seq_len=1M, chunk_size=16K)
| 组件 | 内存 | 占比 | 能否 chunk? |
|---|---|---|---|
| 模型权重 | 16 GB | 64% | ❌ 固定 |
| 最终输出 (hidden_states) | 8 GB | 32% | ❌ 层间数据流 |
| KV 累积 | 4 GB | 16% | ❌ 必须累积 |
| qkv (chunked) | 192 MB | 0.8% | ✅ 已 chunked |
| gate_up (chunked) | 875 MB | 3.5% | ✅ 已 chunked |
| 其他临时张量 | 169 MB | 0.7% | ✅ 已 chunked |
| 总计 | ~25 GB | 100% |
性能权衡分析
Kernel 调用次数:
原始: num_layers × 1 = 32
Chunked (16K): num_layers × (seq_len / chunk_size) = 32 × 64 = 2048
增加: +6300%
预期性能开销:
| chunk_size | 内存节省 | 预期时间开销 | 适用场景 |
|---|---|---|---|
| 32K | 15-24% | +5-10% | 64K-128K |
| 16K | 24-46% | +10-20% | 128K-512K |
| 8K | 34-57% | +20-40% | 512K-1M |
| 4K | 46-78% | +40-80% | >1M |
为什么最终输出不能 Chunk?
原因:hidden_states 是层与层之间的数据流
# Layer N 的输出
hidden_states: [seq_len, hidden_size]
# Layer N+1 的输入
# 需要 Layer N 的完整输出!
qkv = qkv_proj(hidden_states) # 需要完整的 hidden_states
突破方法:
- 模型并行:分布 hidden_states 到多个 GPU
- 流水线并行:不同层在不同 GPU
- 降低 hidden_size:使用更小的模型
实现指南
配置参数
@dataclass
class Config:
# Chunked prefill 配置
prefill_chunk_size: int = 0 # 0 = 不 chunk, >0 = chunk 大小
# 自动选择策略
def get_chunk_size(self, seq_len: int) -> int:
if self.prefill_chunk_size > 0:
return self.prefill_chunk_size
# 自动选择
if seq_len < 32000:
return 0 # 不需要 chunk
elif seq_len < 128000:
return 16384
elif seq_len < 512000:
return 8192
else:
return 4096
实现步骤
第一步:验证 Chunked MLP(~50 行代码)
# 在 model_runner.py 的 run_layerwise_offload_prefill 中
# 将 MLP 调用改为 chunked
chunk_size = self.config.get_chunk_size(len(hidden_states))
if chunk_size > 0:
hidden_states = chunked_mlp_forward(layer.mlp, hidden_states, chunk_size)
else:
hidden_states = layer.mlp(hidden_states)
第二步:实现 Chunked Prefill(~150 行代码)
def run_layerwise_offload_chunked_prefill(self, seqs):
chunk_size = self.config.get_chunk_size(total_tokens)
hidden_states = self.model.model.embed_tokens(input_ids)
for layer_id in range(num_layers):
if chunk_size == 0:
# 原始实现
hidden_states = self._process_layer_full(layer, hidden_states)
else:
# Chunked 实现
hidden_states = self._process_layer_chunked(
layer, hidden_states, chunk_size
)
# Offload KV to CPU
...
return hidden_states
优化建议
-
预分配输出 buffer:
output = torch.empty_like(hidden_states) for chunk in chunks: output[start:end] = process_chunk(chunk) -
复用 KV 累积 buffer:
# 预分配最大可能的 KV buffer k_buffer = torch.zeros(seq_len, kv_dim, ...) v_buffer = torch.zeros(seq_len, kv_dim, ...) # 每次 copy 到对应位置 k_buffer[current_pos:end] = chunk_k -
异步 H2D 传输:
# 在处理当前 chunk 时,异步加载下一层需要的 KV # (需要流式处理支持)
测试验证
# 验证正确性
def test_chunked_prefill_correctness():
model = LLM("path/to/model")
# 原始实现
output_original = model.generate(prompt, enable_cpu_offload=True, prefill_chunk_size=0)
# Chunked 实现
output_chunked = model.generate(prompt, enable_cpu_offload=True, prefill_chunk_size=16384)
# 验证结果相同
assert torch.allclose(output_original, output_chunked, atol=1e-3)
# 验证内存
def test_chunked_prefill_memory():
torch.cuda.reset_peak_memory_stats()
model = LLM("path/to/model", enable_cpu_offload=True, prefill_chunk_size=16384)
model.generate(prompt_1M)
peak = torch.cuda.max_memory_allocated()
assert peak < 24 * 1024**3 # 24GB
总结
关键发现
-
理论最大序列:~16.8M tokens(RoPE float32 限制)
-
主要瓶颈演进(按优化顺序):
- RoPE 精度(最严格,硬性上限)
- MLP gate_up 激活值(原始瓶颈)
- Attention 中间张量(第二瓶颈)
- 最终输出 hidden_states(第三瓶颈,已被 Streaming 消除)
-
Streaming Chunked Prefill(最终方案)效果:
- GPU 内存独立于 seq_len:17.5 GB 恒定
- 内存降低:87%+(相比 layerwise offload)
- 序列扩展:128×(32K → 4M,受 CPU 内存限制)
- 性能开销:10-25%(取决于 chunk_size)
方案对比总结
| 方案 | 最大序列(24GB GPU) | GPU 内存 | 瓶颈 | 推荐度 |
|---|---|---|---|---|
| Layerwise Offload | ~32K | 20-137 GB | GPU 内存 | ⭐⭐⭐⭐ |
| +Chunked MLP | ~100K | 19-98 GB | GPU 内存 | ⭐⭐⭐ |
| +Chunked Prefill(层内) | ~1M | 17-25 GB | GPU 内存 | ⭐⭐⭐⭐ |
| +Streaming(最终) | ~4M | 17.5 GB(恒定) | CPU 内存 | ⭐⭐⭐⭐⭐ |
实现优先级(更新)
| 阶段 | 方案 | 效果 | 优先级 |
|---|---|---|---|
| 已实现 | Layerwise Offload | 32K | P0 ✅ |
| 第一步 | +Chunked MLP | ~100K | P2(可跳过) |
| 第二步 | +Chunked Prefill(层内) | ~1M | P1 |
| 最终目标 | +Streaming | ~4M | P0 ✅ |
Streaming Chunked Prefill 关键优势
- ✅ GPU 内存恒定:17.5 GB,与 seq_len 无关
- ✅ 突破 GPU 限制:最大序列由 CPU 内存决定
- ✅ 实现复杂度可控:~200 行代码
- ✅ 与 layerwise offload 完美集成
- ⚠️ CPU-GPU 传输增加:每个 chunk 需加载之前所有 KV
- ⚠️ 性能开销:10-25%(可接受的权衡)
最终推荐
对于超长序列推理(> 128K):
- ✅ 使用 Streaming Chunked Prefill(最优方案)
- ✅ 设置
prefill_chunk_size = 16384(自动选择) - ✅ 设置
num_kv_buffers = 1(降低 ring buffer 内存) - ✅ 确保充足 CPU 内存(512GB → 4M tokens)
效果:
- RTX 3090 (24GB): 32K → ~4M (128× 提升) ✅
- A100 40GB: 340K → ~4M (12× 提升) ✅
- A100 80GB: 1M → ~4M (4× 提升) ✅
唯一限制:
- CPU 内存:512 GB → ~4M tokens
- RoPE 硬性上限:~16.8M tokens(需改 float64)