[feat] Fixed warmup memory overhead.
This commit is contained in:
@@ -95,9 +95,14 @@ class ModelRunner:
|
|||||||
def warmup_model(self):
|
def warmup_model(self):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
# Use a reasonable warmup length instead of max_model_len
|
||||||
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
# Warmup only needs to trigger CUDA kernel JIT compilation
|
||||||
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
# Using 2 blocks is sufficient and avoids huge memory allocation
|
||||||
|
warmup_len = min(self.block_size * 2, self.config.max_model_len)
|
||||||
|
warmup_len = max(warmup_len, 128) # At least 128 tokens
|
||||||
|
num_seqs = min(self.config.max_num_batched_tokens // warmup_len, self.config.max_num_seqs, 4)
|
||||||
|
num_seqs = max(num_seqs, 1)
|
||||||
|
seqs = [Sequence([0] * warmup_len) for _ in range(num_seqs)]
|
||||||
self.run(seqs, True)
|
self.run(seqs, True)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ def flash_attn_with_lse(
|
|||||||
"""
|
"""
|
||||||
Flash attention forward pass that returns both output and LSE.
|
Flash attention forward pass that returns both output and LSE.
|
||||||
|
|
||||||
Supports GQA (grouped query attention) where num_kv_heads < num_q_heads.
|
Uses flash_attn library which natively supports GQA without memory overhead.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||||
@@ -250,85 +250,28 @@ def flash_attn_with_lse(
|
|||||||
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||||
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
|
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
|
||||||
"""
|
"""
|
||||||
# Ensure contiguous
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
if not q.is_contiguous():
|
|
||||||
q = q.contiguous()
|
|
||||||
if not k.is_contiguous():
|
|
||||||
k = k.contiguous()
|
|
||||||
if not v.is_contiguous():
|
|
||||||
v = v.contiguous()
|
|
||||||
|
|
||||||
batch, seqlen_q, nheads_q, headdim = q.shape
|
batch, seqlen_q, nheads_q, headdim = q.shape
|
||||||
_, seqlen_k, nheads_kv, _ = k.shape
|
_, seqlen_k, nheads_kv, _ = k.shape
|
||||||
|
|
||||||
assert k.shape == (batch, seqlen_k, nheads_kv, headdim)
|
|
||||||
assert v.shape == (batch, seqlen_k, nheads_kv, headdim)
|
|
||||||
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
|
||||||
assert q.dtype == k.dtype == v.dtype
|
|
||||||
|
|
||||||
# Handle GQA by repeating K/V heads
|
|
||||||
if nheads_kv != nheads_q:
|
|
||||||
assert nheads_q % nheads_kv == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_kv ({nheads_kv})"
|
|
||||||
repeat_factor = nheads_q // nheads_kv
|
|
||||||
# [batch, seqlen_k, nheads_kv, headdim] -> [batch, seqlen_k, nheads_q, headdim]
|
|
||||||
k = k.repeat_interleave(repeat_factor, dim=2)
|
|
||||||
v = v.repeat_interleave(repeat_factor, dim=2)
|
|
||||||
nheads = nheads_q
|
|
||||||
else:
|
|
||||||
nheads = nheads_q
|
|
||||||
|
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||||
|
|
||||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
# Use flash_attn_func which natively supports GQA (no memory overhead)
|
||||||
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
# It returns (output, softmax_lse) when return_attn_probs=True is not set
|
||||||
out = torch.empty_like(q)
|
# We need to use the internal function to get LSE
|
||||||
|
out, lse, _ = flash_attn_func(
|
||||||
BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16)
|
q, k, v,
|
||||||
BLOCK = 128
|
softmax_scale=softmax_scale,
|
||||||
num_warps = 4 if headdim <= 64 else 8
|
causal=causal,
|
||||||
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
|
||||||
|
|
||||||
_fwd_kernel_with_lse[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
lse,
|
|
||||||
softmax_scale,
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(2),
|
|
||||||
q.stride(1),
|
|
||||||
k.stride(0),
|
|
||||||
k.stride(2),
|
|
||||||
k.stride(1),
|
|
||||||
v.stride(0),
|
|
||||||
v.stride(2),
|
|
||||||
v.stride(1),
|
|
||||||
out.stride(0),
|
|
||||||
out.stride(2),
|
|
||||||
out.stride(1),
|
|
||||||
nheads,
|
|
||||||
seqlen_q,
|
|
||||||
seqlen_k,
|
|
||||||
seqlen_q_rounded,
|
|
||||||
headdim,
|
|
||||||
seqlen_q // 32,
|
|
||||||
seqlen_k // 32,
|
|
||||||
causal,
|
|
||||||
BLOCK_HEADDIM,
|
|
||||||
BLOCK_M=BLOCK,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trim LSE to actual seqlen_q
|
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
|
||||||
|
# Trim to actual seqlen_q
|
||||||
lse = lse[:, :, :seqlen_q]
|
lse = lse[:, :, :seqlen_q]
|
||||||
|
|
||||||
# Ensure output has same dtype as input
|
|
||||||
out = out.to(q.dtype)
|
|
||||||
|
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user