♻️ refactor: chunked LayerNorm/QKV/MLP for 64k memory optimization

Implement chunked processing for LayerNorm, QKV projection, and MLP
layers to reduce peak activation memory for 64k sequence inference.

Changes:
- Chunked input_layernorm and post_attention_layernorm (chunk_size=128)
- Chunked QKV projection (chunk_size=128)
- Chunked MLP processing (chunk_size=128) with memory cleanup
- Added torch.cuda.empty_cache() calls after each chunk

This reduces peak activation from ~2 GB to ~50 MB per layer,
making 64k inference theoretically possible on 24GB GPUs
(though still limited by memory fragmentation).

Related: docs/64k_memory_analysis.md

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-14 07:01:57 +08:00
parent cf168fd9b9
commit dce6ad6b74

View File

@@ -786,15 +786,56 @@ class ModelRunner:
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# 2a. Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
# 2a. Input LayerNorm (chunked for long sequences)
# LayerNorm creates float32 temporaries: seq_len * hidden_size * 4 bytes
# For 64k: 65536 * 4096 * 4 = ~1 GB per operation
# Using chunk_size=4096 reduces peak to ~125 MB
layernorm_chunk_size = 128
if total_tokens > layernorm_chunk_size:
if residual is None:
# Chunked input_layernorm
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
ln_chunks = []
res_chunks = []
for chunk in hs_chunks:
ln, res = layer.input_layernorm(chunk), chunk
ln_chunks.append(ln)
res_chunks.append(res)
hidden_ln = torch.cat(ln_chunks, dim=0)
residual = torch.cat(res_chunks, dim=0)
else:
# Chunked input_layernorm with residual
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
res_chunks_in = residual.split(layernorm_chunk_size, dim=0)
ln_chunks = []
res_chunks_out = []
for hs_chunk, res_chunk in zip(hs_chunks, res_chunks_in):
ln, res = layer.input_layernorm(hs_chunk, res_chunk)
ln_chunks.append(ln)
res_chunks_out.append(res)
hidden_ln = torch.cat(ln_chunks, dim=0)
residual = torch.cat(res_chunks_out, dim=0)
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# 2b. Self-attention (full sequence)
# QKV projection
qkv = layer.self_attn.qkv_proj(hidden_ln)
# Chunked QKV projection to reduce activation memory for long sequences
# QKV activation = seq_len * (q_size + 2*kv_size) * 2 bytes
# For 64k: 65536 * (4096 + 2*1024) * 2 = ~805 MB
# Using chunk_size=2048 reduces peak to ~25 MB
qkv_chunk_size = 128
if total_tokens > qkv_chunk_size:
chunks = hidden_ln.split(qkv_chunk_size, dim=0)
qkv_chunks = []
for chunk in chunks:
qkv_chunks.append(layer.self_attn.qkv_proj(chunk))
qkv = torch.cat(qkv_chunks, dim=0)
else:
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k, v = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
@@ -838,9 +879,40 @@ class ModelRunner:
attn_output = attn_output.view(total_tokens, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# 2c. Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# 2c. Post-attention LayerNorm (chunked for long sequences)
layernorm_chunk_size = 128
if total_tokens > layernorm_chunk_size:
# Chunked post_attention_layernorm
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
res_chunks_in = residual.split(layernorm_chunk_size, dim=0)
ln_chunks = []
res_chunks_out = []
for hs_chunk, res_chunk in zip(hs_chunks, res_chunks_in):
ln, res = layer.post_attention_layernorm(hs_chunk, res_chunk)
ln_chunks.append(ln)
res_chunks_out.append(res)
hidden_states = torch.cat(ln_chunks, dim=0)
residual = torch.cat(res_chunks_out, dim=0)
else:
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
# Chunked MLP processing to reduce activation memory for long sequences
# MLP activation = seq_len * intermediate_size * 2 bytes
# For 64k: 65536 * 14336 * 2 = ~1.75 GB (down_proj input)
# Using chunk_size=2048 reduces peak to ~55 MB
mlp_chunk_size = 128
if total_tokens > mlp_chunk_size:
chunks = hidden_states.split(mlp_chunk_size, dim=0)
outputs = []
for i, chunk in enumerate(chunks):
outputs.append(layer.mlp(chunk))
del chunk
torch.cuda.empty_cache() # Clean after every chunk
hidden_states = torch.cat(outputs, dim=0)
del outputs
torch.cuda.empty_cache()
else:
hidden_states = layer.mlp(hidden_states)
# 2d. Offload KV to CPU (encapsulated with sparse policy hooks)
offload_engine.offload_layer_kv_sync(layer_id, k, v, cpu_block_ids, total_tokens)