From 8d19e61446e2ebd0310a3867297a669bea87ccd3 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 28 Jan 2026 10:04:38 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20perf:=20replace=20Triton?= =?UTF-8?q?=20merge=20with=20FlashInfer=20merge=5Fstate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use FlashInfer's optimized merge_state kernel for attention output merging in chunked prefill. End-to-end improvement: +0.8% (32K) to +2.4% (64K). Key changes: - Add merge_attention_outputs_flashinfer() with LSE format conversion - FlashInfer uses log2, flash_attn uses ln: convert via LOG2_E/LN_2 - Keep original Triton kernel for fallback Co-Authored-By: Claude Opus 4.5 --- docs/bench_offload_results.md | 42 +++++++++++++ nanovllm/kvcache/sparse/full_policy.py | 18 +++++- nanovllm/kvcache/sparse/xattn_bsa.py | 6 +- nanovllm/ops/chunked_attention.py | 84 ++++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 4 deletions(-) diff --git a/docs/bench_offload_results.md b/docs/bench_offload_results.md index ecb7b2e..cd6ade4 100644 --- a/docs/bench_offload_results.md +++ b/docs/bench_offload_results.md @@ -150,8 +150,50 @@ CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 256 CUDA_VISIBLE_DEVICES=0 python bench_offload.py --enable-xattn --xattn-threshold 0.8 --xattn-stride 16 ``` +## FlashInfer Merge 优化 (2026-01-28) + +将 Triton 实现的 `merge_attention_outputs` 替换为 FlashInfer 的 `cascade.merge_state`。 + +### 性能对比 (Full Attention, block-size 4096) + +| 上下文 | Triton merge | FlashInfer merge | 提升 | +|--------|--------------|------------------|------| +| 32K | 4678 tok/s | 4717 tok/s | **+0.8%** | +| 64K | 3331 tok/s | 3411 tok/s | **+2.4%** | +| 128K | 2144 tok/s | 2178 tok/s | **+1.6%** | + +### 关键发现 + +1. **端到端提升有限**(0.8% ~ 2.4%):merge 操作不是主要瓶颈 + - H2D 传输占主导(64K 传输 64GB) + - Attention 计算是另一主要耗时 + - Merge 在总耗时中占比很小 + +2. **Merge kernel 单独对比**(长序列时 FlashInfer 优势明显): + +| seq_len | heads | Triton (ms) | FlashInfer (ms) | Speedup | +|---------|-------|-------------|-----------------|---------| +| 4096 | 32 | 0.129 | 0.087 | **1.49x** | +| 8192 | 32 | 0.251 | 0.147 | **1.70x** | +| 16384 | 32 | 0.499 | 0.274 | **1.82x** | + +3. **短序列 FlashInfer 反而慢**:格式转换开销(squeeze, transpose, contiguous) + +### 技术细节 + +- **LSE 格式差异**:FlashInfer 使用 log2,flash_attn 使用 ln +- **转换系数**:`LOG2_E = 1.4427`(ln → log2),`LN_2 = 0.6931`(log2 → ln) +- **FlashInfer attention JIT 问题**:CUDA 版本兼容性问题,仅使用 merge_state + +### 代码位置 + +- `nanovllm/ops/chunked_attention.py`: `merge_attention_outputs_flashinfer()` +- `nanovllm/kvcache/sparse/full_policy.py`: 3 处 import 更新 +- `nanovllm/kvcache/sparse/xattn_bsa.py`: 1 处 import 更新 + ## 更新记录 +- 2026-01-28: **FlashInfer merge 替换 Triton merge**,端到端提升 0.8% ~ 2.4% - 2026-01-28: **estimate_block_size 优化后重新测试**,128K XAttention 反超 Full (+2.4%) - 2026-01-27: 添加 GPU-only vs Offload 对比,block size 影响分析 - 2026-01-27: 初始测试,Llama-3.1-8B-Instruct, A100 80GB diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index d342a7e..5b6606c 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -185,7 +185,11 @@ class FullAttentionPolicy(SparsePolicy): Returns: Attention output [seq_len, num_heads, head_dim] """ - from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs + # Use FlashInfer-based implementations (more optimized) + from nanovllm.ops.chunked_attention import ( + flash_attn_with_lse_flashinfer as flash_attn_with_lse, + merge_attention_outputs_flashinfer as merge_attention_outputs, + ) logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, " f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, " @@ -313,7 +317,11 @@ class FullAttentionPolicy(SparsePolicy): Returns: Attention output [batch_size, 1, num_heads, head_dim] """ - from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs + # Use FlashInfer-based implementations (more optimized) + from nanovllm.ops.chunked_attention import ( + flash_attn_with_lse_flashinfer as flash_attn_with_lse, + merge_attention_outputs_flashinfer as merge_attention_outputs, + ) # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] @@ -405,7 +413,11 @@ class FullAttentionPolicy(SparsePolicy): Loads one block at a time, computes attention, and merges results. Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods. """ - from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs + # Use FlashInfer-based implementations (more optimized) + from nanovllm.ops.chunked_attention import ( + flash_attn_with_lse_flashinfer as flash_attn_with_lse, + merge_attention_outputs_flashinfer as merge_attention_outputs, + ) num_blocks = len(cpu_block_table) if num_blocks == 0: diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 879a7a9..bf3978a 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -652,7 +652,11 @@ class XAttentionBSAPolicy(SparsePolicy): Returns: Attention output [seq_len, num_heads, head_dim] """ - from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs + # Use FlashInfer-based implementations (more optimized) + from nanovllm.ops.chunked_attention import ( + flash_attn_with_lse_flashinfer as flash_attn_with_lse, + merge_attention_outputs_flashinfer as merge_attention_outputs, + ) q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] o_acc = None diff --git a/nanovllm/ops/chunked_attention.py b/nanovllm/ops/chunked_attention.py index 6f92c33..4502906 100644 --- a/nanovllm/ops/chunked_attention.py +++ b/nanovllm/ops/chunked_attention.py @@ -414,6 +414,90 @@ def merge_attention_outputs( return o_merged, lse_merged +# ============================================================ +# FlashInfer-based implementations (recommended for merge only) +# ============================================================ + +# LSE conversion constants: FlashInfer uses log2, flash_attn uses ln +_LOG2_E = 1.4426950408889634 # math.log2(math.e) - ln -> log2 +_LN_2 = 0.6931471805599453 # math.log(2) - log2 -> ln + +# Check FlashInfer availability (only for merge_state, not attention kernel) +try: + from flashinfer.cascade import merge_state, merge_state_in_place + FLASHINFER_MERGE_AVAILABLE = True +except ImportError: + FLASHINFER_MERGE_AVAILABLE = False + + +def flash_attn_with_lse_flashinfer( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flash attention that returns output and LSE. + + Uses flash_attn library (FlashInfer attention has JIT compatibility issues). + + Args: + q: Query tensor [batch, seqlen_q, nheads_q, headdim] + k: Key tensor [batch, seqlen_k, nheads_kv, headdim] + v: Value tensor [batch, seqlen_k, nheads_kv, headdim] + softmax_scale: Scaling factor (default: 1/sqrt(headdim)) + causal: Whether to apply causal masking + + Returns: + out: Output tensor [batch, seqlen_q, nheads_q, headdim] + lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] (ln format) + """ + # Use flash_attn directly (FlashInfer attention JIT has CUDA version issues) + return flash_attn_with_lse(q, k, v, softmax_scale, causal) + + +def merge_attention_outputs_flashinfer( + o1: torch.Tensor, + lse1: torch.Tensor, + o2: torch.Tensor, + lse2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Merge two attention outputs using FlashInfer's optimized kernel. + + Args: + o1: First output [batch, seqlen_q, nheads, headdim] + lse1: First LSE [batch, nheads, seqlen_q] (ln format) + o2: Second output [batch, seqlen_q, nheads, headdim] + lse2: Second LSE [batch, nheads, seqlen_q] (ln format) + + Returns: + o_merged: Merged output [batch, seqlen_q, nheads, headdim] + lse_merged: Merged LSE [batch, nheads, seqlen_q] (ln format) + """ + if not FLASHINFER_MERGE_AVAILABLE: + # Fallback to Triton implementation + return merge_attention_outputs(o1, lse1, o2, lse2) + + # Convert to FlashInfer format + # o: [batch, seq, heads, dim] -> [seq, heads, dim] + # lse: [batch, heads, seq] -> [seq, heads] (convert ln -> log2) + v_a = o1.squeeze(0).contiguous() + s_a = (lse1.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E) + v_b = o2.squeeze(0).contiguous() + s_b = (lse2.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E) + + # FlashInfer merge + v_merged, s_merged = merge_state(v_a, s_a, v_b, s_b) + + # Convert back to flash_attn format + o_merged = v_merged.unsqueeze(0) # [1, seq, heads, dim] + lse_merged = (s_merged * _LN_2).transpose(0, 1).unsqueeze(0) # [1, heads, seq] + + return o_merged, lse_merged + + def chunked_attention_varlen( q: torch.Tensor, kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],