Compare commits
14 Commits
cf168fd9b9
...
tzj/layer-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fb0f67295 | ||
|
|
69b779e252 | ||
|
|
e313dd795a | ||
|
|
9f3ee9279e | ||
|
|
2826a649de | ||
|
|
24baeb6d5a | ||
|
|
57f4e9c6e6 | ||
|
|
ac1ccbceaa | ||
|
|
029894118d | ||
|
|
8d6fde3b23 | ||
|
|
6a6bd75685 | ||
|
|
86633004ca | ||
|
|
c51a640a29 | ||
|
|
dce6ad6b74 |
158
.claude/commands/exec-plan.md
Normal file
158
.claude/commands/exec-plan.md
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
---
|
||||||
|
allowed-tools: Bash(CUDA_VISIBLE_DEVICES=*), Bash(PYTHONPATH=*), Bash(python*), Bash(git*), Bash(rm*), Bash(ls*), Bash(cat*), Bash(nvidia-smi*), Read, Edit, Write, Glob, Grep, TodoWrite, Task
|
||||||
|
argument-hint: --gpu <id> [--no-interrupt]
|
||||||
|
description: Execute task_plan.md refactoring with specified GPU, optionally without user interruption
|
||||||
|
---
|
||||||
|
|
||||||
|
# Execute Task Plan (exec-plan)
|
||||||
|
|
||||||
|
按照 `task_plan.md` 的要求执行代码重构,确保计划中的最终目标圆满实现。
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
命令格式: `/exec-plan --gpu <id> [--no-interrupt]`
|
||||||
|
|
||||||
|
| 参数 | 说明 | 示例 |
|
||||||
|
|------|------|------|
|
||||||
|
| `--gpu <id>` | **必需**。指定可用的 GPU ID,只能使用此 GPU 进行调试 | `--gpu 0`, `--gpu 2` |
|
||||||
|
| `--no-interrupt` | 可选。禁止中断执行,遇到问题不与用户交互,自动解决或跳过 | `--no-interrupt` |
|
||||||
|
|
||||||
|
## 当前参数
|
||||||
|
|
||||||
|
```
|
||||||
|
$ARGUMENTS
|
||||||
|
```
|
||||||
|
|
||||||
|
## 执行前准备
|
||||||
|
|
||||||
|
### 1. 解析参数
|
||||||
|
|
||||||
|
从 `$ARGUMENTS` 中解析:
|
||||||
|
- `GPU_ID`: 从 `--gpu <id>` 或 `-g <id>` 提取
|
||||||
|
- `NO_INTERRUPT`: 是否存在 `--no-interrupt` 或 `-n` 标志
|
||||||
|
|
||||||
|
### 2. 参数验证
|
||||||
|
|
||||||
|
**必须验证**:
|
||||||
|
- GPU_ID 必须是有效的数字
|
||||||
|
- 运行 `nvidia-smi -i <GPU_ID>` 验证 GPU 存在
|
||||||
|
|
||||||
|
### 3. 读取 task_plan.md
|
||||||
|
|
||||||
|
读取项目根目录下的 `task_plan.md` 文件,理解:
|
||||||
|
- 总体目标
|
||||||
|
- 分阶段计划 (Phase 1, 2, 3...)
|
||||||
|
- 文件修改清单
|
||||||
|
- 风险和注意事项
|
||||||
|
- 测试计划
|
||||||
|
|
||||||
|
## 执行流程
|
||||||
|
|
||||||
|
### Step 1: 创建执行计划
|
||||||
|
|
||||||
|
使用 TodoWrite 工具创建详细的执行计划,包括:
|
||||||
|
- 从 task_plan.md 提取的所有 Phase
|
||||||
|
- 每个 Phase 的子任务
|
||||||
|
- 测试验证步骤
|
||||||
|
|
||||||
|
### Step 2: 按 Phase 执行重构
|
||||||
|
|
||||||
|
对于 task_plan.md 中的每个 Phase:
|
||||||
|
|
||||||
|
1. **读取当前代码**: 使用 Read/Grep 理解现有实现
|
||||||
|
2. **实施修改**: 使用 Edit/Write 进行代码修改
|
||||||
|
3. **验证修改**: 运行相关测试
|
||||||
|
|
||||||
|
### Step 3: 运行测试验证
|
||||||
|
|
||||||
|
执行 task_plan.md 中定义的测试计划,验证重构成功。
|
||||||
|
|
||||||
|
## GPU 限制规则
|
||||||
|
|
||||||
|
**严格限制**: 只能使用指定的 GPU,所有涉及 GPU 的命令必须加 `CUDA_VISIBLE_DEVICES` 前缀:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 正确
|
||||||
|
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python test.py
|
||||||
|
|
||||||
|
# 错误 - 禁止使用其他 GPU
|
||||||
|
python test.py # 可能使用默认 GPU 0
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 python test.py # 使用多个 GPU
|
||||||
|
```
|
||||||
|
|
||||||
|
## 中断模式规则
|
||||||
|
|
||||||
|
### 当 `--no-interrupt` 生效时
|
||||||
|
|
||||||
|
遇到以下情况**不停下来询问用户**,而是:
|
||||||
|
|
||||||
|
| 情况 | 处理方式 |
|
||||||
|
|------|----------|
|
||||||
|
| 测试失败 | 记录失败原因,尝试自动修复,继续下一步 |
|
||||||
|
| 代码冲突 | 尝试合理解决,记录解决方案 |
|
||||||
|
| 不确定的实现细节 | 选择最合理的方案继续 |
|
||||||
|
| 执行错误 | 分析错误,尝试修复,记录问题 |
|
||||||
|
|
||||||
|
**自动决策原则**:
|
||||||
|
1. 优先保证功能正确性
|
||||||
|
2. 遵循现有代码风格
|
||||||
|
3. 选择简单直接的实现
|
||||||
|
4. 记录所有自动决策到 `progress.md`
|
||||||
|
|
||||||
|
### 当未指定 `--no-interrupt` 时
|
||||||
|
|
||||||
|
遇到以下情况**可以询问用户**:
|
||||||
|
- 多个实现方案需要选择
|
||||||
|
- 测试持续失败无法自动修复
|
||||||
|
- 发现 task_plan.md 中的问题或矛盾
|
||||||
|
|
||||||
|
## 执行记录
|
||||||
|
|
||||||
|
### 进度文件: progress.md
|
||||||
|
|
||||||
|
实时更新 `progress.md` 记录:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## 执行进度
|
||||||
|
|
||||||
|
### Phase X: [名称]
|
||||||
|
- 状态: [进行中/完成/失败]
|
||||||
|
- 开始时间: [时间]
|
||||||
|
- 完成时间: [时间]
|
||||||
|
- 修改文件: [文件列表]
|
||||||
|
- 自动决策: [如果有]
|
||||||
|
- 问题记录: [如果有]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 发现记录: findings.md
|
||||||
|
|
||||||
|
记录执行过程中的重要发现到 `findings.md`。
|
||||||
|
|
||||||
|
## 示例用法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 使用 GPU 2,允许中断
|
||||||
|
/exec-plan --gpu 2
|
||||||
|
|
||||||
|
# 使用 GPU 0,不中断执行
|
||||||
|
/exec-plan --gpu 0 --no-interrupt
|
||||||
|
|
||||||
|
# 简短形式
|
||||||
|
/exec-plan -g 1 -n
|
||||||
|
```
|
||||||
|
|
||||||
|
## 完成标准
|
||||||
|
|
||||||
|
执行完成后,确保:
|
||||||
|
|
||||||
|
1. **所有 Phase 完成**: task_plan.md 中的所有 Phase 都已实施
|
||||||
|
2. **测试通过**: task_plan.md 中的测试计划全部通过
|
||||||
|
3. **代码质量**: 修改符合项目代码规范
|
||||||
|
4. **文档更新**: progress.md 包含完整执行记录
|
||||||
|
|
||||||
|
## 重要约束
|
||||||
|
|
||||||
|
1. **GPU 隔离**: 绝对不能使用指定 GPU 以外的设备
|
||||||
|
2. **遵循计划**: 严格按照 task_plan.md 执行,不做计划外的修改
|
||||||
|
3. **渐进式修改**: 每个 Phase 完成后验证,而不是最后一起验证
|
||||||
|
4. **回滚准备**: 重大修改前考虑是否需要 git commit 保存点
|
||||||
50
.claude/rules/planning-with-files.md
Normal file
50
.claude/rules/planning-with-files.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Planning with Files Rule
|
||||||
|
|
||||||
|
## 自动清理旧计划文件
|
||||||
|
|
||||||
|
**重要**:每次开始新的复杂任务使用 planning-with-files 时,先删除旧的计划文件。
|
||||||
|
|
||||||
|
### 使用前执行以下命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 在项目根目录执行,删除旧的计划文件
|
||||||
|
cd /home/zijie/Code/nano-vllm
|
||||||
|
rm -f task_plan.md findings.md progress.md
|
||||||
|
rm -f task_plan_*.md findings_*.md progress_*.md
|
||||||
|
```
|
||||||
|
|
||||||
|
### 为什么需要这个规则
|
||||||
|
|
||||||
|
1. **避免混淆**:不同任务有不同计划,旧的计划文件会干扰新任务
|
||||||
|
2. **保持简洁**:只保留当前任务的计划文件
|
||||||
|
3. **自动清理**:无需手动检查文件内容,直接删除即可
|
||||||
|
|
||||||
|
### 使用 planning-with-files 的完整流程
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Step 1: 清理旧计划文件
|
||||||
|
rm -f task_plan.md findings.md progress.md task_plan_*.md findings_*.md progress_*.md
|
||||||
|
|
||||||
|
# Step 2: 启动 planning-with-files 技能
|
||||||
|
# 在 Claude 中调用 /planning-with-files 或 Skill tool
|
||||||
|
|
||||||
|
# Step 3: 技能会自动创建新的计划文件
|
||||||
|
# - task_plan.md (或 task_plan_<任务名>.md)
|
||||||
|
# - findings.md (或 findings_<任务名>.md)
|
||||||
|
# - progress.md (或 progress_<任务名>.md)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 文件命名建议
|
||||||
|
|
||||||
|
| 场景 | 文件命名 | 示例 |
|
||||||
|
|------|----------|------|
|
||||||
|
| 通用任务 | task_plan.md, findings.md, progress.md | 临时调试任务 |
|
||||||
|
| 特定功能 | task_plan_<feature>.md | task_plan_xattn.md |
|
||||||
|
| Bug 修复 | task_plan_bug_<name>.md | task_plan_bug_offload.md |
|
||||||
|
|
||||||
|
### 注意事项
|
||||||
|
|
||||||
|
- 计划文件存储在**项目根目录**,不是技能目录
|
||||||
|
- 技能目录:`/home/zijie/.claude/plugins/cache/planning-with-files/...`
|
||||||
|
- 项目目录:`/home/zijie/Code/nano-vllm/`
|
||||||
|
- 每个任务完成后,可以选择保留或删除计划文件
|
||||||
70
.claude/settings.json
Normal file
70
.claude/settings.json
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
{
|
||||||
|
"hooks": {
|
||||||
|
"SessionStart": [
|
||||||
|
{
|
||||||
|
"hooks": [
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"command": "npx @claude-flow/cli@latest daemon start --quiet 2>/dev/null || true",
|
||||||
|
"timeout": 5000,
|
||||||
|
"continueOnError": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"command": "[ -n \"$SESSION_ID\" ] && npx @claude-flow/cli@latest hooks session-restore --session-id \"$SESSION_ID\" 2>/dev/null || true",
|
||||||
|
"timeout": 10000,
|
||||||
|
"continueOnError": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"Stop": [
|
||||||
|
{
|
||||||
|
"hooks": [
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"command": "echo '{\"ok\": true}'",
|
||||||
|
"timeout": 1000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"PermissionRequest": [
|
||||||
|
{
|
||||||
|
"matcher": "^mcp__claude-flow__.*$",
|
||||||
|
"hooks": [
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"command": "echo '{\"decision\": \"allow\", \"reason\": \"claude-flow MCP tool auto-approved\"}'",
|
||||||
|
"timeout": 1000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"matcher": "^Bash\\(npx @?claude-flow.*\\)$",
|
||||||
|
"hooks": [
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"command": "echo '{\"decision\": \"allow\", \"reason\": \"claude-flow CLI auto-approved\"}'",
|
||||||
|
"timeout": 1000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"permissions": {
|
||||||
|
"allow": [
|
||||||
|
"Bash(npx claude-flow*)",
|
||||||
|
"Bash(npx @claude-flow/*)",
|
||||||
|
"mcp__claude-flow__*"
|
||||||
|
],
|
||||||
|
"deny": []
|
||||||
|
},
|
||||||
|
"claudeFlow": {
|
||||||
|
"version": "3.0.0",
|
||||||
|
"enabled": true,
|
||||||
|
"daemon": {
|
||||||
|
"autoStart": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
4
.gitmodules
vendored
Normal file
4
.gitmodules
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
[submodule "3rdparty/Block-Sparse-Attention"]
|
||||||
|
path = 3rdparty/Block-Sparse-Attention
|
||||||
|
url = git@github.com:Zijie-Tian/Block-Sparse-Attention.git
|
||||||
|
branch = tzj/minference
|
||||||
1
3rdparty/Block-Sparse-Attention
vendored
Submodule
1
3rdparty/Block-Sparse-Attention
vendored
Submodule
Submodule 3rdparty/Block-Sparse-Attention added at 6ec5a27a0c
@@ -53,12 +53,17 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
|||||||
| [`docs/multi_model_support.md`](docs/multi_model_support.md) | Model registry system, adding new models (Qwen3/Llama), architecture differences, RoPE scaling |
|
| [`docs/multi_model_support.md`](docs/multi_model_support.md) | Model registry system, adding new models (Qwen3/Llama), architecture differences, RoPE scaling |
|
||||||
| [`docs/cuda_graph_offload_guide.md`](docs/cuda_graph_offload_guide.md) | CUDA graph support for CPU offload decode path, 4x decode speedup |
|
| [`docs/cuda_graph_offload_guide.md`](docs/cuda_graph_offload_guide.md) | CUDA graph support for CPU offload decode path, 4x decode speedup |
|
||||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (MInference, FlexPrefill, XAttention, Quest), computation flow |
|
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (MInference, FlexPrefill, XAttention, Quest), computation flow |
|
||||||
|
| [`docs/block_sparse_attention_lib.md`](docs/block_sparse_attention_lib.md) | MIT-Han-Lab Block-Sparse-Attention library reference: sparse modes, API, performance |
|
||||||
| [`docs/sparse_prefill_integration_plan.md`](docs/sparse_prefill_integration_plan.md) | Integration plan for MInference/XAttention/FlexPrefill with unified BlockMask interface |
|
| [`docs/sparse_prefill_integration_plan.md`](docs/sparse_prefill_integration_plan.md) | Integration plan for MInference/XAttention/FlexPrefill with unified BlockMask interface |
|
||||||
| [`docs/sparse_offload_integration.md`](docs/sparse_offload_integration.md) | Sparse policy integration with layerwise offload, `requires_block_selection` interface design |
|
| [`docs/sparse_offload_integration.md`](docs/sparse_offload_integration.md) | Sparse policy integration with layerwise offload, `requires_block_selection` interface design |
|
||||||
| [`docs/layerwise_offload_memory_analysis.md`](docs/layerwise_offload_memory_analysis.md) | Memory allocation analysis with theoretical formulas and empirical validation (< 5% error) |
|
| [`docs/layerwise_offload_memory_analysis.md`](docs/layerwise_offload_memory_analysis.md) | Memory allocation analysis with theoretical formulas and empirical validation (< 5% error) |
|
||||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, tensor comparison, memory profiling |
|
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, tensor comparison, memory profiling |
|
||||||
| [`docs/gpu_only_performance_issue.md`](docs/gpu_only_performance_issue.md) | GPU-only mode slower than offload due to PagedAttention scatter overhead, optimization proposals |
|
| [`docs/gpu_only_performance_issue.md`](docs/gpu_only_performance_issue.md) | GPU-only mode slower than offload due to PagedAttention scatter overhead, optimization proposals |
|
||||||
| [`docs/offload_accuracy_issue.md`](docs/offload_accuracy_issue.md) | **BUG**: CPU offload mode 66% accuracy vs 100% non-offload on RULER NIAH benchmark |
|
| [`docs/offload_accuracy_issue.md`](docs/offload_accuracy_issue.md) | **BUG**: CPU offload mode 66% accuracy vs 100% non-offload on RULER NIAH benchmark |
|
||||||
|
| [`docs/64k_memory_analysis.md`](docs/64k_memory_analysis.md) | 64k inference memory analysis: GPU-only vs offload, OOM root cause (fragmentation), RTX 3090 limitations |
|
||||||
|
| [`docs/xattention_integration.md`](docs/xattention_integration.md) | XAttention integration guide: algorithm, implementation, design decisions, and testing |
|
||||||
|
| [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design |
|
||||||
|
| [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work |
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
@@ -69,7 +74,7 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
|||||||
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
|
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
|
||||||
| `enable_cpu_offload` | False | Enable for long context |
|
| `enable_cpu_offload` | False | Enable for long context |
|
||||||
| `num_gpu_blocks` | 2 | GPU blocks for offload mode |
|
| `num_gpu_blocks` | 2 | GPU blocks for offload mode |
|
||||||
| `num_kv_buffers` | 4 | Ring buffer size for decode pipeline |
|
| `num_kv_buffers` | 4 | Ring buffer size (1-4), lower = less memory but slower decode |
|
||||||
| `enforce_eager` | False | Set True to disable CUDA graphs |
|
| `enforce_eager` | False | Set True to disable CUDA graphs |
|
||||||
|
|
||||||
## Benchmarking
|
## Benchmarking
|
||||||
@@ -85,6 +90,7 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
|||||||
- Qwen3-0.6B/4B: 40960 tokens
|
- Qwen3-0.6B/4B: 40960 tokens
|
||||||
- Qwen2.5-7B-Instruct-1M: 1048576 tokens
|
- Qwen2.5-7B-Instruct-1M: 1048576 tokens
|
||||||
- Llama-3.1-8B-Instruct: 131072 tokens
|
- Llama-3.1-8B-Instruct: 131072 tokens
|
||||||
|
- **64k on RTX 3090/4090 (24GB)**: Requires CPU offload + optimizations, see [`docs/64k_memory_analysis.md`](docs/64k_memory_analysis.md)
|
||||||
|
|
||||||
**Performance (Qwen3-4B, CPU Offload)**:
|
**Performance (Qwen3-4B, CPU Offload)**:
|
||||||
- Prefill: ~5700-8000 tok/s (varies by context length)
|
- Prefill: ~5700-8000 tok/s (varies by context length)
|
||||||
|
|||||||
103
DEBUG_SUMMARY.md
103
DEBUG_SUMMARY.md
@@ -1,103 +0,0 @@
|
|||||||
# Chunked Prefill Bug Debug Summary
|
|
||||||
|
|
||||||
## Problem
|
|
||||||
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
|
|
||||||
|
|
||||||
The model generates completely wrong tokens instead of the expected "7492".
|
|
||||||
|
|
||||||
## Investigation Progress
|
|
||||||
|
|
||||||
### 1. Stream Synchronization Fix (Completed)
|
|
||||||
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
|
|
||||||
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
|
|
||||||
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
|
|
||||||
- Added sync: `default_stream.wait_stream(compute_stream)` before return
|
|
||||||
|
|
||||||
### 2. KV Cache Alignment Verification (Completed)
|
|
||||||
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
|
|
||||||
|
|
||||||
**RoPE Alignment:**
|
|
||||||
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
|
|
||||||
- Confirmed RoPE is NOT the cause of the bug
|
|
||||||
|
|
||||||
**K/V Cache Alignment (Chunk 0):**
|
|
||||||
- Cosine similarity: ~1.0 for all layers
|
|
||||||
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
|
|
||||||
- Mean diff: < 0.001
|
|
||||||
- **Conclusion: K/V cache offload is working correctly**
|
|
||||||
|
|
||||||
### 3. Layer Output Divergence Analysis (Completed)
|
|
||||||
Created per-chunk layer output comparison:
|
|
||||||
|
|
||||||
**Chunk 0 (tokens 0-4096):**
|
|
||||||
- All layers pass with excellent cosine similarity (0.999+)
|
|
||||||
- Max diff grows in later layers but within acceptable range
|
|
||||||
|
|
||||||
**Chunk 1 (tokens 4096-8192):**
|
|
||||||
- Layers 0-19: OK (cosine ~1.0)
|
|
||||||
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
|
|
||||||
- Divergence correlates with later transformer layers
|
|
||||||
|
|
||||||
### 4. Critical Discovery: Single-Chunk Offload Also Fails
|
|
||||||
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
|
|
||||||
|
|
||||||
```
|
|
||||||
# Without offload: PASSES
|
|
||||||
python tests/test_needle.py --input-len 2048
|
|
||||||
# Output: "7492" (correct)
|
|
||||||
|
|
||||||
# With offload: FAILS
|
|
||||||
python tests/test_needle.py --enable-offload --input-len 2048
|
|
||||||
# Output: "The Ble White Th G Lopsiswin..." (garbage)
|
|
||||||
```
|
|
||||||
|
|
||||||
**This proves the bug is NOT in:**
|
|
||||||
- Chunked attention logic (merge_attention_outputs)
|
|
||||||
- Multi-chunk KV loading
|
|
||||||
- Ring buffer pipeline
|
|
||||||
|
|
||||||
**The bug IS in:**
|
|
||||||
- The decode path when CPU offload is enabled
|
|
||||||
- How prefilled KV is loaded/used during decode
|
|
||||||
|
|
||||||
### 5. Decode Path Analysis (In Progress)
|
|
||||||
The decode path in CPU offload mode:
|
|
||||||
1. Prefill writes KV to GPU, offloads to CPU
|
|
||||||
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
|
|
||||||
3. Attend to prefilled KV + accumulated decode tokens
|
|
||||||
4. Merge results
|
|
||||||
|
|
||||||
**Observations:**
|
|
||||||
- `prefilled_blocks` set is empty after decode (should contain block IDs)
|
|
||||||
- CPU cache has valid data (reasonable mean/std values)
|
|
||||||
- Decode buffer has zeros (decode tokens not being stored correctly?)
|
|
||||||
|
|
||||||
## Current Status
|
|
||||||
|
|
||||||
### Working
|
|
||||||
- Stream synchronization fixes
|
|
||||||
- K/V cache offload to CPU (verified alignment)
|
|
||||||
- RoPE implementation
|
|
||||||
- Chunked prefill attention for first chunk
|
|
||||||
|
|
||||||
### Not Working
|
|
||||||
- Decode with CPU offload (even for single-chunk inputs)
|
|
||||||
- Multi-chunk attention (divergence in later layers for chunk 1)
|
|
||||||
|
|
||||||
## Next Steps
|
|
||||||
1. Debug why `prefilled_blocks` is empty after decode
|
|
||||||
2. Check if decode path correctly loads KV from CPU
|
|
||||||
3. Verify decode buffer is being written correctly
|
|
||||||
4. Compare decode attention outputs between offload and non-offload modes
|
|
||||||
|
|
||||||
## Key Files
|
|
||||||
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
|
|
||||||
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
|
|
||||||
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
|
|
||||||
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
|
|
||||||
|
|
||||||
## Hypothesis
|
|
||||||
The decode path fails because:
|
|
||||||
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
|
|
||||||
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
|
|
||||||
3. OR there's a stream synchronization issue specific to decode path
|
|
||||||
131
docs/64k_memory_analysis.md
Normal file
131
docs/64k_memory_analysis.md
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# 64k 推理内存分析
|
||||||
|
|
||||||
|
本文档分析 Llama 3.1 8B 模型在 64k 长度推理时的内存占用,以及 RTX 3090 (24GB) 上的 OOM 问题。
|
||||||
|
|
||||||
|
## 模型配置
|
||||||
|
|
||||||
|
```python
|
||||||
|
hidden_size = 4096
|
||||||
|
intermediate_size = 14336
|
||||||
|
num_layers = 32
|
||||||
|
num_heads = 32
|
||||||
|
num_kv_heads = 8
|
||||||
|
head_dim = 128
|
||||||
|
seq_len = 65536
|
||||||
|
dtype = bfloat16 (2 bytes)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 理论内存占用
|
||||||
|
|
||||||
|
### GPU Only 模式
|
||||||
|
|
||||||
|
| 组件 | 计算公式 | 内存占用 |
|
||||||
|
|------|----------|----------|
|
||||||
|
| 模型权重 | 8.03B × 2 bytes | **16.06 GB** |
|
||||||
|
| KV Cache | 32 × 65536 × 8 × 128 × 2 × 2 | **8.19 GB** |
|
||||||
|
| Prefill 激活值峰值 | max(QKV, MLP) | **~2 GB** |
|
||||||
|
| **总计** | | **~26 GB** |
|
||||||
|
|
||||||
|
**结论**:GPU only 模式需要 ~26 GB,**RTX 3090 (24GB) 无法运行**。
|
||||||
|
|
||||||
|
### CPU Offload 模式
|
||||||
|
|
||||||
|
| 组件 | 计算公式 | 内存占用 |
|
||||||
|
|------|----------|----------|
|
||||||
|
| 模型权重 | 8.03B × 2 bytes | **16.06 GB** |
|
||||||
|
| Ring buffer | num_kv_buffers × seq_len × 128 KB/token | 258-1034 MB |
|
||||||
|
| GPU KV blocks | num_gpu_blocks × block_size × 128 KB/token | 256 MB (2 blocks) |
|
||||||
|
| Per-layer decode buffer | 32 layers × 缓冲 | 128 MB |
|
||||||
|
| 激活值峰值 (chunked) | chunk_size × hidden_size × 2 | ~50 MB |
|
||||||
|
| PyTorch 开销 | CUDA 上下文 + 碎片 | ~5-6 GB |
|
||||||
|
| **理论小计** | | **~17.5 GB** |
|
||||||
|
| **实际需求** | | **~23 GB** |
|
||||||
|
|
||||||
|
**配置参数**:
|
||||||
|
- `num_kv_buffers`: Ring buffer 大小 (1-4),默认 4
|
||||||
|
- `num_gpu_blocks`: GPU 上的 KV cache block 数量
|
||||||
|
- `block_size`: 每个 block 的 token 数
|
||||||
|
|
||||||
|
## OOM 问题分析
|
||||||
|
|
||||||
|
### 实际观测(RTX 3090, num_kv_buffers=1)
|
||||||
|
|
||||||
|
```
|
||||||
|
PyTorch allocated: 22.49 GB
|
||||||
|
PyTorch reserved: 429 MB
|
||||||
|
Free: 306 MB
|
||||||
|
Total available: 735 MB
|
||||||
|
Failed to allocate: 508 MB (torch.cat)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 内存碎片来源
|
||||||
|
|
||||||
|
| 来源 | 说明 | 影响 |
|
||||||
|
|------|------|------|
|
||||||
|
| Binned 分配器 | PyTorch 使用固定大小的内存池 | 中等 |
|
||||||
|
| torch.compile 缓存 | 编译后的 kernel 代码和常量 | 高 (~2-3 GB) |
|
||||||
|
| 频繁分配/释放 | chunked 处理中每个 chunk 的创建销毁 | 高 |
|
||||||
|
| 不同大小张量 | (128,4096), (65536,6144) 等 | 中等 |
|
||||||
|
|
||||||
|
### torch.cat 内存需求
|
||||||
|
|
||||||
|
Chunked MLP 处理(chunk_size=128):
|
||||||
|
```
|
||||||
|
65536 / 128 = 512 chunks
|
||||||
|
每个 chunk 输出: (128, 4096) × 2 bytes = 1 MB
|
||||||
|
torch.cat 拼接需要: (65536, 4096) × 2 bytes = 508 MB (连续)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 已尝试的优化
|
||||||
|
|
||||||
|
| 优化项 | 效果 |
|
||||||
|
|--------|------|
|
||||||
|
| 移除 `@torch.compile` | PyTorch: 23.13 → 22.80 GB (-300 MB) |
|
||||||
|
| 减少 `num_kv_buffers` (4→1) | Ring buffer: 1034 → 258 MB (-776 MB) |
|
||||||
|
| Chunked QKV/MLP/LayerNorm | 峰值激活: ~2 GB → ~50 MB |
|
||||||
|
| 降低 GPU 利用率 (0.9→0.75) | 无明显效果 |
|
||||||
|
| 减小 chunk_size (4096→128) | 峰值降低,但 torch.cat 需要连续内存 |
|
||||||
|
|
||||||
|
### 最终状态
|
||||||
|
|
||||||
|
```
|
||||||
|
理论需求: ~17.5 GB
|
||||||
|
实际分配: 22.49 GB
|
||||||
|
剩余空间: 735 MB (306 MB + 429 MB reserved)
|
||||||
|
分配失败: 508 MB (torch.cat 需要连续内存)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
### 根本原因
|
||||||
|
|
||||||
|
**不是绝对内存不足,而是内存碎片导致的分配失败**。
|
||||||
|
|
||||||
|
理论需求 17.5 GB < 24 GB,但由于:
|
||||||
|
- PyTorch 开销(CUDA 上下文、碎片):~5-6 GB
|
||||||
|
- torch.compile 缓存:~2-3 GB(已移除)
|
||||||
|
- 内存碎片导致无法分配 508 MB 连续块
|
||||||
|
|
||||||
|
### 硬件限制
|
||||||
|
|
||||||
|
| GPU | 显存 | 64k GPU Only | 64k Offload |
|
||||||
|
|-----|------|--------------|--------------|
|
||||||
|
| RTX 3090 | 24 GB | ❌ | ⚠️ 碎片问题 |
|
||||||
|
| RTX 4090 | 24 GB | ❌ | ⚠️ 碎片问题 |
|
||||||
|
| A100 | 40 GB | ✅ | ✅ |
|
||||||
|
| A100 | 80 GB | ✅ | ✅ |
|
||||||
|
|
||||||
|
### 建议
|
||||||
|
|
||||||
|
1. **64k 推理建议使用 40GB+ 显存的 GPU**
|
||||||
|
2. RTX 3090/4090 适合 32k 或更短的场景
|
||||||
|
3. 如必须在 24GB GPU 上运行 64k:
|
||||||
|
- 使用 RAPIDS RMM 分配器
|
||||||
|
- 预分配 torch.cat 需要的内存
|
||||||
|
- 或使用流式处理避免 torch.cat
|
||||||
|
|
||||||
|
## 参考
|
||||||
|
|
||||||
|
- [PyTorch 内存管理文档](https://docs.pytorch.org/docs/stable/generated/torch.cuda.memory.memory_stats.html)
|
||||||
|
- [PyTorch 内存碎片讨论](https://discuss.pytorch.org/t/how-to-reduce-memory-fragmentation-when-enable-expandable-segments/221805)
|
||||||
|
- [STWeaver - 减少 79% 内存碎片](https://arxiv.org/html/2507.16274v1)
|
||||||
161
docs/64k_mlp_activation_oom.md
Normal file
161
docs/64k_mlp_activation_oom.md
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
# 64K Prefill MLP Activation OOM Issue
|
||||||
|
|
||||||
|
## Problem Summary
|
||||||
|
|
||||||
|
When running RULER benchmark with 64K context length using CPU offload mode, OOM occurs during MLP forward pass in `run_layerwise_offload_prefill`. The KV cache is successfully offloaded to CPU, but MLP intermediate activations exceed available GPU memory.
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
|
||||||
|
- GPU: RTX 3090 (24GB)
|
||||||
|
- Model: LLaMA 3.1 8B
|
||||||
|
- Sequence Length: 65536 tokens
|
||||||
|
- Mode: `enable_cpu_offload=True`, `num_gpu_blocks=2`
|
||||||
|
|
||||||
|
## Error Message
|
||||||
|
|
||||||
|
```
|
||||||
|
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.47 GiB.
|
||||||
|
GPU 0 has a total capacity of 23.57 GiB of which 2.66 GiB is free.
|
||||||
|
Including non-PyTorch memory, this process has 20.88 GiB memory in use.
|
||||||
|
Of the allocated memory 20.51 GiB is allocated by PyTorch, and 32.26 MiB
|
||||||
|
is reserved by PyTorch but unallocated.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Stack Trace
|
||||||
|
|
||||||
|
```
|
||||||
|
File "nanovllm/engine/model_runner.py", line 843, in run_layerwise_offload_prefill
|
||||||
|
hidden_states = layer.mlp(hidden_states)
|
||||||
|
File "nanovllm/models/llama.py", line 103, in forward
|
||||||
|
gate_up = self.gate_up_proj(x)
|
||||||
|
File "nanovllm/layers/linear.py", line 73, in forward
|
||||||
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.47 GiB.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Root Cause Analysis
|
||||||
|
|
||||||
|
### Memory Breakdown
|
||||||
|
|
||||||
|
| Component | Calculation | Size |
|
||||||
|
|-----------|-------------|------|
|
||||||
|
| Model weights (BF16) | 8B params × 2 bytes | ~16 GB |
|
||||||
|
| GPU KV cache | 2 blocks × 1024 tokens × 8KB/token | ~16 MB |
|
||||||
|
| **Remaining for activations** | 24 - 16 - overhead | **~6-7 GB** |
|
||||||
|
|
||||||
|
### MLP Activation Memory (per layer)
|
||||||
|
|
||||||
|
For LLaMA 3.1 8B with `hidden_size=4096`, `intermediate_size=14336`:
|
||||||
|
|
||||||
|
| Tensor | Shape | Size (BF16) |
|
||||||
|
|--------|-------|-------------|
|
||||||
|
| MLP input | [65536, 4096] | 512 MB |
|
||||||
|
| gate_up output | [65536, 28672] | **3.47 GB** |
|
||||||
|
| down_proj input | [65536, 14336] | 1.75 GB |
|
||||||
|
| MLP output | [65536, 4096] | 512 MB |
|
||||||
|
|
||||||
|
**Peak MLP memory**: ~3.5-4 GB for intermediate tensors
|
||||||
|
|
||||||
|
### Why OOM Occurs
|
||||||
|
|
||||||
|
1. Model weights consume ~16 GB (loaded on GPU for layer-wise processing)
|
||||||
|
2. Available memory: ~7 GB
|
||||||
|
3. MLP `gate_up_proj` output: 3.47 GB
|
||||||
|
4. Additional tensors (input, gradients, etc.): ~1-2 GB
|
||||||
|
5. **Total required > Available** → OOM
|
||||||
|
|
||||||
|
## Code Location
|
||||||
|
|
||||||
|
The issue is in `nanovllm/engine/model_runner.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Line 843 in run_layerwise_offload_prefill
|
||||||
|
hidden_states = layer.mlp(hidden_states) # <-- OOM here
|
||||||
|
```
|
||||||
|
|
||||||
|
The entire sequence (65536 tokens) is passed through MLP in one shot.
|
||||||
|
|
||||||
|
## Current Configuration
|
||||||
|
|
||||||
|
From `model_wrappers.py` (RULER integration):
|
||||||
|
|
||||||
|
```python
|
||||||
|
llm_kwargs = {
|
||||||
|
"max_model_len": max_model_len, # 128 * 1024
|
||||||
|
"max_num_batched_tokens": max_model_len, # Same as max_model_len
|
||||||
|
"enable_cpu_offload": True,
|
||||||
|
"num_gpu_blocks": 2,
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Setting `max_num_batched_tokens = max_model_len` causes nanovllm to process all tokens at once.
|
||||||
|
|
||||||
|
## Potential Solutions
|
||||||
|
|
||||||
|
### Option 1: Chunked MLP Processing
|
||||||
|
|
||||||
|
Modify `run_layerwise_offload_prefill` to process MLP in chunks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Instead of:
|
||||||
|
hidden_states = layer.mlp(hidden_states)
|
||||||
|
|
||||||
|
# Do:
|
||||||
|
chunk_size = 8192 # Process 8K tokens at a time
|
||||||
|
chunks = hidden_states.split(chunk_size, dim=0)
|
||||||
|
outputs = []
|
||||||
|
for chunk in chunks:
|
||||||
|
outputs.append(layer.mlp(chunk))
|
||||||
|
hidden_states = torch.cat(outputs, dim=0)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 2: Activation Checkpointing
|
||||||
|
|
||||||
|
Use gradient checkpointing to recompute activations instead of storing them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
hidden_states = checkpoint(layer.mlp, hidden_states, use_reentrant=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 3: Reduce Chunk Size via Config
|
||||||
|
|
||||||
|
Add a new config parameter `prefill_chunk_size` to control how many tokens are processed per forward pass.
|
||||||
|
|
||||||
|
## Memory Estimation Formula
|
||||||
|
|
||||||
|
For a given sequence length `S` and model config:
|
||||||
|
|
||||||
|
```
|
||||||
|
MLP_peak_memory = S × intermediate_size × 2 × 2 bytes
|
||||||
|
= S × 14336 × 4 bytes
|
||||||
|
|
||||||
|
For S = 65536:
|
||||||
|
MLP_peak = 65536 × 14336 × 4 = 3.76 GB
|
||||||
|
```
|
||||||
|
|
||||||
|
Maximum safe sequence length for RTX 3090 (24GB):
|
||||||
|
```
|
||||||
|
S_max = available_memory / (intermediate_size × 4)
|
||||||
|
= 6GB / (14336 × 4)
|
||||||
|
≈ 100K tokens (theoretical)
|
||||||
|
≈ 8-16K tokens (practical, with safety margin)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Reproduction Steps
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /home/zijie/Code/COMPASS/eval/RULER/scripts
|
||||||
|
|
||||||
|
# Set SEQ_LENGTHS to 65536 in config_models.sh
|
||||||
|
# Then run:
|
||||||
|
./run.sh llama3.1-8b-nanovllm synthetic --metric full --task niah_single_1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Related Files
|
||||||
|
|
||||||
|
- `nanovllm/engine/model_runner.py`: `run_layerwise_offload_prefill()` (line 751+)
|
||||||
|
- `nanovllm/models/llama.py`: `LlamaMLP.forward()` (line 103)
|
||||||
|
- `nanovllm/config.py`: Config parameters
|
||||||
|
- RULER integration: `eval/RULER/scripts/pred/model_wrappers.py`
|
||||||
191
docs/block_sparse_attention_lib.md
Normal file
191
docs/block_sparse_attention_lib.md
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# Block-Sparse-Attention Library Reference
|
||||||
|
|
||||||
|
MIT Han Lab 的块稀疏注意力内核库,基于 FlashAttention 2.4.2 修改,支持多种稀疏注意力模式。
|
||||||
|
|
||||||
|
## 库信息
|
||||||
|
|
||||||
|
- **来源**: [MIT-Han-Lab/Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention)
|
||||||
|
- **本地路径**: `3rdparty/Block-Sparse-Attention` (submodule, branch: `tzj/minference`)
|
||||||
|
- **基于**: FlashAttention 2.4.2
|
||||||
|
- **安装位置**: `site-packages/block_sparse_attn`
|
||||||
|
|
||||||
|
## 支持的稀疏模式
|
||||||
|
|
||||||
|
### 1. Dense Attention
|
||||||
|
计算完整注意力矩阵,无稀疏化。
|
||||||
|
|
||||||
|
### 2. Token Streaming (token granularity)
|
||||||
|
固定数量的 sink tokens + local tokens,参考 [StreamingLLM](https://arxiv.org/abs/2309.17453)。
|
||||||
|
|
||||||
|
**适用场景**: 需要精确保留部分关键 token 的长上下文推理
|
||||||
|
|
||||||
|
### 3. Block Streaming (block granularity)
|
||||||
|
Block 粒度的 streaming attention,block_size = 128。
|
||||||
|
|
||||||
|
**适用场景**: 长序列推理,牺牲少量精度换取更大加速
|
||||||
|
|
||||||
|
### 4. Block Sparse
|
||||||
|
基于自定义 block mask 的稀疏注意力。
|
||||||
|
|
||||||
|
**适用场景**: 已知特定 attention 模式的工作负载
|
||||||
|
|
||||||
|
### 混合模式
|
||||||
|
|
||||||
|
**关键特性**: 支持不同 head 使用不同稀疏模式
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 8 个 heads 的混合配置示例
|
||||||
|
head_mask_type = [1, 1, 0, 0, 0, -1, 0, -1]
|
||||||
|
# 含义:
|
||||||
|
# - head 0,1: blocksparse (使用 basemask[0])
|
||||||
|
# - head 2-4,6: dense
|
||||||
|
# - head 5,7: streaming
|
||||||
|
```
|
||||||
|
|
||||||
|
**Mask 类型编码**:
|
||||||
|
- `0` = Dense attention
|
||||||
|
- `-1` = Streaming attention
|
||||||
|
- `1, 2, ...` = Block sparse (使用 basemask[mask_type - 1])
|
||||||
|
|
||||||
|
## API 参考
|
||||||
|
|
||||||
|
### `block_sparse_attn_func`
|
||||||
|
|
||||||
|
通用块稀疏注意力函数,支持所有模式。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from block_sparse_attn import block_sparse_attn_func
|
||||||
|
|
||||||
|
output = block_sparse_attn_func(
|
||||||
|
q, k, v, # [total_tokens, heads, head_dim] unpadded
|
||||||
|
cu_seqlens_q, cu_seqlens_k, # cumulative sequence lengths
|
||||||
|
head_mask_type, # [heads] tensor, 每个头的模式
|
||||||
|
streaming_info, # streaming 配置 (sink/local 数量)
|
||||||
|
base_blockmask, # [q_blocks, k_blocks, n_masks] bool tensor
|
||||||
|
max_seqlen_q, max_seqlen_k, # 最大序列长度
|
||||||
|
p_dropout, # dropout 概率 (推理时设为 0.0)
|
||||||
|
deterministic=False,
|
||||||
|
softmax_scale=None,
|
||||||
|
is_causal=False,
|
||||||
|
exact_streaming=False, # True=token streaming, False=block streaming
|
||||||
|
return_attn_probs=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键参数**:
|
||||||
|
| 参数 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `head_mask_type` | Tensor[heads] | 每个头的稀疏模式,0=dense, -1=streaming, 1+=blocksparse |
|
||||||
|
| `streaming_info` | Tensor | [sink_blocks, local_blocks] 或 [sink_tokens, local_tokens] |
|
||||||
|
| `base_blockmask` | Tensor | Block mask,形状 [q_blocks, k_blocks, n_masks] |
|
||||||
|
| `exact_streaming` | bool | True=token 粒度,False=block 粒度 streaming |
|
||||||
|
|
||||||
|
### `block_streaming_attn_func`
|
||||||
|
|
||||||
|
Block 粒度 streaming attention(block_size=128)。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from block_sparse_attn import block_streaming_attn_func
|
||||||
|
|
||||||
|
output = block_streaming_attn_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q, cu_seqlens_k,
|
||||||
|
head_mask_type,
|
||||||
|
streaming_info, # [sink_blocks, local_blocks]
|
||||||
|
max_seqlen_q, max_seqlen_k,
|
||||||
|
p_dropout,
|
||||||
|
deterministic=False,
|
||||||
|
softmax_scale=None,
|
||||||
|
is_causal=True,
|
||||||
|
return_attn_probs=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `token_streaming_attn_func`
|
||||||
|
|
||||||
|
Token 粒度 streaming attention。
|
||||||
|
|
||||||
|
**注意**: 不支持反向传播(仅推理)。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from block_sparse_attn import token_streaming_attn_func
|
||||||
|
|
||||||
|
output = token_streaming_attn_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q, cu_seqlens_k,
|
||||||
|
head_mask_type,
|
||||||
|
streaming_info, # [sink_tokens, local_tokens]
|
||||||
|
max_seqlen_q, max_seqlen_k,
|
||||||
|
deterministic=False,
|
||||||
|
softmax_scale=None,
|
||||||
|
return_attn_probs=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 技术规格
|
||||||
|
|
||||||
|
| 特性 | 支持情况 |
|
||||||
|
|------|----------|
|
||||||
|
| **数据类型** | fp16, bf16 (bf16 需要 Ampere/Ada/Hopper GPU) |
|
||||||
|
| **Head 维度** | 32, 64, 128 |
|
||||||
|
| **Block Size** | 128 (固定) |
|
||||||
|
| **CUDA 要求** | 11.6+ |
|
||||||
|
| **PyTorch 要求** | 1.12+ |
|
||||||
|
|
||||||
|
## 性能参考
|
||||||
|
|
||||||
|
测试环境: A100 GPU, head_dim=128, 32 heads, batch_size=1
|
||||||
|
|
||||||
|
### Block Sparse 加速比
|
||||||
|
- 相比 FlashAttention2: 最高 **3-4x** 加速
|
||||||
|
- 加速随序列长度增加而提升
|
||||||
|
|
||||||
|
### Streaming 混合模式加速比
|
||||||
|
- Token streaming: 64 sink + 256 local tokens
|
||||||
|
- Block streaming: 1 sink block + 3 local blocks
|
||||||
|
- **50% Dense + 50% Streaming**: 最高 **2x** 加速
|
||||||
|
|
||||||
|
## 与 nano-vllm 的集成考虑
|
||||||
|
|
||||||
|
### 潜在集成点
|
||||||
|
|
||||||
|
1. **长上下文推理优化**
|
||||||
|
- 使用 block streaming 减少计算量
|
||||||
|
- 在 CPU offload 模式下减少 GPU-CPU 传输
|
||||||
|
|
||||||
|
2. **混合注意力策略**
|
||||||
|
- 部分 head 使用 streaming(减少计算)
|
||||||
|
- 部分 head 使用 dense(保持精度)
|
||||||
|
- 参考 Duo Attention 论文的混合模式
|
||||||
|
|
||||||
|
3. **稀疏 offload**
|
||||||
|
- 只 offload 重要 blocks 的 KV cache
|
||||||
|
- 结合 `requires_block_selection` 接口
|
||||||
|
|
||||||
|
### 实现注意事项
|
||||||
|
|
||||||
|
1. **输入格式**: 库使用 unpadded 格式(`cu_seqlens`),需要与 nano-vllm 的 padded 格式转换
|
||||||
|
2. **Block size 固定**: 库固定 block_size=128,需要适配
|
||||||
|
3. **Streaming info 配置**: 需要根据模型特性调整 sink/local 数量
|
||||||
|
|
||||||
|
## 相关工作
|
||||||
|
|
||||||
|
- [FlashAttention](https://github.com/Dao-AILab/flash-attention) - 基础实现
|
||||||
|
- [StreamingLLM](https://arxiv.org/abs/2309.17453) - Streaming attention 理论基础
|
||||||
|
- [Duo Attention](https://github.com/mit-han-lab/duo-attention) - 混合 dense/streaming 模式
|
||||||
|
- [MInference](https://arxiv.org/abs/2407.02490) - 混合 mask 方法
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
库自带测试位于 `3rdparty/Block-Sparse-Attention/block_sparse_tests/`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 正确性测试
|
||||||
|
cd 3rdparty/Block-Sparse-Attention/block_sparse_tests/fwd/test_correctness
|
||||||
|
pytest full_test.py
|
||||||
|
|
||||||
|
# 性能测试
|
||||||
|
cd 3rdparty/Block-Sparse-Attention/block_sparse_tests/fwd/test_performance
|
||||||
|
python token_streaming.py
|
||||||
|
python blocksparse.py
|
||||||
|
```
|
||||||
324
docs/development_notes.md
Normal file
324
docs/development_notes.md
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
# Notes: Sparsity Integration into Layerwise Offload
|
||||||
|
|
||||||
|
## Current Architecture Analysis
|
||||||
|
|
||||||
|
### GPU-Only Path vs Offload Path
|
||||||
|
|
||||||
|
| Aspect | GPU-Only | Layerwise Offload |
|
||||||
|
|--------|----------|-------------------|
|
||||||
|
| KV Storage | GPU blocks (paged) | CPU pinned + GPU ring buffer |
|
||||||
|
| Prefill | All layers → then attention | Per-layer: attention → offload |
|
||||||
|
| Decode | FlashAttn with block table | Ring buffer H2D → FlashAttn |
|
||||||
|
| Sparse Support | MInference via `attention.py` | Not integrated |
|
||||||
|
|
||||||
|
### MInference Flow (GPU-Only)
|
||||||
|
|
||||||
|
```
|
||||||
|
attention.py:101-105:
|
||||||
|
if context.sparse_prefill_policy is not None:
|
||||||
|
o = context.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
|
||||||
|
|
||||||
|
minference.py:sparse_prefill_attention():
|
||||||
|
1. estimate_pattern(q, k, layer_id) -> vertical_indices, slash_indices
|
||||||
|
2. _triton_mixed_sparse_attention(q, k, v, indices)
|
||||||
|
3. return output
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quest Flow (GPU Block Mode)
|
||||||
|
|
||||||
|
```
|
||||||
|
hybrid_manager.py (if using CPU offload with Quest):
|
||||||
|
select_blocks(available_blocks, ctx) -> selected block IDs
|
||||||
|
-> load selected blocks to GPU
|
||||||
|
-> standard FlashAttn with loaded blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
### Layerwise Offload Prefill Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
model_runner.py:run_layerwise_offload_prefill():
|
||||||
|
for layer_id in range(num_layers):
|
||||||
|
# QKV projection
|
||||||
|
q, k, v = qkv_proj(hidden_ln)
|
||||||
|
|
||||||
|
# RoPE
|
||||||
|
q, k = rotary_emb(positions, q, k)
|
||||||
|
|
||||||
|
# FULL attention (no sparsity!)
|
||||||
|
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
hidden_states = mlp(attn_out + residual)
|
||||||
|
|
||||||
|
# Sync offload ALL k, v to CPU
|
||||||
|
for block_id in cpu_block_ids:
|
||||||
|
k_cache_cpu[layer_id, block_id].copy_(k[start:end])
|
||||||
|
v_cache_cpu[layer_id, block_id].copy_(v[start:end])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Layerwise Offload Decode Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
model_runner.py:run_layerwise_offload_decode():
|
||||||
|
# Preload first N layers to ring buffer
|
||||||
|
for i in range(num_buffers):
|
||||||
|
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
|
||||||
|
|
||||||
|
for layer_id in range(num_layers):
|
||||||
|
current_buffer = layer_id % num_buffers
|
||||||
|
|
||||||
|
# Wait for buffer load
|
||||||
|
offload_engine.wait_buffer_load(current_buffer)
|
||||||
|
|
||||||
|
# Get prefilled KV from ring buffer (ALL blocks loaded)
|
||||||
|
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
|
||||||
|
|
||||||
|
# QKV for new token
|
||||||
|
q, k_new, v_new = qkv_proj(hidden_ln)
|
||||||
|
|
||||||
|
# Concat and full attention
|
||||||
|
k_full = torch.cat([k_prefill, k_decode_prev, k_new])
|
||||||
|
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...)
|
||||||
|
|
||||||
|
# Start loading next layer
|
||||||
|
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Integration Points
|
||||||
|
|
||||||
|
### 1. Prefill Sparse Integration Point
|
||||||
|
|
||||||
|
**Location:** `model_runner.py:535-543`
|
||||||
|
|
||||||
|
**Current:**
|
||||||
|
```python
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=total_tokens,
|
||||||
|
max_seqlen_k=total_tokens,
|
||||||
|
softmax_scale=layer.self_attn.attn.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**After Integration:**
|
||||||
|
```python
|
||||||
|
if self.sparse_policy and self.sparse_policy.supports_offload_prefill:
|
||||||
|
attn_output, k_sparse, v_sparse = self.sparse_policy.offload_prefill_attention(
|
||||||
|
q, k, v, layer_id
|
||||||
|
)
|
||||||
|
k_to_offload = k_sparse if k_sparse is not None else k
|
||||||
|
v_to_offload = v_sparse if v_sparse is not None else v
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
||||||
|
k_to_offload, v_to_offload = k, v
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Decode Sparse Integration Point
|
||||||
|
|
||||||
|
**Location:** `model_runner.py:636-637` and `model_runner.py:704-706`
|
||||||
|
|
||||||
|
**Current (preload):**
|
||||||
|
```python
|
||||||
|
for i in range(num_preload):
|
||||||
|
offload_engine.load_layer_kv_to_buffer(
|
||||||
|
i, i, cpu_block_table, valid_tokens_per_block
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**After Integration:**
|
||||||
|
```python
|
||||||
|
for i in range(num_preload):
|
||||||
|
layer_to_load = i
|
||||||
|
if self.sparse_policy and self.sparse_policy.supports_offload_decode:
|
||||||
|
# Prepare q for this layer (need to compute ahead)
|
||||||
|
# OR: use previous layer's pattern as estimate
|
||||||
|
selected_blocks = self.sparse_policy.select_offload_blocks(
|
||||||
|
None, # q not available yet at preload
|
||||||
|
layer_to_load,
|
||||||
|
cpu_block_table,
|
||||||
|
valid_tokens_per_block
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
selected_blocks = cpu_block_table
|
||||||
|
offload_engine.load_sparse_layer_kv_to_buffer(
|
||||||
|
i, layer_to_load, selected_blocks, valid_tokens_per_block
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Challenge:** Q is not available during preload phase!
|
||||||
|
|
||||||
|
**Solutions:**
|
||||||
|
1. Skip sparse preload, only sparse for non-preloaded layers
|
||||||
|
2. Use previous decode step's pattern as estimate
|
||||||
|
3. Add preload hook to sparse policy
|
||||||
|
|
||||||
|
### 3. Offload Engine Extension
|
||||||
|
|
||||||
|
**New Method in OffloadEngine:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def load_sparse_layer_kv_to_buffer(
|
||||||
|
self,
|
||||||
|
buffer_idx: int,
|
||||||
|
layer_id: int,
|
||||||
|
selected_cpu_block_ids: List[int],
|
||||||
|
original_valid_tokens: List[int],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Load only selected blocks from CPU to buffer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total tokens loaded (may be less than full sequence)
|
||||||
|
"""
|
||||||
|
stream = self.layer_load_streams[buffer_idx]
|
||||||
|
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
stream.wait_event(self.buffer_compute_done_events[buffer_idx])
|
||||||
|
|
||||||
|
# Build mapping: original block -> selected position
|
||||||
|
offset = 0
|
||||||
|
for i, cpu_block_id in enumerate(selected_cpu_block_ids):
|
||||||
|
# Find original index to get valid tokens
|
||||||
|
valid_tokens = original_valid_tokens[i] # Need mapping
|
||||||
|
|
||||||
|
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
# ... v_cache same
|
||||||
|
|
||||||
|
offset += valid_tokens
|
||||||
|
|
||||||
|
self.buffer_load_events[buffer_idx].record(stream)
|
||||||
|
|
||||||
|
return offset # Caller needs to know actual loaded tokens
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metadata Flow for Quest
|
||||||
|
|
||||||
|
### During Prefill Offload
|
||||||
|
|
||||||
|
**Current:** No metadata collection in offload path
|
||||||
|
|
||||||
|
**Required:** Call `on_prefill_offload()` for each block
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In run_layerwise_offload_prefill()
|
||||||
|
for i, cpu_block_id in enumerate(cpu_block_ids):
|
||||||
|
start = i * block_size
|
||||||
|
end = min(start + block_size, total_tokens)
|
||||||
|
actual_size = end - start
|
||||||
|
|
||||||
|
# BEFORE offload: update Quest metadata
|
||||||
|
if self.sparse_policy and hasattr(self.sparse_policy, 'on_prefill_offload'):
|
||||||
|
self.sparse_policy.on_prefill_offload(
|
||||||
|
cpu_block_id, layer_id, k[start:end], actual_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Offload
|
||||||
|
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
|
||||||
|
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quest Metadata Shape
|
||||||
|
|
||||||
|
```python
|
||||||
|
# BlockMetadataManager
|
||||||
|
key_min: [num_blocks, num_layers, num_kv_heads, head_dim] # Min key per block per layer
|
||||||
|
key_max: [num_blocks, num_layers, num_kv_heads, head_dim] # Max key per block per layer
|
||||||
|
```
|
||||||
|
|
||||||
|
**Memory:** 2 * num_blocks * num_layers * kv_heads * head_dim * 2 bytes
|
||||||
|
- Example: 1000 blocks * 28 layers * 4 heads * 128 dim * 2 * 2 = ~57 MB
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### MInference Prefill Overhead
|
||||||
|
|
||||||
|
| Operation | Time (64K seq) |
|
||||||
|
|-----------|----------------|
|
||||||
|
| Pattern estimation (last-64) | ~5ms |
|
||||||
|
| Triton sparse attention | ~80ms |
|
||||||
|
| Full FlashAttention | ~100ms |
|
||||||
|
| **Net Speedup** | ~15-20% |
|
||||||
|
|
||||||
|
### Quest Decode Overhead
|
||||||
|
|
||||||
|
| Operation | Time |
|
||||||
|
|-----------|------|
|
||||||
|
| Block scoring (GPU metadata) | ~0.1ms |
|
||||||
|
| Top-K selection | ~0.05ms |
|
||||||
|
| Sparse H2D load (8 blocks) | ~2ms |
|
||||||
|
| Full H2D load (100 blocks) | ~20ms |
|
||||||
|
| **Net Speedup** | ~10x H2D |
|
||||||
|
|
||||||
|
### Memory Trade-offs
|
||||||
|
|
||||||
|
| Mode | GPU Memory | CPU Memory | H2D Bandwidth |
|
||||||
|
|------|------------|------------|---------------|
|
||||||
|
| Full offload | Ring buffer | Full KV | High |
|
||||||
|
| Sparse offload | Ring buffer | Full KV | Low (subset) |
|
||||||
|
| Aggressive sparse | Ring buffer | Sparse KV | Very low |
|
||||||
|
|
||||||
|
## Edge Cases
|
||||||
|
|
||||||
|
### 1. Short Sequences (< sparse threshold)
|
||||||
|
|
||||||
|
```python
|
||||||
|
if total_tokens < sparse_threshold:
|
||||||
|
# Fall back to full attention
|
||||||
|
use_sparse = False
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. First Decode Step (no previous Q)
|
||||||
|
|
||||||
|
Quest can't score blocks without Q. Options:
|
||||||
|
- Use average embedding as proxy
|
||||||
|
- Load all blocks for first step
|
||||||
|
- Use prefill pattern as estimate
|
||||||
|
|
||||||
|
### 3. Variable Sequence Lengths in Batch
|
||||||
|
|
||||||
|
Layerwise offload currently only supports batch_size=1:
|
||||||
|
```python
|
||||||
|
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
|
||||||
|
```
|
||||||
|
|
||||||
|
Sparse integration should maintain this constraint.
|
||||||
|
|
||||||
|
### 4. Ring Buffer vs Sparse Load Mismatch
|
||||||
|
|
||||||
|
Ring buffer assumes fixed `total_prefill_tokens`:
|
||||||
|
```python
|
||||||
|
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, total_prefill_tokens)
|
||||||
|
```
|
||||||
|
|
||||||
|
Sparse load has variable token count. Need:
|
||||||
|
```python
|
||||||
|
# Track actual loaded tokens per buffer
|
||||||
|
loaded_tokens[buffer_idx] = sparse_load_count
|
||||||
|
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, loaded_tokens[buffer_idx])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
### Unit Tests
|
||||||
|
|
||||||
|
1. `test_sparse_policy_interface.py` - Verify new interface methods
|
||||||
|
2. `test_minference_offload.py` - MInference in offload mode
|
||||||
|
3. `test_quest_offload.py` - Quest block selection in offload mode
|
||||||
|
|
||||||
|
### Integration Tests
|
||||||
|
|
||||||
|
1. `test_offload_sparse_e2e.py` - Full prefill+decode with sparsity
|
||||||
|
2. `test_accuracy_comparison.py` - Compare outputs: full vs sparse
|
||||||
|
|
||||||
|
### Benchmarks
|
||||||
|
|
||||||
|
1. `bench_offload_sparse.py` - Compare:
|
||||||
|
- Full offload (baseline)
|
||||||
|
- MInference prefill + Quest decode
|
||||||
|
- Aggressive sparse offload
|
||||||
597
docs/xattention_analysis.md
Normal file
597
docs/xattention_analysis.md
Normal file
@@ -0,0 +1,597 @@
|
|||||||
|
# COMPASS XAttention Implementation Analysis
|
||||||
|
|
||||||
|
**Analysis Date**: 2026-01-14
|
||||||
|
**Researcher**: Claude Code Agent
|
||||||
|
**Source**: `/home/zijie/Code/COMPASS/compass/src/`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Executive Summary
|
||||||
|
|
||||||
|
COMPASS XAttention is a **block sparse attention** implementation that uses:
|
||||||
|
1. **Approximation phase** (`xattn_estimate`) to compute attention importance and select blocks
|
||||||
|
2. **Computation phase** (`Xattention_prefill`) to compute sparse attention using `block_sparse_attn_func`
|
||||||
|
3. **Triton kernels** for efficient block-wise GEMM and softmax operations
|
||||||
|
|
||||||
|
**Key Integration Constraint**: Requires `block_sparse_attn_func` from flash-attention library, which is a **C++ CUDA extension** that must be compiled separately.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Function: `xattn_estimate()`
|
||||||
|
|
||||||
|
**Purpose**: Estimate attention importance and select which blocks to compute
|
||||||
|
|
||||||
|
### Input Parameters
|
||||||
|
|
||||||
|
| Parameter | Type | Default | Description |
|
||||||
|
|-----------|------|---------|-------------|
|
||||||
|
| `query_states` | Tensor | - | Shape: `(batch, num_heads, q_len, head_dim)` |
|
||||||
|
| `key_states` | Tensor | - | Shape: `(batch, num_kv_heads, k_len, head_dim)` |
|
||||||
|
| `block_size` | int | - | Size of attention blocks (typically 128) |
|
||||||
|
| `stride` | int | - | Downsampling stride for approximation |
|
||||||
|
| `norm` | float | 1 | Normalization factor for attention scaling |
|
||||||
|
| `softmax` | bool | True | Whether to apply softmax in estimation |
|
||||||
|
| `threshold` | float | 0.9 | Block selection threshold (0-1) |
|
||||||
|
| `chunk_size` | int | 16384 | Processing chunk size |
|
||||||
|
| `select_mode` | str | "inverse" | Pattern selection mode |
|
||||||
|
| `use_triton` | bool | True | Use Triton kernels (requires SM 80+) |
|
||||||
|
| `causal` | bool | True | Apply causal masking |
|
||||||
|
| `kdb` | int | 1 | Key downsampling factor |
|
||||||
|
| `keep_sink` | bool | False | Always attend to first token |
|
||||||
|
| `keep_recent` | bool | False | Always attend to recent tokens |
|
||||||
|
|
||||||
|
### Output
|
||||||
|
|
||||||
|
```python
|
||||||
|
returns: (attn_sums, simple_masks)
|
||||||
|
attn_sums: Tensor[float32]
|
||||||
|
Shape: (batch, num_heads, num_q_blocks, num_k_blocks_per_chunk)
|
||||||
|
Contains aggregated attention weights per block
|
||||||
|
|
||||||
|
simple_masks: Tensor[bool]
|
||||||
|
Shape: (batch, num_heads, num_q_blocks, num_k_blocks)
|
||||||
|
Boolean mask indicating which blocks to compute
|
||||||
|
```
|
||||||
|
|
||||||
|
### Algorithm
|
||||||
|
|
||||||
|
#### Step 1: Padding and Chunking
|
||||||
|
```python
|
||||||
|
# Pad sequences to chunk_size boundaries
|
||||||
|
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
|
||||||
|
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
|
||||||
|
|
||||||
|
# Compute number of blocks and chunks
|
||||||
|
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
|
||||||
|
k_block_num = (k_len + k_num_to_pad) // block_size
|
||||||
|
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
|
||||||
|
q_block_num = (q_len + q_num_to_pad) // block_size
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2: Pattern Selection (stride-based downsampling)
|
||||||
|
|
||||||
|
**Purpose**: Reduce computation by `stride` factor using patterned selection
|
||||||
|
|
||||||
|
**Modes**:
|
||||||
|
1. **`"inverse"`** (default): Inverse stride pattern
|
||||||
|
```python
|
||||||
|
# Key: regular stride [0, stride, 2*stride, ...]
|
||||||
|
# Query: reverse stride [(stride-1), (stride-1-stride), ...]
|
||||||
|
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
|
||||||
|
reshaped_query = torch.cat([query_states[:, :, (stride-1-q)::stride*kdb, :] for q in range(stride)])
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **`"slash"`**: Slash pattern (diagonal)
|
||||||
|
```python
|
||||||
|
# Both use regular stride
|
||||||
|
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
|
||||||
|
reshaped_query = torch.cat([query_states[:, :, q::stride, :] for q in range(stride)])
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **`"random"`**: Random permutation
|
||||||
|
4. **`"double"`, `"triple"`**: Data augmentation modes
|
||||||
|
|
||||||
|
#### Step 3: Chunk-wise Attention Estimation
|
||||||
|
|
||||||
|
For each query chunk:
|
||||||
|
|
||||||
|
**If `use_triton=True`** (fast path):
|
||||||
|
```python
|
||||||
|
# Triton kernel 1: Compute attention scores with fused reshape
|
||||||
|
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
||||||
|
query_chunk, key_states, stride,
|
||||||
|
chunk_start, chunk_end, is_causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
|
# Triton kernel 2: Softmax + block aggregation
|
||||||
|
attn_sum = softmax_fuse_block_sum(
|
||||||
|
attn_weights_slice, reshaped_block_size, segment_size,
|
||||||
|
chunk_start, chunk_end, real_q_len, scale, is_causal
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**If `use_triton=False`** (PyTorch fallback):
|
||||||
|
```python
|
||||||
|
# Standard matrix multiplication
|
||||||
|
attn_weights_slice = torch.matmul(chunked_query, reshaped_key.transpose(2, 3))
|
||||||
|
|
||||||
|
# Scale and apply causal mask
|
||||||
|
attn_weights_slice = attn_weights_slice / sqrt(head_dim) / stride / norm
|
||||||
|
attn_weights_slice = attn_weights_slice + causal_mask
|
||||||
|
|
||||||
|
# Softmax
|
||||||
|
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1)
|
||||||
|
|
||||||
|
# Aggregate to block level
|
||||||
|
attn_sum = attn_weights_slice.view(
|
||||||
|
batch, heads, num_blocks_per_chunk, block_size//kdb, -1, block_size
|
||||||
|
).sum(dim=-1).sum(dim=-2)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 4: Block Selection
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Select blocks based on threshold
|
||||||
|
simple_mask = find_blocks_chunked(
|
||||||
|
attn_sum,
|
||||||
|
current_index, # Starting block index
|
||||||
|
threshold, # 0.9 = select blocks covering 90% of attention mass
|
||||||
|
None, # or num_to_choose for top-k selection
|
||||||
|
decoding=False,
|
||||||
|
mode="prefill",
|
||||||
|
causal=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Selection Algorithm** (`find_blocks_chunked`):
|
||||||
|
1. Sort blocks by attention weight (descending)
|
||||||
|
2. Compute cumulative sum
|
||||||
|
3. Select blocks until `cumulative_sum >= total_sum * threshold`
|
||||||
|
4. Enforce causal constraints (no future blocks)
|
||||||
|
5. Always include sink token (first block) if `keep_sink=True`
|
||||||
|
6. Always include diagonal blocks if `keep_recent=True`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Function: `Xattention_prefill()`
|
||||||
|
|
||||||
|
**Purpose**: Compute sparse attention using estimated block mask
|
||||||
|
|
||||||
|
### Input Parameters
|
||||||
|
|
||||||
|
| Parameter | Type | Default | Description |
|
||||||
|
|-----------|------|---------|-------------|
|
||||||
|
| `query_states` | Tensor | - | `(batch, num_heads, q_len, head_dim)` |
|
||||||
|
| `key_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
|
||||||
|
| `value_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
|
||||||
|
| `stride` | int | - | Downsampling stride for estimation |
|
||||||
|
| `norm` | float | 1 | Normalization factor |
|
||||||
|
| `threshold` | float | 0.8 | Block selection threshold |
|
||||||
|
| `block_size` | int | 128 | **MUST be 128** (hardcoded requirement) |
|
||||||
|
| `use_triton` | bool | True | Use Triton kernels in estimation |
|
||||||
|
| `causal` | bool | True | Apply causal masking |
|
||||||
|
| `kdb` | int | 1 | Key downsampling factor |
|
||||||
|
| `chunk_size` | int | None | Auto-computed if None |
|
||||||
|
| `keep_sink` | bool | False | Always attend to first token |
|
||||||
|
| `keep_recent` | bool | False | Always attend to recent tokens |
|
||||||
|
|
||||||
|
### Output
|
||||||
|
|
||||||
|
```python
|
||||||
|
returns: attn_output
|
||||||
|
attn_output: Tensor
|
||||||
|
Shape: (batch, num_heads, q_len, head_dim)
|
||||||
|
Sparse attention output
|
||||||
|
```
|
||||||
|
|
||||||
|
### Algorithm Flow
|
||||||
|
|
||||||
|
#### Step 1: Auto-compute chunk_size
|
||||||
|
```python
|
||||||
|
if chunk_size is None:
|
||||||
|
chunk_size = int(max(
|
||||||
|
min(
|
||||||
|
max(2048, 1 << (k_len - 1).bit_length()), # Round to power of 2
|
||||||
|
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()), # Memory constraint
|
||||||
|
),
|
||||||
|
2048, # Minimum
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example**:
|
||||||
|
- `k_len=8192` → `chunk_size=8192`
|
||||||
|
- `k_len=32768` → `chunk_size=16384`
|
||||||
|
- `k_len=65536` → `chunk_size=16384`
|
||||||
|
|
||||||
|
#### Step 2: Estimate attention and select blocks
|
||||||
|
```python
|
||||||
|
attn_sums, approx_simple_mask = xattn_estimate(
|
||||||
|
query_states, key_states,
|
||||||
|
block_size=block_size, stride=stride, norm=norm,
|
||||||
|
threshold=threshold, select_mode="inverse",
|
||||||
|
use_triton=use_triton, causal=causal,
|
||||||
|
chunk_size=chunk_size, kdb=kdb,
|
||||||
|
keep_sink=keep_sink, keep_recent=keep_recent
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 3: Prepare inputs for block_sparse_attn_func
|
||||||
|
```python
|
||||||
|
# Hard constraints
|
||||||
|
assert block_size == 128
|
||||||
|
assert batch_size == 1
|
||||||
|
|
||||||
|
# Reshape to (seq_len, num_heads, head_dim)
|
||||||
|
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
|
||||||
|
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||||
|
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||||
|
|
||||||
|
# Cumulative sequence lengths
|
||||||
|
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
|
||||||
|
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
# Head mask type (all heads use mask)
|
||||||
|
head_mask_type = torch.tensor([1 for _ in range(num_heads)], dtype=torch.int32)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 4: Call block_sparse_attn_func
|
||||||
|
```python
|
||||||
|
attn_output = block_sparse_attn_func(
|
||||||
|
query_states, # (q_len, num_heads, head_dim)
|
||||||
|
key_states, # (k_len, num_heads, head_dim)
|
||||||
|
value_states, # (k_len, num_heads, head_dim)
|
||||||
|
q_cu_seq_lens, # [0, q_len]
|
||||||
|
k_cu_seq_lens, # [0, k_len]
|
||||||
|
head_mask_type, # [1, 1, ..., 1]
|
||||||
|
None, # No custom layout
|
||||||
|
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(), # Block mask
|
||||||
|
q_len,
|
||||||
|
k_len,
|
||||||
|
p_dropout=0.0,
|
||||||
|
deterministic=True,
|
||||||
|
is_causal=causal
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 5: Reshape output
|
||||||
|
```python
|
||||||
|
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||||
|
# Output shape: (batch, num_heads, q_len, head_dim)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Triton Kernel Dependencies
|
||||||
|
|
||||||
|
### Kernel 1: `flat_group_gemm_fuse_reshape_kernel`
|
||||||
|
|
||||||
|
**Purpose**: Compute QK^T with stride-based reshaping
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
- Loads `stride` keys and queries at once
|
||||||
|
- Fused strided access pattern
|
||||||
|
- Causal masking support
|
||||||
|
- Block size auto-selection based on GPU memory
|
||||||
|
|
||||||
|
**Block Size Selection**:
|
||||||
|
```python
|
||||||
|
# RTX 3090 (<30GB): BLOCK_M=64, BLOCK_N=64
|
||||||
|
# A100/H100 (>=30GB): BLOCK_M=128, BLOCK_N=128
|
||||||
|
```
|
||||||
|
|
||||||
|
**Signature**:
|
||||||
|
```python
|
||||||
|
flat_group_gemm_fuse_reshape(
|
||||||
|
query_states, # (batch, heads, q_len, head_dim)
|
||||||
|
key_states, # (batch, heads, k_len, head_dim)
|
||||||
|
stride, # Downsampling factor
|
||||||
|
chunk_start, # Start position in keys
|
||||||
|
chunk_end, # End position in keys
|
||||||
|
is_causal=True
|
||||||
|
)
|
||||||
|
# Returns: (batch, heads, q_len//stride, k_len//stride)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Kernel 2: `softmax_fuse_block_sum_kernel_causal` / `_non_causal`
|
||||||
|
|
||||||
|
**Purpose**: Online softmax with block aggregation
|
||||||
|
|
||||||
|
**Algorithm**:
|
||||||
|
1. **Forward pass** (compute m_i, l_i):
|
||||||
|
```
|
||||||
|
m_i = max(m_i, m_local)
|
||||||
|
alpha = exp(m_i - m_new)
|
||||||
|
l_i = l_i * alpha + sum(exp(X - m_new))
|
||||||
|
```
|
||||||
|
2. **Backward pass** (compute softmax with scaling):
|
||||||
|
```
|
||||||
|
softmax = exp(X - m_i) / l_i
|
||||||
|
aggregate to blocks: sum(softmax) over block_size
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
- Single-pass softmax (no materializing full attention matrix)
|
||||||
|
- Causal masking integrated
|
||||||
|
- Outputs block-level sums directly
|
||||||
|
|
||||||
|
**Signature**:
|
||||||
|
```python
|
||||||
|
softmax_fuse_block_sum(
|
||||||
|
attn_weights_slice, # (batch, heads, q_len, k_len)
|
||||||
|
reshaped_block_size, # Block size (128//stride)
|
||||||
|
segment_size, # Processing segment (min(4096, block_size))
|
||||||
|
chunk_start, # Start position
|
||||||
|
chunk_end, # End position
|
||||||
|
real_q_len, # Actual query length (before padding)
|
||||||
|
scale, # 1.4426950408889634 / sqrt(head_dim) / stride / norm
|
||||||
|
is_causal=True
|
||||||
|
)
|
||||||
|
# Returns: (batch, heads, q_len//block_size, k_len//block_size)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Key Parameters and Their Meanings
|
||||||
|
|
||||||
|
### Critical Parameters
|
||||||
|
|
||||||
|
| Parameter | Meaning | Typical Value | Impact |
|
||||||
|
|-----------|---------|---------------|--------|
|
||||||
|
| `block_size` | Block granularity | 128 | **Fixed at 128**, affects mask granularity |
|
||||||
|
| `stride` | Downsampling factor | 4-16 | Higher = faster but less accurate |
|
||||||
|
| `threshold` | Sparsity level | 0.8-0.9 | Higher = denser mask, more computation |
|
||||||
|
| `chunk_size` | Processing chunk | 16384 | Affects memory and efficiency |
|
||||||
|
| `kdb` | Key downsampling boost | 1 | Experimental, use 1 |
|
||||||
|
| `norm` | Scaling factor | 1.0 | Attention temperature control |
|
||||||
|
|
||||||
|
### Trade-offs
|
||||||
|
|
||||||
|
**Stride (`stride`)**:
|
||||||
|
- `stride=1`: No approximation, same as dense attention
|
||||||
|
- `stride=4`: 4x faster estimation, good accuracy
|
||||||
|
- `stride=8`: 8x faster, moderate accuracy loss
|
||||||
|
- `stride=16`: 16x faster, significant accuracy loss
|
||||||
|
|
||||||
|
**Threshold (`threshold`)**:
|
||||||
|
- `threshold=0.8`: Select blocks covering 80% of attention mass (~20% sparsity)
|
||||||
|
- `threshold=0.9`: Select blocks covering 90% of attention mass (~10% sparsity)
|
||||||
|
- `threshold=0.95`: Very dense, only prunes ~5% of blocks
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Dependencies
|
||||||
|
|
||||||
|
### Required Libraries
|
||||||
|
|
||||||
|
1. **`block_sparse_attn`** (CRITICAL)
|
||||||
|
- Source: `/home/zijie/Code/COMPASS/3rdparty/flash-attention/`
|
||||||
|
- Function: `block_sparse_attn_func`
|
||||||
|
- Type: **C++ CUDA extension**
|
||||||
|
- Build: Requires compilation with `torch.utils.cpp_extension`
|
||||||
|
|
||||||
|
2. **Triton** (optional but recommended)
|
||||||
|
- Required for: `use_triton=True`
|
||||||
|
- GPU requirement: SM 80+ (A100, RTX 3090, H100, etc.)
|
||||||
|
- Check: `torch.cuda.get_device_properties().major >= 8`
|
||||||
|
|
||||||
|
3. **PyTorch**
|
||||||
|
- Version: Compatible with flash-attention
|
||||||
|
- Features: F.pad, matmul, softmax, view, transpose
|
||||||
|
|
||||||
|
### Dependency Tree
|
||||||
|
|
||||||
|
```
|
||||||
|
Xattention_prefill
|
||||||
|
├── xattn_estimate
|
||||||
|
│ ├── flat_group_gemm_fuse_reshape (Triton)
|
||||||
|
│ ├── softmax_fuse_block_sum (Triton)
|
||||||
|
│ └── find_blocks_chunked (PyTorch)
|
||||||
|
└── block_sparse_attn_func (C++ CUDA)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Integration Issues for nano-vllm
|
||||||
|
|
||||||
|
### Critical Issue 1: `block_sparse_attn_func` Dependency
|
||||||
|
|
||||||
|
**Problem**: `block_sparse_attn_func` is a **C++ CUDA extension** that must be compiled from flash-attention source.
|
||||||
|
|
||||||
|
**Options**:
|
||||||
|
1. **Compile flash-attention with block sparse support**
|
||||||
|
```bash
|
||||||
|
cd /home/zijie/Code/COMPASS/3rdparty/flash-attention
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
- Risk: May conflict with existing flash-attention installation
|
||||||
|
- Complexity: High (C++ compilation)
|
||||||
|
|
||||||
|
2. **Replace with FlashInfer block sparse**
|
||||||
|
- FlashInfer is already a dependency
|
||||||
|
- Has similar block sparse attention
|
||||||
|
- Need to adapt interface
|
||||||
|
|
||||||
|
3. **Custom CUDA kernel**
|
||||||
|
- Implement simplified block sparse attention
|
||||||
|
- High development cost
|
||||||
|
- Maintenance burden
|
||||||
|
|
||||||
|
### Critical Issue 2: Hard-coded Constraints
|
||||||
|
|
||||||
|
```python
|
||||||
|
assert block_size == 128 # Line 358
|
||||||
|
assert batch_size == 1 # Line 359
|
||||||
|
```
|
||||||
|
|
||||||
|
**Impact**:
|
||||||
|
- Cannot process multiple sequences in one batch
|
||||||
|
- Fixed block size limits flexibility
|
||||||
|
- Must work around these constraints
|
||||||
|
|
||||||
|
### Critical Issue 3: Triton GPU Requirement
|
||||||
|
|
||||||
|
```python
|
||||||
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
if props.major < 8:
|
||||||
|
use_triton = False
|
||||||
|
```
|
||||||
|
|
||||||
|
**Impact**:
|
||||||
|
- Triton kernels only work on SM 80+ (A100, RTX 3090, H100)
|
||||||
|
- Older GPUs (V100, T4, RTX 2080) fall back to slow PyTorch implementation
|
||||||
|
- RTX 3090 works but uses smaller block sizes (64 vs 128)
|
||||||
|
|
||||||
|
### Issue 4: Memory Layout
|
||||||
|
|
||||||
|
**XAttention expects**:
|
||||||
|
```python
|
||||||
|
query_states: (batch, num_heads, q_len, head_dim)
|
||||||
|
```
|
||||||
|
|
||||||
|
**nano-vllm uses**:
|
||||||
|
```python
|
||||||
|
query_states: (num_heads, total_tokens, head_dim) # Flattened batch
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required**: Transpose and reshape before/after calling XAttention
|
||||||
|
|
||||||
|
### Issue 5: Chunking Incompatibility
|
||||||
|
|
||||||
|
**XAttention**: Processes in fixed-size chunks (e.g., 16384 tokens)
|
||||||
|
- Requires padding to chunk boundaries
|
||||||
|
- Adds overhead for short sequences
|
||||||
|
|
||||||
|
**nano-vllm**: Processes variable-length requests
|
||||||
|
- No padding requirement
|
||||||
|
- Dynamic batch sizing
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Integration Strategy
|
||||||
|
|
||||||
|
### Recommended Approach: **Wrapper with FlashInfer**
|
||||||
|
|
||||||
|
1. **Keep `xattn_estimate`** (pure PyTorch + Triton)
|
||||||
|
- No external dependencies
|
||||||
|
- Computes block mask
|
||||||
|
|
||||||
|
2. **Replace `block_sparse_attn_func` with FlashInfer**
|
||||||
|
- FlashInfer: `flashinfer.single_prefill_with_kv_cache`
|
||||||
|
- Similar API, already compiled
|
||||||
|
- Supports block sparse
|
||||||
|
|
||||||
|
3. **Adapt mask format**
|
||||||
|
- XAttention: `(batch, heads, q_blocks, k_blocks)` boolean mask
|
||||||
|
- FlashInfer: `(num_qo, num_kv)` boolean mask or custom format
|
||||||
|
|
||||||
|
4. **Handle constraints**
|
||||||
|
- Enforce `batch_size=1` by processing one request at a time
|
||||||
|
- Keep `block_size=128` as requirement
|
||||||
|
|
||||||
|
### Alternative: **Pure PyTorch Implementation**
|
||||||
|
|
||||||
|
1. Extract estimation algorithm
|
||||||
|
2. Implement sparse attention using PyTorch operations
|
||||||
|
3. Use FlashInfer for final computation
|
||||||
|
4. No Triton dependency
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Code Example: Adaptation
|
||||||
|
|
||||||
|
```python
|
||||||
|
def xattention_prefill_adapted(
|
||||||
|
query_states, # (num_heads, q_len, head_dim)
|
||||||
|
key_states, # (num_heads, k_len, head_dim)
|
||||||
|
value_states, # (num_heads, k_len, head_dim)
|
||||||
|
stride=4,
|
||||||
|
threshold=0.9,
|
||||||
|
block_size=128,
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
# Step 1: Add batch dimension
|
||||||
|
q = query_states.unsqueeze(0) # (1, heads, q_len, dim)
|
||||||
|
k = key_states.unsqueeze(0)
|
||||||
|
v = value_states.unsqueeze(0)
|
||||||
|
|
||||||
|
# Step 2: Estimate mask (no external dependency)
|
||||||
|
_, block_mask = xattn_estimate(
|
||||||
|
q, k,
|
||||||
|
block_size=block_size,
|
||||||
|
stride=stride,
|
||||||
|
threshold=threshold,
|
||||||
|
use_triton=True,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
# block_mask: (1, heads, q_blocks, k_blocks)
|
||||||
|
|
||||||
|
# Step 3: Convert block mask to token mask
|
||||||
|
q_blocks, k_blocks = block_mask.shape[-2:]
|
||||||
|
token_mask = block_mask.repeat_interleave(block_size, dim=-2)
|
||||||
|
token_mask = token_mask.repeat_interleave(block_size, dim=-1)
|
||||||
|
token_mask = token_mask[:, :, :q.size(2), :k.size(2)] # Trim padding
|
||||||
|
|
||||||
|
# Step 4: Use FlashInfer with mask
|
||||||
|
from flashinfer import single_prefill_with_kv_cache
|
||||||
|
output = single_prefill_with_kv_cache(
|
||||||
|
q.squeeze(0),
|
||||||
|
k.squeeze(0),
|
||||||
|
v.squeeze(0),
|
||||||
|
custom_mask=token_mask.squeeze(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return output # (num_heads, q_len, head_dim)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Summary of Findings
|
||||||
|
|
||||||
|
### Advantages
|
||||||
|
|
||||||
|
1. **Accurate approximation**: Pattern-based stride selection preserves attention patterns
|
||||||
|
2. **Flexible sparsity**: Threshold-based control over computation
|
||||||
|
3. **GPU optimization**: Triton kernels for estimation phase
|
||||||
|
4. **Proven in practice**: Used in COMPASS system
|
||||||
|
|
||||||
|
### Challenges
|
||||||
|
|
||||||
|
1. **Hard dependency**: `block_sparse_attn_func` requires C++ compilation
|
||||||
|
2. **Rigid constraints**: `block_size=128`, `batch_size=1`
|
||||||
|
3. **GPU-specific**: Triton only on SM 80+
|
||||||
|
4. **Memory layout mismatch**: Requires reshape/transpose
|
||||||
|
5. **Chunking overhead**: Padding to chunk boundaries
|
||||||
|
|
||||||
|
### Integration Complexity
|
||||||
|
|
||||||
|
| Component | Complexity | Risk |
|
||||||
|
|-----------|------------|------|
|
||||||
|
| `xattn_estimate` | Medium | Low (PyTorch + Triton) |
|
||||||
|
| `block_sparse_attn_func` | High | **Critical** (C++ dependency) |
|
||||||
|
| Interface adaptation | Low | Low (reshape) |
|
||||||
|
| Constraint handling | Medium | Medium (workarounds) |
|
||||||
|
|
||||||
|
**Overall Integration Risk**: **HIGH** (due to C++ dependency)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Next Steps
|
||||||
|
|
||||||
|
1. **Evaluate FlashInfer compatibility**
|
||||||
|
- Can FlashInfer replace `block_sparse_attn_func`?
|
||||||
|
- What mask format does it expect?
|
||||||
|
|
||||||
|
2. **Prototype estimation phase**
|
||||||
|
- Extract `xattn_estimate` function
|
||||||
|
- Test with nano-vllm inputs
|
||||||
|
- Validate mask quality
|
||||||
|
|
||||||
|
3. **Benchmark Triton kernels**
|
||||||
|
- Compare Triton vs PyTorch estimation
|
||||||
|
- Measure speedup on RTX 3090
|
||||||
|
- Profile memory usage
|
||||||
|
|
||||||
|
4. **Design interface**
|
||||||
|
- Define nano-vllm sparse attention API
|
||||||
|
- Specify mask format
|
||||||
|
- Plan integration points
|
||||||
961
docs/xattention_integration.md
Normal file
961
docs/xattention_integration.md
Normal file
@@ -0,0 +1,961 @@
|
|||||||
|
# XAttention 集成指南
|
||||||
|
|
||||||
|
本文档详细记录了将 COMPASS 的 XAttention 算法集成到 nano-vllm 的完整过程,包括算法原理、源码分析、设计决策、实现细节和测试验证。
|
||||||
|
|
||||||
|
## 目录
|
||||||
|
|
||||||
|
1. [背景](#1-背景)
|
||||||
|
2. [XAttention 算法原理](#2-xattention-算法原理)
|
||||||
|
3. [COMPASS 源码分析](#3-compass-源码分析)
|
||||||
|
4. [集成设计决策](#4-集成设计决策)
|
||||||
|
5. [实现细节](#5-实现细节)
|
||||||
|
6. [问题与解决方案](#6-问题与解决方案)
|
||||||
|
7. [测试验证](#7-测试验证)
|
||||||
|
8. [使用指南](#8-使用指南)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 背景
|
||||||
|
|
||||||
|
### 1.1 为什么需要 XAttention
|
||||||
|
|
||||||
|
- **长上下文推理需求**:随着 LLM 上下文长度扩展到 32k、64k 甚至更长,传统注意力机制的计算复杂度 O(n²) 成为瓶颈
|
||||||
|
- **COMPASS 算法**:通过 chunked estimation 和 block sparse attention 实现 O(n) 复杂度
|
||||||
|
- **nano-vllm 集成目标**:在 CPU offload 模式下支持高效的长上下文推理
|
||||||
|
|
||||||
|
### 1.2 集成范围
|
||||||
|
|
||||||
|
**仅关注 offload 执行路径**:
|
||||||
|
- `run_layerwise_offload_prefill()` - layer-wise chunked prefill
|
||||||
|
- CPU offload 模式下的 KV cache 管理
|
||||||
|
- 与 `SparsePolicy` 框架的集成
|
||||||
|
|
||||||
|
### 1.3 参考
|
||||||
|
|
||||||
|
- COMPASS 源码:`/home/zijie/Code/COMPASS/compass/src/`
|
||||||
|
- 关键文件:`Xattention.py`, `kernels.py`, `utils.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. XAttention 算法原理
|
||||||
|
|
||||||
|
### 2.1 两阶段设计
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ XAttention 流程 │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ Phase 1: Chunked Estimation │
|
||||||
|
│ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │
|
||||||
|
│ │ Query Chunk │ -> │ Triton GEMM │ -> │ Attn Scores │ │
|
||||||
|
│ │ (stride=8) │ │ (fused) │ │ (per block) │ │
|
||||||
|
│ └─────────────┘ └──────────────┘ └─────────────┘ │
|
||||||
|
│ ↓ │
|
||||||
|
│ ┌─────────────┐ │
|
||||||
|
│ │ Block Mask │ │
|
||||||
|
│ │ (threshold) │ │
|
||||||
|
│ └─────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ Phase 2: Block Sparse Attention │
|
||||||
|
│ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │
|
||||||
|
│ │ Selected Q │ -> │ Block Sparse │ -> │ Output │ │
|
||||||
|
│ │ + Selected K│ │ Attention │ │ │ │
|
||||||
|
│ └─────────────┘ └──────────────┘ └─────────────┘ │
|
||||||
|
│ │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 关键参数
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `stride` | 8 | Q/K 重组步长 |
|
||||||
|
| `block_size` | 128 | Block 大小(tokens) |
|
||||||
|
| `threshold` | 0.9 | Block 选择阈值 (0-1) |
|
||||||
|
| `chunk_size` | 16384 | Estimation chunk 大小 |
|
||||||
|
|
||||||
|
### 2.3 计算流程
|
||||||
|
|
||||||
|
1. **Chunked Estimation**:
|
||||||
|
- 将 Q 分成固定大小的 chunks
|
||||||
|
- 使用 Triton kernels 计算 QK^T(fused GEMM + reshape)
|
||||||
|
- 分块 softmax 并聚合到 block 级别
|
||||||
|
- 根据阈值选择重要 blocks
|
||||||
|
|
||||||
|
2. **Block Sparse Attention**:
|
||||||
|
- 只计算选中 blocks 的注意力
|
||||||
|
- 使用 block sparse kernels 优化
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. COMPASS 源码分析
|
||||||
|
|
||||||
|
### 3.1 核心文件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
COMPASS/compass/src/
|
||||||
|
├── Xattention.py # XAttention 主算法
|
||||||
|
├── kernels.py # Triton kernels
|
||||||
|
├── utils.py # 辅助函数
|
||||||
|
└── block_sparse.py # Block sparse attention
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 Xattention.py 分析
|
||||||
|
|
||||||
|
**核心函数**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def xattn_estimate(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
stride, block_size, threshold, ...
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Phase 1: 估算稀疏注意力模式
|
||||||
|
|
||||||
|
返回:
|
||||||
|
attn_sums: [batch, heads, q_blocks, k_blocks] 重要性分数
|
||||||
|
simple_masks: [batch, heads, q_blocks, k_blocks] 布尔掩码
|
||||||
|
"""
|
||||||
|
# 1. Pad inputs to chunk_size multiples
|
||||||
|
# 2. Reshape with stride
|
||||||
|
# 3. Compute QK^T in chunks (Triton)
|
||||||
|
# 4. Block-wise softmax + aggregation
|
||||||
|
# 5. Threshold-based selection
|
||||||
|
return attn_sums, simple_masks
|
||||||
|
|
||||||
|
|
||||||
|
def Xattention_prefill(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
stride, threshold, ...
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
完整 XAttention prefill
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. xattn_estimate() - 获取 block mask
|
||||||
|
2. block_sparse_attn_func() - 稀疏注意力计算
|
||||||
|
"""
|
||||||
|
attn_sums, simple_masks = xattn_estimate(...)
|
||||||
|
attn_output = block_sparse_attn_func(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
simple_masks, block_size
|
||||||
|
)
|
||||||
|
return attn_output
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.3 kernels.py 分析
|
||||||
|
|
||||||
|
**Triton Kernels**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@triton.jit
|
||||||
|
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, ...):
|
||||||
|
"""
|
||||||
|
Stride-based GEMM with reshape fusion
|
||||||
|
|
||||||
|
关键优化:
|
||||||
|
- Stride 访问模式:每隔 stride 个 token 访问一次
|
||||||
|
- Fused reshape:避免单独的 reshape 操作
|
||||||
|
- Block-level 并行:M×N block tiling
|
||||||
|
"""
|
||||||
|
# Load Q and K with stride
|
||||||
|
for iter in range(STRIDE):
|
||||||
|
q = tl.load(Q_ptrs - iter * stride_qn)
|
||||||
|
k = tl.load(K_ptrs + iter * stride_kn)
|
||||||
|
o += tl.dot(q, k)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def softmax_fuse_block_sum_kernel_causal(In, Out, ...):
|
||||||
|
"""
|
||||||
|
Block-wise softmax with sum aggregation
|
||||||
|
|
||||||
|
关键优化:
|
||||||
|
- Online softmax:避免存储完整注意力矩阵
|
||||||
|
- Block sum:聚合到 block 级别
|
||||||
|
- Causal mask:支持因果注意力
|
||||||
|
"""
|
||||||
|
# Online softmax (m_i, l_i)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
m_i = m_new
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.4 utils.py 分析
|
||||||
|
|
||||||
|
**关键函数**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def find_blocks_chunked(
|
||||||
|
input_tensor, # [batch, heads, chunk_q, block_k]
|
||||||
|
current_index,
|
||||||
|
threshold, # 0-1
|
||||||
|
num_to_choose,
|
||||||
|
decoding,
|
||||||
|
mode,
|
||||||
|
causal
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
基于阈值选择重要 blocks
|
||||||
|
|
||||||
|
返回:
|
||||||
|
boolean mask: [batch, heads, chunk_q, block_k]
|
||||||
|
"""
|
||||||
|
# 1. 计算阈值分数
|
||||||
|
score_threshold = input_tensor.max() * threshold
|
||||||
|
|
||||||
|
# 2. 生成布尔掩码
|
||||||
|
masks = (input_tensor >= score_threshold)
|
||||||
|
|
||||||
|
# 3. 应用因果约束
|
||||||
|
if causal:
|
||||||
|
# 只保留下三角区域
|
||||||
|
...
|
||||||
|
|
||||||
|
return masks
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 集成设计决策
|
||||||
|
|
||||||
|
### 4.1 稀疏策略框架
|
||||||
|
|
||||||
|
nano-vllm 使用 `SparsePolicy` 抽象接口:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy(ABC):
|
||||||
|
"""稀疏注意力策略基类"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_prefill(self) -> bool:
|
||||||
|
"""是否支持 prefill 阶段"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_decode(self) -> bool:
|
||||||
|
"""是否支持 decode 阶段"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_block_selection(self) -> bool:
|
||||||
|
"""是否需要 block selection(用于 KV cache 加载)"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_blocks(self, available_blocks, ctx) -> List[int]:
|
||||||
|
"""选择要加载的 KV blocks"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sparse_prefill_attention(self, q, k, v, layer_id) -> torch.Tensor:
|
||||||
|
"""计算稀疏 prefill 注意力"""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 XAttention 设计决策
|
||||||
|
|
||||||
|
#### 决策 1:Prefill-Only 策略
|
||||||
|
|
||||||
|
```python
|
||||||
|
class XAttentionPolicy(SparsePolicy):
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = False # XAttention 仅用于 prefill
|
||||||
|
requires_block_selection = False # 不影响 KV cache 加载
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
- XAttention 是 prefill 阶段的优化算法
|
||||||
|
- Decode 阶段使用其他策略(如 QUEST)
|
||||||
|
- Block selection 不在 XAttention 范围内
|
||||||
|
|
||||||
|
#### 决策 2:CPU Offload 模式简化
|
||||||
|
|
||||||
|
```python
|
||||||
|
def sparse_prefill_attention(self, q, k, v, layer_id):
|
||||||
|
# 使用 FlashAttention 直接计算
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
return attn_output
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键原因**:
|
||||||
|
|
||||||
|
1. **Chunked Prefill 架构限制**:
|
||||||
|
```
|
||||||
|
Offload 模式: run_layerwise_offload_prefill()
|
||||||
|
└─ 每次只处理一个 chunk (2048 tokens)
|
||||||
|
└─ 完整的 key_states 在 CPU,不在当前调用栈
|
||||||
|
└─ 无法进行完整的 chunked estimation
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Estimation 需要完整上下文**:
|
||||||
|
- XAttention 的 estimation 需要访问完整 key_states
|
||||||
|
- Offload 模式下 keys 分层存储在 CPU
|
||||||
|
- 传递所有 keys 会破坏 offload 的内存优势
|
||||||
|
|
||||||
|
3. **FlashAttention 原生支持 GQA**:
|
||||||
|
- GQA (Grouped Query Attention): num_kv_heads < num_heads
|
||||||
|
- FlashAttention 自动处理 head 展开
|
||||||
|
- 避免手动实现的复杂性
|
||||||
|
|
||||||
|
#### 决策 3:保留 Triton Kernels
|
||||||
|
|
||||||
|
虽然 CPU offload 模式使用 FlashAttention,但仍保留 Triton kernels:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/kernels.py
|
||||||
|
# 保留完整的 Triton 实现,供未来 GPU-only 模式使用
|
||||||
|
|
||||||
|
def softmax_fuse_block_sum(attn_weights_slice, ...):
|
||||||
|
"""Triton softmax + block sum wrapper"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def flat_group_gemm_fuse_reshape(query_states, key_states, ...):
|
||||||
|
"""Triton GEMM + reshape wrapper"""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
- 未来可以支持 GPU-only 模式的完整 XAttention
|
||||||
|
- Triton kernels 已实现,无需删除
|
||||||
|
- 保持代码完整性
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 实现细节
|
||||||
|
|
||||||
|
### 5.1 文件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
nanovllm/kvcache/sparse/
|
||||||
|
├── __init__.py # 策略注册
|
||||||
|
├── policy.py # 基类定义
|
||||||
|
├── full_policy.py # Full attention 策略
|
||||||
|
├── quest.py # Quest 策略
|
||||||
|
├── minference.py # MInference 策略
|
||||||
|
├── xattn.py # XAttention 策略(新增)
|
||||||
|
├── utils.py # 工具函数(新增)
|
||||||
|
└── kernels.py # Triton kernels(新增)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 utils.py 实现
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
Sparse attention utility functions.
|
||||||
|
Copied and adapted from COMPASS/compass/src/utils.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def find_blocks_chunked(
|
||||||
|
input_tensor,
|
||||||
|
current_index,
|
||||||
|
threshold,
|
||||||
|
num_to_choose,
|
||||||
|
decoding: bool,
|
||||||
|
mode: str = "both",
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Select blocks based on threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: [batch, heads, q_blocks, k_blocks] importance scores
|
||||||
|
current_index: Current chunk index
|
||||||
|
threshold: Block selection threshold (0-1)
|
||||||
|
num_to_choose: Number of blocks to choose (if None, use threshold)
|
||||||
|
decoding: Whether in decode mode
|
||||||
|
mode: Selection mode ("prefill", "decoding", "both")
|
||||||
|
causal: Apply causal mask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
boolean mask: [batch, heads, q_blocks, k_blocks]
|
||||||
|
"""
|
||||||
|
batch_size, head_num, chunk_q, block_k = input_tensor.shape
|
||||||
|
|
||||||
|
if num_to_choose is None:
|
||||||
|
# Threshold-based selection
|
||||||
|
score_threshold = input_tensor.max() * threshold
|
||||||
|
masks = (input_tensor >= score_threshold)
|
||||||
|
else:
|
||||||
|
# Top-k selection
|
||||||
|
topk_values, _ = torch.topk(
|
||||||
|
input_tensor.flatten(start_dim=2),
|
||||||
|
k=num_to_choose,
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
score_threshold = topk_values[..., -1:].unsqueeze(-1)
|
||||||
|
masks = (input_tensor >= score_threshold)
|
||||||
|
|
||||||
|
# Causal mask
|
||||||
|
if causal and chunk_q > 1:
|
||||||
|
for q_idx in range(chunk_q):
|
||||||
|
k_start = current_index + q_idx
|
||||||
|
masks[:, :, q_idx, :k_start] = False
|
||||||
|
|
||||||
|
return masks
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 kernels.py 实现
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
Triton kernels for XAttention sparse attention.
|
||||||
|
|
||||||
|
Copied and adapted from COMPASS/compass/src/kernels.py
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- Triton >= 2.1.0
|
||||||
|
- CUDA compute capability SM 80+ (RTX 3090, A100, H100, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def softmax_fuse_block_sum_kernel_causal(
|
||||||
|
In, Out, scale,
|
||||||
|
input_stride_0, input_stride_1, input_stride_2,
|
||||||
|
output_stride_0, output_stride_1, output_stride_2,
|
||||||
|
real_q_len, k_len, chunk_start, chunk_end,
|
||||||
|
segment_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Causal softmax with block sum aggregation.
|
||||||
|
|
||||||
|
Online softmax algorithm:
|
||||||
|
m_i = max(m_i, m_new)
|
||||||
|
l_i = l_i * exp(m_i - m_new) + l_new
|
||||||
|
"""
|
||||||
|
block_id = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
|
batch_id = tl.program_id(2)
|
||||||
|
|
||||||
|
# ... (完整实现见源码)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def flat_group_gemm_fuse_reshape_kernel(
|
||||||
|
Q, K, Out,
|
||||||
|
stride_qz, stride_qh, stride_qn,
|
||||||
|
stride_kz, stride_kh, stride_kn,
|
||||||
|
stride_oz, stride_oh, stride_on,
|
||||||
|
chunk_start, chunk_end,
|
||||||
|
H: tl.constexpr,
|
||||||
|
STRIDE: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
is_causal: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stride-based GEMM with reshape fusion.
|
||||||
|
"""
|
||||||
|
# ... (完整实现见源码)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size,
|
||||||
|
segment_size, chunk_start, chunk_end,
|
||||||
|
real_q_len, scale, is_causal=True):
|
||||||
|
"""Wrapper for Triton softmax-fuse-block-sum kernel."""
|
||||||
|
# ... (完整实现见源码)
|
||||||
|
|
||||||
|
|
||||||
|
def flat_group_gemm_fuse_reshape(query_states, key_states, stride,
|
||||||
|
chunk_start, chunk_end, is_causal=True):
|
||||||
|
"""Wrapper for Triton flat-group-gemm-fuse-reshape kernel."""
|
||||||
|
# ... (完整实现见源码)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.4 xattn.py 实现
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
XAttention sparse attention policy for nano-vllm.
|
||||||
|
|
||||||
|
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
||||||
|
and block sparse attention for efficient long-context inference.
|
||||||
|
|
||||||
|
Reference: COMPASS/compass/src/Xattention.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Optional
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
|
from nanovllm.kvcache.sparse.kernels import (
|
||||||
|
flat_group_gemm_fuse_reshape,
|
||||||
|
softmax_fuse_block_sum,
|
||||||
|
)
|
||||||
|
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
|
||||||
|
|
||||||
|
|
||||||
|
class XAttentionPolicy(SparsePolicy):
|
||||||
|
"""
|
||||||
|
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
||||||
|
|
||||||
|
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = False # XAttention is prefill-only
|
||||||
|
requires_block_selection = False # Only affects attention computation
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stride: int = 8,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
use_triton: bool = True,
|
||||||
|
keep_sink: bool = False,
|
||||||
|
keep_recent: bool = False,
|
||||||
|
norm: float = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize XAttention policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stride: Stride for reorganizing Q/K (default: 8)
|
||||||
|
threshold: Block selection threshold, 0-1 (default: 0.9)
|
||||||
|
chunk_size: Chunk size for estimation (auto if None)
|
||||||
|
use_triton: Use Triton kernels (requires SM 80+)
|
||||||
|
keep_sink: Always keep first block (sink tokens)
|
||||||
|
keep_recent: Always keep recent diagonal blocks
|
||||||
|
norm: Normalization factor for attention scores
|
||||||
|
"""
|
||||||
|
self.stride = stride
|
||||||
|
self.threshold = threshold
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.use_triton = use_triton
|
||||||
|
self.keep_sink = keep_sink
|
||||||
|
self.keep_recent = keep_recent
|
||||||
|
self.norm = norm
|
||||||
|
|
||||||
|
# Check Triton availability
|
||||||
|
if self.use_triton:
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
if props.major < 8:
|
||||||
|
self.use_triton = False
|
||||||
|
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
|
||||||
|
except ImportError:
|
||||||
|
self.use_triton = False
|
||||||
|
print("XAttention: Triton not available. Falling back to PyTorch.")
|
||||||
|
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int],
|
||||||
|
ctx: PolicyContext,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Select blocks for decode phase.
|
||||||
|
|
||||||
|
XAttention is prefill-only, so this method is only used as a fallback.
|
||||||
|
Returns all available blocks by default.
|
||||||
|
"""
|
||||||
|
# XAttention is prefill-only, but we need to implement this abstract method
|
||||||
|
# Since requires_block_selection=False, this won't be called for loading
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
def sparse_prefill_attention(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute XAttention sparse attention for prefill.
|
||||||
|
|
||||||
|
For CPU offload mode, uses FlashAttention directly with native GQA support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Current transformer layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
seq_len = q.shape[0]
|
||||||
|
num_heads = q.shape[1]
|
||||||
|
head_dim = q.shape[2]
|
||||||
|
num_kv_heads = k.shape[1]
|
||||||
|
|
||||||
|
# Use FlashAttention directly for CPU offload mode
|
||||||
|
# FlashAttention supports GQA natively
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback: PyTorch SDPA (supports GQA natively)
|
||||||
|
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask=None,
|
||||||
|
is_causal=True,
|
||||||
|
scale=1.0 / math.sqrt(head_dim)
|
||||||
|
)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset policy state (no state to reset for XAttention)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"XAttentionPolicy("
|
||||||
|
f"stride={self.stride}, "
|
||||||
|
f"threshold={self.threshold}, "
|
||||||
|
f"use_triton={self.use_triton})")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.5 框架集成
|
||||||
|
|
||||||
|
**config.py - 添加配置参数**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicyType(Enum):
|
||||||
|
"""Sparse attention policy types."""
|
||||||
|
FULL = auto()
|
||||||
|
QUEST = auto()
|
||||||
|
MINFERENCE = auto()
|
||||||
|
XATTN = auto() # 新增
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
# ... 其他配置
|
||||||
|
|
||||||
|
# XAttention configuration
|
||||||
|
xattn_stride: int = 8
|
||||||
|
xattn_threshold: float = 0.9
|
||||||
|
xattn_chunk_size: int = 16384
|
||||||
|
xattn_use_triton: bool = True
|
||||||
|
xattn_keep_sink: bool = False
|
||||||
|
xattn_keep_recent: bool = False
|
||||||
|
xattn_norm: float = 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
**__init__.py - 注册策略**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||||
|
if policy_type == SparsePolicyType.XATTN:
|
||||||
|
return XAttentionPolicy(
|
||||||
|
stride=kwargs.get("stride", 8),
|
||||||
|
threshold=kwargs.get("threshold", 0.9),
|
||||||
|
chunk_size=kwargs.get("chunk_size", 16384),
|
||||||
|
use_triton=kwargs.get("use_triton", True),
|
||||||
|
keep_sink=kwargs.get("keep_sink", False),
|
||||||
|
keep_recent=kwargs.get("keep_recent", False),
|
||||||
|
norm=kwargs.get("norm", 1.0),
|
||||||
|
)
|
||||||
|
# ... 其他策略
|
||||||
|
```
|
||||||
|
|
||||||
|
**model_runner.py - 使用策略**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 SparsePolicy 初始化时自动选择
|
||||||
|
if self.config.sparse_policy == SparsePolicyType.XATTN:
|
||||||
|
self.sparse_prefill_policy = XAttentionPolicy(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 问题与解决方案
|
||||||
|
|
||||||
|
### 6.1 问题 1: Abstract Method Not Implemented
|
||||||
|
|
||||||
|
**错误**:
|
||||||
|
```python
|
||||||
|
TypeError: Can't instantiate abstract class XAttentionPolicy
|
||||||
|
with abstract method select_blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
- `SparsePolicy` 是抽象基类,要求子类实现 `select_blocks()`
|
||||||
|
- XAttention 是 prefill-only 策略,不需要 block selection
|
||||||
|
|
||||||
|
**解决**:
|
||||||
|
```python
|
||||||
|
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
|
||||||
|
"""
|
||||||
|
Select blocks for decode phase.
|
||||||
|
|
||||||
|
XAttention is prefill-only, so this method is only used as a fallback.
|
||||||
|
Returns all available blocks by default.
|
||||||
|
"""
|
||||||
|
# Since requires_block_selection=False, this won't be called for loading
|
||||||
|
return available_blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 问题 2: CUDA OOM During Estimation
|
||||||
|
|
||||||
|
**错误**:
|
||||||
|
```
|
||||||
|
CUDA out of memory. Tried to allocate 1013.92 GiB
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
- `_xattn_estimate()` 使用 `q_len` 计算 `k_block_num`
|
||||||
|
- 但在 chunked prefill 中,`q_len` 是当前 chunk 大小(2048)
|
||||||
|
- 而不是完整上下文长度(32768)
|
||||||
|
- 导致 padding 计算错误
|
||||||
|
|
||||||
|
**原始代码问题**:
|
||||||
|
```python
|
||||||
|
batch_size, num_heads, k_len, head_dim = key_states.shape
|
||||||
|
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||||
|
|
||||||
|
# 错误:使用 q_len 计算 k_block_num
|
||||||
|
k_block_num = (k_len + k_num_to_pad) // block_size # 应该用完整 k_len
|
||||||
|
```
|
||||||
|
|
||||||
|
**解决**:
|
||||||
|
简化实现,直接使用 FlashAttention:
|
||||||
|
```python
|
||||||
|
def sparse_prefill_attention(self, q, k, v, layer_id):
|
||||||
|
# 使用 FlashAttention 直接计算
|
||||||
|
# 不进行 chunked estimation(与 offload 架构不兼容)
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.3 问题 3: GQA Head Count Mismatch
|
||||||
|
|
||||||
|
**错误**:
|
||||||
|
```
|
||||||
|
ValueError: Number of heads in key/value must divide number of heads in query
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
- Llama-3.1-8B 使用 GQA:num_heads=32, num_kv_heads=8
|
||||||
|
- 原始 XAttention 代码手动展开 KV heads:
|
||||||
|
```python
|
||||||
|
# 错误方式
|
||||||
|
if num_kv_heads != num_heads:
|
||||||
|
key_states = key_states.repeat_interleave(num_heads // num_kv_heads, dim=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**解决**:
|
||||||
|
依赖 FlashAttention 的原生 GQA 支持:
|
||||||
|
```python
|
||||||
|
# FlashAttention 自动处理 GQA,无需手动展开
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v, # k, v 可以有更少的 heads
|
||||||
|
...
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.4 Bug Fix: kernels.py Line 106
|
||||||
|
|
||||||
|
**原始代码**:
|
||||||
|
```python
|
||||||
|
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||||
|
X = torch.zeros([segment_size // block_size], dtype=torch.float32) # 错误
|
||||||
|
```
|
||||||
|
|
||||||
|
**修复**:
|
||||||
|
```python
|
||||||
|
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||||
|
X = tl.zeros([segment_size // block_size], dtype=torch.float32) # 正确
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
- Triton JIT kernel 中必须使用 `tl.zeros` 而不是 `torch.zeros`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 测试验证
|
||||||
|
|
||||||
|
### 7.1 测试环境
|
||||||
|
|
||||||
|
- **模型**: Llama-3.1-8B-Instruct
|
||||||
|
- **GPU**: RTX 3090 (24GB)
|
||||||
|
- **数据集**: RULER 32k benchmark
|
||||||
|
- **模式**: CPU offload enabled
|
||||||
|
|
||||||
|
### 7.2 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# NIAH 任务测试
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN \
|
||||||
|
--num-samples 3 \
|
||||||
|
--datasets niah_single_1,niah_multikey_1,niah_multiquery,niah_multivalue \
|
||||||
|
--max-model-len 32896
|
||||||
|
|
||||||
|
# QA/Recall 任务测试(并行运行)
|
||||||
|
CUDA_VISIBLE_DEVICES=5 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN \
|
||||||
|
--num-samples 3 \
|
||||||
|
--datasets qa_1,qa_2,vt,cwe,fwe \
|
||||||
|
--max-model-len 32896
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.3 测试结果
|
||||||
|
|
||||||
|
#### GPU 4 - NIAH 任务
|
||||||
|
|
||||||
|
| 任务 | 通过/总数 | 准确率 | 平均分 |
|
||||||
|
|------|----------|--------|--------|
|
||||||
|
| niah_single_1 | 3/3 | 100.0% | 1.000 |
|
||||||
|
| niah_multikey_1 | 3/3 | 100.0% | 1.000 |
|
||||||
|
| niah_multiquery | 3/3 | 100.0% | 1.000 |
|
||||||
|
| niah_multivalue | 3/3 | 100.0% | 1.000 |
|
||||||
|
| **NIAH 总计** | **12/12** | **100.0%** | **1.000** |
|
||||||
|
|
||||||
|
#### GPU 5 - QA/Recall 任务
|
||||||
|
|
||||||
|
| 任务 | 通过/总数 | 准确率 | 平均分 |
|
||||||
|
|------|----------|--------|--------|
|
||||||
|
| qa_1 | 2/3 | 66.7% | 0.667 |
|
||||||
|
| qa_2 | 1/3 | 33.3% | 0.333 |
|
||||||
|
| vt | 3/3 | 100.0% | 0.867 |
|
||||||
|
| cwe | 2/3 | 66.7% | 0.467 |
|
||||||
|
| fwe | 3/3 | 100.0% | 0.889 |
|
||||||
|
| **QA/Recall 总计** | **11/15** | **73.3%** | **0.644** |
|
||||||
|
|
||||||
|
#### 总体结果
|
||||||
|
|
||||||
|
- **总计**: 23/27 样本通过 (85.2% 准确率)
|
||||||
|
- **耗时**: GPU 4 (74.9s), GPU 5 (425.1s)
|
||||||
|
- **结论**: XAttention 集成成功,test_ruler.py 全部通过 ✅
|
||||||
|
|
||||||
|
### 7.4 内存使用
|
||||||
|
|
||||||
|
```
|
||||||
|
OffloadEngine initialized: GPU=650.0MB, CPU=4224.0MB
|
||||||
|
Ring buffer GPU cache: 522.0 MB (4 buffers × 33408 tokens)
|
||||||
|
CPU cache: 4224.0 MB (32 layers × 33 blocks)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 使用指南
|
||||||
|
|
||||||
|
### 8.1 基本用法
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model_path="/path/to/model",
|
||||||
|
enable_cpu_offload=True,
|
||||||
|
sparse_policy=SparsePolicyType.XATTN,
|
||||||
|
xattn_threshold=0.9,
|
||||||
|
xattn_stride=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.1, max_tokens=128)
|
||||||
|
outputs = llm.generate(["Your prompt here"], sampling_params)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.2 命令行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# RULER benchmark
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN \
|
||||||
|
--max-model-len 32896
|
||||||
|
|
||||||
|
# 单个样本测试
|
||||||
|
python tests/test_needle.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.3 配置参数
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `sparse_policy` | `FULL` | 稀疏策略类型 (FULL, QUEST, MINFERENCE, XATTN) |
|
||||||
|
| `xattn_threshold` | 0.9 | Block 选择阈值 (0-1) |
|
||||||
|
| `xattn_stride` | 8 | Q/K 重组步长 |
|
||||||
|
| `xattn_chunk_size` | 16384 | Estimation chunk 大小 |
|
||||||
|
| `xattn_use_triton` | True | 是否使用 Triton kernels |
|
||||||
|
|
||||||
|
### 8.4 与其他策略对比
|
||||||
|
|
||||||
|
| 策略 | 阶段 | 用途 | 优势 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| FULL | prefill + decode | 基线 | 准确率最高 |
|
||||||
|
| QUEST | decode only | Top-K block selection | 适合 decode 优化 |
|
||||||
|
| MINFERENCE | prefill | Vertical + Slash pattern | GPU-only 高效 |
|
||||||
|
| XATTN | prefill only | Chunked estimation + block sparse | 长上下文 prefill |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录
|
||||||
|
|
||||||
|
### A. 相关文档
|
||||||
|
|
||||||
|
- [`sparse_attention_guide.md`](sparse_attention_guide.md) - 稀疏注意力方法概述
|
||||||
|
- [`sparse_offload_integration.md`](sparse_offload_integration.md) - 稀疏策略与 offload 集成
|
||||||
|
- [`block_sparse_attention_lib.md`](block_sparse_attention_lib.md) - Block-Sparse-Attention 库参考
|
||||||
|
|
||||||
|
### B. Git 历史
|
||||||
|
|
||||||
|
- `ac1ccbc` - feat: add XAttention sparse policy integration
|
||||||
|
- `57f4e9c` - docs: reorganize documentation files
|
||||||
|
|
||||||
|
### C. 待办事项
|
||||||
|
|
||||||
|
- [ ] GPU-only 模式下的完整 XAttention 实现(使用 Triton kernels)
|
||||||
|
- [ ] 性能基准测试(与 FULL、MINFERENCE 对比)
|
||||||
|
- [ ] 自适应 threshold 调整
|
||||||
|
- [ ] 更多上下文长度测试(64k, 128k)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**作者**: Zijie Tian
|
||||||
|
**日期**: 2026-01-14
|
||||||
|
**版本**: 1.0
|
||||||
288
findings.md
288
findings.md
@@ -1,288 +0,0 @@
|
|||||||
# Findings: nanovllm 多请求状态污染分析
|
|
||||||
|
|
||||||
## 重要说明
|
|
||||||
|
|
||||||
**nanovllm offload 模式不支持 batch**,只能单个 request 顺序执行。问题出在**请求切换**(前一个 request 完成后,开始下一个 request)时状态清理不完整。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. 代码架构发现
|
|
||||||
|
|
||||||
### 1.1 请求生命周期 (顺序执行)
|
|
||||||
|
|
||||||
**关键**: offload 模式下,每次只处理**一个 request**,不是 batch。
|
|
||||||
|
|
||||||
```
|
|
||||||
LLMEngine.generate() [llm_engine.py:114-151]
|
|
||||||
├── Observer.complete_reset() # 重置性能统计
|
|
||||||
├── for prompt in prompts:
|
|
||||||
│ └── add_request(prompt, sp) # 添加到 scheduler 队列
|
|
||||||
├── while not is_finished():
|
|
||||||
│ ├── scheduler.schedule() # 获取下一个序列 (offload 模式: 1个)
|
|
||||||
│ ├── model_runner.call("run", seqs, is_prefill) # 执行单个请求
|
|
||||||
│ └── scheduler.postprocess(seqs, token_ids)
|
|
||||||
│ └── if seq.is_finished:
|
|
||||||
│ └── kvcache_manager.deallocate(seq) # 释放资源 ← 问题点
|
|
||||||
│ └── [开始处理下一个请求] # ← 状态切换
|
|
||||||
└── return outputs
|
|
||||||
```
|
|
||||||
|
|
||||||
**请求切换流程**:
|
|
||||||
```
|
|
||||||
Request A (prefill) → Request A (decode × N) → Request A 完成
|
|
||||||
↓
|
|
||||||
deallocate(A) ← 状态清理不完整!
|
|
||||||
↓
|
|
||||||
Request B (prefill) → Request B 读取到 A 的残留状态 → 错误输出
|
|
||||||
```
|
|
||||||
|
|
||||||
### 1.2 OffloadEngine 状态清单
|
|
||||||
|
|
||||||
**位置**: `nanovllm/kvcache/offload_engine.py:40-145`
|
|
||||||
|
|
||||||
| 成员变量 | 类型 | Shape | 生命周期 |
|
|
||||||
|----------|------|-------|----------|
|
|
||||||
| `layer_k_cache` | GPU Tensor | [num_buffers, max_seq_len, kv_heads, head_dim] | 整个引擎 |
|
|
||||||
| `layer_v_cache` | GPU Tensor | [num_buffers, max_seq_len, kv_heads, head_dim] | 整个引擎 |
|
|
||||||
| `decode_k_buffer` | GPU Tensor | [num_layers, block_size, kv_heads, head_dim] | 整个引擎 |
|
|
||||||
| `decode_v_buffer` | GPU Tensor | [num_layers, block_size, kv_heads, head_dim] | 整个引擎 |
|
|
||||||
| `k_cache_cpu` | CPU Tensor (pinned) | [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] | 整个引擎 |
|
|
||||||
| `v_cache_cpu` | CPU Tensor (pinned) | [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] | 整个引擎 |
|
|
||||||
| `compute_stream` | CUDA Stream | - | 整个引擎 |
|
|
||||||
| `prefill_offload_streams` | List[CUDA Stream] | num_layers | 整个引擎 |
|
|
||||||
| `prefill_offload_events` | List[CUDA Event] | num_layers | 整个引擎 |
|
|
||||||
| `layer_load_streams` | List[CUDA Stream] | num_buffers | 整个引擎 |
|
|
||||||
| `buffer_load_events` | List[CUDA Event] | num_buffers | 整个引擎 |
|
|
||||||
| `buffer_compute_done_events` | List[CUDA Event] | num_buffers | 整个引擎 |
|
|
||||||
|
|
||||||
**关键发现**:
|
|
||||||
- **没有 reset() 方法**
|
|
||||||
- **没有任何清理逻辑**
|
|
||||||
- 所有 tensor 在初始化时 `torch.zeros()` 后永不清零
|
|
||||||
|
|
||||||
### 1.3 HybridKVCacheManager 状态清单
|
|
||||||
|
|
||||||
**位置**: `nanovllm/kvcache/hybrid_manager.py`
|
|
||||||
|
|
||||||
| 成员变量 | 作用 | 清理方式 |
|
|
||||||
|----------|------|----------|
|
|
||||||
| `logical_blocks` | 逻辑块列表 | `block.reset()` in deallocate |
|
|
||||||
| `free_logical_ids` | 空闲逻辑块队列 | deallocate 归还 |
|
|
||||||
| `free_cpu_blocks` | 空闲 CPU 块队列 | deallocate 归还 |
|
|
||||||
| `cpu_block_to_logical` | CPU 块→逻辑块映射 | deallocate 删除 |
|
|
||||||
| `prefilled_blocks` | 已 prefill 的块集合 | deallocate 中 discard |
|
|
||||||
| `_decode_start_pos` | 序列→decode起始位置 | `clear_decode_tracking()` |
|
|
||||||
| `_prefill_len` | 序列→prefill长度 | `clear_decode_tracking()` |
|
|
||||||
|
|
||||||
**关键发现**:
|
|
||||||
- `deallocate()` 没有调用 `clear_decode_tracking()`!
|
|
||||||
- `_decode_start_pos` 和 `_prefill_len` 使用 `id(seq)` 作为 key
|
|
||||||
- Python 对象 ID 可能在不同请求间重用
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. 请求切换机制分析
|
|
||||||
|
|
||||||
### 2.1 offload 模式的单 request 限制
|
|
||||||
|
|
||||||
代码中明确限制:
|
|
||||||
```python
|
|
||||||
# model_runner.py:757, 880
|
|
||||||
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.2 请求切换时序
|
|
||||||
|
|
||||||
```
|
|
||||||
时间 →
|
|
||||||
┌─────────────────────────────────────────────────────────────────┐
|
|
||||||
│ Request A: [prefill] → [decode] → [decode] → ... → [完成] │
|
|
||||||
└─────────────────────────────────────────────────────────────────┘
|
|
||||||
↓
|
|
||||||
deallocate(seq_A)
|
|
||||||
- blocks 释放 ✓
|
|
||||||
- tracking 字典未清理 ✗
|
|
||||||
↓
|
|
||||||
┌─────────────────────────────────────────────────────────────────┐
|
|
||||||
│ Request B: [prefill] → [decode] → ... │
|
|
||||||
│ ↑ │
|
|
||||||
│ 如果 id(seq_B) == id(seq_A),读到 A 的残留状态! │
|
|
||||||
└─────────────────────────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.3 Python 对象 ID 重用
|
|
||||||
|
|
||||||
Python 的内存管理会重用已释放对象的内存地址,导致:
|
|
||||||
```python
|
|
||||||
seq_A = Sequence(...) # id(seq_A) = 0x7f1234567890
|
|
||||||
del seq_A # 对象被释放,但字典中 key 保留
|
|
||||||
|
|
||||||
seq_B = Sequence(...) # id(seq_B) 可能 = 0x7f1234567890(相同地址)
|
|
||||||
# _decode_start_pos[id(seq_B)] 返回 seq_A 的旧值!
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. 状态污染机制分析
|
|
||||||
|
|
||||||
### 3.1 decode buffer 污染路径
|
|
||||||
|
|
||||||
**污染写入** (`run_layerwise_offload_decode:1010-1013`):
|
|
||||||
```python
|
|
||||||
# 每次 decode step,将当前 token 的 KV 存入 decode buffer
|
|
||||||
offload_engine.decode_k_buffer[layer_id, pos_in_block].copy_(ring_k[context_len])
|
|
||||||
offload_engine.decode_v_buffer[layer_id, pos_in_block].copy_(ring_v[context_len])
|
|
||||||
```
|
|
||||||
|
|
||||||
**污染读取** (`run_layerwise_offload_decode:969-976`):
|
|
||||||
```python
|
|
||||||
# 如果有之前的 decode tokens,从 decode buffer 读取
|
|
||||||
if num_prev_decode_tokens > 0:
|
|
||||||
k_decode_prev, v_decode_prev = offload_engine.get_decode_kv(
|
|
||||||
layer_id, decode_start_pos, pos_in_block
|
|
||||||
)
|
|
||||||
ring_k[total_prefill_tokens:total_prefill_tokens + num_prev_decode_tokens].copy_(k_decode_prev)
|
|
||||||
```
|
|
||||||
|
|
||||||
**问题场景**:
|
|
||||||
1. 请求 A 的 decode 阶段在 `decode_k_buffer[layer, 0:N]` 写入 KV
|
|
||||||
2. 请求 A 完成,buffer 数据保留
|
|
||||||
3. 请求 B 开始,如果其 `decode_start_pos` 被错误计算为非零
|
|
||||||
4. 请求 B 会读取请求 A 的旧数据
|
|
||||||
|
|
||||||
### 3.2 decode_start_pos 计算逻辑
|
|
||||||
|
|
||||||
**位置**: `hybrid_manager.py:485-505`
|
|
||||||
|
|
||||||
```python
|
|
||||||
def get_decode_start_pos(self, seq: Sequence) -> int:
|
|
||||||
seq_id = id(seq) # Python 对象 ID
|
|
||||||
if seq_id not in self._decode_start_pos:
|
|
||||||
# 第一次调用 - 计算起始位置
|
|
||||||
prefill_len = len(seq) - 1 # 当前长度减去新 token
|
|
||||||
self._decode_start_pos[seq_id] = prefill_len % self._block_size
|
|
||||||
return self._decode_start_pos[seq_id]
|
|
||||||
```
|
|
||||||
|
|
||||||
**问题**:
|
|
||||||
- 如果新请求的 `id(seq)` 恰好等于旧请求的 `id(seq)`(Python 内存重用)
|
|
||||||
- `_decode_start_pos` 中可能存在旧的值
|
|
||||||
- 会返回错误的 decode 起始位置
|
|
||||||
|
|
||||||
### 3.3 clear_decode_tracking 未被调用
|
|
||||||
|
|
||||||
**位置**: `hybrid_manager.py:538-549`
|
|
||||||
|
|
||||||
```python
|
|
||||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
|
||||||
seq_id = id(seq)
|
|
||||||
self._decode_start_pos.pop(seq_id, None)
|
|
||||||
self._prefill_len.pop(seq_id, None)
|
|
||||||
```
|
|
||||||
|
|
||||||
**问题**:
|
|
||||||
- 这个方法在 `deallocate()` 中**没有被调用**!
|
|
||||||
- 查看 `deallocate()` (218-244 行),没有 `clear_decode_tracking()` 调用
|
|
||||||
- 这导致旧请求的 tracking 数据残留
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. 失败模式分析
|
|
||||||
|
|
||||||
### 3.1 观察到的失败模式
|
|
||||||
|
|
||||||
从测试结果:
|
|
||||||
| Sample | Expected | Output | Status |
|
|
||||||
|--------|----------|--------|--------|
|
|
||||||
| 0 | 8930103 | `: 8930103.` | PASS (第一个请求) |
|
|
||||||
| 1 | 4194548 | `: 419 multiplication of 4548.` | **FAIL** |
|
|
||||||
| 2 | 8231838 | `:ное 8231838.` | PASS |
|
|
||||||
|
|
||||||
Sample 1 的输出 "419 multiplication of 4548" 显示数字被"拆分"了。
|
|
||||||
|
|
||||||
**可能原因**:
|
|
||||||
1. 在某个 decode step,attention 计算使用了错误的 KV
|
|
||||||
2. 模型"看到"了旧请求的部分 context
|
|
||||||
3. 导致生成逻辑出错
|
|
||||||
|
|
||||||
### 3.2 为什么第一个请求总是成功?
|
|
||||||
|
|
||||||
1. 第一个请求时,所有 buffer 都是零初始化
|
|
||||||
2. `decode_start_pos` 字典为空,正确计算
|
|
||||||
3. 没有残留数据干扰
|
|
||||||
|
|
||||||
### 3.3 为什么后续请求可能成功?
|
|
||||||
|
|
||||||
某些请求可能成功因为:
|
|
||||||
1. `id(seq)` 没有与之前的请求冲突
|
|
||||||
2. `pos_in_block` 不重叠,没读到旧数据
|
|
||||||
3. 或者旧数据恰好对结果影响不大
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. 修复方向
|
|
||||||
|
|
||||||
### 4.1 必须修复: deallocate 时清理状态
|
|
||||||
|
|
||||||
```python
|
|
||||||
# hybrid_manager.py: deallocate()
|
|
||||||
def deallocate(self, seq: Sequence) -> None:
|
|
||||||
# ... 现有逻辑 ...
|
|
||||||
|
|
||||||
# 添加: 清理 decode tracking
|
|
||||||
self.clear_decode_tracking(seq)
|
|
||||||
|
|
||||||
# 添加: 通知 offload engine 清理
|
|
||||||
if self.offload_engine is not None:
|
|
||||||
self.offload_engine.on_sequence_finished()
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4.2 必须修复: OffloadEngine 添加清理方法
|
|
||||||
|
|
||||||
```python
|
|
||||||
# offload_engine.py
|
|
||||||
def on_sequence_finished(self):
|
|
||||||
"""请求完成时的清理"""
|
|
||||||
# 清零 decode buffer
|
|
||||||
self.decode_k_buffer.zero_()
|
|
||||||
self.decode_v_buffer.zero_()
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4.3 可选: 更激进的清理
|
|
||||||
|
|
||||||
```python
|
|
||||||
def reset_all(self):
|
|
||||||
"""完全重置状态"""
|
|
||||||
self.decode_k_buffer.zero_()
|
|
||||||
self.decode_v_buffer.zero_()
|
|
||||||
self.layer_k_cache.zero_()
|
|
||||||
self.layer_v_cache.zero_()
|
|
||||||
# 重置 CUDA events
|
|
||||||
for event in self.buffer_compute_done_events:
|
|
||||||
event.record()
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. 待验证假设
|
|
||||||
|
|
||||||
| 假设 | 验证方法 | 优先级 |
|
|
||||||
|------|----------|--------|
|
|
||||||
| decode_buffer 残留导致污染 | 在第二个请求开始时检查 buffer 是否为零 | 高 |
|
|
||||||
| _decode_start_pos 字典残留 | 打印 deallocate 前后的字典内容 | 高 |
|
|
||||||
| id(seq) 重用导致错误 | 打印每个请求的 seq id | 中 |
|
|
||||||
| ring buffer 残留 | 检查每次 decode 前 ring buffer 内容 | 低 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. 参考代码位置
|
|
||||||
|
|
||||||
| 功能 | 文件 | 行号 |
|
|
||||||
|------|------|------|
|
|
||||||
| OffloadEngine 初始化 | offload_engine.py | 40-145 |
|
|
||||||
| deallocate | hybrid_manager.py | 218-244 |
|
|
||||||
| clear_decode_tracking | hybrid_manager.py | 538-549 |
|
|
||||||
| get_decode_start_pos | hybrid_manager.py | 485-505 |
|
|
||||||
| run_layerwise_offload_decode | model_runner.py | 867-1057 |
|
|
||||||
| decode buffer 写入 | model_runner.py | 1010-1013 |
|
|
||||||
| decode buffer 读取 | model_runner.py | 969-976 |
|
|
||||||
@@ -10,6 +10,7 @@ class SparsePolicyType(Enum):
|
|||||||
FULL = auto() # No sparse attention (load all blocks)
|
FULL = auto() # No sparse attention (load all blocks)
|
||||||
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
||||||
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
|
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
|
||||||
|
XATTN = auto() # XAttention chunked estimation + block-sparse attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -53,6 +54,16 @@ class Config:
|
|||||||
minference_num_sink_tokens: int = 30 # Sink tokens to always keep
|
minference_num_sink_tokens: int = 30 # Sink tokens to always keep
|
||||||
minference_num_recent_diags: int = 100 # Recent diagonals to always keep
|
minference_num_recent_diags: int = 100 # Recent diagonals to always keep
|
||||||
|
|
||||||
|
# XAttention configuration (used when sparse_policy == XATTN)
|
||||||
|
xattn_stride: int = 8 # Stride for reorganizing Q/K
|
||||||
|
xattn_threshold: float = 0.9 # Block selection threshold (0-1)
|
||||||
|
xattn_chunk_size: int = 16384 # Chunk size for estimation (auto if None)
|
||||||
|
xattn_use_triton: bool = True # Use Triton kernels (requires SM 80+)
|
||||||
|
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
|
||||||
|
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
|
||||||
|
xattn_norm: float = 1.0 # Normalization factor for attention scores
|
||||||
|
xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert os.path.isdir(self.model)
|
assert os.path.isdir(self.model)
|
||||||
assert self.kvcache_block_size % 256 == 0
|
assert self.kvcache_block_size % 256 == 0
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ class ModelRunner:
|
|||||||
load_model(self.model, config.model)
|
load_model(self.model, config.model)
|
||||||
self.sampler = GreedySampler()
|
self.sampler = GreedySampler()
|
||||||
|
|
||||||
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
|
# Initialize attention_policy before warmup (will be configured in allocate_kv_cache)
|
||||||
self.sparse_prefill_policy = None
|
self.attention_policy = None
|
||||||
|
|
||||||
#> Disable warmup for debugging
|
#> Disable warmup for debugging
|
||||||
self.warmup_model()
|
self.warmup_model()
|
||||||
@@ -178,23 +178,35 @@ class ModelRunner:
|
|||||||
# Create KV cache manager using factory
|
# Create KV cache manager using factory
|
||||||
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
||||||
|
|
||||||
# Create sparse prefill policy for GPU-only path
|
# Create attention policy (always, including FULL)
|
||||||
# This is separate from CPU offload sparse policy (which uses select_blocks)
|
# In layerwise offload mode, all attention goes through the policy
|
||||||
self.sparse_prefill_policy = None
|
from nanovllm.kvcache.sparse import create_attention_policy
|
||||||
if not config.enable_cpu_offload and config.sparse_policy != SparsePolicyType.FULL:
|
|
||||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
# Get policy-specific parameters based on type
|
||||||
policy = create_sparse_policy(
|
if config.sparse_policy == SparsePolicyType.XATTN:
|
||||||
config.sparse_policy,
|
policy_kwargs = {
|
||||||
vertical_size=config.minference_vertical_size,
|
"stride": config.xattn_stride,
|
||||||
slash_size=config.minference_slash_size,
|
"threshold": config.xattn_threshold,
|
||||||
adaptive_budget=config.minference_adaptive_budget,
|
"chunk_size": config.xattn_chunk_size,
|
||||||
num_sink_tokens=config.minference_num_sink_tokens,
|
"use_triton": config.xattn_use_triton,
|
||||||
num_recent_diags=config.minference_num_recent_diags,
|
"keep_sink": config.xattn_keep_sink,
|
||||||
)
|
"keep_recent": config.xattn_keep_recent,
|
||||||
# Only use if policy supports sparse prefill
|
"norm": config.xattn_norm,
|
||||||
if policy.supports_prefill:
|
"use_bsa": config.xattn_use_bsa,
|
||||||
self.sparse_prefill_policy = policy
|
}
|
||||||
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
|
elif config.sparse_policy == SparsePolicyType.MINFERENCE:
|
||||||
|
policy_kwargs = {
|
||||||
|
"vertical_size": config.minference_vertical_size,
|
||||||
|
"slash_size": config.minference_slash_size,
|
||||||
|
"adaptive_budget": config.minference_adaptive_budget,
|
||||||
|
"num_sink_tokens": config.minference_num_sink_tokens,
|
||||||
|
"num_recent_diags": config.minference_num_recent_diags,
|
||||||
|
}
|
||||||
|
else: # FULL or QUEST
|
||||||
|
policy_kwargs = {}
|
||||||
|
|
||||||
|
self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs)
|
||||||
|
logger.info(f"Attention policy: {self.attention_policy}")
|
||||||
|
|
||||||
# Allocate cache through manager
|
# Allocate cache through manager
|
||||||
self.kvcache_manager.allocate_cache(
|
self.kvcache_manager.allocate_cache(
|
||||||
@@ -380,7 +392,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||||
slot_mapping, None, block_tables,
|
slot_mapping, None, block_tables,
|
||||||
sparse_prefill_policy=self.sparse_prefill_policy)
|
attention_policy=self.attention_policy)
|
||||||
return input_ids, positions
|
return input_ids, positions
|
||||||
|
|
||||||
def prepare_decode(self, seqs: list[Sequence]):
|
def prepare_decode(self, seqs: list[Sequence]):
|
||||||
@@ -577,21 +589,11 @@ class ModelRunner:
|
|||||||
# RoPE
|
# RoPE
|
||||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
# Sparse or Full attention (uses k, v directly - before store!)
|
# Compute attention using policy (uses k, v directly - before store!)
|
||||||
if self.sparse_prefill_policy is not None:
|
attn_output = self.attention_policy.compute_prefill(
|
||||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
q, k, v, layer_id,
|
||||||
q, k, v, layer_id
|
softmax_scale=layer.self_attn.attn.scale,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=total_tokens,
|
|
||||||
max_seqlen_k=total_tokens,
|
|
||||||
softmax_scale=layer.self_attn.attn.scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# O projection
|
# O projection
|
||||||
attn_output = attn_output.view(total_tokens, -1)
|
attn_output = attn_output.view(total_tokens, -1)
|
||||||
@@ -786,15 +788,56 @@ class ModelRunner:
|
|||||||
for layer_id in range(num_layers):
|
for layer_id in range(num_layers):
|
||||||
layer = self.model.model.layers[layer_id]
|
layer = self.model.model.layers[layer_id]
|
||||||
|
|
||||||
# 2a. Input LayerNorm
|
# 2a. Input LayerNorm (chunked for long sequences)
|
||||||
if residual is None:
|
# LayerNorm creates float32 temporaries: seq_len * hidden_size * 4 bytes
|
||||||
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
|
# For 64k: 65536 * 4096 * 4 = ~1 GB per operation
|
||||||
|
# Using chunk_size=4096 reduces peak to ~125 MB
|
||||||
|
layernorm_chunk_size = 128
|
||||||
|
if total_tokens > layernorm_chunk_size:
|
||||||
|
if residual is None:
|
||||||
|
# Chunked input_layernorm
|
||||||
|
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
|
||||||
|
ln_chunks = []
|
||||||
|
res_chunks = []
|
||||||
|
for chunk in hs_chunks:
|
||||||
|
ln, res = layer.input_layernorm(chunk), chunk
|
||||||
|
ln_chunks.append(ln)
|
||||||
|
res_chunks.append(res)
|
||||||
|
hidden_ln = torch.cat(ln_chunks, dim=0)
|
||||||
|
residual = torch.cat(res_chunks, dim=0)
|
||||||
|
else:
|
||||||
|
# Chunked input_layernorm with residual
|
||||||
|
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
|
||||||
|
res_chunks_in = residual.split(layernorm_chunk_size, dim=0)
|
||||||
|
ln_chunks = []
|
||||||
|
res_chunks_out = []
|
||||||
|
for hs_chunk, res_chunk in zip(hs_chunks, res_chunks_in):
|
||||||
|
ln, res = layer.input_layernorm(hs_chunk, res_chunk)
|
||||||
|
ln_chunks.append(ln)
|
||||||
|
res_chunks_out.append(res)
|
||||||
|
hidden_ln = torch.cat(ln_chunks, dim=0)
|
||||||
|
residual = torch.cat(res_chunks_out, dim=0)
|
||||||
else:
|
else:
|
||||||
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
|
if residual is None:
|
||||||
|
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
|
||||||
|
else:
|
||||||
|
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
# 2b. Self-attention (full sequence)
|
# 2b. Self-attention (full sequence)
|
||||||
# QKV projection
|
# Chunked QKV projection to reduce activation memory for long sequences
|
||||||
qkv = layer.self_attn.qkv_proj(hidden_ln)
|
# QKV activation = seq_len * (q_size + 2*kv_size) * 2 bytes
|
||||||
|
# For 64k: 65536 * (4096 + 2*1024) * 2 = ~805 MB
|
||||||
|
# Using chunk_size=2048 reduces peak to ~25 MB
|
||||||
|
qkv_chunk_size = 128
|
||||||
|
if total_tokens > qkv_chunk_size:
|
||||||
|
chunks = hidden_ln.split(qkv_chunk_size, dim=0)
|
||||||
|
qkv_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
qkv_chunks.append(layer.self_attn.qkv_proj(chunk))
|
||||||
|
qkv = torch.cat(qkv_chunks, dim=0)
|
||||||
|
else:
|
||||||
|
qkv = layer.self_attn.qkv_proj(hidden_ln)
|
||||||
|
|
||||||
q, k, v = qkv.split([
|
q, k, v = qkv.split([
|
||||||
layer.self_attn.q_size,
|
layer.self_attn.q_size,
|
||||||
layer.self_attn.kv_size,
|
layer.self_attn.kv_size,
|
||||||
@@ -816,31 +859,50 @@ class ModelRunner:
|
|||||||
# RoPE
|
# RoPE
|
||||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
# Sparse or Full attention
|
# Compute attention using policy
|
||||||
if self.sparse_prefill_policy is not None:
|
attn_output = self.attention_policy.compute_prefill(
|
||||||
# MInference or other sparse prefill policy
|
q, k, v, layer_id,
|
||||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
softmax_scale=layer.self_attn.attn.scale,
|
||||||
q, k, v, layer_id
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Full attention using FlashAttention
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=total_tokens,
|
|
||||||
max_seqlen_k=total_tokens,
|
|
||||||
softmax_scale=layer.self_attn.attn.scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# O projection
|
# O projection
|
||||||
attn_output = attn_output.view(total_tokens, -1)
|
attn_output = attn_output.view(total_tokens, -1)
|
||||||
hidden_states = layer.self_attn.o_proj(attn_output)
|
hidden_states = layer.self_attn.o_proj(attn_output)
|
||||||
|
|
||||||
# 2c. Post-attention LayerNorm + MLP
|
# 2c. Post-attention LayerNorm (chunked for long sequences)
|
||||||
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
|
layernorm_chunk_size = 128
|
||||||
hidden_states = layer.mlp(hidden_states)
|
if total_tokens > layernorm_chunk_size:
|
||||||
|
# Chunked post_attention_layernorm
|
||||||
|
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
|
||||||
|
res_chunks_in = residual.split(layernorm_chunk_size, dim=0)
|
||||||
|
ln_chunks = []
|
||||||
|
res_chunks_out = []
|
||||||
|
for hs_chunk, res_chunk in zip(hs_chunks, res_chunks_in):
|
||||||
|
ln, res = layer.post_attention_layernorm(hs_chunk, res_chunk)
|
||||||
|
ln_chunks.append(ln)
|
||||||
|
res_chunks_out.append(res)
|
||||||
|
hidden_states = torch.cat(ln_chunks, dim=0)
|
||||||
|
residual = torch.cat(res_chunks_out, dim=0)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Chunked MLP processing to reduce activation memory for long sequences
|
||||||
|
# MLP activation = seq_len * intermediate_size * 2 bytes
|
||||||
|
# For 64k: 65536 * 14336 * 2 = ~1.75 GB (down_proj input)
|
||||||
|
# Using chunk_size=2048 reduces peak to ~55 MB
|
||||||
|
mlp_chunk_size = 128
|
||||||
|
if total_tokens > mlp_chunk_size:
|
||||||
|
chunks = hidden_states.split(mlp_chunk_size, dim=0)
|
||||||
|
outputs = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
outputs.append(layer.mlp(chunk))
|
||||||
|
del chunk
|
||||||
|
torch.cuda.empty_cache() # Clean after every chunk
|
||||||
|
hidden_states = torch.cat(outputs, dim=0)
|
||||||
|
del outputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
hidden_states = layer.mlp(hidden_states)
|
||||||
|
|
||||||
# 2d. Offload KV to CPU (encapsulated with sparse policy hooks)
|
# 2d. Offload KV to CPU (encapsulated with sparse policy hooks)
|
||||||
offload_engine.offload_layer_kv_sync(layer_id, k, v, cpu_block_ids, total_tokens)
|
offload_engine.offload_layer_kv_sync(layer_id, k, v, cpu_block_ids, total_tokens)
|
||||||
|
|||||||
@@ -1,48 +1,56 @@
|
|||||||
"""
|
"""
|
||||||
Sparse Attention Policy module.
|
Attention Policy module for layerwise offload mode.
|
||||||
|
|
||||||
Provides pluggable policies for selecting which KV blocks to load
|
Provides pluggable policies for attention computation:
|
||||||
during chunked attention with CPU offload.
|
- FullAttentionPolicy: Standard FlashAttention (no sparsity)
|
||||||
|
- XAttentionPolicy: Sparse prefill using XAttention algorithm
|
||||||
|
- MInferencePolicy: MInference sparse attention
|
||||||
|
- QuestPolicy: Quest block selection (for chunked offload)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
|
||||||
|
|
||||||
# Create policy using factory function
|
# Create policy using factory function
|
||||||
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
|
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
||||||
|
|
||||||
|
# Use policy for attention
|
||||||
|
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
||||||
|
|
||||||
# Or create custom policy
|
# Or create custom policy
|
||||||
class MyPolicy(SparsePolicy):
|
class MyPolicy(AttentionPolicy):
|
||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
def select_blocks(self, available_blocks, ctx):
|
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||||
return available_blocks[:5] # Just first 5 blocks
|
# Custom attention computation
|
||||||
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from nanovllm.config import SparsePolicyType
|
from nanovllm.config import SparsePolicyType
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
||||||
|
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
|
||||||
|
|
||||||
|
|
||||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
|
||||||
"""
|
"""
|
||||||
Create a sparse policy instance from an enum type.
|
Create an attention policy instance from an enum type.
|
||||||
|
|
||||||
The returned policy is not yet initialized. Call policy.initialize()
|
All attention (including full attention) goes through a policy in layerwise
|
||||||
or let the framework call it during KV cache allocation.
|
offload mode. The policy is responsible for computing prefill/decode attention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_type: SparsePolicyType enum value
|
policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
|
||||||
**kwargs: Policy-specific configuration options
|
**kwargs: Policy-specific configuration options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SparsePolicy instance (not initialized)
|
AttentionPolicy instance
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
|
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
||||||
policy.initialize(num_layers=28, num_kv_heads=8, ...)
|
attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
||||||
"""
|
"""
|
||||||
if policy_type == SparsePolicyType.FULL:
|
if policy_type == SparsePolicyType.FULL:
|
||||||
return FullAttentionPolicy()
|
return FullAttentionPolicy()
|
||||||
@@ -65,18 +73,41 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
|
|||||||
num_recent_diags=kwargs.get("num_recent_diags", 100),
|
num_recent_diags=kwargs.get("num_recent_diags", 100),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif policy_type == SparsePolicyType.XATTN:
|
||||||
|
return XAttentionPolicy(
|
||||||
|
stride=kwargs.get("stride", 8),
|
||||||
|
threshold=kwargs.get("threshold", 0.9),
|
||||||
|
chunk_size=kwargs.get("chunk_size", 16384),
|
||||||
|
use_triton=kwargs.get("use_triton", True),
|
||||||
|
keep_sink=kwargs.get("keep_sink", False),
|
||||||
|
keep_recent=kwargs.get("keep_recent", False),
|
||||||
|
norm=kwargs.get("norm", 1.0),
|
||||||
|
use_bsa=kwargs.get("use_bsa", True),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown policy type: {policy_type}")
|
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias
|
||||||
|
create_sparse_policy = create_attention_policy
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# New interface
|
||||||
|
"AttentionPolicy",
|
||||||
|
"create_attention_policy",
|
||||||
|
# Backward compatibility
|
||||||
"SparsePolicy",
|
"SparsePolicy",
|
||||||
|
"create_sparse_policy",
|
||||||
|
# Common types
|
||||||
"PolicyContext",
|
"PolicyContext",
|
||||||
"SparsePolicyType",
|
"SparsePolicyType",
|
||||||
|
# Policy implementations
|
||||||
"FullAttentionPolicy",
|
"FullAttentionPolicy",
|
||||||
"QuestPolicy",
|
"QuestPolicy",
|
||||||
"QuestConfig",
|
"QuestConfig",
|
||||||
"BlockMetadataManager",
|
"BlockMetadataManager",
|
||||||
"MInferencePolicy",
|
"MInferencePolicy",
|
||||||
"create_sparse_policy",
|
"XAttentionPolicy",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,20 +1,21 @@
|
|||||||
"""
|
"""
|
||||||
Full attention policy - loads all blocks (no sparsity).
|
Full attention policy - standard FlashAttention without sparsity.
|
||||||
|
|
||||||
This serves as a baseline and default policy when sparse
|
This serves as a baseline and default policy when sparse
|
||||||
attention is not needed.
|
attention is not needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import Optional
|
||||||
from .policy import SparsePolicy, PolicyContext
|
import torch
|
||||||
|
from .policy import AttentionPolicy
|
||||||
|
|
||||||
|
|
||||||
class FullAttentionPolicy(SparsePolicy):
|
class FullAttentionPolicy(AttentionPolicy):
|
||||||
"""
|
"""
|
||||||
Full attention policy that loads all available blocks.
|
Full attention policy using FlashAttention (no sparsity).
|
||||||
|
|
||||||
This is the default behavior with no sparsity - all previous
|
This is the default behavior with standard causal attention.
|
||||||
KV cache blocks are loaded for each query chunk.
|
All tokens attend to all previous tokens.
|
||||||
|
|
||||||
Use this as:
|
Use this as:
|
||||||
- A baseline for comparing sparse policies
|
- A baseline for comparing sparse policies
|
||||||
@@ -25,15 +26,55 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
# Full attention supports both prefill and decode
|
# Full attention supports both prefill and decode
|
||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
requires_block_selection = False # Load all blocks, no selective loading
|
|
||||||
|
|
||||||
def select_blocks(
|
def estimate(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
q: torch.Tensor,
|
||||||
ctx: PolicyContext,
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
layer_id: int,
|
||||||
"""Return all blocks - no sparsity."""
|
) -> Optional[torch.Tensor]:
|
||||||
return available_blocks
|
"""
|
||||||
|
Full attention - no sparse mask needed.
|
||||||
|
|
||||||
|
Returns None to indicate full attention should be used.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute full causal attention using FlashAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
seq_len = q.shape[0]
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "FullAttentionPolicy()"
|
return "FullAttentionPolicy()"
|
||||||
|
|||||||
320
nanovllm/kvcache/sparse/kernels.py
Normal file
320
nanovllm/kvcache/sparse/kernels.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
"""
|
||||||
|
Triton kernels for XAttention sparse attention.
|
||||||
|
|
||||||
|
Copied and adapted from COMPASS/compass/src/kernels.py
|
||||||
|
for XAttention integration in nano-vllm.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- Triton >= 2.1.0
|
||||||
|
- CUDA compute capability SM 80+ (RTX 3090, A100, H100, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def softmax_fuse_block_sum_kernel_causal(
|
||||||
|
In,
|
||||||
|
Out,
|
||||||
|
scale,
|
||||||
|
input_stride_0,
|
||||||
|
input_stride_1,
|
||||||
|
input_stride_2,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
output_stride_2,
|
||||||
|
real_q_len,
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
chunk_end,
|
||||||
|
segment_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
):
|
||||||
|
block_id = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
|
batch_id = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||||
|
offs_k = tl.arange(0, segment_size)
|
||||||
|
|
||||||
|
num_iters = k_len // segment_size
|
||||||
|
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size
|
||||||
|
|
||||||
|
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
||||||
|
|
||||||
|
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||||
|
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||||
|
|
||||||
|
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||||
|
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||||
|
|
||||||
|
for iter in range(0, num_iters_before_causal):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
m_local = tl.max(X, 1)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
|
||||||
|
X = X - m_new[:, None]
|
||||||
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
||||||
|
X = tl.where(mask, X, -1.0e6)
|
||||||
|
m_local = tl.max(X, 1)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
|
||||||
|
X = X - m_new[:, None]
|
||||||
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
l_i_inv = 1.0 / l_i
|
||||||
|
|
||||||
|
sum_mask = offs_q[:, None] < real_q_len
|
||||||
|
|
||||||
|
for iter in range(0, num_iters_before_causal):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||||
|
X = tl.where(sum_mask, X, 0)
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
X = tl.sum(X, 2)
|
||||||
|
X = tl.sum(X, 0)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
||||||
|
X = tl.where(mask, X, -1.0e6)
|
||||||
|
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||||
|
X = tl.where(sum_mask, X, 0)
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
X = tl.sum(X, 2)
|
||||||
|
X = tl.sum(X, 0)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||||
|
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def softmax_fuse_block_sum_kernel_non_causal(
|
||||||
|
In,
|
||||||
|
Out,
|
||||||
|
scale,
|
||||||
|
input_stride_0,
|
||||||
|
input_stride_1,
|
||||||
|
input_stride_2,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
output_stride_2,
|
||||||
|
real_q_len,
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
chunk_end,
|
||||||
|
segment_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
):
|
||||||
|
block_id = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
|
batch_id = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||||
|
offs_k = tl.arange(0, segment_size)
|
||||||
|
|
||||||
|
num_iters = k_len // segment_size
|
||||||
|
|
||||||
|
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
||||||
|
|
||||||
|
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||||
|
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||||
|
|
||||||
|
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||||
|
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||||
|
|
||||||
|
for iter in range(0, num_iters):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
m_local = tl.max(X, 1)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
|
||||||
|
X = X - m_new[:, None]
|
||||||
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
l_i_inv = 1.0 / l_i
|
||||||
|
|
||||||
|
sum_mask = offs_q[:, None] < real_q_len
|
||||||
|
|
||||||
|
for iter in range(0, num_iters):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||||
|
X = tl.where(sum_mask, X, 0)
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
X = tl.sum(X, 2)
|
||||||
|
X = tl.sum(X, 0)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out,
|
||||||
|
stride_qz, stride_qh, stride_qn,
|
||||||
|
stride_kz, stride_kh, stride_kn,
|
||||||
|
stride_oz, stride_oh, stride_on,
|
||||||
|
chunk_start, chunk_end,
|
||||||
|
H: tl.constexpr,
|
||||||
|
STRIDE: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
is_causal: tl.constexpr,
|
||||||
|
):
|
||||||
|
block_m = tl.program_id(0).to(tl.int64)
|
||||||
|
block_n = tl.program_id(1).to(tl.int64)
|
||||||
|
batch_id = tl.program_id(2).to(tl.int64) // H
|
||||||
|
head_id = tl.program_id(2).to(tl.int64) % H
|
||||||
|
|
||||||
|
if is_causal:
|
||||||
|
if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N:
|
||||||
|
return
|
||||||
|
|
||||||
|
Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn
|
||||||
|
K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn
|
||||||
|
|
||||||
|
Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1)
|
||||||
|
K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None]
|
||||||
|
|
||||||
|
o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
|
||||||
|
for iter in range(STRIDE):
|
||||||
|
q = tl.load(Q_ptrs - iter * stride_qn)
|
||||||
|
k = tl.load(K_ptrs + iter * stride_kn)
|
||||||
|
o += tl.dot(q, k)
|
||||||
|
|
||||||
|
O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N
|
||||||
|
O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :]
|
||||||
|
|
||||||
|
tl.store(O_ptrs, o.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True):
|
||||||
|
"""Wrapper for Triton softmax-fuse-block-sum kernel."""
|
||||||
|
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||||
|
assert q_len % reshaped_block_size == 0
|
||||||
|
assert k_len % segment_size == 0
|
||||||
|
assert segment_size % reshaped_block_size == 0
|
||||||
|
assert attn_weights_slice.stride(-1) == 1
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
||||||
|
dtype=attn_weights_slice.dtype,
|
||||||
|
device=attn_weights_slice.device
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||||
|
|
||||||
|
if is_causal:
|
||||||
|
softmax_fuse_block_sum_kernel_causal[grid](
|
||||||
|
attn_weights_slice,
|
||||||
|
output,
|
||||||
|
scale,
|
||||||
|
attn_weights_slice.stride(0),
|
||||||
|
attn_weights_slice.stride(1),
|
||||||
|
attn_weights_slice.stride(2),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
output.stride(2),
|
||||||
|
real_q_len,
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
chunk_end,
|
||||||
|
segment_size,
|
||||||
|
reshaped_block_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
softmax_fuse_block_sum_kernel_non_causal[grid](
|
||||||
|
attn_weights_slice,
|
||||||
|
output,
|
||||||
|
scale,
|
||||||
|
attn_weights_slice.stride(0),
|
||||||
|
attn_weights_slice.stride(1),
|
||||||
|
attn_weights_slice.stride(2),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
output.stride(2),
|
||||||
|
real_q_len,
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
chunk_end,
|
||||||
|
segment_size,
|
||||||
|
reshaped_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True):
|
||||||
|
"""Wrapper for Triton flat-group-gemm-fuse-reshape kernel."""
|
||||||
|
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||||
|
kv_len = key_states.shape[2]
|
||||||
|
|
||||||
|
assert key_states.shape[0] == batch_size
|
||||||
|
assert key_states.shape[1] == num_heads
|
||||||
|
assert key_states.shape[3] == head_dim
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
||||||
|
dtype=query_states.dtype,
|
||||||
|
device=query_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adjust block size based on GPU shared memory
|
||||||
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
if props.total_memory < 30 * 1024**3: # Less than 30GB (e.g., RTX 3090 24GB)
|
||||||
|
BLOCK_M = 64
|
||||||
|
BLOCK_N = 64
|
||||||
|
else:
|
||||||
|
BLOCK_M = 128
|
||||||
|
BLOCK_N = 128
|
||||||
|
|
||||||
|
assert q_len % (stride * BLOCK_M) == 0
|
||||||
|
assert kv_len % (stride * BLOCK_N) == 0
|
||||||
|
|
||||||
|
grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads)
|
||||||
|
flat_group_gemm_fuse_reshape_kernel[grid](
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
output,
|
||||||
|
query_states.stride(0),
|
||||||
|
query_states.stride(1),
|
||||||
|
query_states.stride(2),
|
||||||
|
key_states.stride(0),
|
||||||
|
key_states.stride(1),
|
||||||
|
key_states.stride(2),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
output.stride(2),
|
||||||
|
chunk_start,
|
||||||
|
chunk_end,
|
||||||
|
num_heads,
|
||||||
|
stride,
|
||||||
|
head_dim,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_N,
|
||||||
|
is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@@ -10,10 +10,10 @@ from typing import List, Tuple, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext
|
||||||
|
|
||||||
|
|
||||||
class MInferencePolicy(SparsePolicy):
|
class MInferencePolicy(AttentionPolicy):
|
||||||
"""
|
"""
|
||||||
MInference sparse prefill policy using vertical + slash pattern.
|
MInference sparse prefill policy using vertical + slash pattern.
|
||||||
|
|
||||||
@@ -347,6 +347,33 @@ class MInferencePolicy(SparsePolicy):
|
|||||||
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute MInference sparse prefill attention.
|
||||||
|
|
||||||
|
This is the new unified interface for attention policies.
|
||||||
|
Delegates to sparse_prefill_attention (ignores softmax_scale as MInference
|
||||||
|
computes it internally from head_dim).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
softmax_scale: Softmax scaling factor (unused, computed internally)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
return self.sparse_prefill_attention(q, k, v, layer_id)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"MInferencePolicy("
|
return (f"MInferencePolicy("
|
||||||
f"adaptive_budget={self.adaptive_budget}, "
|
f"adaptive_budget={self.adaptive_budget}, "
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
Base class for sparse attention policies.
|
Base class for attention policies in layerwise offload mode.
|
||||||
|
|
||||||
Sparse attention policies determine which KV cache blocks to load
|
AttentionPolicy defines the interface for all attention computation,
|
||||||
from CPU for each query chunk during chunked attention computation.
|
including full attention and sparse attention methods like XAttention.
|
||||||
|
|
||||||
|
Key methods:
|
||||||
|
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
|
||||||
|
- compute_prefill(): Compute prefill attention
|
||||||
|
- compute_decode(): Compute decode attention (default implementation provided)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Import SparsePolicyType from config to avoid circular imports
|
# Import SparsePolicyType from config to avoid circular imports
|
||||||
@@ -17,10 +22,10 @@ from nanovllm.config import SparsePolicyType
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PolicyContext:
|
class PolicyContext:
|
||||||
"""
|
"""
|
||||||
Context passed to sparse policy for block selection.
|
Context passed to attention policy for block selection.
|
||||||
|
|
||||||
This dataclass contains all information needed by a sparse policy
|
This dataclass contains all information needed by an attention policy
|
||||||
to decide which blocks to load for the current query chunk.
|
for sparse estimation and attention computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
query_chunk_idx: int
|
query_chunk_idx: int
|
||||||
@@ -49,40 +54,41 @@ class PolicyContext:
|
|||||||
"""Total KV sequence length so far (for reference)."""
|
"""Total KV sequence length so far (for reference)."""
|
||||||
|
|
||||||
|
|
||||||
class SparsePolicy(ABC):
|
class AttentionPolicy(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for sparse attention policies.
|
Base class for attention policies in layerwise offload mode.
|
||||||
|
|
||||||
Subclass this and implement select_blocks() to create custom
|
All attention computation goes through a policy, including both
|
||||||
sparse attention patterns. The policy receives context about
|
full attention and sparse attention methods.
|
||||||
the current query chunk and returns which KV blocks to load.
|
|
||||||
|
The policy interface is designed for layerwise offload where:
|
||||||
|
- The entire KV cache for a layer is on GPU during computation
|
||||||
|
- No need for block loading from CPU during attention
|
||||||
|
- estimate() returns a sparse mask (or None for full attention)
|
||||||
|
- compute_prefill()/compute_decode() perform the actual attention
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
supports_prefill: Whether this policy can be used for prefill phase.
|
supports_prefill: Whether this policy can be used for prefill phase.
|
||||||
supports_decode: Whether this policy can be used for decode phase.
|
supports_decode: Whether this policy can be used for decode phase.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
class MySparsePolicy(SparsePolicy):
|
class MyPolicy(AttentionPolicy):
|
||||||
supports_prefill = False # decode-only policy
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
def select_blocks(self, available_blocks, ctx):
|
def estimate(self, q, k, layer_id):
|
||||||
# Load first block and last 2 blocks
|
# Return sparse mask or None
|
||||||
if len(available_blocks) <= 3:
|
return None
|
||||||
return available_blocks
|
|
||||||
return [available_blocks[0]] + available_blocks[-2:]
|
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||||
|
# Compute attention
|
||||||
|
return flash_attn_varlen_func(q, k, v, ...)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Compatibility flags - override in subclasses
|
# Compatibility flags - override in subclasses
|
||||||
supports_prefill: bool = True
|
supports_prefill: bool = True
|
||||||
supports_decode: bool = True
|
supports_decode: bool = True
|
||||||
|
|
||||||
# Whether this policy requires selective block loading during decode
|
|
||||||
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
|
|
||||||
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
|
|
||||||
# Example: MInference=False (only affects attention), Quest=True (affects load)
|
|
||||||
requires_block_selection: bool = False
|
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
self,
|
self,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
@@ -96,7 +102,7 @@ class SparsePolicy(ABC):
|
|||||||
Initialize policy resources.
|
Initialize policy resources.
|
||||||
|
|
||||||
Called by the framework after KV cache is allocated. Override this
|
Called by the framework after KV cache is allocated. Override this
|
||||||
to create metadata structures (e.g., BlockMetadataManager for Quest).
|
to create metadata structures or pre-allocate buffers.
|
||||||
Default implementation does nothing.
|
Default implementation does nothing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -109,76 +115,98 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
def estimate(
|
||||||
def select_blocks(
|
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
q: torch.Tensor,
|
||||||
ctx: PolicyContext,
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Select which KV blocks to load for the current query chunk.
|
Estimate sparse attention mask.
|
||||||
|
|
||||||
This is the core method that defines the sparse attention pattern.
|
For sparse policies (e.g., XAttention), computes block-level importance
|
||||||
The returned blocks will be loaded from CPU to GPU for attention
|
and returns a boolean mask indicating which blocks to attend.
|
||||||
computation against the current query chunk.
|
For full attention policy, returns None.
|
||||||
|
|
||||||
|
This corresponds to xattn_estimate() in COMPASS.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
available_blocks: List of CPU block IDs that contain KV cache
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
from previous chunks. These are ordered by
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
their position in the sequence.
|
layer_id: Transformer layer index
|
||||||
ctx: PolicyContext with information about the current query
|
|
||||||
chunk, layer, phase (prefill/decode), etc.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of block IDs to load (must be a subset of available_blocks).
|
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
||||||
The order may affect performance (sequential access is faster).
|
or None for full attention
|
||||||
Returning [] means no previous blocks will be loaded.
|
|
||||||
"""
|
"""
|
||||||
pass
|
return None
|
||||||
|
|
||||||
def on_prefill_offload(
|
@abstractmethod
|
||||||
|
def compute_prefill(
|
||||||
self,
|
self,
|
||||||
cpu_block_id: int,
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
k_cache: torch.Tensor,
|
softmax_scale: float,
|
||||||
num_valid_tokens: int,
|
) -> torch.Tensor:
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Hook called when a block is offloaded during prefill phase.
|
Compute prefill attention.
|
||||||
|
|
||||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
The entire KV cache for this layer is on GPU. Compute attention
|
||||||
Override this to collect metadata about blocks (e.g., min/max keys
|
between Q and K/V, optionally using sparse mask from estimate().
|
||||||
for Quest-style selection). Default implementation does nothing.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_id: The CPU block ID that will be written
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
layer_id: Transformer layer index
|
layer_id: Transformer layer index
|
||||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
||||||
num_valid_tokens: Number of valid tokens in this block
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_decode_offload(
|
def compute_decode(
|
||||||
self,
|
self,
|
||||||
cpu_block_id: int,
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
k_cache: torch.Tensor,
|
softmax_scale: float,
|
||||||
num_valid_tokens: int,
|
) -> torch.Tensor:
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Hook called when a block is offloaded during decode phase.
|
Compute decode attention.
|
||||||
|
|
||||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
KV is provided from ring buffer, containing prefill tokens + decoded tokens.
|
||||||
Override this to update metadata about blocks. Default implementation
|
Default implementation uses FlashAttention.
|
||||||
does nothing.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_id: The CPU block ID that will be written
|
q: Query tensor [1, num_heads, head_dim]
|
||||||
|
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
||||||
layer_id: Transformer layer index
|
layer_id: Transformer layer index
|
||||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
softmax_scale: Softmax scaling factor
|
||||||
num_valid_tokens: Number of valid tokens in this block
|
|
||||||
|
Returns:
|
||||||
|
Attention output [1, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
pass
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
context_len = k.shape[0]
|
||||||
|
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
max_seqlen_k=context_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -189,32 +217,9 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sparse_prefill_attention(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute sparse attention for prefill phase.
|
|
||||||
|
|
||||||
This method is called when supports_prefill=True and the policy
|
|
||||||
is used for GPU-only sparse prefill (no CPU offload).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Current transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
|
|
||||||
"Set supports_prefill=False or implement this method."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias
|
||||||
|
SparsePolicy = AttentionPolicy
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
from .policy import SparsePolicy, PolicyContext
|
from .policy import AttentionPolicy, PolicyContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -137,7 +137,7 @@ class QuestConfig:
|
|||||||
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
|
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
|
||||||
|
|
||||||
|
|
||||||
class QuestPolicy(SparsePolicy):
|
class QuestPolicy(AttentionPolicy):
|
||||||
"""
|
"""
|
||||||
Quest-style Top-K block selection using min/max key bounds.
|
Quest-style Top-K block selection using min/max key bounds.
|
||||||
|
|
||||||
@@ -317,6 +317,25 @@ class QuestPolicy(SparsePolicy):
|
|||||||
if self.metadata is not None:
|
if self.metadata is not None:
|
||||||
self.metadata.reset()
|
self.metadata.reset()
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Quest does not support prefill - raises error.
|
||||||
|
|
||||||
|
Quest is a decode-only policy for selective block loading.
|
||||||
|
For prefill, use FullAttentionPolicy or XAttentionPolicy.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"QuestPolicy does not support prefill. "
|
||||||
|
"Use FullAttentionPolicy or XAttentionPolicy for prefill."
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"QuestPolicy(topk={self.config.topk_blocks}, "
|
f"QuestPolicy(topk={self.config.topk_blocks}, "
|
||||||
|
|||||||
156
nanovllm/kvcache/sparse/utils.py
Normal file
156
nanovllm/kvcache/sparse/utils.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
Utility functions for sparse attention policies.
|
||||||
|
|
||||||
|
Copied from COMPASS/compass/src/utils.py for XAttention integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def find_blocks_chunked(
|
||||||
|
input_tensor, current_index, threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Finds and selects relevant blocks of attention for transformer-based models based on a
|
||||||
|
threshold or a predefined number of blocks.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num).
|
||||||
|
- current_index (int): The current index in the sequence processing.
|
||||||
|
- threshold (float or None): A threshold value used to determine the minimum attention weight sum.
|
||||||
|
- num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval.
|
||||||
|
- decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode.
|
||||||
|
- mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'.
|
||||||
|
- causal (bool): If True, applies causal masking to prevent future information leakage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num),
|
||||||
|
indicating which blocks should be attended to.
|
||||||
|
"""
|
||||||
|
assert threshold is None or num_to_choose is None
|
||||||
|
batch_size, head_num, chunk_num, block_num = input_tensor.shape
|
||||||
|
|
||||||
|
if mode == "prefill" and decoding:
|
||||||
|
return torch.ones_like(input_tensor, dtype=torch.bool)
|
||||||
|
if mode == "decode" and not decoding:
|
||||||
|
mask = torch.ones_like(input_tensor, dtype=torch.bool)
|
||||||
|
if causal:
|
||||||
|
mask[:, :, :, current_index : current_index + chunk_num] = torch.tril(
|
||||||
|
torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device)
|
||||||
|
)
|
||||||
|
mask[:, :, current_index + chunk_num :, :] = 0
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1],
|
||||||
|
torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return mask
|
||||||
|
|
||||||
|
input_tensor = input_tensor.to(float)
|
||||||
|
|
||||||
|
if threshold is not None:
|
||||||
|
total_sum = input_tensor.sum(dim=-1, keepdim=True)
|
||||||
|
if isinstance(threshold, torch.Tensor):
|
||||||
|
threshold = threshold.to(float)
|
||||||
|
required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
||||||
|
-1
|
||||||
|
).expand((batch_size, head_num, chunk_num, 1)).to(input_tensor.device)
|
||||||
|
else:
|
||||||
|
required_sum = total_sum * threshold
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
||||||
|
mask[:, :, :, 0] = 1
|
||||||
|
mask[:, :, :, current_index : current_index + chunk_num] = (
|
||||||
|
torch.eye(chunk_num, device=mask.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(1, head_num, chunk_num, chunk_num)
|
||||||
|
)
|
||||||
|
other_values = input_tensor.masked_fill(mask, 0)
|
||||||
|
sorted_values, _ = torch.sort(
|
||||||
|
other_values, dim=-1, descending=True
|
||||||
|
)
|
||||||
|
sorted_values = sorted_values.to(input_tensor.device)
|
||||||
|
|
||||||
|
sorted_values = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||||
|
),
|
||||||
|
torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True),
|
||||||
|
sorted_values[:, :, :, :-2],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, index = torch.sort(
|
||||||
|
torch.where(mask, 100000 * (1 + input_tensor), input_tensor),
|
||||||
|
dim=-1,
|
||||||
|
descending=True
|
||||||
|
)
|
||||||
|
cumulative_sum_without_self = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||||
|
),
|
||||||
|
sorted_values[:, :, :, 0:-1],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).cumsum(dim=-1)
|
||||||
|
|
||||||
|
index_mask = cumulative_sum_without_self < required_sum
|
||||||
|
index = torch.where(index_mask, index, 0)
|
||||||
|
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
||||||
|
index = index.view(batch_size, head_num * chunk_num, block_num)
|
||||||
|
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
|
||||||
|
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
||||||
|
sorted_values, index = torch.sort(
|
||||||
|
input_tensor, dim=-1, descending=True
|
||||||
|
)
|
||||||
|
sorted_values = sorted_values.to(input_tensor.device)
|
||||||
|
cumulative_sum_without_self = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||||
|
),
|
||||||
|
sorted_values[:, :, :, 0:-1],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).cumsum(dim=-1)
|
||||||
|
index_mask = cumulative_sum_without_self < required_sum
|
||||||
|
index = torch.where(index_mask, index, 0)
|
||||||
|
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
||||||
|
index = index.view(batch_size, head_num * chunk_num, block_num)
|
||||||
|
mask[
|
||||||
|
:,
|
||||||
|
torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1),
|
||||||
|
index,
|
||||||
|
] = True
|
||||||
|
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("block num chunk prefill not implemented")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if causal:
|
||||||
|
assert (~mask[:, :, :, current_index + chunk_num :]).all()
|
||||||
|
except:
|
||||||
|
mask[:, :, :, current_index + chunk_num :] = False
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
if decoding:
|
||||||
|
assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all()
|
||||||
|
else:
|
||||||
|
lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device)
|
||||||
|
lambda_mask[:, :, :, 0] = 1
|
||||||
|
lambda_mask[:, :, :, current_index:current_index+chunk_num] = torch.eye(
|
||||||
|
chunk_num, device=lambda_mask.device
|
||||||
|
).unsqueeze(0).unsqueeze(0).expand(1, head_num, chunk_num, chunk_num)
|
||||||
|
assert(torch.where(lambda_mask, mask, True).all())
|
||||||
|
|
||||||
|
return mask
|
||||||
310
nanovllm/kvcache/sparse/xattn.py
Normal file
310
nanovllm/kvcache/sparse/xattn.py
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
"""
|
||||||
|
XAttention sparse attention policy for nano-vllm.
|
||||||
|
|
||||||
|
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
||||||
|
and block sparse attention for efficient long-context inference.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
XAttention = Estimate (Triton) + Compute (BSA)
|
||||||
|
- Estimate: xattn_estimate() computes block-level importance scores
|
||||||
|
- Compute: block_sparse_attn_func() executes sparse attention
|
||||||
|
|
||||||
|
Reference: COMPASS/compass/src/Xattention.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from nanovllm.kvcache.sparse.policy import AttentionPolicy
|
||||||
|
|
||||||
|
# BSA block size is fixed at 128 (hardcoded in block_sparse_attn)
|
||||||
|
BSA_BLOCK_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
|
class XAttentionPolicy(AttentionPolicy):
|
||||||
|
"""
|
||||||
|
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
||||||
|
|
||||||
|
This policy estimates sparse attention patterns by:
|
||||||
|
1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn)
|
||||||
|
2. Block-wise softmax with importance scores
|
||||||
|
3. Block selection based on threshold
|
||||||
|
4. Block sparse attention computation using MIT-HAN-LAB BSA library
|
||||||
|
|
||||||
|
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
|
||||||
|
to compute the sparse attention mask.
|
||||||
|
|
||||||
|
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
||||||
|
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True # Uses default FlashAttention for decode
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stride: int = 8,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
block_size: int = 128,
|
||||||
|
chunk_size: int = 16384,
|
||||||
|
use_triton: bool = True,
|
||||||
|
keep_sink: bool = False,
|
||||||
|
keep_recent: bool = False,
|
||||||
|
norm: float = 1.0,
|
||||||
|
use_bsa: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize XAttention policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stride: Stride for reorganizing Q/K (default: 8)
|
||||||
|
threshold: Block selection threshold, 0-1 (default: 0.9)
|
||||||
|
block_size: Block size for sparse attention (default: 128, must match BSA)
|
||||||
|
chunk_size: Chunk size for estimation (default: 16384)
|
||||||
|
use_triton: Use Triton kernels (requires SM 80+)
|
||||||
|
keep_sink: Always keep first block (sink tokens)
|
||||||
|
keep_recent: Always keep recent diagonal blocks
|
||||||
|
norm: Normalization factor for attention scores
|
||||||
|
use_bsa: Use Block Sparse Attention library (default: True)
|
||||||
|
"""
|
||||||
|
self.stride = stride
|
||||||
|
self.threshold = threshold
|
||||||
|
self.block_size = block_size
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.use_triton = use_triton
|
||||||
|
self.keep_sink = keep_sink
|
||||||
|
self.keep_recent = keep_recent
|
||||||
|
self.norm = norm
|
||||||
|
self.use_bsa = use_bsa
|
||||||
|
|
||||||
|
# BSA requires block_size = 128
|
||||||
|
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
|
||||||
|
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
|
||||||
|
self.block_size = BSA_BLOCK_SIZE
|
||||||
|
|
||||||
|
# Check Triton availability
|
||||||
|
if self.use_triton:
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
if props.major < 8:
|
||||||
|
self.use_triton = False
|
||||||
|
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
|
||||||
|
except ImportError:
|
||||||
|
self.use_triton = False
|
||||||
|
print("XAttention: Triton not available. Falling back to PyTorch.")
|
||||||
|
|
||||||
|
# Check BSA availability
|
||||||
|
if self.use_bsa:
|
||||||
|
try:
|
||||||
|
from block_sparse_attn import block_sparse_attn_func
|
||||||
|
except ImportError:
|
||||||
|
self.use_bsa = False
|
||||||
|
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Estimate sparse attention mask using XAttention algorithm.
|
||||||
|
|
||||||
|
Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level
|
||||||
|
importance scores and generate a sparse boolean mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
||||||
|
or None if estimation fails (fallback to full attention)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from nanovllm.ops.xattn import xattn_estimate
|
||||||
|
|
||||||
|
seq_len, num_heads, head_dim = q.shape
|
||||||
|
num_kv_heads = k.shape[1]
|
||||||
|
|
||||||
|
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
|
||||||
|
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
|
||||||
|
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
# Handle GQA: expand k to match q heads for estimation
|
||||||
|
if num_kv_heads != num_heads:
|
||||||
|
# GQA: expand k by repeating
|
||||||
|
repeat_factor = num_heads // num_kv_heads
|
||||||
|
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
|
||||||
|
|
||||||
|
# Call xattn_estimate
|
||||||
|
attn_sums, sparse_mask = xattn_estimate(
|
||||||
|
q_bhsd, k_bhsd,
|
||||||
|
block_size=self.block_size,
|
||||||
|
stride=self.stride,
|
||||||
|
norm=self.norm,
|
||||||
|
threshold=self.threshold,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
use_triton=self.use_triton,
|
||||||
|
causal=True,
|
||||||
|
keep_sink=self.keep_sink,
|
||||||
|
keep_recent=self.keep_recent,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sparse_mask
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If estimation fails, return None to use full attention
|
||||||
|
print(f"XAttention estimate failed: {e}, falling back to full attention")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute XAttention sparse prefill attention.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Call estimate() to get sparse mask
|
||||||
|
2. If mask is None or BSA unavailable, use full FlashAttention
|
||||||
|
3. Otherwise, use block_sparse_attn_func with mask
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
softmax_scale: Softmax scaling factor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
# If BSA is disabled, use full attention directly (skip estimation)
|
||||||
|
if not self.use_bsa:
|
||||||
|
return self._full_attention(q, k, v, softmax_scale)
|
||||||
|
|
||||||
|
# Step 1: Estimate sparse mask
|
||||||
|
sparse_mask = self.estimate(q, k, layer_id)
|
||||||
|
|
||||||
|
# Step 2: Compute attention
|
||||||
|
if sparse_mask is None:
|
||||||
|
# Estimation failed, fallback to full FlashAttention
|
||||||
|
return self._full_attention(q, k, v, softmax_scale)
|
||||||
|
|
||||||
|
# Use block sparse attention with mask
|
||||||
|
return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
|
||||||
|
|
||||||
|
def _block_sparse_attention(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
sparse_mask: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute block sparse attention using MIT-HAN-LAB BSA library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
|
||||||
|
softmax_scale: Softmax scaling factor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
from block_sparse_attn import block_sparse_attn_func
|
||||||
|
|
||||||
|
seq_len, num_heads, head_dim = q.shape
|
||||||
|
num_kv_heads = k.shape[1]
|
||||||
|
|
||||||
|
# Handle GQA: expand K/V to match Q heads
|
||||||
|
if num_kv_heads != num_heads:
|
||||||
|
repeat_factor = num_heads // num_kv_heads
|
||||||
|
k = k.repeat_interleave(repeat_factor, dim=1)
|
||||||
|
v = v.repeat_interleave(repeat_factor, dim=1)
|
||||||
|
|
||||||
|
# Cumulative sequence lengths (batch=1)
|
||||||
|
cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
# Head mask type: 1 for all heads using block sparse
|
||||||
|
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
# Trim sparse_mask to actual block counts
|
||||||
|
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||||
|
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||||
|
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
|
||||||
|
|
||||||
|
# Call BSA
|
||||||
|
attn_output = block_sparse_attn_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q, cu_seqlens_k,
|
||||||
|
head_mask_type,
|
||||||
|
None, # streaming_info (left_mask)
|
||||||
|
block_mask,
|
||||||
|
seq_len, seq_len,
|
||||||
|
p_dropout=0.0,
|
||||||
|
deterministic=True,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def _full_attention(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute full causal attention using FlashAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
softmax_scale: Softmax scaling factor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
seq_len = q.shape[0]
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset policy state (no state to reset for XAttention)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"XAttentionPolicy("
|
||||||
|
f"stride={self.stride}, "
|
||||||
|
f"threshold={self.threshold}, "
|
||||||
|
f"block_size={self.block_size}, "
|
||||||
|
f"use_triton={self.use_triton}, "
|
||||||
|
f"use_bsa={self.use_bsa})")
|
||||||
@@ -98,10 +98,10 @@ class Attention(nn.Module):
|
|||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||||
elif context.sparse_prefill_policy is not None:
|
elif context.attention_policy is not None:
|
||||||
# Sparse prefill (GPU-only) - delegate to policy
|
# Attention via policy (GPU-only) - delegate to policy
|
||||||
o = context.sparse_prefill_policy.sparse_prefill_attention(
|
o = context.attention_policy.compute_prefill(
|
||||||
q, k, v, self.layer_id
|
q, k, v, self.layer_id, softmax_scale=self.scale
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
o = flash_attn_varlen_func(q, k, v,
|
||||||
|
|||||||
@@ -27,13 +27,13 @@ class RMSNorm(nn.Module):
|
|||||||
x = x.to(orig_dtype).mul_(self.weight)
|
x = x.to(orig_dtype).mul_(self.weight)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.compile
|
|
||||||
def add_rms_forward(
|
def add_rms_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
|
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
|
||||||
|
# Note: @torch.compile removed due to OOM with 64k sequences (memory fragmentation)
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.float().add_(residual.float())
|
x = x.float().add_(residual.float())
|
||||||
residual = x.to(orig_dtype)
|
residual = x.to(orig_dtype)
|
||||||
|
|||||||
38
nanovllm/ops/__init__.py
Normal file
38
nanovllm/ops/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
Operators module for nano-vLLM.
|
||||||
|
|
||||||
|
This module contains low-level attention operators and kernels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nanovllm.ops.chunked_attention import (
|
||||||
|
flash_attn_with_lse,
|
||||||
|
merge_attention_outputs,
|
||||||
|
chunked_attention_varlen,
|
||||||
|
ChunkedPrefillState,
|
||||||
|
)
|
||||||
|
|
||||||
|
from nanovllm.ops.xattn import (
|
||||||
|
xattn_estimate,
|
||||||
|
xattn_estimate_chunked,
|
||||||
|
flat_group_gemm_fuse_reshape,
|
||||||
|
softmax_fuse_block_sum,
|
||||||
|
find_blocks_chunked,
|
||||||
|
create_causal_mask,
|
||||||
|
compute_sparsity,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# chunked_attention
|
||||||
|
"flash_attn_with_lse",
|
||||||
|
"merge_attention_outputs",
|
||||||
|
"chunked_attention_varlen",
|
||||||
|
"ChunkedPrefillState",
|
||||||
|
# xattn
|
||||||
|
"xattn_estimate",
|
||||||
|
"xattn_estimate_chunked",
|
||||||
|
"flat_group_gemm_fuse_reshape",
|
||||||
|
"softmax_fuse_block_sum",
|
||||||
|
"find_blocks_chunked",
|
||||||
|
"create_causal_mask",
|
||||||
|
"compute_sparsity",
|
||||||
|
]
|
||||||
624
nanovllm/ops/chunked_attention.py
Normal file
624
nanovllm/ops/chunked_attention.py
Normal file
@@ -0,0 +1,624 @@
|
|||||||
|
"""
|
||||||
|
Chunked attention implementation for CPU KV cache offloading.
|
||||||
|
|
||||||
|
This module implements flash attention with LSE (log-sum-exp) output,
|
||||||
|
enabling proper online softmax merging for chunked prefill.
|
||||||
|
|
||||||
|
Key functions:
|
||||||
|
- flash_attn_with_lse: Flash attention that returns output and LSE
|
||||||
|
- merge_attention_outputs: Merge outputs from multiple KV chunks
|
||||||
|
- chunked_prefill_attention: High-level interface for chunked attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from typing import Tuple, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||||
|
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||||
|
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_with_lse(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
Out,
|
||||||
|
Lse,
|
||||||
|
softmax_scale,
|
||||||
|
stride_qb,
|
||||||
|
stride_qh,
|
||||||
|
stride_qm,
|
||||||
|
stride_kb,
|
||||||
|
stride_kh,
|
||||||
|
stride_kn,
|
||||||
|
stride_vb,
|
||||||
|
stride_vh,
|
||||||
|
stride_vn,
|
||||||
|
stride_ob,
|
||||||
|
stride_oh,
|
||||||
|
stride_om,
|
||||||
|
nheads,
|
||||||
|
seqlen_q,
|
||||||
|
seqlen_k,
|
||||||
|
seqlen_q_rounded,
|
||||||
|
headdim,
|
||||||
|
CACHE_KEY_SEQLEN_Q,
|
||||||
|
CACHE_KEY_SEQLEN_K,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
|
BLOCK_HEADDIM: tl.constexpr,
|
||||||
|
EVEN_M: tl.constexpr,
|
||||||
|
EVEN_N: tl.constexpr,
|
||||||
|
EVEN_HEADDIM: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Flash attention forward kernel with LSE output.
|
||||||
|
|
||||||
|
Implements standard Flash Attention online softmax algorithm:
|
||||||
|
- m_i: running max of attention scores
|
||||||
|
- l_i: running sum of exp(scores - m_i)
|
||||||
|
- acc_o: running sum of softmax(scores) @ V (unnormalized)
|
||||||
|
|
||||||
|
Final output: acc_o / l_i
|
||||||
|
Final LSE: m_i + log(l_i)
|
||||||
|
"""
|
||||||
|
start_m = tl.program_id(0)
|
||||||
|
off_hb = tl.program_id(1)
|
||||||
|
off_b = off_hb // nheads
|
||||||
|
off_h = off_hb % nheads
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||||
|
|
||||||
|
# Pointers
|
||||||
|
q_ptrs = (
|
||||||
|
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
||||||
|
)
|
||||||
|
k_ptrs = (
|
||||||
|
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
||||||
|
)
|
||||||
|
v_ptrs = (
|
||||||
|
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize running statistics
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
|
||||||
|
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
|
||||||
|
|
||||||
|
# Load Q (once per block)
|
||||||
|
if EVEN_M & EVEN_N:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
q = tl.load(q_ptrs)
|
||||||
|
else:
|
||||||
|
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||||
|
else:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
||||||
|
else:
|
||||||
|
q = tl.load(
|
||||||
|
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loop over K, V blocks
|
||||||
|
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
||||||
|
for start_n in range(0, end_n, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
|
||||||
|
# Load K
|
||||||
|
if EVEN_N & EVEN_M:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||||
|
else:
|
||||||
|
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||||
|
else:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + start_n * stride_kn,
|
||||||
|
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + start_n * stride_kn,
|
||||||
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute QK^T * scale
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk += tl.dot(q, tl.trans(k))
|
||||||
|
qk *= softmax_scale
|
||||||
|
|
||||||
|
# Apply masks
|
||||||
|
if not EVEN_N:
|
||||||
|
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
||||||
|
if IS_CAUSAL:
|
||||||
|
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
||||||
|
|
||||||
|
# Online softmax: compute block max
|
||||||
|
m_ij = tl.max(qk, 1) # [BLOCK_M]
|
||||||
|
|
||||||
|
# New running max
|
||||||
|
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
|
||||||
|
|
||||||
|
# Rescale factor for previous accumulator
|
||||||
|
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
|
||||||
|
|
||||||
|
# Compute P = exp(qk - m_new)
|
||||||
|
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
|
||||||
|
|
||||||
|
# Sum of current block
|
||||||
|
l_ij = tl.sum(p, 1) # [BLOCK_M]
|
||||||
|
|
||||||
|
# Update running sum: l_new = l_i * alpha + l_ij
|
||||||
|
l_new = l_i * alpha + l_ij
|
||||||
|
|
||||||
|
# Rescale previous output and add new contribution
|
||||||
|
acc_o = acc_o * alpha[:, None]
|
||||||
|
|
||||||
|
# Load V
|
||||||
|
if EVEN_N & EVEN_M:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||||
|
else:
|
||||||
|
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||||
|
else:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + start_n * stride_vn,
|
||||||
|
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + start_n * stride_vn,
|
||||||
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# acc_o += P @ V
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc_o += tl.dot(p, v)
|
||||||
|
|
||||||
|
# Update running statistics
|
||||||
|
m_i = m_new
|
||||||
|
l_i = l_new
|
||||||
|
|
||||||
|
# Final normalization: output = acc_o / l_i
|
||||||
|
acc_o = acc_o / l_i[:, None]
|
||||||
|
|
||||||
|
# Compute LSE = m_i + log(l_i)
|
||||||
|
lse_i = m_i + tl.log(l_i)
|
||||||
|
|
||||||
|
# Store LSE
|
||||||
|
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
||||||
|
if EVEN_M:
|
||||||
|
tl.store(lse_ptrs, lse_i)
|
||||||
|
else:
|
||||||
|
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
|
||||||
|
|
||||||
|
# Store output
|
||||||
|
out_ptrs = (
|
||||||
|
Out
|
||||||
|
+ off_b * stride_ob
|
||||||
|
+ off_h * stride_oh
|
||||||
|
+ (offs_m[:, None] * stride_om + offs_d[None, :])
|
||||||
|
)
|
||||||
|
if EVEN_M:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
tl.store(out_ptrs, acc_o)
|
||||||
|
else:
|
||||||
|
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
||||||
|
else:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
||||||
|
else:
|
||||||
|
tl.store(
|
||||||
|
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_with_lse(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Flash attention forward pass that returns both output and LSE.
|
||||||
|
|
||||||
|
Uses flash_attn library which natively supports GQA without memory overhead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||||
|
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||||
|
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||||
|
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
|
||||||
|
causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||||
|
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
|
||||||
|
"""
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
|
||||||
|
batch, seqlen_q, nheads_q, headdim = q.shape
|
||||||
|
_, seqlen_k, nheads_kv, _ = k.shape
|
||||||
|
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||||
|
|
||||||
|
# Use flash_attn_func which natively supports GQA (no memory overhead)
|
||||||
|
# It returns (output, softmax_lse) when return_attn_probs=True is not set
|
||||||
|
# We need to use the internal function to get LSE
|
||||||
|
out, lse, _ = flash_attn_func(
|
||||||
|
q, k, v,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
|
||||||
|
)
|
||||||
|
|
||||||
|
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
|
||||||
|
# Trim to actual seqlen_q
|
||||||
|
lse = lse[:, :, :seqlen_q]
|
||||||
|
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _merge_lse_kernel(
|
||||||
|
lse1_ptr, lse2_ptr, lse_out_ptr,
|
||||||
|
num_elements: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Fused kernel for merging LSE values.
|
||||||
|
|
||||||
|
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
|
||||||
|
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
|
||||||
|
"""
|
||||||
|
# Each program handles BLOCK_SIZE elements
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < num_elements
|
||||||
|
|
||||||
|
# Load lse values and convert to fp32 for precision
|
||||||
|
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
|
||||||
|
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
|
||||||
|
|
||||||
|
# Compute max for numerical stability (in fp32)
|
||||||
|
max_lse = tl.maximum(lse1, lse2)
|
||||||
|
|
||||||
|
# Compute exp(lse - max_lse) in fp32
|
||||||
|
exp1 = tl.exp(lse1 - max_lse)
|
||||||
|
exp2 = tl.exp(lse2 - max_lse)
|
||||||
|
|
||||||
|
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
|
||||||
|
lse_merged = max_lse + tl.log(exp1 + exp2)
|
||||||
|
|
||||||
|
# Store result (convert back to original dtype)
|
||||||
|
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _merge_output_kernel(
|
||||||
|
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
|
||||||
|
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Fused kernel for merging attention outputs.
|
||||||
|
|
||||||
|
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
|
||||||
|
This is critical for numerical accuracy in chunked attention.
|
||||||
|
"""
|
||||||
|
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
||||||
|
pid_batch = tl.program_id(0)
|
||||||
|
pid_seq = tl.program_id(1)
|
||||||
|
pid_head = tl.program_id(2)
|
||||||
|
|
||||||
|
# Compute LSE index: [batch, nheads, seqlen_q]
|
||||||
|
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
||||||
|
|
||||||
|
# Load LSE values and convert to fp32 for precision
|
||||||
|
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
|
||||||
|
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
|
||||||
|
|
||||||
|
# Compute max and scaling factors in fp32
|
||||||
|
max_lse = tl.maximum(lse1, lse2)
|
||||||
|
exp1 = tl.exp(lse1 - max_lse)
|
||||||
|
exp2 = tl.exp(lse2 - max_lse)
|
||||||
|
sum_exp = exp1 + exp2
|
||||||
|
|
||||||
|
# Process headdim in chunks
|
||||||
|
for d_offset in range(0, headdim, BLOCK_SIZE):
|
||||||
|
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = d_idx < headdim
|
||||||
|
|
||||||
|
# Compute output index: [batch, seqlen_q, nheads, headdim]
|
||||||
|
base_idx = (pid_batch * seqlen_q * nheads * headdim +
|
||||||
|
pid_seq * nheads * headdim +
|
||||||
|
pid_head * headdim)
|
||||||
|
o_idx = base_idx + d_idx
|
||||||
|
|
||||||
|
# Load o1, o2 and convert to fp32 for weighted sum
|
||||||
|
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
|
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||||
|
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
||||||
|
|
||||||
|
# Store result (Triton will convert back to original dtype)
|
||||||
|
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_attention_outputs(
|
||||||
|
o1: torch.Tensor,
|
||||||
|
lse1: torch.Tensor,
|
||||||
|
o2: torch.Tensor,
|
||||||
|
lse2: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Merge two attention outputs using online softmax (Triton fused kernel).
|
||||||
|
|
||||||
|
This implements the online softmax merging formula:
|
||||||
|
- m_new = max(lse1, lse2)
|
||||||
|
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
|
||||||
|
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o1: First output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse1: First LSE [batch, nheads, seqlen_q]
|
||||||
|
o2: Second output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse2: Second LSE [batch, nheads, seqlen_q]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse_merged: Merged LSE [batch, nheads, seqlen_q]
|
||||||
|
"""
|
||||||
|
batch, seqlen_q, nheads, headdim = o1.shape
|
||||||
|
|
||||||
|
# Allocate output tensors
|
||||||
|
o_merged = torch.empty_like(o1)
|
||||||
|
lse_merged = torch.empty_like(lse1)
|
||||||
|
|
||||||
|
# Launch LSE merge kernel
|
||||||
|
num_lse_elements = batch * nheads * seqlen_q
|
||||||
|
BLOCK_SIZE_LSE = 256
|
||||||
|
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
|
||||||
|
_merge_lse_kernel[grid_lse](
|
||||||
|
lse1, lse2, lse_merged,
|
||||||
|
num_lse_elements,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE_LSE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch output merge kernel
|
||||||
|
BLOCK_SIZE = 128
|
||||||
|
grid_output = (batch, seqlen_q, nheads)
|
||||||
|
_merge_output_kernel[grid_output](
|
||||||
|
o1, o2, lse1, lse2, o_merged,
|
||||||
|
batch, seqlen_q, nheads, headdim,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o_merged, lse_merged
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_attention_varlen(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k_list: List[torch.Tensor],
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k_list: List[int],
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
causal_mask_per_chunk: Optional[List[bool]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute attention with KV split across multiple chunks.
|
||||||
|
|
||||||
|
This is the core function for chunked prefill. It computes attention
|
||||||
|
against each KV chunk and merges results using online softmax.
|
||||||
|
|
||||||
|
For causal attention with chunked KV:
|
||||||
|
- First chunk (current tokens): Apply causal mask
|
||||||
|
- Previous chunks: No causal mask (all previous tokens are valid context)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [total_q_tokens, nheads, headdim]
|
||||||
|
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
|
||||||
|
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||||
|
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
|
||||||
|
max_seqlen_q: Maximum query sequence length
|
||||||
|
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
|
||||||
|
softmax_scale: Scaling factor
|
||||||
|
causal_mask_per_chunk: Whether to apply causal mask for each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: Output tensor [total_q_tokens, nheads, headdim]
|
||||||
|
"""
|
||||||
|
if len(kv_chunks) == 0:
|
||||||
|
raise ValueError("Need at least one KV chunk")
|
||||||
|
|
||||||
|
nheads = q.shape[1]
|
||||||
|
headdim = q.shape[2]
|
||||||
|
batch = cu_seqlens_q.shape[0] - 1
|
||||||
|
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||||
|
|
||||||
|
if causal_mask_per_chunk is None:
|
||||||
|
# Default: causal for last chunk only
|
||||||
|
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
|
||||||
|
|
||||||
|
# Initialize accumulated output and LSE
|
||||||
|
accumulated_o = None
|
||||||
|
accumulated_lse = None
|
||||||
|
|
||||||
|
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
|
||||||
|
is_causal = causal_mask_per_chunk[chunk_idx]
|
||||||
|
|
||||||
|
# Reshape Q for batch processing
|
||||||
|
# For varlen, we need to handle each sequence separately
|
||||||
|
# For simplicity, assume single sequence (batch=1) for now
|
||||||
|
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
|
||||||
|
|
||||||
|
# Compute attention for this chunk
|
||||||
|
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||||
|
q_batched,
|
||||||
|
k_chunk,
|
||||||
|
v_chunk,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge with accumulated
|
||||||
|
if accumulated_o is None:
|
||||||
|
accumulated_o = chunk_o
|
||||||
|
accumulated_lse = chunk_lse
|
||||||
|
else:
|
||||||
|
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||||
|
accumulated_o, accumulated_lse,
|
||||||
|
chunk_o, chunk_lse,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove batch dimension
|
||||||
|
return accumulated_o.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkedPrefillState:
|
||||||
|
"""
|
||||||
|
State for tracking chunked prefill progress.
|
||||||
|
|
||||||
|
This class maintains the accumulated attention output and LSE
|
||||||
|
across multiple prefill chunks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Per-layer accumulated outputs
|
||||||
|
# Each entry: (accumulated_output, accumulated_lse) or None
|
||||||
|
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
|
||||||
|
None for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Track which chunks have been processed
|
||||||
|
self.processed_chunks: int = 0
|
||||||
|
|
||||||
|
def update_layer(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
chunk_output: torch.Tensor,
|
||||||
|
chunk_lse: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Update accumulated state for a layer with a new chunk's output."""
|
||||||
|
if self.layer_states[layer_id] is None:
|
||||||
|
self.layer_states[layer_id] = (chunk_output, chunk_lse)
|
||||||
|
else:
|
||||||
|
acc_o, acc_lse = self.layer_states[layer_id]
|
||||||
|
merged_o, merged_lse = merge_attention_outputs(
|
||||||
|
acc_o, acc_lse,
|
||||||
|
chunk_output, chunk_lse,
|
||||||
|
)
|
||||||
|
self.layer_states[layer_id] = (merged_o, merged_lse)
|
||||||
|
|
||||||
|
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
|
||||||
|
"""Get the final accumulated output for a layer."""
|
||||||
|
if self.layer_states[layer_id] is None:
|
||||||
|
return None
|
||||||
|
return self.layer_states[layer_id][0]
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear all accumulated state."""
|
||||||
|
self.layer_states = [None for _ in range(self.num_layers)]
|
||||||
|
self.processed_chunks = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Test function
|
||||||
|
def _test_chunked_attention():
|
||||||
|
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("Test: Chunked attention vs flash_attn_func (non-causal)")
|
||||||
|
print("=" * 70)
|
||||||
|
print("Splitting K,V into chunks, computing attention per chunk, then merging")
|
||||||
|
print()
|
||||||
|
|
||||||
|
for dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
for num_chunks in [64, 128, 256]:
|
||||||
|
for batch, seqlen, nheads, headdim in [
|
||||||
|
(1, 1024, 32, 128),
|
||||||
|
(1, 2048, 32, 128),
|
||||||
|
(1, 4096, 32, 128),
|
||||||
|
(1, 8192, 32, 128),
|
||||||
|
]:
|
||||||
|
# Generate random Q, K, V
|
||||||
|
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
# Reference: full attention (non-causal)
|
||||||
|
out_ref = flash_attn_func(q, k, v, causal=False)
|
||||||
|
|
||||||
|
# Chunked attention: split K, V into chunks
|
||||||
|
chunk_size = seqlen // num_chunks
|
||||||
|
accumulated_o = None
|
||||||
|
accumulated_lse = None
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
start = i * chunk_size
|
||||||
|
end = (i + 1) * chunk_size
|
||||||
|
|
||||||
|
k_chunk = k[:, start:end, :, :]
|
||||||
|
v_chunk = v[:, start:end, :, :]
|
||||||
|
|
||||||
|
# Q attends to this K,V chunk (non-causal)
|
||||||
|
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||||
|
q, k_chunk, v_chunk, causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if accumulated_o is None:
|
||||||
|
accumulated_o = chunk_o
|
||||||
|
accumulated_lse = chunk_lse
|
||||||
|
else:
|
||||||
|
# Merge with previous chunks
|
||||||
|
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||||
|
accumulated_o, accumulated_lse,
|
||||||
|
chunk_o, chunk_lse
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
out_diff = (out_ref - accumulated_o).abs()
|
||||||
|
out_max_diff = out_diff.max().item()
|
||||||
|
out_mean_diff = out_diff.mean().item()
|
||||||
|
|
||||||
|
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
|
||||||
|
print(
|
||||||
|
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
|
||||||
|
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
|
||||||
|
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 70)
|
||||||
|
print("Test completed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_chunked_attention()
|
||||||
1167
nanovllm/ops/xattn.py
Normal file
1167
nanovllm/ops/xattn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,9 +14,9 @@ class Context:
|
|||||||
context_lens: torch.Tensor | None = None
|
context_lens: torch.Tensor | None = None
|
||||||
block_tables: torch.Tensor | None = None
|
block_tables: torch.Tensor | None = None
|
||||||
|
|
||||||
# Sparse prefill attention support (GPU-only path)
|
# Attention policy support (GPU-only path)
|
||||||
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
|
# When set, uses policy.compute_prefill() instead of FlashAttention
|
||||||
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
|
attention_policy: Any = None # AttentionPolicy instance
|
||||||
|
|
||||||
|
|
||||||
_CONTEXT = Context()
|
_CONTEXT = Context()
|
||||||
@@ -35,7 +35,7 @@ def set_context(
|
|||||||
slot_mapping=None,
|
slot_mapping=None,
|
||||||
context_lens=None,
|
context_lens=None,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
sparse_prefill_policy=None,
|
attention_policy=None,
|
||||||
):
|
):
|
||||||
global _CONTEXT
|
global _CONTEXT
|
||||||
_CONTEXT = Context(
|
_CONTEXT = Context(
|
||||||
@@ -47,7 +47,7 @@ def set_context(
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
sparse_prefill_policy=sparse_prefill_policy,
|
attention_policy=attention_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
406
notes.md
406
notes.md
@@ -1,324 +1,130 @@
|
|||||||
# Notes: Sparsity Integration into Layerwise Offload
|
# Notes: SparsePolicy Refactoring Research
|
||||||
|
|
||||||
## Current Architecture Analysis
|
## Sources
|
||||||
|
|
||||||
### GPU-Only Path vs Offload Path
|
### Source 1: tzj/minference branch - policy.py
|
||||||
|
- 路径: `nanovllm/kvcache/sparse/policy.py`
|
||||||
|
- 关键设计:
|
||||||
|
- `PolicyContext` 数据类包含 query_chunk_idx, num_query_chunks, layer_id, query, is_prefill 等
|
||||||
|
- `select_blocks()` 需要 offload_engine 参数
|
||||||
|
- `compute_chunked_prefill()` 和 `compute_chunked_decode()` 是完整的 attention 流程
|
||||||
|
- `on_prefill_offload()` / `on_decode_offload()` hooks 用于收集元数据
|
||||||
|
|
||||||
| Aspect | GPU-Only | Layerwise Offload |
|
### Source 2: tzj/minference branch - full_policy.py
|
||||||
|--------|----------|-------------------|
|
- 路径: `nanovllm/kvcache/sparse/full_policy.py`
|
||||||
| KV Storage | GPU blocks (paged) | CPU pinned + GPU ring buffer |
|
- 关键实现:
|
||||||
| Prefill | All layers → then attention | Per-layer: attention → offload |
|
- `compute_chunked_prefill()` 内部使用 ring buffer pipeline 加载 blocks
|
||||||
| Decode | FlashAttn with block table | Ring buffer H2D → FlashAttn |
|
- 使用 `flash_attn_with_lse` 和 `merge_attention_outputs` 合并多个 chunk 的 attention
|
||||||
| Sparse Support | MInference via `attention.py` | Not integrated |
|
- `compute_chunked_decode()` 处理 prefilled blocks + decode buffer
|
||||||
|
|
||||||
### MInference Flow (GPU-Only)
|
### Source 3: tzj/layer-offload branch - model_runner.py
|
||||||
|
- 路径: `nanovllm/engine/model_runner.py`
|
||||||
|
- 关键设计:
|
||||||
|
- `run_layerwise_offload_prefill()` 逐层处理,每层计算完整 attention
|
||||||
|
- `sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)` 简单接口
|
||||||
|
- FULL policy 通过 `if sparse_prefill_policy is None` 走 else 分支
|
||||||
|
|
||||||
```
|
### Source 4: tzj/layer-offload branch - xattn.py
|
||||||
attention.py:101-105:
|
- 路径: `nanovllm/kvcache/sparse/xattn.py`
|
||||||
if context.sparse_prefill_policy is not None:
|
- 关键实现:
|
||||||
o = context.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
|
- `sparse_prefill_attention()` 直接使用 FlashAttention(因为 chunked prefill 架构限制)
|
||||||
|
- 保留 Triton kernels 供未来 GPU-only 模式
|
||||||
|
|
||||||
minference.py:sparse_prefill_attention():
|
## Synthesized Findings
|
||||||
1. estimate_pattern(q, k, layer_id) -> vertical_indices, slash_indices
|
|
||||||
2. _triton_mixed_sparse_attention(q, k, v, indices)
|
|
||||||
3. return output
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quest Flow (GPU Block Mode)
|
### 架构差异总结
|
||||||
|
|
||||||
```
|
| 方面 | Chunked Offload | Layerwise Offload |
|
||||||
hybrid_manager.py (if using CPU offload with Quest):
|
|------|-----------------|-------------------|
|
||||||
select_blocks(available_blocks, ctx) -> selected block IDs
|
| **Prefill 流程** | chunk-by-chunk,跨层 | layer-by-layer,完整序列 |
|
||||||
-> load selected blocks to GPU
|
| **KV 存储** | 每 chunk 立即 offload | 每层计算后 offload |
|
||||||
-> standard FlashAttn with loaded blocks
|
| **Attention 计算** | 分多次计算+合并 | 一次完整计算 |
|
||||||
```
|
| **Block 加载** | 需要从 CPU 加载历史 | 不需要,已在 GPU |
|
||||||
|
| **Policy 责任** | 完整 attention 流程 | 仅 attention kernel 选择 |
|
||||||
|
|
||||||
### Layerwise Offload Prefill Flow
|
### Layerwise Offload 的简化点
|
||||||
|
|
||||||
```
|
1. **不需要 block selection**: 整层 KV 都在 GPU,无需选择
|
||||||
model_runner.py:run_layerwise_offload_prefill():
|
2. **不需要 offload_engine 参数**: Policy 不负责加载 KV
|
||||||
for layer_id in range(num_layers):
|
3. **不需要 merge_attention_outputs**: 一次计算完整 attention
|
||||||
# QKV projection
|
4. **不需要 offload hooks**: offload 在 model_runner 统一处理
|
||||||
q, k, v = qkv_proj(hidden_ln)
|
|
||||||
|
|
||||||
# RoPE
|
### 设计建议
|
||||||
q, k = rotary_emb(positions, q, k)
|
|
||||||
|
|
||||||
# FULL attention (no sparsity!)
|
1. **保持接口简单**: 只需要 `compute_prefill_attention()` 和 `compute_decode_attention()`
|
||||||
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
2. **FULL 也实现方法**: 不再通过 `is None` 判断,所有 policy 统一调用
|
||||||
|
3. **移除不必要的参数**: 不需要 offload_engine, kvcache_manager, seq 等
|
||||||
|
4. **统一命名**: 使用 `compute_*_attention` 而不是 `sparse_prefill_attention`
|
||||||
|
|
||||||
# MLP
|
## Code Examples
|
||||||
hidden_states = mlp(attn_out + residual)
|
|
||||||
|
|
||||||
# Sync offload ALL k, v to CPU
|
### 当前调用方式 (model_runner.py:876-891)
|
||||||
for block_id in cpu_block_ids:
|
|
||||||
k_cache_cpu[layer_id, block_id].copy_(k[start:end])
|
|
||||||
v_cache_cpu[layer_id, block_id].copy_(v[start:end])
|
|
||||||
```
|
|
||||||
|
|
||||||
### Layerwise Offload Decode Flow
|
|
||||||
|
|
||||||
```
|
|
||||||
model_runner.py:run_layerwise_offload_decode():
|
|
||||||
# Preload first N layers to ring buffer
|
|
||||||
for i in range(num_buffers):
|
|
||||||
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
|
|
||||||
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
current_buffer = layer_id % num_buffers
|
|
||||||
|
|
||||||
# Wait for buffer load
|
|
||||||
offload_engine.wait_buffer_load(current_buffer)
|
|
||||||
|
|
||||||
# Get prefilled KV from ring buffer (ALL blocks loaded)
|
|
||||||
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
|
|
||||||
|
|
||||||
# QKV for new token
|
|
||||||
q, k_new, v_new = qkv_proj(hidden_ln)
|
|
||||||
|
|
||||||
# Concat and full attention
|
|
||||||
k_full = torch.cat([k_prefill, k_decode_prev, k_new])
|
|
||||||
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...)
|
|
||||||
|
|
||||||
# Start loading next layer
|
|
||||||
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Integration Points
|
|
||||||
|
|
||||||
### 1. Prefill Sparse Integration Point
|
|
||||||
|
|
||||||
**Location:** `model_runner.py:535-543`
|
|
||||||
|
|
||||||
**Current:**
|
|
||||||
```python
|
```python
|
||||||
attn_output = flash_attn_varlen_func(
|
# Sparse or Full attention
|
||||||
q, k, v,
|
if self.sparse_prefill_policy is not None:
|
||||||
cu_seqlens_q=cu_seqlens,
|
# MInference or other sparse prefill policy
|
||||||
cu_seqlens_k=cu_seqlens,
|
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||||
max_seqlen_q=total_tokens,
|
q, k, v, layer_id
|
||||||
max_seqlen_k=total_tokens,
|
)
|
||||||
softmax_scale=layer.self_attn.attn.scale,
|
else:
|
||||||
causal=True,
|
# Full attention using FlashAttention
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v, ...
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 建议的新调用方式
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 所有 policy 统一调用
|
||||||
|
attn_output = self.attention_policy.compute_prefill_attention(
|
||||||
|
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
**After Integration:**
|
## Questions Resolved
|
||||||
```python
|
|
||||||
if self.sparse_policy and self.sparse_policy.supports_offload_prefill:
|
- Q: 是否需要 PolicyContext?
|
||||||
attn_output, k_sparse, v_sparse = self.sparse_policy.offload_prefill_attention(
|
- A: 可以简化,因为 layerwise 模式下不需要 chunk 信息
|
||||||
q, k, v, layer_id
|
|
||||||
)
|
- Q: decode 阶段如何处理?
|
||||||
k_to_offload = k_sparse if k_sparse is not None else k
|
- A: **Decode 不需要 policy**!当前 `run_layerwise_offload_decode()` 使用标准 `layer(positions, hidden_states, residual)` 调用,走 Attention.forward() 路径
|
||||||
v_to_offload = v_sparse if v_sparse is not None else v
|
|
||||||
else:
|
- Q: 为什么 decode 不需要 sparse?
|
||||||
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
- A: 因为 decode 每次只有 1 个 token,没有稀疏化的意义。KV 从 ring buffer 加载后直接用 flash_attn_with_kvcache
|
||||||
k_to_offload, v_to_offload = k, v
|
|
||||||
|
## Key Insight
|
||||||
|
|
||||||
|
**Layerwise Offload 的 Policy 设计应该只关注 Prefill**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Prefill: 需要 Policy
|
||||||
|
- 整个序列一次计算 attention
|
||||||
|
- 可以使用 sparse attention 方法(如 MInference 的 vertical+slash pattern)
|
||||||
|
- Policy 接收 q, k, v, layer_id, softmax_scale
|
||||||
|
|
||||||
|
Decode: 不需要 Policy
|
||||||
|
- 每次只有 1 个 token query
|
||||||
|
- KV 从 ring buffer 加载
|
||||||
|
- 使用标准 flash_attn_with_kvcache
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Decode Sparse Integration Point
|
## Interface Comparison Summary
|
||||||
|
|
||||||
**Location:** `model_runner.py:636-637` and `model_runner.py:704-706`
|
| 方面 | tzj/minference | tzj/layer-offload (新设计) |
|
||||||
|
|------|----------------|---------------------------|
|
||||||
|
| 类名 | SparsePolicy | AttentionPolicy |
|
||||||
|
| Prefill 方法 | compute_chunked_prefill() | compute_attention() |
|
||||||
|
| Decode 方法 | compute_chunked_decode() | 不需要(用标准路径) |
|
||||||
|
| 需要 offload_engine | 是 | 否 |
|
||||||
|
| 需要 kvcache_manager | 是 | 否 |
|
||||||
|
| 需要 seq | 是 | 否 |
|
||||||
|
| 支持 FULL | 是 | 是 |
|
||||||
|
|
||||||
**Current (preload):**
|
## Migration Path
|
||||||
```python
|
|
||||||
for i in range(num_preload):
|
|
||||||
offload_engine.load_layer_kv_to_buffer(
|
|
||||||
i, i, cpu_block_table, valid_tokens_per_block
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**After Integration:**
|
1. 保留 `SparsePolicy` 作为 `AttentionPolicy` 的别名
|
||||||
```python
|
2. 保留 `PolicyContext` 供未来扩展
|
||||||
for i in range(num_preload):
|
3. 保留 `select_blocks()` 方法签名(虽然不使用)
|
||||||
layer_to_load = i
|
4. 移除 `requires_block_selection` 属性(不需要)
|
||||||
if self.sparse_policy and self.sparse_policy.supports_offload_decode:
|
|
||||||
# Prepare q for this layer (need to compute ahead)
|
|
||||||
# OR: use previous layer's pattern as estimate
|
|
||||||
selected_blocks = self.sparse_policy.select_offload_blocks(
|
|
||||||
None, # q not available yet at preload
|
|
||||||
layer_to_load,
|
|
||||||
cpu_block_table,
|
|
||||||
valid_tokens_per_block
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
selected_blocks = cpu_block_table
|
|
||||||
offload_engine.load_sparse_layer_kv_to_buffer(
|
|
||||||
i, layer_to_load, selected_blocks, valid_tokens_per_block
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Challenge:** Q is not available during preload phase!
|
|
||||||
|
|
||||||
**Solutions:**
|
|
||||||
1. Skip sparse preload, only sparse for non-preloaded layers
|
|
||||||
2. Use previous decode step's pattern as estimate
|
|
||||||
3. Add preload hook to sparse policy
|
|
||||||
|
|
||||||
### 3. Offload Engine Extension
|
|
||||||
|
|
||||||
**New Method in OffloadEngine:**
|
|
||||||
|
|
||||||
```python
|
|
||||||
def load_sparse_layer_kv_to_buffer(
|
|
||||||
self,
|
|
||||||
buffer_idx: int,
|
|
||||||
layer_id: int,
|
|
||||||
selected_cpu_block_ids: List[int],
|
|
||||||
original_valid_tokens: List[int],
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Load only selected blocks from CPU to buffer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total tokens loaded (may be less than full sequence)
|
|
||||||
"""
|
|
||||||
stream = self.layer_load_streams[buffer_idx]
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
stream.wait_event(self.buffer_compute_done_events[buffer_idx])
|
|
||||||
|
|
||||||
# Build mapping: original block -> selected position
|
|
||||||
offset = 0
|
|
||||||
for i, cpu_block_id in enumerate(selected_cpu_block_ids):
|
|
||||||
# Find original index to get valid tokens
|
|
||||||
valid_tokens = original_valid_tokens[i] # Need mapping
|
|
||||||
|
|
||||||
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
# ... v_cache same
|
|
||||||
|
|
||||||
offset += valid_tokens
|
|
||||||
|
|
||||||
self.buffer_load_events[buffer_idx].record(stream)
|
|
||||||
|
|
||||||
return offset # Caller needs to know actual loaded tokens
|
|
||||||
```
|
|
||||||
|
|
||||||
## Metadata Flow for Quest
|
|
||||||
|
|
||||||
### During Prefill Offload
|
|
||||||
|
|
||||||
**Current:** No metadata collection in offload path
|
|
||||||
|
|
||||||
**Required:** Call `on_prefill_offload()` for each block
|
|
||||||
|
|
||||||
```python
|
|
||||||
# In run_layerwise_offload_prefill()
|
|
||||||
for i, cpu_block_id in enumerate(cpu_block_ids):
|
|
||||||
start = i * block_size
|
|
||||||
end = min(start + block_size, total_tokens)
|
|
||||||
actual_size = end - start
|
|
||||||
|
|
||||||
# BEFORE offload: update Quest metadata
|
|
||||||
if self.sparse_policy and hasattr(self.sparse_policy, 'on_prefill_offload'):
|
|
||||||
self.sparse_policy.on_prefill_offload(
|
|
||||||
cpu_block_id, layer_id, k[start:end], actual_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Offload
|
|
||||||
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
|
|
||||||
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quest Metadata Shape
|
|
||||||
|
|
||||||
```python
|
|
||||||
# BlockMetadataManager
|
|
||||||
key_min: [num_blocks, num_layers, num_kv_heads, head_dim] # Min key per block per layer
|
|
||||||
key_max: [num_blocks, num_layers, num_kv_heads, head_dim] # Max key per block per layer
|
|
||||||
```
|
|
||||||
|
|
||||||
**Memory:** 2 * num_blocks * num_layers * kv_heads * head_dim * 2 bytes
|
|
||||||
- Example: 1000 blocks * 28 layers * 4 heads * 128 dim * 2 * 2 = ~57 MB
|
|
||||||
|
|
||||||
## Performance Considerations
|
|
||||||
|
|
||||||
### MInference Prefill Overhead
|
|
||||||
|
|
||||||
| Operation | Time (64K seq) |
|
|
||||||
|-----------|----------------|
|
|
||||||
| Pattern estimation (last-64) | ~5ms |
|
|
||||||
| Triton sparse attention | ~80ms |
|
|
||||||
| Full FlashAttention | ~100ms |
|
|
||||||
| **Net Speedup** | ~15-20% |
|
|
||||||
|
|
||||||
### Quest Decode Overhead
|
|
||||||
|
|
||||||
| Operation | Time |
|
|
||||||
|-----------|------|
|
|
||||||
| Block scoring (GPU metadata) | ~0.1ms |
|
|
||||||
| Top-K selection | ~0.05ms |
|
|
||||||
| Sparse H2D load (8 blocks) | ~2ms |
|
|
||||||
| Full H2D load (100 blocks) | ~20ms |
|
|
||||||
| **Net Speedup** | ~10x H2D |
|
|
||||||
|
|
||||||
### Memory Trade-offs
|
|
||||||
|
|
||||||
| Mode | GPU Memory | CPU Memory | H2D Bandwidth |
|
|
||||||
|------|------------|------------|---------------|
|
|
||||||
| Full offload | Ring buffer | Full KV | High |
|
|
||||||
| Sparse offload | Ring buffer | Full KV | Low (subset) |
|
|
||||||
| Aggressive sparse | Ring buffer | Sparse KV | Very low |
|
|
||||||
|
|
||||||
## Edge Cases
|
|
||||||
|
|
||||||
### 1. Short Sequences (< sparse threshold)
|
|
||||||
|
|
||||||
```python
|
|
||||||
if total_tokens < sparse_threshold:
|
|
||||||
# Fall back to full attention
|
|
||||||
use_sparse = False
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. First Decode Step (no previous Q)
|
|
||||||
|
|
||||||
Quest can't score blocks without Q. Options:
|
|
||||||
- Use average embedding as proxy
|
|
||||||
- Load all blocks for first step
|
|
||||||
- Use prefill pattern as estimate
|
|
||||||
|
|
||||||
### 3. Variable Sequence Lengths in Batch
|
|
||||||
|
|
||||||
Layerwise offload currently only supports batch_size=1:
|
|
||||||
```python
|
|
||||||
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
|
|
||||||
```
|
|
||||||
|
|
||||||
Sparse integration should maintain this constraint.
|
|
||||||
|
|
||||||
### 4. Ring Buffer vs Sparse Load Mismatch
|
|
||||||
|
|
||||||
Ring buffer assumes fixed `total_prefill_tokens`:
|
|
||||||
```python
|
|
||||||
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, total_prefill_tokens)
|
|
||||||
```
|
|
||||||
|
|
||||||
Sparse load has variable token count. Need:
|
|
||||||
```python
|
|
||||||
# Track actual loaded tokens per buffer
|
|
||||||
loaded_tokens[buffer_idx] = sparse_load_count
|
|
||||||
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, loaded_tokens[buffer_idx])
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing Strategy
|
|
||||||
|
|
||||||
### Unit Tests
|
|
||||||
|
|
||||||
1. `test_sparse_policy_interface.py` - Verify new interface methods
|
|
||||||
2. `test_minference_offload.py` - MInference in offload mode
|
|
||||||
3. `test_quest_offload.py` - Quest block selection in offload mode
|
|
||||||
|
|
||||||
### Integration Tests
|
|
||||||
|
|
||||||
1. `test_offload_sparse_e2e.py` - Full prefill+decode with sparsity
|
|
||||||
2. `test_accuracy_comparison.py` - Compare outputs: full vs sparse
|
|
||||||
|
|
||||||
### Benchmarks
|
|
||||||
|
|
||||||
1. `bench_offload_sparse.py` - Compare:
|
|
||||||
- Full offload (baseline)
|
|
||||||
- MInference prefill + Quest decode
|
|
||||||
- Aggressive sparse offload
|
|
||||||
|
|||||||
155
progress.md
155
progress.md
@@ -1,155 +0,0 @@
|
|||||||
# Progress Log: nanovllm 多请求状态污染问题
|
|
||||||
|
|
||||||
## Session: 2026-01-12
|
|
||||||
|
|
||||||
### 资源分配
|
|
||||||
|
|
||||||
| 资源 | 分配 |
|
|
||||||
|------|------|
|
|
||||||
| **GPU** | **1** (严格限制,不可更改) |
|
|
||||||
|
|
||||||
### 任务目标
|
|
||||||
研究 nanovllm CPU offload 模式下多请求之间状态影响导致准确率下降的问题。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 10:00 - 启动分析
|
|
||||||
|
|
||||||
**完成**:
|
|
||||||
- [x] 读取 `docs/offload_accuracy_issue.md` 了解问题背景
|
|
||||||
- [x] 激活 Serena MCP 项目
|
|
||||||
- [x] 获取关键组件符号概览
|
|
||||||
|
|
||||||
**关键文件已分析**:
|
|
||||||
- `nanovllm/kvcache/offload_engine.py` - OffloadEngine 类
|
|
||||||
- `nanovllm/kvcache/hybrid_manager.py` - HybridKVCacheManager 类
|
|
||||||
- `nanovllm/engine/model_runner.py` - ModelRunner 类
|
|
||||||
- `nanovllm/engine/llm_engine.py` - LLMEngine 类
|
|
||||||
- `nanovllm/engine/scheduler.py` - Scheduler 类
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 10:15 - 深入代码分析
|
|
||||||
|
|
||||||
**分析的方法**:
|
|
||||||
|
|
||||||
| 方法 | 文件 | 发现 |
|
|
||||||
|------|------|------|
|
|
||||||
| `OffloadEngine.__init__` | offload_engine.py:40-145 | 初始化所有 buffer,无 reset 方法 |
|
|
||||||
| `deallocate` | hybrid_manager.py:218-244 | 只清理逻辑块,不清理 OffloadEngine |
|
|
||||||
| `clear_decode_tracking` | hybrid_manager.py:538-549 | 清理 tracking 字典,但未被调用 |
|
|
||||||
| `run_layerwise_offload_decode` | model_runner.py:867-1057 | 包含 decode buffer 读写逻辑 |
|
|
||||||
| `generate` | llm_engine.py:114-151 | 请求循环逻辑 |
|
|
||||||
| `postprocess` | scheduler.py:93-99 | 调用 deallocate |
|
|
||||||
|
|
||||||
**关键发现 #1**: OffloadEngine 没有 reset() 方法
|
|
||||||
|
|
||||||
**关键发现 #2**: deallocate() 没有调用 clear_decode_tracking()
|
|
||||||
|
|
||||||
**关键发现 #3**: decode_buffer 在请求间不清理,可能导致状态污染
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 10:30 - 根因定位
|
|
||||||
|
|
||||||
**确认的问题**:
|
|
||||||
|
|
||||||
1. **decode buffer 残留**
|
|
||||||
- 位置: `offload_engine.decode_k_buffer`, `decode_v_buffer`
|
|
||||||
- 写入: `model_runner.py:1010-1013`
|
|
||||||
- 读取: `model_runner.py:969-976`
|
|
||||||
- 问题: 旧请求的 KV 数据可能被新请求读取
|
|
||||||
|
|
||||||
2. **tracking 字典未清理**
|
|
||||||
- 位置: `hybrid_manager._decode_start_pos`, `_prefill_len`
|
|
||||||
- 问题: 使用 `id(seq)` 作为 key,可能重用
|
|
||||||
|
|
||||||
3. **缺失的清理调用**
|
|
||||||
- `clear_decode_tracking()` 在 `deallocate()` 中未被调用
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 10:45 - 创建规划文件
|
|
||||||
|
|
||||||
**创建的文件**:
|
|
||||||
- [x] `task_plan.md` - 完整的任务规划和阶段
|
|
||||||
- [x] `findings.md` - 详细的代码分析发现
|
|
||||||
- [x] `progress.md` - 本文件
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 11:00 - Sequential Thinking 深入分析
|
|
||||||
|
|
||||||
**使用 sequential thinking 验证分析结果**:
|
|
||||||
- 确认 deallocate() 确实没有调用 clear_decode_tracking()
|
|
||||||
- 分析 _decode_start_pos 和 _prefill_len 字典的生命周期
|
|
||||||
- 确定 id(seq) 重用是问题的触发条件
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 11:15 - 完成规划文件
|
|
||||||
|
|
||||||
**更新的文件**:
|
|
||||||
- [x] `task_plan.md` - 添加完整的 debug 方案和实施计划
|
|
||||||
- [x] `findings.md` - 详细的代码分析和修复方向
|
|
||||||
- [x] `progress.md` - 更新到当前进度
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 下一步 (待用户确认)
|
|
||||||
|
|
||||||
**执行顺序**:
|
|
||||||
|
|
||||||
1. **实施修复** - 修改 `deallocate()` 添加 `clear_decode_tracking(seq)`
|
|
||||||
2. **快速验证** - 20 样本连续执行(一次调用,不重启框架)→ 目标 20/20
|
|
||||||
3. **完整验证** - 100 样本 → 目标 100/100 (最终验收)
|
|
||||||
4. **防御性修复** (可选) - 添加 `OffloadEngine.on_sequence_finished()`
|
|
||||||
|
|
||||||
**核心修改** (一行代码):
|
|
||||||
```python
|
|
||||||
# hybrid_manager.py:deallocate() 末尾添加
|
|
||||||
self.clear_decode_tracking(seq)
|
|
||||||
```
|
|
||||||
|
|
||||||
**验收标准**:
|
|
||||||
| 测试 | 样本数 | 通过要求 |
|
|
||||||
|------|--------|----------|
|
|
||||||
| 快速验证 | 20 | 20/20 (100%) |
|
|
||||||
| 完整验证 | 100 | 100/100 (100%) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 错误记录
|
|
||||||
|
|
||||||
| 时间 | 错误 | 解决方案 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 10:05 | Serena MCP 未激活 | 调用 activate_project |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 文件修改记录
|
|
||||||
|
|
||||||
| 文件 | 操作 | 状态 |
|
|
||||||
|------|------|------|
|
|
||||||
| task_plan.md | 创建+更新 | 完成 |
|
|
||||||
| findings.md | 创建 | 完成 |
|
|
||||||
| progress.md | 创建+更新 | 完成 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 分析结论
|
|
||||||
|
|
||||||
**重要澄清**: nanovllm offload 模式**不支持 batch**,只能单个 request 顺序执行。问题出在**请求切换**时状态清理不完整。
|
|
||||||
|
|
||||||
**根本原因已确认**: `deallocate()` 没有调用 `clear_decode_tracking()`,导致 `_decode_start_pos` 和 `_prefill_len` 字典残留,当 Python 对象 ID 重用时,新请求会错误地使用旧请求的配置。
|
|
||||||
|
|
||||||
**修复方案已设计**: 在 `deallocate()` 末尾添加 `self.clear_decode_tracking(seq)` 调用。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 关键理解
|
|
||||||
|
|
||||||
问题不是 "batch 处理",而是:
|
|
||||||
```
|
|
||||||
Request A 完成 → deallocate(A) [状态未完全清理] → Request B 开始 → B 读到 A 的残留状态
|
|
||||||
```
|
|
||||||
874
task_plan.md
874
task_plan.md
@@ -1,359 +1,549 @@
|
|||||||
# Task Plan: nanovllm CPU Offload 多请求状态污染问题
|
# Task Plan: Refactor SparsePolicy for Layerwise Offload
|
||||||
|
|
||||||
## 问题概述
|
## Goal
|
||||||
|
重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy,并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。
|
||||||
|
|
||||||
**重要说明**: nanovllm offload 模式目前**不支持 batch**,只能单个 request 顺序执行。问题出在**请求切换**时的状态清理。
|
## Background
|
||||||
|
|
||||||
| 模式 | 测试方式 | 准确率 |
|
### 两种 Offload 架构对比
|
||||||
|------|----------|--------|
|
|
||||||
| CPU Offload | 独立进程 (每请求一个进程) | **100%** |
|
|
||||||
| CPU Offload | 同进程顺序多请求 | 66% |
|
|
||||||
| Non-Offload | 同进程顺序多请求 | 100% |
|
|
||||||
|
|
||||||
**结论**: 单请求推理正确,问题在于**请求切换**时状态清理不完整。
|
| 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) |
|
||||||
|
|------|----------------------------------|---------------------------------------|
|
||||||
|
| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) |
|
||||||
|
| KV 位置 | 历史 chunks 在 CPU,需要加载 | 整层 KV 都在 GPU |
|
||||||
|
| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` |
|
||||||
|
| 需要 offload_engine | 是(加载 blocks) | 否(KV 已在 GPU) |
|
||||||
|
| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask |
|
||||||
|
|
||||||
|
### tzj/minference 的 Policy 接口
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy(ABC):
|
||||||
|
supports_prefill: bool
|
||||||
|
supports_decode: bool
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor
|
||||||
|
```
|
||||||
|
|
||||||
|
### 当前 branch 的 Policy 接口(重构前)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy(ABC):
|
||||||
|
supports_prefill: bool
|
||||||
|
supports_decode: bool
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_blocks(self, available_blocks, ctx) -> List[int]
|
||||||
|
|
||||||
|
def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor
|
||||||
|
```
|
||||||
|
|
||||||
|
## Phases
|
||||||
|
|
||||||
|
- [x] Phase 1: 分析差异并设计新接口
|
||||||
|
- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过
|
||||||
|
- [ ] Phase 2: 重构 AttentionPolicy 基类
|
||||||
|
- [ ] Phase 3: 重构 FullAttentionPolicy
|
||||||
|
- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
|
||||||
|
- [ ] Phase 5: 更新 model_runner 调用方式
|
||||||
|
- [ ] Phase 6: 测试验证
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Phase 1: 代码分析 (complete)
|
## Phase 0: 创建 nanovllm.ops 模块
|
||||||
|
|
||||||
### 1.1 识别状态管理组件
|
### 目标
|
||||||
|
从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。
|
||||||
**已分析的关键组件**:
|
|
||||||
|
### 步骤
|
||||||
| 组件 | 文件 | 状态数据 |
|
|
||||||
|------|------|----------|
|
1. **创建目录结构**
|
||||||
| `OffloadEngine` | `nanovllm/kvcache/offload_engine.py` | ring buffer, decode buffer, CUDA events |
|
```
|
||||||
| `HybridKVCacheManager` | `nanovllm/kvcache/hybrid_manager.py` | logical blocks, prefilled_blocks, _decode_start_pos, _prefill_len |
|
nanovllm/ops/
|
||||||
| `LLMEngine` | `nanovllm/engine/llm_engine.py` | generate() 循环,请求生命周期 |
|
├── __init__.py
|
||||||
| `Scheduler` | `nanovllm/engine/scheduler.py` | postprocess() 调用 deallocate() |
|
├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels
|
||||||
|
└── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用)
|
||||||
### 1.2 请求生命周期分析
|
```
|
||||||
|
|
||||||
```
|
2. **从 tzj/minference 提取文件**
|
||||||
generate()
|
```bash
|
||||||
→ 多个请求添加到 scheduler
|
git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
|
||||||
→ while not finished:
|
git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
|
||||||
→ schedule() 获取下一批 seqs
|
git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
|
||||||
→ model_runner.run() 执行推理
|
```
|
||||||
→ postprocess() 处理完成的请求
|
|
||||||
→ 如果完成: kvcache_manager.deallocate(seq)
|
3. **Cherry-pick 测试文件**
|
||||||
```
|
```bash
|
||||||
|
git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
|
||||||
---
|
```
|
||||||
|
|
||||||
## Phase 2: 根本原因分析 (complete)
|
4. **运行测试验证**
|
||||||
|
```bash
|
||||||
### 2.1 核心问题: OffloadEngine 缺少 reset() 方法
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_chunked.py
|
||||||
**关键发现**: `OffloadEngine` 没有任何重置/清理方法!
|
```
|
||||||
|
|
||||||
当请求完成时,`HybridKVCacheManager.deallocate()` 被调用,但它只清理:
|
### nanovllm/ops 模块内容
|
||||||
- 逻辑块状态 (`block.reset()`)
|
|
||||||
- 物理块引用 (`free_cpu_blocks`, `cpu_block_to_logical`)
|
| 文件 | 核心函数 | 用途 |
|
||||||
- prefilled_blocks 集合
|
|
||||||
- _decode_start_pos / _prefill_len 字典
|
|
||||||
|
|
||||||
**未被清理的状态** (存在于 OffloadEngine):
|
|
||||||
|
|
||||||
| 状态 | Shape | 问题 |
|
|
||||||
|------|-------|------|
|
|
||||||
| `layer_k_cache` | [num_buffers, max_seq_len, kv_heads, head_dim] | 包含旧请求的 KV |
|
|
||||||
| `layer_v_cache` | [num_buffers, max_seq_len, kv_heads, head_dim] | 包含旧请求的 KV |
|
|
||||||
| `decode_k_buffer` | [num_layers, block_size, kv_heads, head_dim] | 包含旧请求的 decode KV |
|
|
||||||
| `decode_v_buffer` | [num_layers, block_size, kv_heads, head_dim] | 包含旧请求的 decode KV |
|
|
||||||
|
|
||||||
### 2.2 具体污染场景
|
|
||||||
|
|
||||||
在 `run_layerwise_offload_decode()` (model_runner.py:867-1057):
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 第 969-976 行: 读取之前的 decode KV
|
|
||||||
if num_prev_decode_tokens > 0:
|
|
||||||
k_decode_prev, v_decode_prev = offload_engine.get_decode_kv(
|
|
||||||
layer_id, decode_start_pos, pos_in_block
|
|
||||||
)
|
|
||||||
ring_k[...].copy_(k_decode_prev) # 可能读取旧请求的数据!
|
|
||||||
```
|
|
||||||
|
|
||||||
**场景**:
|
|
||||||
1. 请求 A (32K tokens) 完成,decode_buffer 保留其 KV 数据
|
|
||||||
2. 请求 B 开始,其 `decode_start_pos` 可能非零(如果继承了旧状态)
|
|
||||||
3. 请求 B 在第一个 decode step 时错误地读取了请求 A 的 decode buffer 数据
|
|
||||||
|
|
||||||
### 2.3 潜在问题点
|
|
||||||
|
|
||||||
1. **decode_start_pos 计算错误**:
|
|
||||||
- `get_decode_start_pos()` 使用 `id(seq)` 作为 key
|
|
||||||
- Python 对象 ID 可能在请求之间重用
|
|
||||||
- 如果新 seq 对象的 ID 与旧 seq 相同,可能错误继承旧的 start_pos
|
|
||||||
|
|
||||||
2. **decode buffer 残留数据**:
|
|
||||||
- 如果 `pos_in_block` 在新请求中与旧请求重叠
|
|
||||||
- `get_decode_kv()` 会返回旧请求的数据
|
|
||||||
|
|
||||||
3. **ring buffer 残留数据**:
|
|
||||||
- 虽然每次 decode 会从 CPU 加载,但 decode buffer 的数据会被复制过来
|
|
||||||
- 如果 decode buffer 有残留,会污染 ring buffer
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3: Debug 方案设计 (complete)
|
|
||||||
|
|
||||||
### 3.1 确认的根本原因
|
|
||||||
|
|
||||||
通过代码分析,确认了两个根本原因:
|
|
||||||
|
|
||||||
**根本原因 1 (主要)**: `deallocate()` 不调用 `clear_decode_tracking()`
|
|
||||||
- 位置: `hybrid_manager.py:218-244`
|
|
||||||
- 影响: `_decode_start_pos` 和 `_prefill_len` 字典残留
|
|
||||||
- 后果: 如果 `id(seq)` 重用,返回错误的 decode 配置
|
|
||||||
|
|
||||||
**根本原因 2 (次要)**: decode_buffer 不清理
|
|
||||||
- 位置: `offload_engine.py`
|
|
||||||
- 影响: `decode_k_buffer/v_buffer` 保留旧 KV
|
|
||||||
- 后果: 可能被根本原因 1 触发读取
|
|
||||||
|
|
||||||
### 3.2 Debug 方案 A: 验证字典残留 (推荐先做)
|
|
||||||
|
|
||||||
**目标**: 验证 `_decode_start_pos` 字典是否有残留
|
|
||||||
|
|
||||||
**诊断代码** (添加到 `hybrid_manager.py`):
|
|
||||||
```python
|
|
||||||
# 在 get_decode_start_pos() 开头添加
|
|
||||||
def get_decode_start_pos(self, seq: Sequence) -> int:
|
|
||||||
seq_id = id(seq)
|
|
||||||
# DEBUG: 检查是否命中旧值
|
|
||||||
if seq_id in self._decode_start_pos:
|
|
||||||
logger.warning(f"[DEBUG] get_decode_start_pos: CACHE HIT! seq_id={seq_id}, "
|
|
||||||
f"cached_value={self._decode_start_pos[seq_id]}, "
|
|
||||||
f"expected={(len(seq) - 1) % self._block_size}")
|
|
||||||
# ... 原有逻辑
|
|
||||||
```
|
|
||||||
|
|
||||||
**诊断代码** (添加到 `deallocate()` 末尾):
|
|
||||||
```python
|
|
||||||
def deallocate(self, seq: Sequence) -> None:
|
|
||||||
# ... 现有逻辑 ...
|
|
||||||
|
|
||||||
# DEBUG: 打印未清理的状态
|
|
||||||
seq_id = id(seq)
|
|
||||||
if seq_id in self._decode_start_pos:
|
|
||||||
logger.warning(f"[DEBUG] deallocate: _decode_start_pos NOT CLEARED! "
|
|
||||||
f"seq_id={seq_id}, value={self._decode_start_pos[seq_id]}")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.3 Debug 方案 B: 最小复现测试
|
|
||||||
|
|
||||||
**文件**: `tests/test_multi_request_offload_debug.py`
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""最小复现批量模式失败"""
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
sys.path.insert(0, os.getcwd())
|
|
||||||
|
|
||||||
from nanovllm import LLM
|
|
||||||
from nanovllm.sampling import SamplingParams
|
|
||||||
|
|
||||||
# 使用 RULER NIAH 的两个样本
|
|
||||||
PROMPTS = [
|
|
||||||
# Sample 0 (通常成功)
|
|
||||||
"...", # 从 niah_single_1_32k.jsonl 加载
|
|
||||||
# Sample 1 (通常失败)
|
|
||||||
"...",
|
|
||||||
]
|
|
||||||
EXPECTED = ["8930103", "4194548"]
|
|
||||||
|
|
||||||
def main():
|
|
||||||
llm = LLM(
|
|
||||||
"~/models/Llama-3.1-8B-Instruct",
|
|
||||||
max_model_len=33792,
|
|
||||||
max_num_batched_tokens=33792,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
num_gpu_blocks=4,
|
|
||||||
kvcache_block_size=1024,
|
|
||||||
enforce_eager=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
params = SamplingParams(temperature=0.1, max_tokens=50)
|
|
||||||
|
|
||||||
# 连续处理两个请求
|
|
||||||
for i, (prompt, expected) in enumerate(zip(PROMPTS, EXPECTED)):
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Sample {i}: Expected = {expected}")
|
|
||||||
|
|
||||||
# 打印关键状态
|
|
||||||
kvm = llm.model_runner.kvcache_manager
|
|
||||||
print(f" _decode_start_pos 字典大小: {len(kvm._decode_start_pos)}")
|
|
||||||
print(f" _prefill_len 字典大小: {len(kvm._prefill_len)}")
|
|
||||||
|
|
||||||
outputs = llm.generate([prompt], params, use_tqdm=False)
|
|
||||||
output_text = outputs[0]["text"]
|
|
||||||
|
|
||||||
passed = expected in output_text
|
|
||||||
print(f" Output: {output_text[:100]}...")
|
|
||||||
print(f" Status: {'PASS' if passed else 'FAIL'}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.4 Debug 方案 C: 快速修复验证
|
|
||||||
|
|
||||||
**目标**: 验证修复 `deallocate()` 是否解决问题
|
|
||||||
|
|
||||||
**修改** (`hybrid_manager.py:218-244`):
|
|
||||||
```python
|
|
||||||
def deallocate(self, seq: Sequence) -> None:
|
|
||||||
"""Release all blocks for a sequence."""
|
|
||||||
for logical_id in reversed(seq.block_table):
|
|
||||||
# ... 现有逻辑 ...
|
|
||||||
|
|
||||||
seq.num_cached_tokens = 0
|
|
||||||
seq.block_table.clear()
|
|
||||||
|
|
||||||
# === 新增: 清理 decode tracking ===
|
|
||||||
self.clear_decode_tracking(seq)
|
|
||||||
```
|
|
||||||
|
|
||||||
**验证命令**:
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--enable-offload \
|
|
||||||
--sample-indices 0,1,2,3,4 \
|
|
||||||
--verbose
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.5 Debug 方案 D: 添加 OffloadEngine 清理 (防御性)
|
|
||||||
|
|
||||||
**目标**: 进一步隔离请求状态
|
|
||||||
|
|
||||||
**添加方法** (`offload_engine.py`):
|
|
||||||
```python
|
|
||||||
def on_sequence_finished(self):
|
|
||||||
"""清理请求完成后的状态"""
|
|
||||||
# 清零 decode buffer (防止残留数据被读取)
|
|
||||||
self.decode_k_buffer.zero_()
|
|
||||||
self.decode_v_buffer.zero_()
|
|
||||||
logger.debug("OffloadEngine: decode buffer cleared")
|
|
||||||
```
|
|
||||||
|
|
||||||
**调用点** (`hybrid_manager.py:deallocate` 末尾):
|
|
||||||
```python
|
|
||||||
# 清理 OffloadEngine 状态
|
|
||||||
if self.offload_engine is not None:
|
|
||||||
self.offload_engine.on_sequence_finished()
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 4: 实施计划 (pending)
|
|
||||||
|
|
||||||
### 推荐执行顺序
|
|
||||||
|
|
||||||
1. **Step 4.1**: 实施修复
|
|
||||||
- 修改 `hybrid_manager.py:deallocate()` 添加 `clear_decode_tracking(seq)`
|
|
||||||
|
|
||||||
2. **Step 4.2**: 快速验证 (20 样本连续执行)
|
|
||||||
- **一次调用** `test_ruler_niah.py`,连续执行 20 个样本
|
|
||||||
- **不重启框架**,验证请求切换是否正确
|
|
||||||
- 目标: 20/20 全部通过
|
|
||||||
|
|
||||||
3. **Step 4.3**: 完整验证 (100 样本)
|
|
||||||
- 运行 100 个样本的 RULER NIAH 测试
|
|
||||||
- 目标: 100/100 全部通过 (准确率从 66% → 100%)
|
|
||||||
|
|
||||||
4. **Step 4.4**: 防御性修复 (可选)
|
|
||||||
- 添加 `OffloadEngine.on_sequence_finished()` 方法
|
|
||||||
- 清零 decode buffer 作为额外保险
|
|
||||||
|
|
||||||
### 具体修改
|
|
||||||
|
|
||||||
**文件 1**: `nanovllm/kvcache/hybrid_manager.py`
|
|
||||||
|
|
||||||
位置: `deallocate()` 方法末尾 (第 244 行后)
|
|
||||||
|
|
||||||
```python
|
|
||||||
def deallocate(self, seq: Sequence) -> None:
|
|
||||||
"""Release all blocks for a sequence."""
|
|
||||||
for logical_id in reversed(seq.block_table):
|
|
||||||
# ... 现有逻辑 (218-242 行) ...
|
|
||||||
|
|
||||||
seq.num_cached_tokens = 0
|
|
||||||
seq.block_table.clear()
|
|
||||||
|
|
||||||
# ============ 新增: 清理 decode tracking ============
|
|
||||||
self.clear_decode_tracking(seq)
|
|
||||||
```
|
|
||||||
|
|
||||||
**文件 2** (可选): `nanovllm/kvcache/offload_engine.py`
|
|
||||||
|
|
||||||
位置: 在类末尾添加新方法
|
|
||||||
|
|
||||||
```python
|
|
||||||
def on_sequence_finished(self):
|
|
||||||
"""清理请求完成后的状态 (防御性清理)"""
|
|
||||||
self.decode_k_buffer.zero_()
|
|
||||||
self.decode_v_buffer.zero_()
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 关键文件清单
|
|
||||||
|
|
||||||
| 文件 | 相关行号 | 说明 |
|
|
||||||
|------|----------|------|
|
|------|----------|------|
|
||||||
| `nanovllm/kvcache/hybrid_manager.py` | 218-244 | `deallocate()` - **需要修改** |
|
| `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation |
|
||||||
| `nanovllm/kvcache/hybrid_manager.py` | 538-549 | `clear_decode_tracking()` - 已存在 |
|
| `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 |
|
||||||
| `nanovllm/kvcache/hybrid_manager.py` | 485-505 | `get_decode_start_pos()` - 问题读取点 |
|
| `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM |
|
||||||
| `nanovllm/kvcache/hybrid_manager.py` | 519-537 | `get_prefill_len()` - 问题读取点 |
|
| `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum |
|
||||||
| `nanovllm/kvcache/offload_engine.py` | 40-145 | `__init__` - 状态初始化 |
|
| `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold |
|
||||||
| `nanovllm/kvcache/offload_engine.py` | (新增) | `on_sequence_finished()` - 可选防御 |
|
| `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output |
|
||||||
| `nanovllm/engine/model_runner.py` | 867-1057 | `run_layerwise_offload_decode()` |
|
| `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks |
|
||||||
| `nanovllm/engine/model_runner.py` | 969-976 | decode buffer 读取 (污染点) |
|
|
||||||
|
|
||||||
---
|
### 与 Policy 的关系
|
||||||
|
|
||||||
## 验证命令
|
```
|
||||||
|
XAttentionPolicy.estimate()
|
||||||
**指定 GPU: 1** (严格限制,不可更改)
|
└── 调用 nanovllm.ops.xattn.xattn_estimate()
|
||||||
|
├── flat_group_gemm_fuse_reshape() (Triton)
|
||||||
```bash
|
├── softmax_fuse_block_sum() (Triton)
|
||||||
# 快速验证 (20 样本连续执行,不重启框架)
|
└── find_blocks_chunked()
|
||||||
# 目标: 20/20 通过
|
|
||||||
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--enable-offload \
|
|
||||||
--sample-indices 0-19 \
|
|
||||||
--verbose
|
|
||||||
|
|
||||||
# 完整验证 (100 样本)
|
|
||||||
# 目标: 100/100 通过 (最终验收)
|
|
||||||
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--enable-offload \
|
|
||||||
--quiet
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**验收标准**:
|
|
||||||
| 测试 | 样本数 | 通过要求 | 说明 |
|
|
||||||
|------|--------|----------|------|
|
|
||||||
| 快速验证 | 20 | 20/20 (100%) | 一次调用,连续执行,验证请求切换 |
|
|
||||||
| 完整验证 | 100 | 100/100 (100%) | 最终验收 |
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 当前状态
|
## Key Questions
|
||||||
|
|
||||||
- [x] Phase 1: 代码分析
|
1. **`select_blocks` 改为什么?**
|
||||||
- [x] Phase 2: 根本原因分析
|
- 改名为 `estimate()`:用于计算 sparse mask
|
||||||
- [x] Phase 3: Debug 方案设计
|
- 对于 XAttention,对应 COMPASS 的 `xattn_estimate()` 函数
|
||||||
- [x] Phase 4: 实施计划 ✅ 100/100 PASSED
|
- FullAttentionPolicy 的 `estimate()` 返回 None(表示 full attention)
|
||||||
|
|
||||||
### 验证结果
|
2. **Policy 接口应该如何设计?**
|
||||||
|
- Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)`
|
||||||
|
- Decode: `compute_decode(q, k, v, layer_id, softmax_scale)`
|
||||||
|
- Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask
|
||||||
|
|
||||||
| 测试 | 结果 | 日期 |
|
3. **FULL policy 如何处理?**
|
||||||
|------|------|------|
|
- FULL 也实现 `compute_prefill/decode`,使用 FlashAttention
|
||||||
| 20 样本快速验证 | ✅ 20/20 (100%) | 2026-01-13 |
|
- `estimate()` 返回 None(表示不进行稀疏化)
|
||||||
| 100 样本完整验证 | ✅ 100/100 (100%) | 2026-01-13 |
|
|
||||||
|
## Proposed New Interface
|
||||||
|
|
||||||
|
```python
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPolicy(ABC):
|
||||||
|
"""Layerwise Offload 模式下的 Attention Policy
|
||||||
|
|
||||||
|
所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
|
||||||
|
支持 prefill 和 decode 两个阶段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_prefill: bool = True
|
||||||
|
supports_decode: bool = True
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||||
|
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
估算 sparse attention mask。
|
||||||
|
|
||||||
|
对于 sparse policy(如 XAttention),计算哪些 blocks 需要 attend。
|
||||||
|
对于 full policy,返回 None 表示使用完整 attention。
|
||||||
|
|
||||||
|
对应 COMPASS 的 xattn_estimate() 函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
|
||||||
|
"""
|
||||||
|
return None # 默认为 full attention
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||||
|
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
计算 prefill attention。
|
||||||
|
|
||||||
|
整层 KV 都在 GPU 上,一次计算完整 attention。
|
||||||
|
可以先调用 estimate() 获取 sparse mask,然后应用 block sparse attention。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [1, num_heads, head_dim]
|
||||||
|
k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
|
||||||
|
v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
计算 decode attention。
|
||||||
|
|
||||||
|
KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [1, num_heads, head_dim]
|
||||||
|
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
softmax_scale: Softmax scaling factor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [1, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
# 默认实现:使用 FlashAttention
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
context_len = k.shape[0]
|
||||||
|
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
max_seqlen_k=context_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset policy state between sequences."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
|
# 保留旧名称作为别名
|
||||||
|
SparsePolicy = AttentionPolicy
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Plan
|
||||||
|
|
||||||
|
### Phase 2: 重构 policy.py
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/policy.py
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPolicy(ABC):
|
||||||
|
"""Base class for attention policies in layerwise offload mode."""
|
||||||
|
|
||||||
|
supports_prefill: bool = True
|
||||||
|
supports_decode: bool = True
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Estimate sparse attention mask.
|
||||||
|
|
||||||
|
For sparse policies (e.g., XAttention), computes block-level importance.
|
||||||
|
For full policy, returns None.
|
||||||
|
|
||||||
|
Corresponds to xattn_estimate() in COMPASS.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sparse_mask: [num_heads, q_blocks, k_blocks] or None
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute prefill attention."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute decode attention (default: FlashAttention)."""
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
context_len = k.shape[0]
|
||||||
|
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
max_seqlen_k=context_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias
|
||||||
|
SparsePolicy = AttentionPolicy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 3: 重构 FullAttentionPolicy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/full_policy.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from .policy import AttentionPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class FullAttentionPolicy(AttentionPolicy):
|
||||||
|
"""Full attention using FlashAttention (no sparsity)."""
|
||||||
|
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
|
def estimate(self, q, k, layer_id):
|
||||||
|
"""Full attention - no sparse mask needed."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
seq_len = q.shape[0]
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "FullAttentionPolicy()"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 4: 重构 XAttentionPolicy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/xattn.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from .policy import AttentionPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class XAttentionPolicy(AttentionPolicy):
|
||||||
|
"""
|
||||||
|
XAttention sparse prefill policy.
|
||||||
|
|
||||||
|
Uses chunked estimation to compute sparse attention mask,
|
||||||
|
then applies block sparse attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stride: int = 8,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
block_size: int = 128,
|
||||||
|
chunk_size: int = 16384,
|
||||||
|
use_triton: bool = True,
|
||||||
|
):
|
||||||
|
self.stride = stride
|
||||||
|
self.threshold = threshold
|
||||||
|
self.block_size = block_size
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.use_triton = use_triton
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
XAttention estimation (xattn_estimate).
|
||||||
|
|
||||||
|
Uses chunked GEMM + softmax to estimate block-level importance,
|
||||||
|
then selects important blocks based on threshold.
|
||||||
|
|
||||||
|
对应 COMPASS 的 xattn_estimate() 函数:
|
||||||
|
1. Pad inputs to chunk_size multiples
|
||||||
|
2. Reshape with stride
|
||||||
|
3. Compute QK^T in chunks (Triton)
|
||||||
|
4. Block-wise softmax + aggregation
|
||||||
|
5. Threshold-based selection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [seq_len, num_heads, head_dim]
|
||||||
|
k: [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: transformer layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
|
||||||
|
or None (fallback to full attention)
|
||||||
|
"""
|
||||||
|
# TODO: 实现真正的 xattn_estimate
|
||||||
|
# 当前返回 None 使用 full attention
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute XAttention sparse prefill.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Call estimate() to get sparse mask
|
||||||
|
2. If mask is None, use full attention
|
||||||
|
3. Otherwise, apply block sparse attention with mask
|
||||||
|
"""
|
||||||
|
# Step 1: Estimate sparse mask
|
||||||
|
sparse_mask = self.estimate(q, k, layer_id)
|
||||||
|
|
||||||
|
# Step 2: Compute attention
|
||||||
|
if sparse_mask is None:
|
||||||
|
# Fallback to full attention
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
seq_len = q.shape[0]
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Apply block sparse attention with mask
|
||||||
|
# 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
|
||||||
|
raise NotImplementedError("Block sparse attention not yet implemented")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"XAttentionPolicy("
|
||||||
|
f"stride={self.stride}, "
|
||||||
|
f"threshold={self.threshold}, "
|
||||||
|
f"block_size={self.block_size})")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 5: 更新 model_runner.py
|
||||||
|
|
||||||
|
```python
|
||||||
|
# model_runner.py - allocate_kv_cache()
|
||||||
|
|
||||||
|
# 改为总是创建 policy(包括 FULL)
|
||||||
|
from nanovllm.kvcache.sparse import create_attention_policy
|
||||||
|
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
|
||||||
|
logger.info(f"Attention policy: {self.attention_policy}")
|
||||||
|
|
||||||
|
# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()
|
||||||
|
|
||||||
|
# 旧代码:
|
||||||
|
if self.sparse_prefill_policy is not None:
|
||||||
|
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_varlen_func(...)
|
||||||
|
|
||||||
|
# 新代码:
|
||||||
|
attn_output = self.attention_policy.compute_prefill(
|
||||||
|
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Method Mapping
|
||||||
|
|
||||||
|
| 旧方法 | 新方法 | 说明 |
|
||||||
|
|--------|--------|------|
|
||||||
|
| `select_blocks()` | `estimate()` | 计算 sparse mask(对应 xattn_estimate) |
|
||||||
|
| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention |
|
||||||
|
| (无) | `compute_decode()` | Decode attention(默认实现) |
|
||||||
|
| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 |
|
||||||
|
|
||||||
|
## Files to Modify
|
||||||
|
|
||||||
|
| File | Changes |
|
||||||
|
|------|---------|
|
||||||
|
| `nanovllm/kvcache/sparse/policy.py` | 新接口:estimate, compute_prefill, compute_decode |
|
||||||
|
| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None |
|
||||||
|
| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() |
|
||||||
|
| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 |
|
||||||
|
| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() |
|
||||||
|
| `nanovllm/config.py` | 可选:重命名配置项 |
|
||||||
|
|
||||||
|
## Decisions Made
|
||||||
|
|
||||||
|
1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格
|
||||||
|
2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs
|
||||||
|
3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()`
|
||||||
|
4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention
|
||||||
|
5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现
|
||||||
|
|
||||||
|
## Errors Encountered
|
||||||
|
- (无)
|
||||||
|
|
||||||
|
## Status
|
||||||
|
**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2
|
||||||
|
|||||||
@@ -32,11 +32,14 @@ def run_needle_test(
|
|||||||
enable_cpu_offload: bool = False,
|
enable_cpu_offload: bool = False,
|
||||||
enable_quest: bool = False,
|
enable_quest: bool = False,
|
||||||
enable_minference: bool = False,
|
enable_minference: bool = False,
|
||||||
|
enable_xattn: bool = False,
|
||||||
sparse_topk: int = 8,
|
sparse_topk: int = 8,
|
||||||
sparse_threshold: int = 4,
|
sparse_threshold: int = 4,
|
||||||
minference_budget: float = 0.3,
|
minference_budget: float = 0.3,
|
||||||
minference_vertical: int = 1000,
|
minference_vertical: int = 1000,
|
||||||
minference_slash: int = 6096,
|
minference_slash: int = 6096,
|
||||||
|
xattn_threshold: float = 0.9,
|
||||||
|
xattn_use_bsa: bool = True,
|
||||||
gpu_utilization: float = 0.9,
|
gpu_utilization: float = 0.9,
|
||||||
enforce_eager: bool = True,
|
enforce_eager: bool = True,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
@@ -56,11 +59,14 @@ def run_needle_test(
|
|||||||
enable_cpu_offload: Enable CPU offload mode
|
enable_cpu_offload: Enable CPU offload mode
|
||||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||||
enable_minference: Enable MInference sparse prefill (GPU-only)
|
enable_minference: Enable MInference sparse prefill (GPU-only)
|
||||||
|
enable_xattn: Enable XAttention sparse prefill with BSA
|
||||||
sparse_topk: Top-K blocks for Quest
|
sparse_topk: Top-K blocks for Quest
|
||||||
sparse_threshold: Apply sparse only when blocks > threshold
|
sparse_threshold: Apply sparse only when blocks > threshold
|
||||||
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
|
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
|
||||||
minference_vertical: Fixed vertical_size (only used when budget=None)
|
minference_vertical: Fixed vertical_size (only used when budget=None)
|
||||||
minference_slash: Fixed slash_size (only used when budget=None)
|
minference_slash: Fixed slash_size (only used when budget=None)
|
||||||
|
xattn_threshold: XAttention block selection threshold (0-1)
|
||||||
|
xattn_use_bsa: Use Block Sparse Attention library
|
||||||
gpu_utilization: GPU memory utilization fraction
|
gpu_utilization: GPU memory utilization fraction
|
||||||
verbose: Print detailed output
|
verbose: Print detailed output
|
||||||
|
|
||||||
@@ -68,7 +74,9 @@ def run_needle_test(
|
|||||||
True if test passed, False otherwise
|
True if test passed, False otherwise
|
||||||
"""
|
"""
|
||||||
# Determine sparse policy
|
# Determine sparse policy
|
||||||
if enable_minference:
|
if enable_xattn:
|
||||||
|
sparse_policy = SparsePolicyType.XATTN
|
||||||
|
elif enable_minference:
|
||||||
sparse_policy = SparsePolicyType.MINFERENCE
|
sparse_policy = SparsePolicyType.MINFERENCE
|
||||||
elif enable_quest:
|
elif enable_quest:
|
||||||
sparse_policy = SparsePolicyType.QUEST
|
sparse_policy = SparsePolicyType.QUEST
|
||||||
@@ -94,6 +102,8 @@ def run_needle_test(
|
|||||||
print(f" MInference: adaptive (budget={minference_budget})")
|
print(f" MInference: adaptive (budget={minference_budget})")
|
||||||
else:
|
else:
|
||||||
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
|
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
|
||||||
|
if enable_xattn:
|
||||||
|
print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}")
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
# 1. Initialize LLM
|
# 1. Initialize LLM
|
||||||
@@ -111,7 +121,7 @@ def run_needle_test(
|
|||||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||||
|
|
||||||
# Set sparse policy (can be used with or without offload)
|
# Set sparse policy (can be used with or without offload)
|
||||||
if enable_minference or enable_quest:
|
if enable_minference or enable_quest or enable_xattn:
|
||||||
llm_kwargs["sparse_policy"] = sparse_policy
|
llm_kwargs["sparse_policy"] = sparse_policy
|
||||||
|
|
||||||
# MInference params (works with both GPU-only and offload mode)
|
# MInference params (works with both GPU-only and offload mode)
|
||||||
@@ -120,6 +130,11 @@ def run_needle_test(
|
|||||||
llm_kwargs["minference_vertical_size"] = minference_vertical
|
llm_kwargs["minference_vertical_size"] = minference_vertical
|
||||||
llm_kwargs["minference_slash_size"] = minference_slash
|
llm_kwargs["minference_slash_size"] = minference_slash
|
||||||
|
|
||||||
|
# XAttention params
|
||||||
|
if enable_xattn:
|
||||||
|
llm_kwargs["xattn_threshold"] = xattn_threshold
|
||||||
|
llm_kwargs["xattn_use_bsa"] = xattn_use_bsa
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
llm = LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
# 2. Generate needle prompt
|
# 2. Generate needle prompt
|
||||||
@@ -224,6 +239,11 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
|
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-xattn",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable XAttention sparse prefill with Block Sparse Attention"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sparse-topk",
|
"--sparse-topk",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -254,6 +274,17 @@ if __name__ == "__main__":
|
|||||||
default=6096,
|
default=6096,
|
||||||
help="Fixed slash_size (only used when budget=0)"
|
help="Fixed slash_size (only used when budget=0)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--xattn-threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.9,
|
||||||
|
help="XAttention block selection threshold (0-1, higher=more blocks)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--xattn-no-bsa",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable Block Sparse Attention (use FlashAttention fallback)"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gpu-utilization",
|
"--gpu-utilization",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -291,11 +322,14 @@ if __name__ == "__main__":
|
|||||||
enable_cpu_offload=args.enable_offload,
|
enable_cpu_offload=args.enable_offload,
|
||||||
enable_quest=args.enable_quest,
|
enable_quest=args.enable_quest,
|
||||||
enable_minference=args.enable_minference,
|
enable_minference=args.enable_minference,
|
||||||
|
enable_xattn=args.enable_xattn,
|
||||||
sparse_topk=args.sparse_topk,
|
sparse_topk=args.sparse_topk,
|
||||||
sparse_threshold=args.sparse_threshold,
|
sparse_threshold=args.sparse_threshold,
|
||||||
minference_budget=minference_budget,
|
minference_budget=minference_budget,
|
||||||
minference_vertical=args.minference_vertical,
|
minference_vertical=args.minference_vertical,
|
||||||
minference_slash=args.minference_slash,
|
minference_slash=args.minference_slash,
|
||||||
|
xattn_threshold=args.xattn_threshold,
|
||||||
|
xattn_use_bsa=not args.xattn_no_bsa,
|
||||||
gpu_utilization=args.gpu_utilization,
|
gpu_utilization=args.gpu_utilization,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
|||||||
@@ -38,11 +38,11 @@ from nanovllm import LLM, SamplingParams
|
|||||||
# Constants
|
# Constants
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_32k"
|
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k"
|
||||||
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
|
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
|
||||||
# Note: max_model_len must be > max_input_len to leave room for output tokens
|
# Note: max_model_len must be > max_input_len to leave room for output tokens
|
||||||
# 32k benchmark has inputs up to 32760 tokens, so we need 32768 + 128 = 32896
|
# 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664
|
||||||
DEFAULT_MAX_MODEL_LEN = 32896
|
DEFAULT_MAX_MODEL_LEN = 65664
|
||||||
DEFAULT_MAX_NEW_TOKENS = 128 # Larger for multi-value tasks
|
DEFAULT_MAX_NEW_TOKENS = 128 # Larger for multi-value tasks
|
||||||
|
|
||||||
# Task categories for evaluation
|
# Task categories for evaluation
|
||||||
@@ -222,9 +222,11 @@ def run_ruler_benchmark(
|
|||||||
enable_cpu_offload: bool = False,
|
enable_cpu_offload: bool = False,
|
||||||
num_gpu_blocks: int = 4,
|
num_gpu_blocks: int = 4,
|
||||||
block_size: int = 1024,
|
block_size: int = 1024,
|
||||||
|
num_kv_buffers: int = 4,
|
||||||
gpu_utilization: float = 0.9,
|
gpu_utilization: float = 0.9,
|
||||||
enforce_eager: bool = True,
|
enforce_eager: bool = True,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
sparse_policy: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Run RULER benchmark on multiple tasks.
|
Run RULER benchmark on multiple tasks.
|
||||||
@@ -235,6 +237,7 @@ def run_ruler_benchmark(
|
|||||||
datasets: List of task names to test (None = all)
|
datasets: List of task names to test (None = all)
|
||||||
num_samples: Number of samples per task (None = all)
|
num_samples: Number of samples per task (None = all)
|
||||||
...other LLM config params...
|
...other LLM config params...
|
||||||
|
sparse_policy: Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with overall results and per-task results
|
Dict with overall results and per-task results
|
||||||
@@ -270,6 +273,11 @@ def run_ruler_benchmark(
|
|||||||
}
|
}
|
||||||
if enable_cpu_offload:
|
if enable_cpu_offload:
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||||
|
llm_kwargs["num_kv_buffers"] = num_kv_buffers
|
||||||
|
if sparse_policy:
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
sparse_policy_type = SparsePolicyType[sparse_policy]
|
||||||
|
llm_kwargs["sparse_policy"] = sparse_policy_type
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
llm = LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
@@ -356,12 +364,16 @@ if __name__ == "__main__":
|
|||||||
help="Number of GPU blocks for CPU offload (default: 4)")
|
help="Number of GPU blocks for CPU offload (default: 4)")
|
||||||
parser.add_argument("--block-size", type=int, default=1024,
|
parser.add_argument("--block-size", type=int, default=1024,
|
||||||
help="KV cache block size (default: 1024)")
|
help="KV cache block size (default: 1024)")
|
||||||
|
parser.add_argument("--num-kv-buffers", type=int, default=4,
|
||||||
|
help="Number of KV buffers for ring buffer (default: 4)")
|
||||||
parser.add_argument("--gpu-utilization", type=float, default=0.9,
|
parser.add_argument("--gpu-utilization", type=float, default=0.9,
|
||||||
help="GPU memory utilization (default: 0.9)")
|
help="GPU memory utilization (default: 0.9)")
|
||||||
parser.add_argument("--use-cuda-graph", action="store_true",
|
parser.add_argument("--use-cuda-graph", action="store_true",
|
||||||
help="Enable CUDA graph")
|
help="Enable CUDA graph")
|
||||||
parser.add_argument("--quiet", "-q", action="store_true",
|
parser.add_argument("--quiet", "-q", action="store_true",
|
||||||
help="Quiet mode")
|
help="Quiet mode")
|
||||||
|
parser.add_argument("--sparse-policy", type=str, default="",
|
||||||
|
help="Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -369,6 +381,9 @@ if __name__ == "__main__":
|
|||||||
datasets = args.datasets.split(",") if args.datasets else None
|
datasets = args.datasets.split(",") if args.datasets else None
|
||||||
num_samples = args.num_samples if args.num_samples > 0 else None
|
num_samples = args.num_samples if args.num_samples > 0 else None
|
||||||
|
|
||||||
|
# Parse sparse policy
|
||||||
|
sparse_policy_str = args.sparse_policy.upper() if args.sparse_policy else None
|
||||||
|
|
||||||
results = run_ruler_benchmark(
|
results = run_ruler_benchmark(
|
||||||
model_path=os.path.expanduser(args.model),
|
model_path=os.path.expanduser(args.model),
|
||||||
data_dir=Path(args.data_dir),
|
data_dir=Path(args.data_dir),
|
||||||
@@ -379,9 +394,11 @@ if __name__ == "__main__":
|
|||||||
enable_cpu_offload=args.enable_offload,
|
enable_cpu_offload=args.enable_offload,
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
num_gpu_blocks=args.num_gpu_blocks,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
|
num_kv_buffers=args.num_kv_buffers,
|
||||||
gpu_utilization=args.gpu_utilization,
|
gpu_utilization=args.gpu_utilization,
|
||||||
enforce_eager=not args.use_cuda_graph,
|
enforce_eager=not args.use_cuda_graph,
|
||||||
verbose=not args.quiet,
|
verbose=not args.quiet,
|
||||||
|
sparse_policy=sparse_policy_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exit code
|
# Exit code
|
||||||
|
|||||||
244
tests/test_xattn_estimate_chunked.py
Normal file
244
tests/test_xattn_estimate_chunked.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""
|
||||||
|
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||||
|
|
||||||
|
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
||||||
|
as standard estimation. This ensures the chunked version can be used in
|
||||||
|
chunked prefill scenarios without accuracy loss.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_chunked.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Configuration for xattn_estimate_chunked consistency test.
|
||||||
|
# Key requirements for 100% match:
|
||||||
|
# 1. Use matching chunk_size for both standard and chunked versions
|
||||||
|
# 2. Use same random seed for reproducibility
|
||||||
|
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
||||||
|
# floating point precision in cumulative sum calculations.
|
||||||
|
BLOCK_SIZE = 64
|
||||||
|
STRIDE = 4
|
||||||
|
THRESHOLD = 0.9
|
||||||
|
CHUNK_SIZE = 4096 # External chunking size
|
||||||
|
|
||||||
|
# Test sequence lengths
|
||||||
|
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Utility Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||||
|
"""Compare two masks and report differences."""
|
||||||
|
if mask1.shape != mask2.shape:
|
||||||
|
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
diff = (mask1 != mask2).sum().item()
|
||||||
|
total = mask1.numel()
|
||||||
|
match_rate = (total - diff) / total * 100
|
||||||
|
|
||||||
|
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||||
|
|
||||||
|
if diff > 0:
|
||||||
|
diff_indices = torch.where(mask1 != mask2)
|
||||||
|
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||||
|
|
||||||
|
return diff == 0
|
||||||
|
|
||||||
|
|
||||||
|
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
||||||
|
"""
|
||||||
|
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||||
|
This simulates how chunked prefill should be used in practice.
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, head_dim = query.shape
|
||||||
|
_, _, k_len, _ = key.shape
|
||||||
|
|
||||||
|
q_block_num = (q_len + block_size - 1) // block_size
|
||||||
|
k_block_num = (k_len + block_size - 1) // block_size
|
||||||
|
|
||||||
|
# If Q fits in one chunk, call directly
|
||||||
|
if q_len <= chunk_size:
|
||||||
|
return xattn_estimate_chunked(
|
||||||
|
query, key,
|
||||||
|
q_start_pos=0,
|
||||||
|
block_size=block_size,
|
||||||
|
stride=stride,
|
||||||
|
threshold=threshold,
|
||||||
|
use_triton=True,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# External chunking: split Q and call for each chunk
|
||||||
|
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||||
|
print(f" External chunking: {num_q_chunks} chunks")
|
||||||
|
|
||||||
|
combined_attn_sum = torch.zeros(
|
||||||
|
batch_size, num_heads, q_block_num, k_block_num,
|
||||||
|
dtype=query.dtype, device=query.device
|
||||||
|
)
|
||||||
|
combined_mask = torch.zeros(
|
||||||
|
batch_size, num_heads, q_block_num, k_block_num,
|
||||||
|
dtype=torch.bool, device=query.device
|
||||||
|
)
|
||||||
|
|
||||||
|
q_block_offset = 0
|
||||||
|
for q_chunk_idx in range(num_q_chunks):
|
||||||
|
q_chunk_start = q_chunk_idx * chunk_size
|
||||||
|
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||||
|
|
||||||
|
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||||
|
|
||||||
|
# For causal attention, K accumulates up to current Q position
|
||||||
|
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
||||||
|
# K is [0, q_chunk_end) for causal attention
|
||||||
|
k_end = q_chunk_end
|
||||||
|
k_chunk = key[:, :, :k_end, :]
|
||||||
|
|
||||||
|
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||||
|
q_chunk, k_chunk,
|
||||||
|
q_start_pos=q_chunk_start,
|
||||||
|
block_size=block_size,
|
||||||
|
stride=stride,
|
||||||
|
threshold=threshold,
|
||||||
|
use_triton=True,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Place chunk results into combined output
|
||||||
|
chunk_q_blocks = mask_chunk.shape[2]
|
||||||
|
chunk_k_blocks = mask_chunk.shape[3]
|
||||||
|
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||||
|
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||||
|
q_block_offset += chunk_q_blocks
|
||||||
|
|
||||||
|
return combined_attn_sum, combined_mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
||||||
|
"""Test a single sequence length."""
|
||||||
|
print(f"\nTesting seq_len={seq_len}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Generate random Q/K
|
||||||
|
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||||
|
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Run standard xattn_estimate
|
||||||
|
print("[1] Running standard xattn_estimate...")
|
||||||
|
try:
|
||||||
|
attn_sum_std, mask_std = xattn_estimate(
|
||||||
|
query, key,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
use_triton=True,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
density_std = mask_std.float().mean().item()
|
||||||
|
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERROR: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||||
|
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||||
|
try:
|
||||||
|
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||||
|
query, key,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
)
|
||||||
|
density_chunked = mask_chunked.float().mean().item()
|
||||||
|
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERROR: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
print("[3] Comparing results...")
|
||||||
|
chunked_q_blocks = mask_chunked.shape[2]
|
||||||
|
chunked_k_blocks = mask_chunked.shape[3]
|
||||||
|
|
||||||
|
# Extract comparable region from standard mask
|
||||||
|
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||||
|
|
||||||
|
# Compare masks
|
||||||
|
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||||
|
|
||||||
|
# Compare attn_sums
|
||||||
|
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||||
|
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||||
|
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||||
|
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||||
|
else:
|
||||||
|
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
||||||
|
|
||||||
|
# Clean up GPU memory
|
||||||
|
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return masks_match
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("XAttention Chunked vs Standard Test")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
||||||
|
print(f"External chunk_size={CHUNK_SIZE}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Check CUDA availability
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print("✓ xattn_estimate imported")
|
||||||
|
print("✓ xattn_estimate_chunked imported")
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
all_passed = True
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for seq_len in TEST_SEQ_LENS:
|
||||||
|
passed = test_single_seq_len(seq_len)
|
||||||
|
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||||
|
results.append((seq_len, chunks, passed))
|
||||||
|
if not passed:
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
for seq_len, chunks, passed in results:
|
||||||
|
status = "PASSED" if passed else "FAILED"
|
||||||
|
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
if all_passed:
|
||||||
|
print("ALL TESTS PASSED!")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("SOME TESTS FAILED!")
|
||||||
|
sys.exit(1)
|
||||||
Reference in New Issue
Block a user