Compare commits
2 Commits
2c2383c786
...
8d19e61446
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d19e61446 | ||
|
|
4484ebbb77 |
@@ -34,6 +34,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/observer_architecture.md`](docs/observer_architecture.md) | 📊 Observer 架构: InferenceObserver (TTFT/TPOT)、MemoryObserver (H2D/D2H/D2D) 设计 |
|
||||
| [`docs/memory_communication_benchmark.md`](docs/memory_communication_benchmark.md) | 📊 通信量测试: Full vs XAttention 通信量对比 (32K/64K)、阶段分离统计 |
|
||||
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析,softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
|
||||
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL),≤10B 推荐模型 |
|
||||
|
||||
## Rules Index
|
||||
|
||||
|
||||
@@ -150,8 +150,50 @@ CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 256
|
||||
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --enable-xattn --xattn-threshold 0.8 --xattn-stride 16
|
||||
```
|
||||
|
||||
## FlashInfer Merge 优化 (2026-01-28)
|
||||
|
||||
将 Triton 实现的 `merge_attention_outputs` 替换为 FlashInfer 的 `cascade.merge_state`。
|
||||
|
||||
### 性能对比 (Full Attention, block-size 4096)
|
||||
|
||||
| 上下文 | Triton merge | FlashInfer merge | 提升 |
|
||||
|--------|--------------|------------------|------|
|
||||
| 32K | 4678 tok/s | 4717 tok/s | **+0.8%** |
|
||||
| 64K | 3331 tok/s | 3411 tok/s | **+2.4%** |
|
||||
| 128K | 2144 tok/s | 2178 tok/s | **+1.6%** |
|
||||
|
||||
### 关键发现
|
||||
|
||||
1. **端到端提升有限**(0.8% ~ 2.4%):merge 操作不是主要瓶颈
|
||||
- H2D 传输占主导(64K 传输 64GB)
|
||||
- Attention 计算是另一主要耗时
|
||||
- Merge 在总耗时中占比很小
|
||||
|
||||
2. **Merge kernel 单独对比**(长序列时 FlashInfer 优势明显):
|
||||
|
||||
| seq_len | heads | Triton (ms) | FlashInfer (ms) | Speedup |
|
||||
|---------|-------|-------------|-----------------|---------|
|
||||
| 4096 | 32 | 0.129 | 0.087 | **1.49x** |
|
||||
| 8192 | 32 | 0.251 | 0.147 | **1.70x** |
|
||||
| 16384 | 32 | 0.499 | 0.274 | **1.82x** |
|
||||
|
||||
3. **短序列 FlashInfer 反而慢**:格式转换开销(squeeze, transpose, contiguous)
|
||||
|
||||
### 技术细节
|
||||
|
||||
- **LSE 格式差异**:FlashInfer 使用 log2,flash_attn 使用 ln
|
||||
- **转换系数**:`LOG2_E = 1.4427`(ln → log2),`LN_2 = 0.6931`(log2 → ln)
|
||||
- **FlashInfer attention JIT 问题**:CUDA 版本兼容性问题,仅使用 merge_state
|
||||
|
||||
### 代码位置
|
||||
|
||||
- `nanovllm/ops/chunked_attention.py`: `merge_attention_outputs_flashinfer()`
|
||||
- `nanovllm/kvcache/sparse/full_policy.py`: 3 处 import 更新
|
||||
- `nanovllm/kvcache/sparse/xattn_bsa.py`: 1 处 import 更新
|
||||
|
||||
## 更新记录
|
||||
|
||||
- 2026-01-28: **FlashInfer merge 替换 Triton merge**,端到端提升 0.8% ~ 2.4%
|
||||
- 2026-01-28: **estimate_block_size 优化后重新测试**,128K XAttention 反超 Full (+2.4%)
|
||||
- 2026-01-27: 添加 GPU-only vs Offload 对比,block size 影响分析
|
||||
- 2026-01-27: 初始测试,Llama-3.1-8B-Instruct, A100 80GB
|
||||
|
||||
184
docs/long_context_models_1m.md
Normal file
184
docs/long_context_models_1m.md
Normal file
@@ -0,0 +1,184 @@
|
||||
# 1M+ 上下文长度模型列表
|
||||
|
||||
本文档收集了 Hugging Face 上支持 1M (1,048,576) 及以上上下文长度的开源模型。
|
||||
|
||||
> 更新时间: 2026-01-28
|
||||
|
||||
---
|
||||
|
||||
## 一、纯语言模型 (≤10B 参数)
|
||||
|
||||
### 1. 官方原版模型
|
||||
|
||||
| 厂商 | 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||
|------|------|--------|------|--------|------|
|
||||
| **Qwen** | Qwen2.5-7B-Instruct-1M | 1M | 7B | 69.3K | [HF](https://hf.co/Qwen/Qwen2.5-7B-Instruct-1M) |
|
||||
| **THUDM** | GLM-4-9B-Chat-1M | 1M | 9B | 5.0K | [HF](https://hf.co/zai-org/glm-4-9b-chat-1m) |
|
||||
| **InternLM** | InternLM2.5-7B-Chat-1M | 1M | 7B | 322 | [HF](https://hf.co/internlm/internlm2_5-7b-chat-1m) |
|
||||
| **NVIDIA** | Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct | 1M | 8B | 2.9K | [HF](https://hf.co/nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct) |
|
||||
| **LWM** | LWM-Text-1M | 1M | 7B | 75 | [HF](https://hf.co/LargeWorldModel/LWM-Text-1M) |
|
||||
| **LWM** | LWM-Text-Chat-1M | 1M | 7B | 3.0K | [HF](https://hf.co/LargeWorldModel/LWM-Text-Chat-1M) |
|
||||
|
||||
### 2. Gradient AI 扩展系列 (基于 Llama 3)
|
||||
|
||||
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||
|------|--------|------|--------|------|
|
||||
| Llama-3-8B-Instruct-Gradient-1048k | **1M** | 8B | 44.8K | [HF](https://hf.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k) |
|
||||
| Llama-3-8B-Instruct-Gradient-4194k | **4M** | 8B | 9 | [HF](https://hf.co/gradientai/Llama-3-8B-Instruct-Gradient-4194k) |
|
||||
|
||||
### 3. 社区衍生版本 (Abliterated)
|
||||
|
||||
| 模型 | 上下文 | 基础模型 | 下载量 | 链接 |
|
||||
|------|--------|----------|--------|------|
|
||||
| Qwen2.5-7B-Instruct-1M-abliterated | 1M | Qwen2.5-7B | 375 | [HF](https://hf.co/huihui-ai/Qwen2.5-7B-Instruct-1M-abliterated) |
|
||||
| Nemotron-8B-UltraLong-1M-Abliterated | 1M | Nemotron-8B | 46 | [HF](https://hf.co/SicariusSicariiStuff/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct_Abliterated) |
|
||||
|
||||
---
|
||||
|
||||
## 二、视觉-语言模型 (≤10B 参数)
|
||||
|
||||
### Qwen3 VL 系列
|
||||
|
||||
#### Instruct 版本
|
||||
|
||||
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||
|------|--------|------|--------|------|
|
||||
| Qwen3-VL-2B-Instruct-1M-GGUF | 1M | 2B | 824 | [HF](https://hf.co/unsloth/Qwen3-VL-2B-Instruct-1M-GGUF) |
|
||||
| Qwen3-VL-4B-Instruct-1M-GGUF | 1M | 4B | 936 | [HF](https://hf.co/unsloth/Qwen3-VL-4B-Instruct-1M-GGUF) |
|
||||
| Qwen3-VL-8B-Instruct-1M-GGUF | 1M | 8B | 962 | [HF](https://hf.co/unsloth/Qwen3-VL-8B-Instruct-1M-GGUF) |
|
||||
|
||||
#### Thinking 推理版本
|
||||
|
||||
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||
|------|--------|------|--------|------|
|
||||
| Qwen3-VL-2B-Thinking-1M-GGUF | 1M | 2B | 808 | [HF](https://hf.co/unsloth/Qwen3-VL-2B-Thinking-1M-GGUF) |
|
||||
| Qwen3-VL-4B-Thinking-1M-GGUF | 1M | 4B | 666 | [HF](https://hf.co/unsloth/Qwen3-VL-4B-Thinking-1M-GGUF) |
|
||||
| Qwen3-VL-8B-Thinking-1M-GGUF | 1M | 8B | 4.6K | [HF](https://hf.co/unsloth/Qwen3-VL-8B-Thinking-1M-GGUF) |
|
||||
|
||||
---
|
||||
|
||||
## 三、推荐模型 (≤10B)
|
||||
|
||||
| 用途 | 推荐模型 | 理由 |
|
||||
|------|----------|------|
|
||||
| **通用对话** | Qwen2.5-7B-Instruct-1M | 官方支持,RULER 93.1分,Apache 2.0 |
|
||||
| **中英双语** | GLM-4-9B-Chat-1M | 清华出品,中文优化 |
|
||||
| **最长上下文** | Llama-3-8B-Gradient-4194k | 支持 4M 上下文 |
|
||||
| **多模态** | Qwen3-VL-8B-Thinking-1M | 视觉理解 + 推理能力 |
|
||||
| **无审查** | Qwen2.5-7B-Instruct-1M-abliterated | 移除安全限制 |
|
||||
|
||||
---
|
||||
|
||||
## 四、VRAM 需求参考
|
||||
|
||||
| 模型规模 | 1M 上下文 VRAM | 备注 |
|
||||
|----------|----------------|------|
|
||||
| 7B (FP16) | ~120GB | 需多卡 |
|
||||
| 7B (INT4) | ~40GB | 单卡 A100 可行 |
|
||||
| 8B (FP16) | ~130GB | 需多卡 |
|
||||
| 9B (FP16) | ~140GB | 需多卡 |
|
||||
|
||||
---
|
||||
|
||||
## 五、技术对比
|
||||
|
||||
| 模型系列 | 扩展技术 | RULER 得分 | 许可证 |
|
||||
|---------|---------|------------|--------|
|
||||
| Qwen2.5-1M | Dual Chunk Attention | 93.1 | Apache 2.0 |
|
||||
| GLM-4-1M | - | 89.9 | 自定义 |
|
||||
| Gradient-Llama | 渐进式扩展 | - | Llama 3 |
|
||||
| Nemotron-1M | NVIDIA 训练 | - | CC-BY-NC-4.0 |
|
||||
| LWM-1M | RingAttention | - | 开源 |
|
||||
|
||||
---
|
||||
|
||||
---
|
||||
|
||||
# 附录:大参数模型 (>10B)
|
||||
|
||||
> 以下模型参数量超过 10B,需要更多计算资源。
|
||||
|
||||
## A. 纯语言模型 (>10B)
|
||||
|
||||
### 官方模型
|
||||
|
||||
| 厂商 | 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||
|------|------|--------|------|--------|------|
|
||||
| **Qwen** | Qwen2.5-14B-Instruct-1M | 1M | 14B | 4.7K | [HF](https://hf.co/Qwen/Qwen2.5-14B-Instruct-1M) |
|
||||
| **MiniMax** | MiniMax-Text-01 | 1M | 456B MoE | 721 | [HF](https://hf.co/MiniMaxAI/MiniMax-Text-01) |
|
||||
| **Gradient** | Llama-3-70B-Instruct-Gradient-1048k | 1M | 70B | 9 | [HF](https://hf.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k) |
|
||||
|
||||
### Qwen3 Coder 系列 (MoE)
|
||||
|
||||
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|
||||
|------|--------|-----------------|--------|------|
|
||||
| Qwen3-Coder-30B-A3B-Instruct-1M-GGUF | 1M | 30B / 3B | 13.1K | [HF](https://hf.co/unsloth/Qwen3-Coder-30B-A3B-Instruct-1M-GGUF) |
|
||||
| Qwen3-Coder-480B-A35B-Instruct-1M | 1M | 480B / 35B | 50 | [HF](https://hf.co/unsloth/Qwen3-Coder-480B-A35B-Instruct-1M) |
|
||||
| Qwen3-Coder-480B-A35B-Instruct-1M-GGUF | 1M | 480B / 35B | 1.7K | [HF](https://hf.co/unsloth/Qwen3-Coder-480B-A35B-Instruct-1M-GGUF) |
|
||||
| Qwen3-Coder-42B-A3B-TOTAL-RECALL-1M | 1M | 42B / 3B | - | [HF](https://hf.co/DavidAU/Qwen3-Coder-42B-A3B-Instruct-TOTAL-RECALL-MASTER-CODER-M-1million-ctx) |
|
||||
|
||||
### 社区衍生版本
|
||||
|
||||
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||
|------|--------|------|--------|------|
|
||||
| Qwen2.5-14B-Instruct-1M-abliterated | 1M | 14B | 147 | [HF](https://hf.co/huihui-ai/Qwen2.5-14B-Instruct-1M-abliterated) |
|
||||
|
||||
---
|
||||
|
||||
## B. 视觉-语言模型 (>10B)
|
||||
|
||||
### Meta Llama 4 系列 (MoE 多模态)
|
||||
|
||||
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|
||||
|------|--------|-----------------|--------|------|
|
||||
| Llama-4-Scout-17B-16E-Instruct | **10M** | 109B / 17B | 180K | [HF](https://hf.co/meta-llama/Llama-4-Scout-17B-16E-Instruct) |
|
||||
| Llama-4-Maverick-17B-128E-Instruct | **1M** | 400B / 17B | 32.6K | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E-Instruct) |
|
||||
| Llama-4-Scout-17B-16E | 10M | 109B / 17B | 8.4K | [HF](https://hf.co/meta-llama/Llama-4-Scout-17B-16E) |
|
||||
| Llama-4-Maverick-17B-128E | 1M | 400B / 17B | 368 | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E) |
|
||||
| Llama-4-Maverick-17B-128E-Instruct-FP8 | 1M | 400B / 17B | 29.6K | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8) |
|
||||
|
||||
### Qwen3 VL 大模型系列
|
||||
|
||||
#### Dense 模型
|
||||
|
||||
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||
|------|--------|------|--------|------|
|
||||
| Qwen3-VL-32B-Instruct-1M-GGUF | 1M | 32B | 1.2K | [HF](https://hf.co/unsloth/Qwen3-VL-32B-Instruct-1M-GGUF) |
|
||||
| Qwen3-VL-32B-Thinking-1M-GGUF | 1M | 32B | 452 | [HF](https://hf.co/unsloth/Qwen3-VL-32B-Thinking-1M-GGUF) |
|
||||
|
||||
#### MoE 模型
|
||||
|
||||
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|
||||
|------|--------|-----------------|--------|------|
|
||||
| Qwen3-VL-30B-A3B-Instruct-1M-GGUF | 1M | 30B / 3B | 821 | [HF](https://hf.co/unsloth/Qwen3-VL-30B-A3B-Instruct-1M-GGUF) |
|
||||
| Qwen3-VL-30B-A3B-Thinking-1M-GGUF | 1M | 30B / 3B | 944 | [HF](https://hf.co/unsloth/Qwen3-VL-30B-A3B-Thinking-1M-GGUF) |
|
||||
| Qwen3-VL-235B-A22B-Instruct-1M-GGUF | 1M | 235B / 22B | 581 | [HF](https://hf.co/unsloth/Qwen3-VL-235B-A22B-Instruct-1M-GGUF) |
|
||||
| Qwen3-VL-235B-A22B-Thinking-1M-GGUF | 1M | 235B / 22B | 733 | [HF](https://hf.co/unsloth/Qwen3-VL-235B-A22B-Thinking-1M-GGUF) |
|
||||
|
||||
#### MXFP4 量化版本
|
||||
|
||||
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||
|------|--------|------|--------|------|
|
||||
| Qwen3-VL-30B-A3B-Instruct-1M-MXFP4_MOE-GGUF | 1M | 30B MoE | 689 | [HF](https://hf.co/noctrex/Qwen3-VL-30B-A3B-Instruct-1M-MXFP4_MOE-GGUF) |
|
||||
| Qwen3-VL-30B-A3B-Thinking-1M-MXFP4_MOE-GGUF | 1M | 30B MoE | 565 | [HF](https://hf.co/noctrex/Qwen3-VL-30B-A3B-Thinking-1M-MXFP4_MOE-GGUF) |
|
||||
| Qwen3-VL-235B-A22B-Instruct-1M-MXFP4_MOE-GGUF | 1M | 235B MoE | 136 | [HF](https://hf.co/noctrex/Qwen3-VL-235B-A22B-Instruct-1M-MXFP4_MOE-GGUF) |
|
||||
| Qwen3-VL-235B-A22B-Thinking-1M-MXFP4_MOE-GGUF | 1M | 235B MoE | 244 | [HF](https://hf.co/noctrex/Qwen3-VL-235B-A22B-Thinking-1M-MXFP4_MOE-GGUF) |
|
||||
|
||||
---
|
||||
|
||||
## 统计汇总
|
||||
|
||||
| 类别 | ≤10B 模型数 | >10B 模型数 | 最大上下文 |
|
||||
|------|-------------|-------------|-----------|
|
||||
| 纯语言模型 | 10 | 8 | 4M |
|
||||
| 视觉-语言模型 | 6 | 14 | 10M |
|
||||
| **合计** | **16** | **22** | **10M** |
|
||||
|
||||
---
|
||||
|
||||
## 参考资源
|
||||
|
||||
- [Qwen2.5-1M 官方博客](https://qwenlm.github.io/blog/qwen2.5-1m/)
|
||||
- [LongRoPE 论文](https://huggingface.co/papers/2402.13753)
|
||||
- [InfiniteHiP 论文](https://huggingface.co/papers/2502.08910)
|
||||
- [Top LLMs for Long Context Windows](https://www.siliconflow.com/articles/en/top-LLMs-for-long-context-windows)
|
||||
@@ -185,7 +185,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
# Use FlashInfer-based implementations (more optimized)
|
||||
from nanovllm.ops.chunked_attention import (
|
||||
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||
)
|
||||
|
||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
|
||||
@@ -313,7 +317,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
Returns:
|
||||
Attention output [batch_size, 1, num_heads, head_dim]
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
# Use FlashInfer-based implementations (more optimized)
|
||||
from nanovllm.ops.chunked_attention import (
|
||||
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||
)
|
||||
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
@@ -405,7 +413,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
Loads one block at a time, computes attention, and merges results.
|
||||
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
# Use FlashInfer-based implementations (more optimized)
|
||||
from nanovllm.ops.chunked_attention import (
|
||||
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||
)
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
|
||||
@@ -652,7 +652,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
# Use FlashInfer-based implementations (more optimized)
|
||||
from nanovllm.ops.chunked_attention import (
|
||||
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||
)
|
||||
|
||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||
o_acc = None
|
||||
|
||||
@@ -414,6 +414,90 @@ def merge_attention_outputs(
|
||||
return o_merged, lse_merged
|
||||
|
||||
|
||||
# ============================================================
|
||||
# FlashInfer-based implementations (recommended for merge only)
|
||||
# ============================================================
|
||||
|
||||
# LSE conversion constants: FlashInfer uses log2, flash_attn uses ln
|
||||
_LOG2_E = 1.4426950408889634 # math.log2(math.e) - ln -> log2
|
||||
_LN_2 = 0.6931471805599453 # math.log(2) - log2 -> ln
|
||||
|
||||
# Check FlashInfer availability (only for merge_state, not attention kernel)
|
||||
try:
|
||||
from flashinfer.cascade import merge_state, merge_state_in_place
|
||||
FLASHINFER_MERGE_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASHINFER_MERGE_AVAILABLE = False
|
||||
|
||||
|
||||
def flash_attn_with_lse_flashinfer(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Flash attention that returns output and LSE.
|
||||
|
||||
Uses flash_attn library (FlashInfer attention has JIT compatibility issues).
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
|
||||
causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] (ln format)
|
||||
"""
|
||||
# Use flash_attn directly (FlashInfer attention JIT has CUDA version issues)
|
||||
return flash_attn_with_lse(q, k, v, softmax_scale, causal)
|
||||
|
||||
|
||||
def merge_attention_outputs_flashinfer(
|
||||
o1: torch.Tensor,
|
||||
lse1: torch.Tensor,
|
||||
o2: torch.Tensor,
|
||||
lse2: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Merge two attention outputs using FlashInfer's optimized kernel.
|
||||
|
||||
Args:
|
||||
o1: First output [batch, seqlen_q, nheads, headdim]
|
||||
lse1: First LSE [batch, nheads, seqlen_q] (ln format)
|
||||
o2: Second output [batch, seqlen_q, nheads, headdim]
|
||||
lse2: Second LSE [batch, nheads, seqlen_q] (ln format)
|
||||
|
||||
Returns:
|
||||
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
||||
lse_merged: Merged LSE [batch, nheads, seqlen_q] (ln format)
|
||||
"""
|
||||
if not FLASHINFER_MERGE_AVAILABLE:
|
||||
# Fallback to Triton implementation
|
||||
return merge_attention_outputs(o1, lse1, o2, lse2)
|
||||
|
||||
# Convert to FlashInfer format
|
||||
# o: [batch, seq, heads, dim] -> [seq, heads, dim]
|
||||
# lse: [batch, heads, seq] -> [seq, heads] (convert ln -> log2)
|
||||
v_a = o1.squeeze(0).contiguous()
|
||||
s_a = (lse1.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E)
|
||||
v_b = o2.squeeze(0).contiguous()
|
||||
s_b = (lse2.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E)
|
||||
|
||||
# FlashInfer merge
|
||||
v_merged, s_merged = merge_state(v_a, s_a, v_b, s_b)
|
||||
|
||||
# Convert back to flash_attn format
|
||||
o_merged = v_merged.unsqueeze(0) # [1, seq, heads, dim]
|
||||
lse_merged = (s_merged * _LN_2).transpose(0, 1).unsqueeze(0) # [1, heads, seq]
|
||||
|
||||
return o_merged, lse_merged
|
||||
|
||||
|
||||
def chunked_attention_varlen(
|
||||
q: torch.Tensor,
|
||||
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
|
||||
Reference in New Issue
Block a user