diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 3ae1a57..325d0ea 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -786,15 +786,56 @@ class ModelRunner: for layer_id in range(num_layers): layer = self.model.model.layers[layer_id] - # 2a. Input LayerNorm - if residual is None: - hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states + # 2a. Input LayerNorm (chunked for long sequences) + # LayerNorm creates float32 temporaries: seq_len * hidden_size * 4 bytes + # For 64k: 65536 * 4096 * 4 = ~1 GB per operation + # Using chunk_size=4096 reduces peak to ~125 MB + layernorm_chunk_size = 128 + if total_tokens > layernorm_chunk_size: + if residual is None: + # Chunked input_layernorm + hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0) + ln_chunks = [] + res_chunks = [] + for chunk in hs_chunks: + ln, res = layer.input_layernorm(chunk), chunk + ln_chunks.append(ln) + res_chunks.append(res) + hidden_ln = torch.cat(ln_chunks, dim=0) + residual = torch.cat(res_chunks, dim=0) + else: + # Chunked input_layernorm with residual + hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0) + res_chunks_in = residual.split(layernorm_chunk_size, dim=0) + ln_chunks = [] + res_chunks_out = [] + for hs_chunk, res_chunk in zip(hs_chunks, res_chunks_in): + ln, res = layer.input_layernorm(hs_chunk, res_chunk) + ln_chunks.append(ln) + res_chunks_out.append(res) + hidden_ln = torch.cat(ln_chunks, dim=0) + residual = torch.cat(res_chunks_out, dim=0) else: - hidden_ln, residual = layer.input_layernorm(hidden_states, residual) + if residual is None: + hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states + else: + hidden_ln, residual = layer.input_layernorm(hidden_states, residual) # 2b. Self-attention (full sequence) - # QKV projection - qkv = layer.self_attn.qkv_proj(hidden_ln) + # Chunked QKV projection to reduce activation memory for long sequences + # QKV activation = seq_len * (q_size + 2*kv_size) * 2 bytes + # For 64k: 65536 * (4096 + 2*1024) * 2 = ~805 MB + # Using chunk_size=2048 reduces peak to ~25 MB + qkv_chunk_size = 128 + if total_tokens > qkv_chunk_size: + chunks = hidden_ln.split(qkv_chunk_size, dim=0) + qkv_chunks = [] + for chunk in chunks: + qkv_chunks.append(layer.self_attn.qkv_proj(chunk)) + qkv = torch.cat(qkv_chunks, dim=0) + else: + qkv = layer.self_attn.qkv_proj(hidden_ln) + q, k, v = qkv.split([ layer.self_attn.q_size, layer.self_attn.kv_size, @@ -838,9 +879,40 @@ class ModelRunner: attn_output = attn_output.view(total_tokens, -1) hidden_states = layer.self_attn.o_proj(attn_output) - # 2c. Post-attention LayerNorm + MLP - hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual) - hidden_states = layer.mlp(hidden_states) + # 2c. Post-attention LayerNorm (chunked for long sequences) + layernorm_chunk_size = 128 + if total_tokens > layernorm_chunk_size: + # Chunked post_attention_layernorm + hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0) + res_chunks_in = residual.split(layernorm_chunk_size, dim=0) + ln_chunks = [] + res_chunks_out = [] + for hs_chunk, res_chunk in zip(hs_chunks, res_chunks_in): + ln, res = layer.post_attention_layernorm(hs_chunk, res_chunk) + ln_chunks.append(ln) + res_chunks_out.append(res) + hidden_states = torch.cat(ln_chunks, dim=0) + residual = torch.cat(res_chunks_out, dim=0) + else: + hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual) + + # Chunked MLP processing to reduce activation memory for long sequences + # MLP activation = seq_len * intermediate_size * 2 bytes + # For 64k: 65536 * 14336 * 2 = ~1.75 GB (down_proj input) + # Using chunk_size=2048 reduces peak to ~55 MB + mlp_chunk_size = 128 + if total_tokens > mlp_chunk_size: + chunks = hidden_states.split(mlp_chunk_size, dim=0) + outputs = [] + for i, chunk in enumerate(chunks): + outputs.append(layer.mlp(chunk)) + del chunk + torch.cuda.empty_cache() # Clean after every chunk + hidden_states = torch.cat(outputs, dim=0) + del outputs + torch.cuda.empty_cache() + else: + hidden_states = layer.mlp(hidden_states) # 2d. Offload KV to CPU (encapsulated with sparse policy hooks) offload_engine.offload_layer_kv_sync(layer_id, k, v, cpu_block_ids, total_tokens)