[feat] Fixed warmup memory overhead.
This commit is contained in:
@@ -237,7 +237,7 @@ def flash_attn_with_lse(
|
||||
"""
|
||||
Flash attention forward pass that returns both output and LSE.
|
||||
|
||||
Supports GQA (grouped query attention) where num_kv_heads < num_q_heads.
|
||||
Uses flash_attn library which natively supports GQA without memory overhead.
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
@@ -250,85 +250,28 @@ def flash_attn_with_lse(
|
||||
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
|
||||
"""
|
||||
# Ensure contiguous
|
||||
if not q.is_contiguous():
|
||||
q = q.contiguous()
|
||||
if not k.is_contiguous():
|
||||
k = k.contiguous()
|
||||
if not v.is_contiguous():
|
||||
v = v.contiguous()
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
|
||||
batch, seqlen_q, nheads_q, headdim = q.shape
|
||||
_, seqlen_k, nheads_kv, _ = k.shape
|
||||
|
||||
assert k.shape == (batch, seqlen_k, nheads_kv, headdim)
|
||||
assert v.shape == (batch, seqlen_k, nheads_kv, headdim)
|
||||
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
|
||||
# Handle GQA by repeating K/V heads
|
||||
if nheads_kv != nheads_q:
|
||||
assert nheads_q % nheads_kv == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_kv ({nheads_kv})"
|
||||
repeat_factor = nheads_q // nheads_kv
|
||||
# [batch, seqlen_k, nheads_kv, headdim] -> [batch, seqlen_k, nheads_q, headdim]
|
||||
k = k.repeat_interleave(repeat_factor, dim=2)
|
||||
v = v.repeat_interleave(repeat_factor, dim=2)
|
||||
nheads = nheads_q
|
||||
else:
|
||||
nheads = nheads_q
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
out = torch.empty_like(q)
|
||||
|
||||
BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16)
|
||||
BLOCK = 128
|
||||
num_warps = 4 if headdim <= 64 else 8
|
||||
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
||||
|
||||
_fwd_kernel_with_lse[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
lse,
|
||||
softmax_scale,
|
||||
q.stride(0),
|
||||
q.stride(2),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(2),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(2),
|
||||
v.stride(1),
|
||||
out.stride(0),
|
||||
out.stride(2),
|
||||
out.stride(1),
|
||||
nheads,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
seqlen_q_rounded,
|
||||
headdim,
|
||||
seqlen_q // 32,
|
||||
seqlen_k // 32,
|
||||
causal,
|
||||
BLOCK_HEADDIM,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
# Use flash_attn_func which natively supports GQA (no memory overhead)
|
||||
# It returns (output, softmax_lse) when return_attn_probs=True is not set
|
||||
# We need to use the internal function to get LSE
|
||||
out, lse, _ = flash_attn_func(
|
||||
q, k, v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
|
||||
)
|
||||
|
||||
# Trim LSE to actual seqlen_q
|
||||
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
|
||||
# Trim to actual seqlen_q
|
||||
lse = lse[:, :, :seqlen_q]
|
||||
|
||||
# Ensure output has same dtype as input
|
||||
out = out.to(q.dtype)
|
||||
|
||||
return out, lse
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user