[feat] Fixed warmup memory overhead.

This commit is contained in:
Zijie Tian
2025-12-15 21:39:14 +08:00
parent 91a0f09a24
commit dc7807a211
2 changed files with 20 additions and 72 deletions

View File

@@ -95,9 +95,14 @@ class ModelRunner:
def warmup_model(self): def warmup_model(self):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len # Use a reasonable warmup length instead of max_model_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) # Warmup only needs to trigger CUDA kernel JIT compilation
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] # Using 2 blocks is sufficient and avoids huge memory allocation
warmup_len = min(self.block_size * 2, self.config.max_model_len)
warmup_len = max(warmup_len, 128) # At least 128 tokens
num_seqs = min(self.config.max_num_batched_tokens // warmup_len, self.config.max_num_seqs, 4)
num_seqs = max(num_seqs, 1)
seqs = [Sequence([0] * warmup_len) for _ in range(num_seqs)]
self.run(seqs, True) self.run(seqs, True)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -237,7 +237,7 @@ def flash_attn_with_lse(
""" """
Flash attention forward pass that returns both output and LSE. Flash attention forward pass that returns both output and LSE.
Supports GQA (grouped query attention) where num_kv_heads < num_q_heads. Uses flash_attn library which natively supports GQA without memory overhead.
Args: Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim] q: Query tensor [batch, seqlen_q, nheads_q, headdim]
@@ -250,85 +250,28 @@ def flash_attn_with_lse(
out: Output tensor [batch, seqlen_q, nheads_q, headdim] out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
""" """
# Ensure contiguous from flash_attn.flash_attn_interface import flash_attn_func
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
k = k.contiguous()
if not v.is_contiguous():
v = v.contiguous()
batch, seqlen_q, nheads_q, headdim = q.shape batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape _, seqlen_k, nheads_kv, _ = k.shape
assert k.shape == (batch, seqlen_k, nheads_kv, headdim)
assert v.shape == (batch, seqlen_k, nheads_kv, headdim)
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.dtype == k.dtype == v.dtype
# Handle GQA by repeating K/V heads
if nheads_kv != nheads_q:
assert nheads_q % nheads_kv == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_kv ({nheads_kv})"
repeat_factor = nheads_q // nheads_kv
# [batch, seqlen_k, nheads_kv, headdim] -> [batch, seqlen_k, nheads_q, headdim]
k = k.repeat_interleave(repeat_factor, dim=2)
v = v.repeat_interleave(repeat_factor, dim=2)
nheads = nheads_q
else:
nheads = nheads_q
if softmax_scale is None: if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim) softmax_scale = 1.0 / math.sqrt(headdim)
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 # Use flash_attn_func which natively supports GQA (no memory overhead)
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) # It returns (output, softmax_lse) when return_attn_probs=True is not set
out = torch.empty_like(q) # We need to use the internal function to get LSE
out, lse, _ = flash_attn_func(
BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16) q, k, v,
BLOCK = 128 softmax_scale=softmax_scale,
num_warps = 4 if headdim <= 64 else 8 causal=causal,
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
_fwd_kernel_with_lse[grid](
q,
k,
v,
out,
lse,
softmax_scale,
q.stride(0),
q.stride(2),
q.stride(1),
k.stride(0),
k.stride(2),
k.stride(1),
v.stride(0),
v.stride(2),
v.stride(1),
out.stride(0),
out.stride(2),
out.stride(1),
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
seqlen_q // 32,
seqlen_k // 32,
causal,
BLOCK_HEADDIM,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
) )
# Trim LSE to actual seqlen_q # lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
# Trim to actual seqlen_q
lse = lse[:, :, :seqlen_q] lse = lse[:, :, :seqlen_q]
# Ensure output has same dtype as input
out = out.to(q.dtype)
return out, lse return out, lse