⚡️ perf: replace Triton merge with FlashInfer merge_state
Use FlashInfer's optimized merge_state kernel for attention output merging in chunked prefill. End-to-end improvement: +0.8% (32K) to +2.4% (64K). Key changes: - Add merge_attention_outputs_flashinfer() with LSE format conversion - FlashInfer uses log2, flash_attn uses ln: convert via LOG2_E/LN_2 - Keep original Triton kernel for fallback Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -414,6 +414,90 @@ def merge_attention_outputs(
|
||||
return o_merged, lse_merged
|
||||
|
||||
|
||||
# ============================================================
|
||||
# FlashInfer-based implementations (recommended for merge only)
|
||||
# ============================================================
|
||||
|
||||
# LSE conversion constants: FlashInfer uses log2, flash_attn uses ln
|
||||
_LOG2_E = 1.4426950408889634 # math.log2(math.e) - ln -> log2
|
||||
_LN_2 = 0.6931471805599453 # math.log(2) - log2 -> ln
|
||||
|
||||
# Check FlashInfer availability (only for merge_state, not attention kernel)
|
||||
try:
|
||||
from flashinfer.cascade import merge_state, merge_state_in_place
|
||||
FLASHINFER_MERGE_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASHINFER_MERGE_AVAILABLE = False
|
||||
|
||||
|
||||
def flash_attn_with_lse_flashinfer(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Flash attention that returns output and LSE.
|
||||
|
||||
Uses flash_attn library (FlashInfer attention has JIT compatibility issues).
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
|
||||
causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] (ln format)
|
||||
"""
|
||||
# Use flash_attn directly (FlashInfer attention JIT has CUDA version issues)
|
||||
return flash_attn_with_lse(q, k, v, softmax_scale, causal)
|
||||
|
||||
|
||||
def merge_attention_outputs_flashinfer(
|
||||
o1: torch.Tensor,
|
||||
lse1: torch.Tensor,
|
||||
o2: torch.Tensor,
|
||||
lse2: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Merge two attention outputs using FlashInfer's optimized kernel.
|
||||
|
||||
Args:
|
||||
o1: First output [batch, seqlen_q, nheads, headdim]
|
||||
lse1: First LSE [batch, nheads, seqlen_q] (ln format)
|
||||
o2: Second output [batch, seqlen_q, nheads, headdim]
|
||||
lse2: Second LSE [batch, nheads, seqlen_q] (ln format)
|
||||
|
||||
Returns:
|
||||
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
||||
lse_merged: Merged LSE [batch, nheads, seqlen_q] (ln format)
|
||||
"""
|
||||
if not FLASHINFER_MERGE_AVAILABLE:
|
||||
# Fallback to Triton implementation
|
||||
return merge_attention_outputs(o1, lse1, o2, lse2)
|
||||
|
||||
# Convert to FlashInfer format
|
||||
# o: [batch, seq, heads, dim] -> [seq, heads, dim]
|
||||
# lse: [batch, heads, seq] -> [seq, heads] (convert ln -> log2)
|
||||
v_a = o1.squeeze(0).contiguous()
|
||||
s_a = (lse1.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E)
|
||||
v_b = o2.squeeze(0).contiguous()
|
||||
s_b = (lse2.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E)
|
||||
|
||||
# FlashInfer merge
|
||||
v_merged, s_merged = merge_state(v_a, s_a, v_b, s_b)
|
||||
|
||||
# Convert back to flash_attn format
|
||||
o_merged = v_merged.unsqueeze(0) # [1, seq, heads, dim]
|
||||
lse_merged = (s_merged * _LN_2).transpose(0, 1).unsqueeze(0) # [1, heads, seq]
|
||||
|
||||
return o_merged, lse_merged
|
||||
|
||||
|
||||
def chunked_attention_varlen(
|
||||
q: torch.Tensor,
|
||||
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
|
||||
Reference in New Issue
Block a user