️ perf: replace Triton merge with FlashInfer merge_state

Use FlashInfer's optimized merge_state kernel for attention output merging
in chunked prefill. End-to-end improvement: +0.8% (32K) to +2.4% (64K).

Key changes:
- Add merge_attention_outputs_flashinfer() with LSE format conversion
- FlashInfer uses log2, flash_attn uses ln: convert via LOG2_E/LN_2
- Keep original Triton kernel for fallback

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-28 10:04:38 +08:00
parent 4484ebbb77
commit 8d19e61446
4 changed files with 146 additions and 4 deletions

View File

@@ -652,7 +652,11 @@ class XAttentionBSAPolicy(SparsePolicy):
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# Use FlashInfer-based implementations (more optimized)
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
merge_attention_outputs_flashinfer as merge_attention_outputs,
)
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None