⚡️ perf: replace Triton merge with FlashInfer merge_state
Use FlashInfer's optimized merge_state kernel for attention output merging in chunked prefill. End-to-end improvement: +0.8% (32K) to +2.4% (64K). Key changes: - Add merge_attention_outputs_flashinfer() with LSE format conversion - FlashInfer uses log2, flash_attn uses ln: convert via LOG2_E/LN_2 - Keep original Triton kernel for fallback Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -150,8 +150,50 @@ CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 256
|
||||
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --enable-xattn --xattn-threshold 0.8 --xattn-stride 16
|
||||
```
|
||||
|
||||
## FlashInfer Merge 优化 (2026-01-28)
|
||||
|
||||
将 Triton 实现的 `merge_attention_outputs` 替换为 FlashInfer 的 `cascade.merge_state`。
|
||||
|
||||
### 性能对比 (Full Attention, block-size 4096)
|
||||
|
||||
| 上下文 | Triton merge | FlashInfer merge | 提升 |
|
||||
|--------|--------------|------------------|------|
|
||||
| 32K | 4678 tok/s | 4717 tok/s | **+0.8%** |
|
||||
| 64K | 3331 tok/s | 3411 tok/s | **+2.4%** |
|
||||
| 128K | 2144 tok/s | 2178 tok/s | **+1.6%** |
|
||||
|
||||
### 关键发现
|
||||
|
||||
1. **端到端提升有限**(0.8% ~ 2.4%):merge 操作不是主要瓶颈
|
||||
- H2D 传输占主导(64K 传输 64GB)
|
||||
- Attention 计算是另一主要耗时
|
||||
- Merge 在总耗时中占比很小
|
||||
|
||||
2. **Merge kernel 单独对比**(长序列时 FlashInfer 优势明显):
|
||||
|
||||
| seq_len | heads | Triton (ms) | FlashInfer (ms) | Speedup |
|
||||
|---------|-------|-------------|-----------------|---------|
|
||||
| 4096 | 32 | 0.129 | 0.087 | **1.49x** |
|
||||
| 8192 | 32 | 0.251 | 0.147 | **1.70x** |
|
||||
| 16384 | 32 | 0.499 | 0.274 | **1.82x** |
|
||||
|
||||
3. **短序列 FlashInfer 反而慢**:格式转换开销(squeeze, transpose, contiguous)
|
||||
|
||||
### 技术细节
|
||||
|
||||
- **LSE 格式差异**:FlashInfer 使用 log2,flash_attn 使用 ln
|
||||
- **转换系数**:`LOG2_E = 1.4427`(ln → log2),`LN_2 = 0.6931`(log2 → ln)
|
||||
- **FlashInfer attention JIT 问题**:CUDA 版本兼容性问题,仅使用 merge_state
|
||||
|
||||
### 代码位置
|
||||
|
||||
- `nanovllm/ops/chunked_attention.py`: `merge_attention_outputs_flashinfer()`
|
||||
- `nanovllm/kvcache/sparse/full_policy.py`: 3 处 import 更新
|
||||
- `nanovllm/kvcache/sparse/xattn_bsa.py`: 1 处 import 更新
|
||||
|
||||
## 更新记录
|
||||
|
||||
- 2026-01-28: **FlashInfer merge 替换 Triton merge**,端到端提升 0.8% ~ 2.4%
|
||||
- 2026-01-28: **estimate_block_size 优化后重新测试**,128K XAttention 反超 Full (+2.4%)
|
||||
- 2026-01-27: 添加 GPU-only vs Offload 对比,block size 影响分析
|
||||
- 2026-01-27: 初始测试,Llama-3.1-8B-Instruct, A100 80GB
|
||||
|
||||
Reference in New Issue
Block a user