[WIP] replace merge attention with triton kernel.

This commit is contained in:
Zijie Tian
2025-12-25 01:07:05 +08:00
parent cf5e7df093
commit 16fcf8350b
5 changed files with 490 additions and 405 deletions

View File

@@ -275,6 +275,85 @@ def flash_attn_with_lse(
return out, lse
@triton.jit
def _merge_lse_kernel(
lse1_ptr, lse2_ptr, lse_out_ptr,
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values."""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
# Compute max for numerical stability
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2)
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@triton.jit
def _merge_output_kernel(
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs."""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
pid_head = tl.program_id(2)
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values
lse1 = tl.load(lse1_ptr + lse_idx)
lse2 = tl.load(lse2_ptr + lse_idx)
# Compute max and scaling factors
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
sum_exp = exp1 + exp2
# Process headdim in chunks
for d_offset in range(0, headdim, BLOCK_SIZE):
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
mask = d_idx < headdim
# Compute output index: [batch, seqlen_q, nheads, headdim]
base_idx = (pid_batch * seqlen_q * nheads * headdim +
pid_seq * nheads * headdim +
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
def merge_attention_outputs(
o1: torch.Tensor,
lse1: torch.Tensor,
@@ -282,7 +361,7 @@ def merge_attention_outputs(
lse2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge two attention outputs using online softmax.
Merge two attention outputs using online softmax (Triton fused kernel).
This implements the online softmax merging formula:
- m_new = max(lse1, lse2)
@@ -299,31 +378,30 @@ def merge_attention_outputs(
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
lse_merged: Merged LSE [batch, nheads, seqlen_q]
"""
# lse shape: [batch, nheads, seqlen_q]
# o shape: [batch, seqlen_q, nheads, headdim]
batch, seqlen_q, nheads, headdim = o1.shape
# Compute max for numerical stability
max_lse = torch.maximum(lse1, lse2)
# Allocate output tensors
o_merged = torch.empty_like(o1)
lse_merged = torch.empty_like(lse1)
# Compute scaling factors
# exp1, exp2 shape: [batch, nheads, seqlen_q]
exp1 = torch.exp(lse1 - max_lse)
exp2 = torch.exp(lse2 - max_lse)
# Launch LSE merge kernel
num_lse_elements = batch * nheads * seqlen_q
BLOCK_SIZE_LSE = 256
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
_merge_lse_kernel[grid_lse](
lse1, lse2, lse_merged,
num_lse_elements,
BLOCK_SIZE=BLOCK_SIZE_LSE,
)
# Reshape for broadcasting with output
# [batch, nheads, seqlen_q] -> [batch, seqlen_q, nheads, 1]
exp1_broad = exp1.transpose(1, 2).unsqueeze(-1)
exp2_broad = exp2.transpose(1, 2).unsqueeze(-1)
# Merge outputs
sum_exp = exp1_broad + exp2_broad
o_merged = (o1 * exp1_broad + o2 * exp2_broad) / sum_exp
# Compute merged LSE
lse_merged = max_lse + torch.log(exp1 + exp2)
# Ensure output has same dtype as input
o_merged = o_merged.to(o1.dtype)
# Launch output merge kernel
BLOCK_SIZE = 128
grid_output = (batch, seqlen_q, nheads)
_merge_output_kernel[grid_output](
o1, o2, lse1, lse2, o_merged,
batch, seqlen_q, nheads, headdim,
BLOCK_SIZE=BLOCK_SIZE,
)
return o_merged, lse_merged