[feat] Fixed warmup memory overhead.

This commit is contained in:
Zijie Tian
2025-12-15 21:39:14 +08:00
parent 91a0f09a24
commit dc7807a211
2 changed files with 20 additions and 72 deletions

View File

@@ -237,7 +237,7 @@ def flash_attn_with_lse(
"""
Flash attention forward pass that returns both output and LSE.
Supports GQA (grouped query attention) where num_kv_heads < num_q_heads.
Uses flash_attn library which natively supports GQA without memory overhead.
Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
@@ -250,85 +250,28 @@ def flash_attn_with_lse(
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
"""
# Ensure contiguous
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
k = k.contiguous()
if not v.is_contiguous():
v = v.contiguous()
from flash_attn.flash_attn_interface import flash_attn_func
batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape
assert k.shape == (batch, seqlen_k, nheads_kv, headdim)
assert v.shape == (batch, seqlen_k, nheads_kv, headdim)
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.dtype == k.dtype == v.dtype
# Handle GQA by repeating K/V heads
if nheads_kv != nheads_q:
assert nheads_q % nheads_kv == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_kv ({nheads_kv})"
repeat_factor = nheads_q // nheads_kv
# [batch, seqlen_k, nheads_kv, headdim] -> [batch, seqlen_k, nheads_q, headdim]
k = k.repeat_interleave(repeat_factor, dim=2)
v = v.repeat_interleave(repeat_factor, dim=2)
nheads = nheads_q
else:
nheads = nheads_q
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
out = torch.empty_like(q)
BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16)
BLOCK = 128
num_warps = 4 if headdim <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_kernel_with_lse[grid](
q,
k,
v,
out,
lse,
softmax_scale,
q.stride(0),
q.stride(2),
q.stride(1),
k.stride(0),
k.stride(2),
k.stride(1),
v.stride(0),
v.stride(2),
v.stride(1),
out.stride(0),
out.stride(2),
out.stride(1),
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
seqlen_q // 32,
seqlen_k // 32,
causal,
BLOCK_HEADDIM,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
# Use flash_attn_func which natively supports GQA (no memory overhead)
# It returns (output, softmax_lse) when return_attn_probs=True is not set
# We need to use the internal function to get LSE
out, lse, _ = flash_attn_func(
q, k, v,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
)
# Trim LSE to actual seqlen_q
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
# Trim to actual seqlen_q
lse = lse[:, :, :seqlen_q]
# Ensure output has same dtype as input
out = out.to(q.dtype)
return out, lse