[WIP] fixing attention compute error.

This commit is contained in:
Zijie Tian
2025-12-30 00:31:48 +08:00
parent bf4c63c7ec
commit 89f8020d38
12 changed files with 2175 additions and 103 deletions

View File

@@ -281,7 +281,11 @@ def _merge_lse_kernel(
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values."""
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
@@ -289,21 +293,21 @@ def _merge_lse_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@@ -313,7 +317,11 @@ def _merge_output_kernel(
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs."""
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
@@ -322,11 +330,11 @@ def _merge_output_kernel(
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values
lse1 = tl.load(lse1_ptr + lse_idx)
lse2 = tl.load(lse2_ptr + lse_idx)
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
@@ -343,14 +351,14 @@ def _merge_output_kernel(
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)