diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index ad60f61..38cd684 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -95,9 +95,14 @@ class ModelRunner: def warmup_model(self): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len - num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) - seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] + # Use a reasonable warmup length instead of max_model_len + # Warmup only needs to trigger CUDA kernel JIT compilation + # Using 2 blocks is sufficient and avoids huge memory allocation + warmup_len = min(self.block_size * 2, self.config.max_model_len) + warmup_len = max(warmup_len, 128) # At least 128 tokens + num_seqs = min(self.config.max_num_batched_tokens // warmup_len, self.config.max_num_seqs, 4) + num_seqs = max(num_seqs, 1) + seqs = [Sequence([0] * warmup_len) for _ in range(num_seqs)] self.run(seqs, True) torch.cuda.empty_cache() diff --git a/nanovllm/kvcache/chunked_attention.py b/nanovllm/kvcache/chunked_attention.py index ddcb62c..bde7c58 100644 --- a/nanovllm/kvcache/chunked_attention.py +++ b/nanovllm/kvcache/chunked_attention.py @@ -237,7 +237,7 @@ def flash_attn_with_lse( """ Flash attention forward pass that returns both output and LSE. - Supports GQA (grouped query attention) where num_kv_heads < num_q_heads. + Uses flash_attn library which natively supports GQA without memory overhead. Args: q: Query tensor [batch, seqlen_q, nheads_q, headdim] @@ -250,85 +250,28 @@ def flash_attn_with_lse( out: Output tensor [batch, seqlen_q, nheads_q, headdim] lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] """ - # Ensure contiguous - if not q.is_contiguous(): - q = q.contiguous() - if not k.is_contiguous(): - k = k.contiguous() - if not v.is_contiguous(): - v = v.contiguous() + from flash_attn.flash_attn_interface import flash_attn_func batch, seqlen_q, nheads_q, headdim = q.shape _, seqlen_k, nheads_kv, _ = k.shape - assert k.shape == (batch, seqlen_k, nheads_kv, headdim) - assert v.shape == (batch, seqlen_k, nheads_kv, headdim) - assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" - assert q.dtype == k.dtype == v.dtype - - # Handle GQA by repeating K/V heads - if nheads_kv != nheads_q: - assert nheads_q % nheads_kv == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_kv ({nheads_kv})" - repeat_factor = nheads_q // nheads_kv - # [batch, seqlen_k, nheads_kv, headdim] -> [batch, seqlen_k, nheads_q, headdim] - k = k.repeat_interleave(repeat_factor, dim=2) - v = v.repeat_interleave(repeat_factor, dim=2) - nheads = nheads_q - else: - nheads = nheads_q - if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(headdim) - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - out = torch.empty_like(q) - - BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16) - BLOCK = 128 - num_warps = 4 if headdim <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - - _fwd_kernel_with_lse[grid]( - q, - k, - v, - out, - lse, - softmax_scale, - q.stride(0), - q.stride(2), - q.stride(1), - k.stride(0), - k.stride(2), - k.stride(1), - v.stride(0), - v.stride(2), - v.stride(1), - out.stride(0), - out.stride(2), - out.stride(1), - nheads, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - headdim, - seqlen_q // 32, - seqlen_k // 32, - causal, - BLOCK_HEADDIM, - BLOCK_M=BLOCK, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, + # Use flash_attn_func which natively supports GQA (no memory overhead) + # It returns (output, softmax_lse) when return_attn_probs=True is not set + # We need to use the internal function to get LSE + out, lse, _ = flash_attn_func( + q, k, v, + softmax_scale=softmax_scale, + causal=causal, + return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask) ) - # Trim LSE to actual seqlen_q + # lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded] + # Trim to actual seqlen_q lse = lse[:, :, :seqlen_q] - # Ensure output has same dtype as input - out = out.to(q.dtype) - return out, lse