14 Commits

Author SHA1 Message Date
Zijie Tian
5fb0f67295 [WIP] need refactor. 2026-01-22 22:20:34 +08:00
Zijie Tian
69b779e252 📝 docs: add layer offload planning notes and task plan
Add planning documents for layer-wise offload implementation:
- notes.md: Implementation notes and findings
- task_plan.md: Detailed task breakdown and progress tracking

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 06:04:36 +08:00
Zijie Tian
e313dd795a feat: add exec-plan command for automated task plan execution
Add a new Claude command that executes task_plan.md refactoring with:
- GPU isolation via --gpu <id> parameter (required)
- Optional --no-interrupt mode for autonomous execution
- Progress tracking via progress.md and findings.md
- Strict CUDA_VISIBLE_DEVICES enforcement

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 06:03:42 +08:00
Zijie Tian
9f3ee9279e feat: add nanovllm.ops module with XAttention estimation kernels
Add ops module ported from tzj/minference branch containing:
- xattn.py: XAttention block importance estimation with Triton kernels
  - xattn_estimate(): standard estimation for sparse attention mask
  - xattn_estimate_chunked(): chunked prefill compatible version
  - flat_group_gemm_fuse_reshape(): fused stride reshape + GEMM kernel
  - softmax_fuse_block_sum(): online softmax + block-wise sum kernel
- chunked_attention.py: Flash attention with LSE output for chunk merging
- test_xattn_estimate_chunked.py: verification test (all seq_lens pass)

This prepares the foundation for AttentionPolicy refactoring where
XAttentionPolicy.estimate() will call these ops.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 06:00:42 +08:00
Zijie Tian
2826a649de docs: add XAttention integration guide
Comprehensive documentation for XAttention sparse policy integration:
- Algorithm principles (chunked estimation + block sparse attention)
- COMPASS source code analysis
- Design decisions for CPU offload mode
- Implementation details (utils.py, kernels.py, xattn.py)
- Problem-solving (OOM, GQA, abstract method)
- Test validation results (RULER 32k benchmark)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 10:16:21 +08:00
Zijie Tian
24baeb6d5a chore: add planning-with-files rule configuration 2026-01-14 10:09:52 +08:00
Zijie Tian
57f4e9c6e6 docs: reorganize documentation files
- Move notes.md to docs/development_notes.md
- Move Xattention_analysis.md to docs/xattention_analysis.md
- Delete DEBUG_SUMMARY.md (no longer needed)
- Update CLAUDE.md with documentation index entries

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 10:08:41 +08:00
Zijie Tian
ac1ccbceaa feat: add XAttention sparse policy integration
Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload
execution path. Uses FlashAttention with native GQA support for
offload mode.

New files:
- nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility
- nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention
- nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation

Modified:
- nanovllm/config.py: Add XATTN configuration parameters
- nanovllm/engine/model_runner.py: Support XATTN policy
- nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy
- tests/test_ruler.py: Add --sparse-policy parameter

Test results (32k ruler):
- NIAH tasks: 12/12 (100%)
- QA/Recall tasks: 11/15 (73%)
- Overall: 23/27 (85%)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 10:04:46 +08:00
Zijie Tian
029894118d feat: add claude-flow MCP configuration
Add .claude/settings.json to enable claude-flow MCP in all worktrees.

This configuration includes:
- SessionStart hook to auto-start claude-flow daemon
- Auto-approval for claude-flow MCP tools and CLI commands
- Basic claude-flow settings

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 09:18:09 +08:00
Zijie Tian
8d6fde3b23 docs: add Block-Sparse-Attention library reference
Add comprehensive documentation for the MIT-Han-Lab Block-Sparse-Attention
library (3rdparty submodule, branch: tzj/minference).

The new document covers:
- Four sparse attention modes (dense, token/block streaming, block sparse)
- Hybrid mask support (different patterns per head)
- Complete API reference for all three functions
- Performance benchmarks (up to 3-4x speedup on A100)
- Integration considerations for nano-vllm

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 08:39:03 +08:00
Zijie Tian
6a6bd75685 feat: add Block-Sparse-Attention submodule (tzj/minference branch)
Add 3rdparty/Block-Sparse-Attention as a git submodule from the
tzj/minference branch of Zijie-Tian/Block-Sparse-Attention repository.

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 08:07:07 +08:00
Zijie Tian
86633004ca 📝 docs: add 64k memory analysis and test configuration updates
Add comprehensive memory analysis for 64k inference on Llama 3.1 8B:

New documentation:
- docs/64k_memory_analysis.md: GPU-only vs offload memory analysis,
  OOM root cause (memory fragmentation), RTX 3090 limitations,
  theoretical vs actual memory usage breakdown

Test configuration updates:
- tests/test_ruler.py: Add --num-kv-buffers parameter for ring buffer
  size tuning (default 4, can reduce to 1 for lower memory)
- Update default data_dir to ruler_64k
- Update default max_model_len to 65664 for 64k support

CLAUDE.md updates:
- Add 64k_memory_analysis.md to documentation index
- Document num_kv_buffers parameter in Configuration section
- Add 64k hardware requirements note to Model Limits

Key findings: 64k inference requires ~26GB (GPU-only) or ~23GB (offload)
due to memory fragmentation on 24GB GPUs, making A100 (40GB+) the
recommended hardware for 64k workloads.

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 07:02:09 +08:00
Zijie Tian
c51a640a29 🐛 fix: remove torch.compile from add_rms_forward to avoid recompilation
The add_rms_forward method processes two input tensors (x and residual),
which causes torch.compile recompilation issues. Keep @torch.compile only
on rms_forward which processes a single input.

This prevents unnecessary recompilation overhead during inference.

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 07:02:02 +08:00
Zijie Tian
dce6ad6b74 ♻️ refactor: chunked LayerNorm/QKV/MLP for 64k memory optimization
Implement chunked processing for LayerNorm, QKV projection, and MLP
layers to reduce peak activation memory for 64k sequence inference.

Changes:
- Chunked input_layernorm and post_attention_layernorm (chunk_size=128)
- Chunked QKV projection (chunk_size=128)
- Chunked MLP processing (chunk_size=128) with memory cleanup
- Added torch.cuda.empty_cache() calls after each chunk

This reduces peak activation from ~2 GB to ~50 MB per layer,
making 64k inference theoretically possible on 24GB GPUs
(though still limited by memory fragmentation).

Related: docs/64k_memory_analysis.md

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 07:01:57 +08:00
36 changed files with 6607 additions and 1397 deletions

View File

@@ -0,0 +1,158 @@
---
allowed-tools: Bash(CUDA_VISIBLE_DEVICES=*), Bash(PYTHONPATH=*), Bash(python*), Bash(git*), Bash(rm*), Bash(ls*), Bash(cat*), Bash(nvidia-smi*), Read, Edit, Write, Glob, Grep, TodoWrite, Task
argument-hint: --gpu <id> [--no-interrupt]
description: Execute task_plan.md refactoring with specified GPU, optionally without user interruption
---
# Execute Task Plan (exec-plan)
按照 `task_plan.md` 的要求执行代码重构,确保计划中的最终目标圆满实现。
## 参数说明
命令格式: `/exec-plan --gpu <id> [--no-interrupt]`
| 参数 | 说明 | 示例 |
|------|------|------|
| `--gpu <id>` | **必需**。指定可用的 GPU ID只能使用此 GPU 进行调试 | `--gpu 0`, `--gpu 2` |
| `--no-interrupt` | 可选。禁止中断执行,遇到问题不与用户交互,自动解决或跳过 | `--no-interrupt` |
## 当前参数
```
$ARGUMENTS
```
## 执行前准备
### 1. 解析参数
`$ARGUMENTS` 中解析:
- `GPU_ID`: 从 `--gpu <id>``-g <id>` 提取
- `NO_INTERRUPT`: 是否存在 `--no-interrupt``-n` 标志
### 2. 参数验证
**必须验证**:
- GPU_ID 必须是有效的数字
- 运行 `nvidia-smi -i <GPU_ID>` 验证 GPU 存在
### 3. 读取 task_plan.md
读取项目根目录下的 `task_plan.md` 文件,理解:
- 总体目标
- 分阶段计划 (Phase 1, 2, 3...)
- 文件修改清单
- 风险和注意事项
- 测试计划
## 执行流程
### Step 1: 创建执行计划
使用 TodoWrite 工具创建详细的执行计划,包括:
- 从 task_plan.md 提取的所有 Phase
- 每个 Phase 的子任务
- 测试验证步骤
### Step 2: 按 Phase 执行重构
对于 task_plan.md 中的每个 Phase
1. **读取当前代码**: 使用 Read/Grep 理解现有实现
2. **实施修改**: 使用 Edit/Write 进行代码修改
3. **验证修改**: 运行相关测试
### Step 3: 运行测试验证
执行 task_plan.md 中定义的测试计划,验证重构成功。
## GPU 限制规则
**严格限制**: 只能使用指定的 GPU所有涉及 GPU 的命令必须加 `CUDA_VISIBLE_DEVICES` 前缀:
```bash
# 正确
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python test.py
# 错误 - 禁止使用其他 GPU
python test.py # 可能使用默认 GPU 0
CUDA_VISIBLE_DEVICES=0,1 python test.py # 使用多个 GPU
```
## 中断模式规则
### 当 `--no-interrupt` 生效时
遇到以下情况**不停下来询问用户**,而是:
| 情况 | 处理方式 |
|------|----------|
| 测试失败 | 记录失败原因,尝试自动修复,继续下一步 |
| 代码冲突 | 尝试合理解决,记录解决方案 |
| 不确定的实现细节 | 选择最合理的方案继续 |
| 执行错误 | 分析错误,尝试修复,记录问题 |
**自动决策原则**:
1. 优先保证功能正确性
2. 遵循现有代码风格
3. 选择简单直接的实现
4. 记录所有自动决策到 `progress.md`
### 当未指定 `--no-interrupt` 时
遇到以下情况**可以询问用户**
- 多个实现方案需要选择
- 测试持续失败无法自动修复
- 发现 task_plan.md 中的问题或矛盾
## 执行记录
### 进度文件: progress.md
实时更新 `progress.md` 记录:
```markdown
## 执行进度
### Phase X: [名称]
- 状态: [进行中/完成/失败]
- 开始时间: [时间]
- 完成时间: [时间]
- 修改文件: [文件列表]
- 自动决策: [如果有]
- 问题记录: [如果有]
```
### 发现记录: findings.md
记录执行过程中的重要发现到 `findings.md`
## 示例用法
```bash
# 使用 GPU 2允许中断
/exec-plan --gpu 2
# 使用 GPU 0不中断执行
/exec-plan --gpu 0 --no-interrupt
# 简短形式
/exec-plan -g 1 -n
```
## 完成标准
执行完成后,确保:
1. **所有 Phase 完成**: task_plan.md 中的所有 Phase 都已实施
2. **测试通过**: task_plan.md 中的测试计划全部通过
3. **代码质量**: 修改符合项目代码规范
4. **文档更新**: progress.md 包含完整执行记录
## 重要约束
1. **GPU 隔离**: 绝对不能使用指定 GPU 以外的设备
2. **遵循计划**: 严格按照 task_plan.md 执行,不做计划外的修改
3. **渐进式修改**: 每个 Phase 完成后验证,而不是最后一起验证
4. **回滚准备**: 重大修改前考虑是否需要 git commit 保存点

View File

@@ -0,0 +1,50 @@
# Planning with Files Rule
## 自动清理旧计划文件
**重要**:每次开始新的复杂任务使用 planning-with-files 时,先删除旧的计划文件。
### 使用前执行以下命令
```bash
# 在项目根目录执行,删除旧的计划文件
cd /home/zijie/Code/nano-vllm
rm -f task_plan.md findings.md progress.md
rm -f task_plan_*.md findings_*.md progress_*.md
```
### 为什么需要这个规则
1. **避免混淆**:不同任务有不同计划,旧的计划文件会干扰新任务
2. **保持简洁**:只保留当前任务的计划文件
3. **自动清理**:无需手动检查文件内容,直接删除即可
### 使用 planning-with-files 的完整流程
```bash
# Step 1: 清理旧计划文件
rm -f task_plan.md findings.md progress.md task_plan_*.md findings_*.md progress_*.md
# Step 2: 启动 planning-with-files 技能
# 在 Claude 中调用 /planning-with-files 或 Skill tool
# Step 3: 技能会自动创建新的计划文件
# - task_plan.md (或 task_plan_<任务名>.md)
# - findings.md (或 findings_<任务名>.md)
# - progress.md (或 progress_<任务名>.md)
```
### 文件命名建议
| 场景 | 文件命名 | 示例 |
|------|----------|------|
| 通用任务 | task_plan.md, findings.md, progress.md | 临时调试任务 |
| 特定功能 | task_plan_<feature>.md | task_plan_xattn.md |
| Bug 修复 | task_plan_bug_<name>.md | task_plan_bug_offload.md |
### 注意事项
- 计划文件存储在**项目根目录**,不是技能目录
- 技能目录:`/home/zijie/.claude/plugins/cache/planning-with-files/...`
- 项目目录:`/home/zijie/Code/nano-vllm/`
- 每个任务完成后,可以选择保留或删除计划文件

70
.claude/settings.json Normal file
View File

@@ -0,0 +1,70 @@
{
"hooks": {
"SessionStart": [
{
"hooks": [
{
"type": "command",
"command": "npx @claude-flow/cli@latest daemon start --quiet 2>/dev/null || true",
"timeout": 5000,
"continueOnError": true
},
{
"type": "command",
"command": "[ -n \"$SESSION_ID\" ] && npx @claude-flow/cli@latest hooks session-restore --session-id \"$SESSION_ID\" 2>/dev/null || true",
"timeout": 10000,
"continueOnError": true
}
]
}
],
"Stop": [
{
"hooks": [
{
"type": "command",
"command": "echo '{\"ok\": true}'",
"timeout": 1000
}
]
}
],
"PermissionRequest": [
{
"matcher": "^mcp__claude-flow__.*$",
"hooks": [
{
"type": "command",
"command": "echo '{\"decision\": \"allow\", \"reason\": \"claude-flow MCP tool auto-approved\"}'",
"timeout": 1000
}
]
},
{
"matcher": "^Bash\\(npx @?claude-flow.*\\)$",
"hooks": [
{
"type": "command",
"command": "echo '{\"decision\": \"allow\", \"reason\": \"claude-flow CLI auto-approved\"}'",
"timeout": 1000
}
]
}
]
},
"permissions": {
"allow": [
"Bash(npx claude-flow*)",
"Bash(npx @claude-flow/*)",
"mcp__claude-flow__*"
],
"deny": []
},
"claudeFlow": {
"version": "3.0.0",
"enabled": true,
"daemon": {
"autoStart": true
}
}
}

4
.gitmodules vendored Normal file
View File

@@ -0,0 +1,4 @@
[submodule "3rdparty/Block-Sparse-Attention"]
path = 3rdparty/Block-Sparse-Attention
url = git@github.com:Zijie-Tian/Block-Sparse-Attention.git
branch = tzj/minference

View File

@@ -53,12 +53,17 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
| [`docs/multi_model_support.md`](docs/multi_model_support.md) | Model registry system, adding new models (Qwen3/Llama), architecture differences, RoPE scaling | | [`docs/multi_model_support.md`](docs/multi_model_support.md) | Model registry system, adding new models (Qwen3/Llama), architecture differences, RoPE scaling |
| [`docs/cuda_graph_offload_guide.md`](docs/cuda_graph_offload_guide.md) | CUDA graph support for CPU offload decode path, 4x decode speedup | | [`docs/cuda_graph_offload_guide.md`](docs/cuda_graph_offload_guide.md) | CUDA graph support for CPU offload decode path, 4x decode speedup |
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (MInference, FlexPrefill, XAttention, Quest), computation flow | | [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (MInference, FlexPrefill, XAttention, Quest), computation flow |
| [`docs/block_sparse_attention_lib.md`](docs/block_sparse_attention_lib.md) | MIT-Han-Lab Block-Sparse-Attention library reference: sparse modes, API, performance |
| [`docs/sparse_prefill_integration_plan.md`](docs/sparse_prefill_integration_plan.md) | Integration plan for MInference/XAttention/FlexPrefill with unified BlockMask interface | | [`docs/sparse_prefill_integration_plan.md`](docs/sparse_prefill_integration_plan.md) | Integration plan for MInference/XAttention/FlexPrefill with unified BlockMask interface |
| [`docs/sparse_offload_integration.md`](docs/sparse_offload_integration.md) | Sparse policy integration with layerwise offload, `requires_block_selection` interface design | | [`docs/sparse_offload_integration.md`](docs/sparse_offload_integration.md) | Sparse policy integration with layerwise offload, `requires_block_selection` interface design |
| [`docs/layerwise_offload_memory_analysis.md`](docs/layerwise_offload_memory_analysis.md) | Memory allocation analysis with theoretical formulas and empirical validation (< 5% error) | | [`docs/layerwise_offload_memory_analysis.md`](docs/layerwise_offload_memory_analysis.md) | Memory allocation analysis with theoretical formulas and empirical validation (< 5% error) |
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, tensor comparison, memory profiling | | [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, tensor comparison, memory profiling |
| [`docs/gpu_only_performance_issue.md`](docs/gpu_only_performance_issue.md) | GPU-only mode slower than offload due to PagedAttention scatter overhead, optimization proposals | | [`docs/gpu_only_performance_issue.md`](docs/gpu_only_performance_issue.md) | GPU-only mode slower than offload due to PagedAttention scatter overhead, optimization proposals |
| [`docs/offload_accuracy_issue.md`](docs/offload_accuracy_issue.md) | **BUG**: CPU offload mode 66% accuracy vs 100% non-offload on RULER NIAH benchmark | | [`docs/offload_accuracy_issue.md`](docs/offload_accuracy_issue.md) | **BUG**: CPU offload mode 66% accuracy vs 100% non-offload on RULER NIAH benchmark |
| [`docs/64k_memory_analysis.md`](docs/64k_memory_analysis.md) | 64k inference memory analysis: GPU-only vs offload, OOM root cause (fragmentation), RTX 3090 limitations |
| [`docs/xattention_integration.md`](docs/xattention_integration.md) | XAttention integration guide: algorithm, implementation, design decisions, and testing |
| [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design |
| [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work |
## Configuration ## Configuration
@@ -69,7 +74,7 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
| `gpu_memory_utilization` | 0.9 | GPU memory fraction | | `gpu_memory_utilization` | 0.9 | GPU memory fraction |
| `enable_cpu_offload` | False | Enable for long context | | `enable_cpu_offload` | False | Enable for long context |
| `num_gpu_blocks` | 2 | GPU blocks for offload mode | | `num_gpu_blocks` | 2 | GPU blocks for offload mode |
| `num_kv_buffers` | 4 | Ring buffer size for decode pipeline | | `num_kv_buffers` | 4 | Ring buffer size (1-4), lower = less memory but slower decode |
| `enforce_eager` | False | Set True to disable CUDA graphs | | `enforce_eager` | False | Set True to disable CUDA graphs |
## Benchmarking ## Benchmarking
@@ -85,6 +90,7 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
- Qwen3-0.6B/4B: 40960 tokens - Qwen3-0.6B/4B: 40960 tokens
- Qwen2.5-7B-Instruct-1M: 1048576 tokens - Qwen2.5-7B-Instruct-1M: 1048576 tokens
- Llama-3.1-8B-Instruct: 131072 tokens - Llama-3.1-8B-Instruct: 131072 tokens
- **64k on RTX 3090/4090 (24GB)**: Requires CPU offload + optimizations, see [`docs/64k_memory_analysis.md`](docs/64k_memory_analysis.md)
**Performance (Qwen3-4B, CPU Offload)**: **Performance (Qwen3-4B, CPU Offload)**:
- Prefill: ~5700-8000 tok/s (varies by context length) - Prefill: ~5700-8000 tok/s (varies by context length)

View File

@@ -1,103 +0,0 @@
# Chunked Prefill Bug Debug Summary
## Problem
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
The model generates completely wrong tokens instead of the expected "7492".
## Investigation Progress
### 1. Stream Synchronization Fix (Completed)
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
- Added sync: `default_stream.wait_stream(compute_stream)` before return
### 2. KV Cache Alignment Verification (Completed)
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
**RoPE Alignment:**
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
- Confirmed RoPE is NOT the cause of the bug
**K/V Cache Alignment (Chunk 0):**
- Cosine similarity: ~1.0 for all layers
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
- Mean diff: < 0.001
- **Conclusion: K/V cache offload is working correctly**
### 3. Layer Output Divergence Analysis (Completed)
Created per-chunk layer output comparison:
**Chunk 0 (tokens 0-4096):**
- All layers pass with excellent cosine similarity (0.999+)
- Max diff grows in later layers but within acceptable range
**Chunk 1 (tokens 4096-8192):**
- Layers 0-19: OK (cosine ~1.0)
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
- Divergence correlates with later transformer layers
### 4. Critical Discovery: Single-Chunk Offload Also Fails
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
```
# Without offload: PASSES
python tests/test_needle.py --input-len 2048
# Output: "7492" (correct)
# With offload: FAILS
python tests/test_needle.py --enable-offload --input-len 2048
# Output: "The Ble White Th G Lopsiswin..." (garbage)
```
**This proves the bug is NOT in:**
- Chunked attention logic (merge_attention_outputs)
- Multi-chunk KV loading
- Ring buffer pipeline
**The bug IS in:**
- The decode path when CPU offload is enabled
- How prefilled KV is loaded/used during decode
### 5. Decode Path Analysis (In Progress)
The decode path in CPU offload mode:
1. Prefill writes KV to GPU, offloads to CPU
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
3. Attend to prefilled KV + accumulated decode tokens
4. Merge results
**Observations:**
- `prefilled_blocks` set is empty after decode (should contain block IDs)
- CPU cache has valid data (reasonable mean/std values)
- Decode buffer has zeros (decode tokens not being stored correctly?)
## Current Status
### Working
- Stream synchronization fixes
- K/V cache offload to CPU (verified alignment)
- RoPE implementation
- Chunked prefill attention for first chunk
### Not Working
- Decode with CPU offload (even for single-chunk inputs)
- Multi-chunk attention (divergence in later layers for chunk 1)
## Next Steps
1. Debug why `prefilled_blocks` is empty after decode
2. Check if decode path correctly loads KV from CPU
3. Verify decode buffer is being written correctly
4. Compare decode attention outputs between offload and non-offload modes
## Key Files
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
## Hypothesis
The decode path fails because:
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
3. OR there's a stream synchronization issue specific to decode path

131
docs/64k_memory_analysis.md Normal file
View File

@@ -0,0 +1,131 @@
# 64k 推理内存分析
本文档分析 Llama 3.1 8B 模型在 64k 长度推理时的内存占用,以及 RTX 3090 (24GB) 上的 OOM 问题。
## 模型配置
```python
hidden_size = 4096
intermediate_size = 14336
num_layers = 32
num_heads = 32
num_kv_heads = 8
head_dim = 128
seq_len = 65536
dtype = bfloat16 (2 bytes)
```
## 理论内存占用
### GPU Only 模式
| 组件 | 计算公式 | 内存占用 |
|------|----------|----------|
| 模型权重 | 8.03B × 2 bytes | **16.06 GB** |
| KV Cache | 32 × 65536 × 8 × 128 × 2 × 2 | **8.19 GB** |
| Prefill 激活值峰值 | max(QKV, MLP) | **~2 GB** |
| **总计** | | **~26 GB** |
**结论**GPU only 模式需要 ~26 GB**RTX 3090 (24GB) 无法运行**。
### CPU Offload 模式
| 组件 | 计算公式 | 内存占用 |
|------|----------|----------|
| 模型权重 | 8.03B × 2 bytes | **16.06 GB** |
| Ring buffer | num_kv_buffers × seq_len × 128 KB/token | 258-1034 MB |
| GPU KV blocks | num_gpu_blocks × block_size × 128 KB/token | 256 MB (2 blocks) |
| Per-layer decode buffer | 32 layers × 缓冲 | 128 MB |
| 激活值峰值 (chunked) | chunk_size × hidden_size × 2 | ~50 MB |
| PyTorch 开销 | CUDA 上下文 + 碎片 | ~5-6 GB |
| **理论小计** | | **~17.5 GB** |
| **实际需求** | | **~23 GB** |
**配置参数**
- `num_kv_buffers`: Ring buffer 大小 (1-4),默认 4
- `num_gpu_blocks`: GPU 上的 KV cache block 数量
- `block_size`: 每个 block 的 token 数
## OOM 问题分析
### 实际观测RTX 3090, num_kv_buffers=1
```
PyTorch allocated: 22.49 GB
PyTorch reserved: 429 MB
Free: 306 MB
Total available: 735 MB
Failed to allocate: 508 MB (torch.cat)
```
### 内存碎片来源
| 来源 | 说明 | 影响 |
|------|------|------|
| Binned 分配器 | PyTorch 使用固定大小的内存池 | 中等 |
| torch.compile 缓存 | 编译后的 kernel 代码和常量 | 高 (~2-3 GB) |
| 频繁分配/释放 | chunked 处理中每个 chunk 的创建销毁 | 高 |
| 不同大小张量 | (128,4096), (65536,6144) 等 | 中等 |
### torch.cat 内存需求
Chunked MLP 处理chunk_size=128
```
65536 / 128 = 512 chunks
每个 chunk 输出: (128, 4096) × 2 bytes = 1 MB
torch.cat 拼接需要: (65536, 4096) × 2 bytes = 508 MB (连续)
```
## 已尝试的优化
| 优化项 | 效果 |
|--------|------|
| 移除 `@torch.compile` | PyTorch: 23.13 → 22.80 GB (-300 MB) |
| 减少 `num_kv_buffers` (4→1) | Ring buffer: 1034 → 258 MB (-776 MB) |
| Chunked QKV/MLP/LayerNorm | 峰值激活: ~2 GB → ~50 MB |
| 降低 GPU 利用率 (0.9→0.75) | 无明显效果 |
| 减小 chunk_size (4096→128) | 峰值降低,但 torch.cat 需要连续内存 |
### 最终状态
```
理论需求: ~17.5 GB
实际分配: 22.49 GB
剩余空间: 735 MB (306 MB + 429 MB reserved)
分配失败: 508 MB (torch.cat 需要连续内存)
```
## 结论
### 根本原因
**不是绝对内存不足,而是内存碎片导致的分配失败**
理论需求 17.5 GB < 24 GB但由于
- PyTorch 开销CUDA 上下文、碎片):~5-6 GB
- torch.compile 缓存:~2-3 GB已移除
- 内存碎片导致无法分配 508 MB 连续块
### 硬件限制
| GPU | 显存 | 64k GPU Only | 64k Offload |
|-----|------|--------------|--------------|
| RTX 3090 | 24 GB | ❌ | ⚠️ 碎片问题 |
| RTX 4090 | 24 GB | ❌ | ⚠️ 碎片问题 |
| A100 | 40 GB | ✅ | ✅ |
| A100 | 80 GB | ✅ | ✅ |
### 建议
1. **64k 推理建议使用 40GB+ 显存的 GPU**
2. RTX 3090/4090 适合 32k 或更短的场景
3. 如必须在 24GB GPU 上运行 64k
- 使用 RAPIDS RMM 分配器
- 预分配 torch.cat 需要的内存
- 或使用流式处理避免 torch.cat
## 参考
- [PyTorch 内存管理文档](https://docs.pytorch.org/docs/stable/generated/torch.cuda.memory.memory_stats.html)
- [PyTorch 内存碎片讨论](https://discuss.pytorch.org/t/how-to-reduce-memory-fragmentation-when-enable-expandable-segments/221805)
- [STWeaver - 减少 79% 内存碎片](https://arxiv.org/html/2507.16274v1)

View File

@@ -0,0 +1,161 @@
# 64K Prefill MLP Activation OOM Issue
## Problem Summary
When running RULER benchmark with 64K context length using CPU offload mode, OOM occurs during MLP forward pass in `run_layerwise_offload_prefill`. The KV cache is successfully offloaded to CPU, but MLP intermediate activations exceed available GPU memory.
## Environment
- GPU: RTX 3090 (24GB)
- Model: LLaMA 3.1 8B
- Sequence Length: 65536 tokens
- Mode: `enable_cpu_offload=True`, `num_gpu_blocks=2`
## Error Message
```
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.47 GiB.
GPU 0 has a total capacity of 23.57 GiB of which 2.66 GiB is free.
Including non-PyTorch memory, this process has 20.88 GiB memory in use.
Of the allocated memory 20.51 GiB is allocated by PyTorch, and 32.26 MiB
is reserved by PyTorch but unallocated.
```
## Stack Trace
```
File "nanovllm/engine/model_runner.py", line 843, in run_layerwise_offload_prefill
hidden_states = layer.mlp(hidden_states)
File "nanovllm/models/llama.py", line 103, in forward
gate_up = self.gate_up_proj(x)
File "nanovllm/layers/linear.py", line 73, in forward
return F.linear(x, self.weight, self.bias)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.47 GiB.
```
## Root Cause Analysis
### Memory Breakdown
| Component | Calculation | Size |
|-----------|-------------|------|
| Model weights (BF16) | 8B params × 2 bytes | ~16 GB |
| GPU KV cache | 2 blocks × 1024 tokens × 8KB/token | ~16 MB |
| **Remaining for activations** | 24 - 16 - overhead | **~6-7 GB** |
### MLP Activation Memory (per layer)
For LLaMA 3.1 8B with `hidden_size=4096`, `intermediate_size=14336`:
| Tensor | Shape | Size (BF16) |
|--------|-------|-------------|
| MLP input | [65536, 4096] | 512 MB |
| gate_up output | [65536, 28672] | **3.47 GB** |
| down_proj input | [65536, 14336] | 1.75 GB |
| MLP output | [65536, 4096] | 512 MB |
**Peak MLP memory**: ~3.5-4 GB for intermediate tensors
### Why OOM Occurs
1. Model weights consume ~16 GB (loaded on GPU for layer-wise processing)
2. Available memory: ~7 GB
3. MLP `gate_up_proj` output: 3.47 GB
4. Additional tensors (input, gradients, etc.): ~1-2 GB
5. **Total required > Available** → OOM
## Code Location
The issue is in `nanovllm/engine/model_runner.py`:
```python
# Line 843 in run_layerwise_offload_prefill
hidden_states = layer.mlp(hidden_states) # <-- OOM here
```
The entire sequence (65536 tokens) is passed through MLP in one shot.
## Current Configuration
From `model_wrappers.py` (RULER integration):
```python
llm_kwargs = {
"max_model_len": max_model_len, # 128 * 1024
"max_num_batched_tokens": max_model_len, # Same as max_model_len
"enable_cpu_offload": True,
"num_gpu_blocks": 2,
...
}
```
Setting `max_num_batched_tokens = max_model_len` causes nanovllm to process all tokens at once.
## Potential Solutions
### Option 1: Chunked MLP Processing
Modify `run_layerwise_offload_prefill` to process MLP in chunks:
```python
# Instead of:
hidden_states = layer.mlp(hidden_states)
# Do:
chunk_size = 8192 # Process 8K tokens at a time
chunks = hidden_states.split(chunk_size, dim=0)
outputs = []
for chunk in chunks:
outputs.append(layer.mlp(chunk))
hidden_states = torch.cat(outputs, dim=0)
```
### Option 2: Activation Checkpointing
Use gradient checkpointing to recompute activations instead of storing them:
```python
from torch.utils.checkpoint import checkpoint
hidden_states = checkpoint(layer.mlp, hidden_states, use_reentrant=False)
```
### Option 3: Reduce Chunk Size via Config
Add a new config parameter `prefill_chunk_size` to control how many tokens are processed per forward pass.
## Memory Estimation Formula
For a given sequence length `S` and model config:
```
MLP_peak_memory = S × intermediate_size × 2 × 2 bytes
= S × 14336 × 4 bytes
For S = 65536:
MLP_peak = 65536 × 14336 × 4 = 3.76 GB
```
Maximum safe sequence length for RTX 3090 (24GB):
```
S_max = available_memory / (intermediate_size × 4)
= 6GB / (14336 × 4)
≈ 100K tokens (theoretical)
≈ 8-16K tokens (practical, with safety margin)
```
## Reproduction Steps
```bash
cd /home/zijie/Code/COMPASS/eval/RULER/scripts
# Set SEQ_LENGTHS to 65536 in config_models.sh
# Then run:
./run.sh llama3.1-8b-nanovllm synthetic --metric full --task niah_single_1
```
## Related Files
- `nanovllm/engine/model_runner.py`: `run_layerwise_offload_prefill()` (line 751+)
- `nanovllm/models/llama.py`: `LlamaMLP.forward()` (line 103)
- `nanovllm/config.py`: Config parameters
- RULER integration: `eval/RULER/scripts/pred/model_wrappers.py`

View File

@@ -0,0 +1,191 @@
# Block-Sparse-Attention Library Reference
MIT Han Lab 的块稀疏注意力内核库,基于 FlashAttention 2.4.2 修改,支持多种稀疏注意力模式。
## 库信息
- **来源**: [MIT-Han-Lab/Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention)
- **本地路径**: `3rdparty/Block-Sparse-Attention` (submodule, branch: `tzj/minference`)
- **基于**: FlashAttention 2.4.2
- **安装位置**: `site-packages/block_sparse_attn`
## 支持的稀疏模式
### 1. Dense Attention
计算完整注意力矩阵,无稀疏化。
### 2. Token Streaming (token granularity)
固定数量的 sink tokens + local tokens参考 [StreamingLLM](https://arxiv.org/abs/2309.17453)。
**适用场景**: 需要精确保留部分关键 token 的长上下文推理
### 3. Block Streaming (block granularity)
Block 粒度的 streaming attentionblock_size = 128。
**适用场景**: 长序列推理,牺牲少量精度换取更大加速
### 4. Block Sparse
基于自定义 block mask 的稀疏注意力。
**适用场景**: 已知特定 attention 模式的工作负载
### 混合模式
**关键特性**: 支持不同 head 使用不同稀疏模式
```python
# 8 个 heads 的混合配置示例
head_mask_type = [1, 1, 0, 0, 0, -1, 0, -1]
# 含义:
# - head 0,1: blocksparse (使用 basemask[0])
# - head 2-4,6: dense
# - head 5,7: streaming
```
**Mask 类型编码**:
- `0` = Dense attention
- `-1` = Streaming attention
- `1, 2, ...` = Block sparse (使用 basemask[mask_type - 1])
## API 参考
### `block_sparse_attn_func`
通用块稀疏注意力函数,支持所有模式。
```python
from block_sparse_attn import block_sparse_attn_func
output = block_sparse_attn_func(
q, k, v, # [total_tokens, heads, head_dim] unpadded
cu_seqlens_q, cu_seqlens_k, # cumulative sequence lengths
head_mask_type, # [heads] tensor, 每个头的模式
streaming_info, # streaming 配置 (sink/local 数量)
base_blockmask, # [q_blocks, k_blocks, n_masks] bool tensor
max_seqlen_q, max_seqlen_k, # 最大序列长度
p_dropout, # dropout 概率 (推理时设为 0.0)
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False, # True=token streaming, False=block streaming
return_attn_probs=False,
)
```
**关键参数**:
| 参数 | 类型 | 说明 |
|------|------|------|
| `head_mask_type` | Tensor[heads] | 每个头的稀疏模式0=dense, -1=streaming, 1+=blocksparse |
| `streaming_info` | Tensor | [sink_blocks, local_blocks] 或 [sink_tokens, local_tokens] |
| `base_blockmask` | Tensor | Block mask形状 [q_blocks, k_blocks, n_masks] |
| `exact_streaming` | bool | True=token 粒度False=block 粒度 streaming |
### `block_streaming_attn_func`
Block 粒度 streaming attentionblock_size=128
```python
from block_sparse_attn import block_streaming_attn_func
output = block_streaming_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info, # [sink_blocks, local_blocks]
max_seqlen_q, max_seqlen_k,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=True,
return_attn_probs=False,
)
```
### `token_streaming_attn_func`
Token 粒度 streaming attention。
**注意**: 不支持反向传播(仅推理)。
```python
from block_sparse_attn import token_streaming_attn_func
output = token_streaming_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info, # [sink_tokens, local_tokens]
max_seqlen_q, max_seqlen_k,
deterministic=False,
softmax_scale=None,
return_attn_probs=False,
)
```
## 技术规格
| 特性 | 支持情况 |
|------|----------|
| **数据类型** | fp16, bf16 (bf16 需要 Ampere/Ada/Hopper GPU) |
| **Head 维度** | 32, 64, 128 |
| **Block Size** | 128 (固定) |
| **CUDA 要求** | 11.6+ |
| **PyTorch 要求** | 1.12+ |
## 性能参考
测试环境: A100 GPU, head_dim=128, 32 heads, batch_size=1
### Block Sparse 加速比
- 相比 FlashAttention2: 最高 **3-4x** 加速
- 加速随序列长度增加而提升
### Streaming 混合模式加速比
- Token streaming: 64 sink + 256 local tokens
- Block streaming: 1 sink block + 3 local blocks
- **50% Dense + 50% Streaming**: 最高 **2x** 加速
## 与 nano-vllm 的集成考虑
### 潜在集成点
1. **长上下文推理优化**
- 使用 block streaming 减少计算量
- 在 CPU offload 模式下减少 GPU-CPU 传输
2. **混合注意力策略**
- 部分 head 使用 streaming减少计算
- 部分 head 使用 dense保持精度
- 参考 Duo Attention 论文的混合模式
3. **稀疏 offload**
- 只 offload 重要 blocks 的 KV cache
- 结合 `requires_block_selection` 接口
### 实现注意事项
1. **输入格式**: 库使用 unpadded 格式(`cu_seqlens`),需要与 nano-vllm 的 padded 格式转换
2. **Block size 固定**: 库固定 block_size=128需要适配
3. **Streaming info 配置**: 需要根据模型特性调整 sink/local 数量
## 相关工作
- [FlashAttention](https://github.com/Dao-AILab/flash-attention) - 基础实现
- [StreamingLLM](https://arxiv.org/abs/2309.17453) - Streaming attention 理论基础
- [Duo Attention](https://github.com/mit-han-lab/duo-attention) - 混合 dense/streaming 模式
- [MInference](https://arxiv.org/abs/2407.02490) - 混合 mask 方法
## 测试
库自带测试位于 `3rdparty/Block-Sparse-Attention/block_sparse_tests/`:
```bash
# 正确性测试
cd 3rdparty/Block-Sparse-Attention/block_sparse_tests/fwd/test_correctness
pytest full_test.py
# 性能测试
cd 3rdparty/Block-Sparse-Attention/block_sparse_tests/fwd/test_performance
python token_streaming.py
python blocksparse.py
```

324
docs/development_notes.md Normal file
View File

@@ -0,0 +1,324 @@
# Notes: Sparsity Integration into Layerwise Offload
## Current Architecture Analysis
### GPU-Only Path vs Offload Path
| Aspect | GPU-Only | Layerwise Offload |
|--------|----------|-------------------|
| KV Storage | GPU blocks (paged) | CPU pinned + GPU ring buffer |
| Prefill | All layers → then attention | Per-layer: attention → offload |
| Decode | FlashAttn with block table | Ring buffer H2D → FlashAttn |
| Sparse Support | MInference via `attention.py` | Not integrated |
### MInference Flow (GPU-Only)
```
attention.py:101-105:
if context.sparse_prefill_policy is not None:
o = context.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
minference.py:sparse_prefill_attention():
1. estimate_pattern(q, k, layer_id) -> vertical_indices, slash_indices
2. _triton_mixed_sparse_attention(q, k, v, indices)
3. return output
```
### Quest Flow (GPU Block Mode)
```
hybrid_manager.py (if using CPU offload with Quest):
select_blocks(available_blocks, ctx) -> selected block IDs
-> load selected blocks to GPU
-> standard FlashAttn with loaded blocks
```
### Layerwise Offload Prefill Flow
```
model_runner.py:run_layerwise_offload_prefill():
for layer_id in range(num_layers):
# QKV projection
q, k, v = qkv_proj(hidden_ln)
# RoPE
q, k = rotary_emb(positions, q, k)
# FULL attention (no sparsity!)
attn_output = flash_attn_varlen_func(q, k, v, ...)
# MLP
hidden_states = mlp(attn_out + residual)
# Sync offload ALL k, v to CPU
for block_id in cpu_block_ids:
k_cache_cpu[layer_id, block_id].copy_(k[start:end])
v_cache_cpu[layer_id, block_id].copy_(v[start:end])
```
### Layerwise Offload Decode Flow
```
model_runner.py:run_layerwise_offload_decode():
# Preload first N layers to ring buffer
for i in range(num_buffers):
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
for layer_id in range(num_layers):
current_buffer = layer_id % num_buffers
# Wait for buffer load
offload_engine.wait_buffer_load(current_buffer)
# Get prefilled KV from ring buffer (ALL blocks loaded)
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
# QKV for new token
q, k_new, v_new = qkv_proj(hidden_ln)
# Concat and full attention
k_full = torch.cat([k_prefill, k_decode_prev, k_new])
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...)
# Start loading next layer
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
```
## Integration Points
### 1. Prefill Sparse Integration Point
**Location:** `model_runner.py:535-543`
**Current:**
```python
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
```
**After Integration:**
```python
if self.sparse_policy and self.sparse_policy.supports_offload_prefill:
attn_output, k_sparse, v_sparse = self.sparse_policy.offload_prefill_attention(
q, k, v, layer_id
)
k_to_offload = k_sparse if k_sparse is not None else k
v_to_offload = v_sparse if v_sparse is not None else v
else:
attn_output = flash_attn_varlen_func(q, k, v, ...)
k_to_offload, v_to_offload = k, v
```
### 2. Decode Sparse Integration Point
**Location:** `model_runner.py:636-637` and `model_runner.py:704-706`
**Current (preload):**
```python
for i in range(num_preload):
offload_engine.load_layer_kv_to_buffer(
i, i, cpu_block_table, valid_tokens_per_block
)
```
**After Integration:**
```python
for i in range(num_preload):
layer_to_load = i
if self.sparse_policy and self.sparse_policy.supports_offload_decode:
# Prepare q for this layer (need to compute ahead)
# OR: use previous layer's pattern as estimate
selected_blocks = self.sparse_policy.select_offload_blocks(
None, # q not available yet at preload
layer_to_load,
cpu_block_table,
valid_tokens_per_block
)
else:
selected_blocks = cpu_block_table
offload_engine.load_sparse_layer_kv_to_buffer(
i, layer_to_load, selected_blocks, valid_tokens_per_block
)
```
**Challenge:** Q is not available during preload phase!
**Solutions:**
1. Skip sparse preload, only sparse for non-preloaded layers
2. Use previous decode step's pattern as estimate
3. Add preload hook to sparse policy
### 3. Offload Engine Extension
**New Method in OffloadEngine:**
```python
def load_sparse_layer_kv_to_buffer(
self,
buffer_idx: int,
layer_id: int,
selected_cpu_block_ids: List[int],
original_valid_tokens: List[int],
) -> int:
"""
Load only selected blocks from CPU to buffer.
Returns:
Total tokens loaded (may be less than full sequence)
"""
stream = self.layer_load_streams[buffer_idx]
with torch.cuda.stream(stream):
stream.wait_event(self.buffer_compute_done_events[buffer_idx])
# Build mapping: original block -> selected position
offset = 0
for i, cpu_block_id in enumerate(selected_cpu_block_ids):
# Find original index to get valid tokens
valid_tokens = original_valid_tokens[i] # Need mapping
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
non_blocking=True
)
# ... v_cache same
offset += valid_tokens
self.buffer_load_events[buffer_idx].record(stream)
return offset # Caller needs to know actual loaded tokens
```
## Metadata Flow for Quest
### During Prefill Offload
**Current:** No metadata collection in offload path
**Required:** Call `on_prefill_offload()` for each block
```python
# In run_layerwise_offload_prefill()
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
# BEFORE offload: update Quest metadata
if self.sparse_policy and hasattr(self.sparse_policy, 'on_prefill_offload'):
self.sparse_policy.on_prefill_offload(
cpu_block_id, layer_id, k[start:end], actual_size
)
# Offload
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
```
### Quest Metadata Shape
```python
# BlockMetadataManager
key_min: [num_blocks, num_layers, num_kv_heads, head_dim] # Min key per block per layer
key_max: [num_blocks, num_layers, num_kv_heads, head_dim] # Max key per block per layer
```
**Memory:** 2 * num_blocks * num_layers * kv_heads * head_dim * 2 bytes
- Example: 1000 blocks * 28 layers * 4 heads * 128 dim * 2 * 2 = ~57 MB
## Performance Considerations
### MInference Prefill Overhead
| Operation | Time (64K seq) |
|-----------|----------------|
| Pattern estimation (last-64) | ~5ms |
| Triton sparse attention | ~80ms |
| Full FlashAttention | ~100ms |
| **Net Speedup** | ~15-20% |
### Quest Decode Overhead
| Operation | Time |
|-----------|------|
| Block scoring (GPU metadata) | ~0.1ms |
| Top-K selection | ~0.05ms |
| Sparse H2D load (8 blocks) | ~2ms |
| Full H2D load (100 blocks) | ~20ms |
| **Net Speedup** | ~10x H2D |
### Memory Trade-offs
| Mode | GPU Memory | CPU Memory | H2D Bandwidth |
|------|------------|------------|---------------|
| Full offload | Ring buffer | Full KV | High |
| Sparse offload | Ring buffer | Full KV | Low (subset) |
| Aggressive sparse | Ring buffer | Sparse KV | Very low |
## Edge Cases
### 1. Short Sequences (< sparse threshold)
```python
if total_tokens < sparse_threshold:
# Fall back to full attention
use_sparse = False
```
### 2. First Decode Step (no previous Q)
Quest can't score blocks without Q. Options:
- Use average embedding as proxy
- Load all blocks for first step
- Use prefill pattern as estimate
### 3. Variable Sequence Lengths in Batch
Layerwise offload currently only supports batch_size=1:
```python
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
```
Sparse integration should maintain this constraint.
### 4. Ring Buffer vs Sparse Load Mismatch
Ring buffer assumes fixed `total_prefill_tokens`:
```python
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, total_prefill_tokens)
```
Sparse load has variable token count. Need:
```python
# Track actual loaded tokens per buffer
loaded_tokens[buffer_idx] = sparse_load_count
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, loaded_tokens[buffer_idx])
```
## Testing Strategy
### Unit Tests
1. `test_sparse_policy_interface.py` - Verify new interface methods
2. `test_minference_offload.py` - MInference in offload mode
3. `test_quest_offload.py` - Quest block selection in offload mode
### Integration Tests
1. `test_offload_sparse_e2e.py` - Full prefill+decode with sparsity
2. `test_accuracy_comparison.py` - Compare outputs: full vs sparse
### Benchmarks
1. `bench_offload_sparse.py` - Compare:
- Full offload (baseline)
- MInference prefill + Quest decode
- Aggressive sparse offload

597
docs/xattention_analysis.md Normal file
View File

@@ -0,0 +1,597 @@
# COMPASS XAttention Implementation Analysis
**Analysis Date**: 2026-01-14
**Researcher**: Claude Code Agent
**Source**: `/home/zijie/Code/COMPASS/compass/src/`
---
## Executive Summary
COMPASS XAttention is a **block sparse attention** implementation that uses:
1. **Approximation phase** (`xattn_estimate`) to compute attention importance and select blocks
2. **Computation phase** (`Xattention_prefill`) to compute sparse attention using `block_sparse_attn_func`
3. **Triton kernels** for efficient block-wise GEMM and softmax operations
**Key Integration Constraint**: Requires `block_sparse_attn_func` from flash-attention library, which is a **C++ CUDA extension** that must be compiled separately.
---
## 1. Function: `xattn_estimate()`
**Purpose**: Estimate attention importance and select which blocks to compute
### Input Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `query_states` | Tensor | - | Shape: `(batch, num_heads, q_len, head_dim)` |
| `key_states` | Tensor | - | Shape: `(batch, num_kv_heads, k_len, head_dim)` |
| `block_size` | int | - | Size of attention blocks (typically 128) |
| `stride` | int | - | Downsampling stride for approximation |
| `norm` | float | 1 | Normalization factor for attention scaling |
| `softmax` | bool | True | Whether to apply softmax in estimation |
| `threshold` | float | 0.9 | Block selection threshold (0-1) |
| `chunk_size` | int | 16384 | Processing chunk size |
| `select_mode` | str | "inverse" | Pattern selection mode |
| `use_triton` | bool | True | Use Triton kernels (requires SM 80+) |
| `causal` | bool | True | Apply causal masking |
| `kdb` | int | 1 | Key downsampling factor |
| `keep_sink` | bool | False | Always attend to first token |
| `keep_recent` | bool | False | Always attend to recent tokens |
### Output
```python
returns: (attn_sums, simple_masks)
attn_sums: Tensor[float32]
Shape: (batch, num_heads, num_q_blocks, num_k_blocks_per_chunk)
Contains aggregated attention weights per block
simple_masks: Tensor[bool]
Shape: (batch, num_heads, num_q_blocks, num_k_blocks)
Boolean mask indicating which blocks to compute
```
### Algorithm
#### Step 1: Padding and Chunking
```python
# Pad sequences to chunk_size boundaries
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
# Compute number of blocks and chunks
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
k_block_num = (k_len + k_num_to_pad) // block_size
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
q_block_num = (q_len + q_num_to_pad) // block_size
```
#### Step 2: Pattern Selection (stride-based downsampling)
**Purpose**: Reduce computation by `stride` factor using patterned selection
**Modes**:
1. **`"inverse"`** (default): Inverse stride pattern
```python
# Key: regular stride [0, stride, 2*stride, ...]
# Query: reverse stride [(stride-1), (stride-1-stride), ...]
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
reshaped_query = torch.cat([query_states[:, :, (stride-1-q)::stride*kdb, :] for q in range(stride)])
```
2. **`"slash"`**: Slash pattern (diagonal)
```python
# Both use regular stride
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
reshaped_query = torch.cat([query_states[:, :, q::stride, :] for q in range(stride)])
```
3. **`"random"`**: Random permutation
4. **`"double"`, `"triple"`**: Data augmentation modes
#### Step 3: Chunk-wise Attention Estimation
For each query chunk:
**If `use_triton=True`** (fast path):
```python
# Triton kernel 1: Compute attention scores with fused reshape
attn_weights_slice = flat_group_gemm_fuse_reshape(
query_chunk, key_states, stride,
chunk_start, chunk_end, is_causal=causal
)
# Triton kernel 2: Softmax + block aggregation
attn_sum = softmax_fuse_block_sum(
attn_weights_slice, reshaped_block_size, segment_size,
chunk_start, chunk_end, real_q_len, scale, is_causal
)
```
**If `use_triton=False`** (PyTorch fallback):
```python
# Standard matrix multiplication
attn_weights_slice = torch.matmul(chunked_query, reshaped_key.transpose(2, 3))
# Scale and apply causal mask
attn_weights_slice = attn_weights_slice / sqrt(head_dim) / stride / norm
attn_weights_slice = attn_weights_slice + causal_mask
# Softmax
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1)
# Aggregate to block level
attn_sum = attn_weights_slice.view(
batch, heads, num_blocks_per_chunk, block_size//kdb, -1, block_size
).sum(dim=-1).sum(dim=-2)
```
#### Step 4: Block Selection
```python
# Select blocks based on threshold
simple_mask = find_blocks_chunked(
attn_sum,
current_index, # Starting block index
threshold, # 0.9 = select blocks covering 90% of attention mass
None, # or num_to_choose for top-k selection
decoding=False,
mode="prefill",
causal=True
)
```
**Selection Algorithm** (`find_blocks_chunked`):
1. Sort blocks by attention weight (descending)
2. Compute cumulative sum
3. Select blocks until `cumulative_sum >= total_sum * threshold`
4. Enforce causal constraints (no future blocks)
5. Always include sink token (first block) if `keep_sink=True`
6. Always include diagonal blocks if `keep_recent=True`
---
## 2. Function: `Xattention_prefill()`
**Purpose**: Compute sparse attention using estimated block mask
### Input Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `query_states` | Tensor | - | `(batch, num_heads, q_len, head_dim)` |
| `key_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
| `value_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
| `stride` | int | - | Downsampling stride for estimation |
| `norm` | float | 1 | Normalization factor |
| `threshold` | float | 0.8 | Block selection threshold |
| `block_size` | int | 128 | **MUST be 128** (hardcoded requirement) |
| `use_triton` | bool | True | Use Triton kernels in estimation |
| `causal` | bool | True | Apply causal masking |
| `kdb` | int | 1 | Key downsampling factor |
| `chunk_size` | int | None | Auto-computed if None |
| `keep_sink` | bool | False | Always attend to first token |
| `keep_recent` | bool | False | Always attend to recent tokens |
### Output
```python
returns: attn_output
attn_output: Tensor
Shape: (batch, num_heads, q_len, head_dim)
Sparse attention output
```
### Algorithm Flow
#### Step 1: Auto-compute chunk_size
```python
if chunk_size is None:
chunk_size = int(max(
min(
max(2048, 1 << (k_len - 1).bit_length()), # Round to power of 2
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()), # Memory constraint
),
2048, # Minimum
))
```
**Example**:
- `k_len=8192` → `chunk_size=8192`
- `k_len=32768` → `chunk_size=16384`
- `k_len=65536` → `chunk_size=16384`
#### Step 2: Estimate attention and select blocks
```python
attn_sums, approx_simple_mask = xattn_estimate(
query_states, key_states,
block_size=block_size, stride=stride, norm=norm,
threshold=threshold, select_mode="inverse",
use_triton=use_triton, causal=causal,
chunk_size=chunk_size, kdb=kdb,
keep_sink=keep_sink, keep_recent=keep_recent
)
```
#### Step 3: Prepare inputs for block_sparse_attn_func
```python
# Hard constraints
assert block_size == 128
assert batch_size == 1
# Reshape to (seq_len, num_heads, head_dim)
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
# Cumulative sequence lengths
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
# Head mask type (all heads use mask)
head_mask_type = torch.tensor([1 for _ in range(num_heads)], dtype=torch.int32)
```
#### Step 4: Call block_sparse_attn_func
```python
attn_output = block_sparse_attn_func(
query_states, # (q_len, num_heads, head_dim)
key_states, # (k_len, num_heads, head_dim)
value_states, # (k_len, num_heads, head_dim)
q_cu_seq_lens, # [0, q_len]
k_cu_seq_lens, # [0, k_len]
head_mask_type, # [1, 1, ..., 1]
None, # No custom layout
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(), # Block mask
q_len,
k_len,
p_dropout=0.0,
deterministic=True,
is_causal=causal
)
```
#### Step 5: Reshape output
```python
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
# Output shape: (batch, num_heads, q_len, head_dim)
```
---
## 3. Triton Kernel Dependencies
### Kernel 1: `flat_group_gemm_fuse_reshape_kernel`
**Purpose**: Compute QK^T with stride-based reshaping
**Key Features**:
- Loads `stride` keys and queries at once
- Fused strided access pattern
- Causal masking support
- Block size auto-selection based on GPU memory
**Block Size Selection**:
```python
# RTX 3090 (<30GB): BLOCK_M=64, BLOCK_N=64
# A100/H100 (>=30GB): BLOCK_M=128, BLOCK_N=128
```
**Signature**:
```python
flat_group_gemm_fuse_reshape(
query_states, # (batch, heads, q_len, head_dim)
key_states, # (batch, heads, k_len, head_dim)
stride, # Downsampling factor
chunk_start, # Start position in keys
chunk_end, # End position in keys
is_causal=True
)
# Returns: (batch, heads, q_len//stride, k_len//stride)
```
### Kernel 2: `softmax_fuse_block_sum_kernel_causal` / `_non_causal`
**Purpose**: Online softmax with block aggregation
**Algorithm**:
1. **Forward pass** (compute m_i, l_i):
```
m_i = max(m_i, m_local)
alpha = exp(m_i - m_new)
l_i = l_i * alpha + sum(exp(X - m_new))
```
2. **Backward pass** (compute softmax with scaling):
```
softmax = exp(X - m_i) / l_i
aggregate to blocks: sum(softmax) over block_size
```
**Key Features**:
- Single-pass softmax (no materializing full attention matrix)
- Causal masking integrated
- Outputs block-level sums directly
**Signature**:
```python
softmax_fuse_block_sum(
attn_weights_slice, # (batch, heads, q_len, k_len)
reshaped_block_size, # Block size (128//stride)
segment_size, # Processing segment (min(4096, block_size))
chunk_start, # Start position
chunk_end, # End position
real_q_len, # Actual query length (before padding)
scale, # 1.4426950408889634 / sqrt(head_dim) / stride / norm
is_causal=True
)
# Returns: (batch, heads, q_len//block_size, k_len//block_size)
```
---
## 4. Key Parameters and Their Meanings
### Critical Parameters
| Parameter | Meaning | Typical Value | Impact |
|-----------|---------|---------------|--------|
| `block_size` | Block granularity | 128 | **Fixed at 128**, affects mask granularity |
| `stride` | Downsampling factor | 4-16 | Higher = faster but less accurate |
| `threshold` | Sparsity level | 0.8-0.9 | Higher = denser mask, more computation |
| `chunk_size` | Processing chunk | 16384 | Affects memory and efficiency |
| `kdb` | Key downsampling boost | 1 | Experimental, use 1 |
| `norm` | Scaling factor | 1.0 | Attention temperature control |
### Trade-offs
**Stride (`stride`)**:
- `stride=1`: No approximation, same as dense attention
- `stride=4`: 4x faster estimation, good accuracy
- `stride=8`: 8x faster, moderate accuracy loss
- `stride=16`: 16x faster, significant accuracy loss
**Threshold (`threshold`)**:
- `threshold=0.8`: Select blocks covering 80% of attention mass (~20% sparsity)
- `threshold=0.9`: Select blocks covering 90% of attention mass (~10% sparsity)
- `threshold=0.95`: Very dense, only prunes ~5% of blocks
---
## 5. Dependencies
### Required Libraries
1. **`block_sparse_attn`** (CRITICAL)
- Source: `/home/zijie/Code/COMPASS/3rdparty/flash-attention/`
- Function: `block_sparse_attn_func`
- Type: **C++ CUDA extension**
- Build: Requires compilation with `torch.utils.cpp_extension`
2. **Triton** (optional but recommended)
- Required for: `use_triton=True`
- GPU requirement: SM 80+ (A100, RTX 3090, H100, etc.)
- Check: `torch.cuda.get_device_properties().major >= 8`
3. **PyTorch**
- Version: Compatible with flash-attention
- Features: F.pad, matmul, softmax, view, transpose
### Dependency Tree
```
Xattention_prefill
├── xattn_estimate
│ ├── flat_group_gemm_fuse_reshape (Triton)
│ ├── softmax_fuse_block_sum (Triton)
│ └── find_blocks_chunked (PyTorch)
└── block_sparse_attn_func (C++ CUDA)
```
---
## 6. Integration Issues for nano-vllm
### Critical Issue 1: `block_sparse_attn_func` Dependency
**Problem**: `block_sparse_attn_func` is a **C++ CUDA extension** that must be compiled from flash-attention source.
**Options**:
1. **Compile flash-attention with block sparse support**
```bash
cd /home/zijie/Code/COMPASS/3rdparty/flash-attention
python setup.py install
```
- Risk: May conflict with existing flash-attention installation
- Complexity: High (C++ compilation)
2. **Replace with FlashInfer block sparse**
- FlashInfer is already a dependency
- Has similar block sparse attention
- Need to adapt interface
3. **Custom CUDA kernel**
- Implement simplified block sparse attention
- High development cost
- Maintenance burden
### Critical Issue 2: Hard-coded Constraints
```python
assert block_size == 128 # Line 358
assert batch_size == 1 # Line 359
```
**Impact**:
- Cannot process multiple sequences in one batch
- Fixed block size limits flexibility
- Must work around these constraints
### Critical Issue 3: Triton GPU Requirement
```python
props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.major < 8:
use_triton = False
```
**Impact**:
- Triton kernels only work on SM 80+ (A100, RTX 3090, H100)
- Older GPUs (V100, T4, RTX 2080) fall back to slow PyTorch implementation
- RTX 3090 works but uses smaller block sizes (64 vs 128)
### Issue 4: Memory Layout
**XAttention expects**:
```python
query_states: (batch, num_heads, q_len, head_dim)
```
**nano-vllm uses**:
```python
query_states: (num_heads, total_tokens, head_dim) # Flattened batch
```
**Required**: Transpose and reshape before/after calling XAttention
### Issue 5: Chunking Incompatibility
**XAttention**: Processes in fixed-size chunks (e.g., 16384 tokens)
- Requires padding to chunk boundaries
- Adds overhead for short sequences
**nano-vllm**: Processes variable-length requests
- No padding requirement
- Dynamic batch sizing
---
## 7. Integration Strategy
### Recommended Approach: **Wrapper with FlashInfer**
1. **Keep `xattn_estimate`** (pure PyTorch + Triton)
- No external dependencies
- Computes block mask
2. **Replace `block_sparse_attn_func` with FlashInfer**
- FlashInfer: `flashinfer.single_prefill_with_kv_cache`
- Similar API, already compiled
- Supports block sparse
3. **Adapt mask format**
- XAttention: `(batch, heads, q_blocks, k_blocks)` boolean mask
- FlashInfer: `(num_qo, num_kv)` boolean mask or custom format
4. **Handle constraints**
- Enforce `batch_size=1` by processing one request at a time
- Keep `block_size=128` as requirement
### Alternative: **Pure PyTorch Implementation**
1. Extract estimation algorithm
2. Implement sparse attention using PyTorch operations
3. Use FlashInfer for final computation
4. No Triton dependency
---
## 8. Code Example: Adaptation
```python
def xattention_prefill_adapted(
query_states, # (num_heads, q_len, head_dim)
key_states, # (num_heads, k_len, head_dim)
value_states, # (num_heads, k_len, head_dim)
stride=4,
threshold=0.9,
block_size=128,
causal=True,
):
# Step 1: Add batch dimension
q = query_states.unsqueeze(0) # (1, heads, q_len, dim)
k = key_states.unsqueeze(0)
v = value_states.unsqueeze(0)
# Step 2: Estimate mask (no external dependency)
_, block_mask = xattn_estimate(
q, k,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
causal=causal,
)
# block_mask: (1, heads, q_blocks, k_blocks)
# Step 3: Convert block mask to token mask
q_blocks, k_blocks = block_mask.shape[-2:]
token_mask = block_mask.repeat_interleave(block_size, dim=-2)
token_mask = token_mask.repeat_interleave(block_size, dim=-1)
token_mask = token_mask[:, :, :q.size(2), :k.size(2)] # Trim padding
# Step 4: Use FlashInfer with mask
from flashinfer import single_prefill_with_kv_cache
output = single_prefill_with_kv_cache(
q.squeeze(0),
k.squeeze(0),
v.squeeze(0),
custom_mask=token_mask.squeeze(0),
)
return output # (num_heads, q_len, head_dim)
```
---
## 9. Summary of Findings
### Advantages
1. **Accurate approximation**: Pattern-based stride selection preserves attention patterns
2. **Flexible sparsity**: Threshold-based control over computation
3. **GPU optimization**: Triton kernels for estimation phase
4. **Proven in practice**: Used in COMPASS system
### Challenges
1. **Hard dependency**: `block_sparse_attn_func` requires C++ compilation
2. **Rigid constraints**: `block_size=128`, `batch_size=1`
3. **GPU-specific**: Triton only on SM 80+
4. **Memory layout mismatch**: Requires reshape/transpose
5. **Chunking overhead**: Padding to chunk boundaries
### Integration Complexity
| Component | Complexity | Risk |
|-----------|------------|------|
| `xattn_estimate` | Medium | Low (PyTorch + Triton) |
| `block_sparse_attn_func` | High | **Critical** (C++ dependency) |
| Interface adaptation | Low | Low (reshape) |
| Constraint handling | Medium | Medium (workarounds) |
**Overall Integration Risk**: **HIGH** (due to C++ dependency)
---
## 10. Next Steps
1. **Evaluate FlashInfer compatibility**
- Can FlashInfer replace `block_sparse_attn_func`?
- What mask format does it expect?
2. **Prototype estimation phase**
- Extract `xattn_estimate` function
- Test with nano-vllm inputs
- Validate mask quality
3. **Benchmark Triton kernels**
- Compare Triton vs PyTorch estimation
- Measure speedup on RTX 3090
- Profile memory usage
4. **Design interface**
- Define nano-vllm sparse attention API
- Specify mask format
- Plan integration points

View File

@@ -0,0 +1,961 @@
# XAttention 集成指南
本文档详细记录了将 COMPASS 的 XAttention 算法集成到 nano-vllm 的完整过程,包括算法原理、源码分析、设计决策、实现细节和测试验证。
## 目录
1. [背景](#1-背景)
2. [XAttention 算法原理](#2-xattention-算法原理)
3. [COMPASS 源码分析](#3-compass-源码分析)
4. [集成设计决策](#4-集成设计决策)
5. [实现细节](#5-实现细节)
6. [问题与解决方案](#6-问题与解决方案)
7. [测试验证](#7-测试验证)
8. [使用指南](#8-使用指南)
---
## 1. 背景
### 1.1 为什么需要 XAttention
- **长上下文推理需求**:随着 LLM 上下文长度扩展到 32k、64k 甚至更长,传统注意力机制的计算复杂度 O(n²) 成为瓶颈
- **COMPASS 算法**:通过 chunked estimation 和 block sparse attention 实现 O(n) 复杂度
- **nano-vllm 集成目标**:在 CPU offload 模式下支持高效的长上下文推理
### 1.2 集成范围
**仅关注 offload 执行路径**
- `run_layerwise_offload_prefill()` - layer-wise chunked prefill
- CPU offload 模式下的 KV cache 管理
-`SparsePolicy` 框架的集成
### 1.3 参考
- COMPASS 源码:`/home/zijie/Code/COMPASS/compass/src/`
- 关键文件:`Xattention.py`, `kernels.py`, `utils.py`
---
## 2. XAttention 算法原理
### 2.1 两阶段设计
```
┌─────────────────────────────────────────────────────────────┐
│ XAttention 流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Phase 1: Chunked Estimation │
│ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │
│ │ Query Chunk │ -> │ Triton GEMM │ -> │ Attn Scores │ │
│ │ (stride=8) │ │ (fused) │ │ (per block) │ │
│ └─────────────┘ └──────────────┘ └─────────────┘ │
│ ↓ │
│ ┌─────────────┐ │
│ │ Block Mask │ │
│ │ (threshold) │ │
│ └─────────────┘ │
│ │
│ Phase 2: Block Sparse Attention │
│ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │
│ │ Selected Q │ -> │ Block Sparse │ -> │ Output │ │
│ │ + Selected K│ │ Attention │ │ │ │
│ └─────────────┘ └──────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
```
### 2.2 关键参数
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `stride` | 8 | Q/K 重组步长 |
| `block_size` | 128 | Block 大小tokens |
| `threshold` | 0.9 | Block 选择阈值 (0-1) |
| `chunk_size` | 16384 | Estimation chunk 大小 |
### 2.3 计算流程
1. **Chunked Estimation**
- 将 Q 分成固定大小的 chunks
- 使用 Triton kernels 计算 QK^Tfused GEMM + reshape
- 分块 softmax 并聚合到 block 级别
- 根据阈值选择重要 blocks
2. **Block Sparse Attention**
- 只计算选中 blocks 的注意力
- 使用 block sparse kernels 优化
---
## 3. COMPASS 源码分析
### 3.1 核心文件结构
```
COMPASS/compass/src/
├── Xattention.py # XAttention 主算法
├── kernels.py # Triton kernels
├── utils.py # 辅助函数
└── block_sparse.py # Block sparse attention
```
### 3.2 Xattention.py 分析
**核心函数**
```python
def xattn_estimate(
query_states, key_states, value_states,
stride, block_size, threshold, ...
):
"""
Phase 1: 估算稀疏注意力模式
返回:
attn_sums: [batch, heads, q_blocks, k_blocks] 重要性分数
simple_masks: [batch, heads, q_blocks, k_blocks] 布尔掩码
"""
# 1. Pad inputs to chunk_size multiples
# 2. Reshape with stride
# 3. Compute QK^T in chunks (Triton)
# 4. Block-wise softmax + aggregation
# 5. Threshold-based selection
return attn_sums, simple_masks
def Xattention_prefill(
query_states, key_states, value_states,
stride, threshold, ...
):
"""
完整 XAttention prefill
流程:
1. xattn_estimate() - 获取 block mask
2. block_sparse_attn_func() - 稀疏注意力计算
"""
attn_sums, simple_masks = xattn_estimate(...)
attn_output = block_sparse_attn_func(
query_states, key_states, value_states,
simple_masks, block_size
)
return attn_output
```
### 3.3 kernels.py 分析
**Triton Kernels**
```python
@triton.jit
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, ...):
"""
Stride-based GEMM with reshape fusion
关键优化:
- Stride 访问模式:每隔 stride 个 token 访问一次
- Fused reshape避免单独的 reshape 操作
- Block-level 并行M×N block tiling
"""
# Load Q and K with stride
for iter in range(STRIDE):
q = tl.load(Q_ptrs - iter * stride_qn)
k = tl.load(K_ptrs + iter * stride_kn)
o += tl.dot(q, k)
@triton.jit
def softmax_fuse_block_sum_kernel_causal(In, Out, ...):
"""
Block-wise softmax with sum aggregation
关键优化:
- Online softmax避免存储完整注意力矩阵
- Block sum聚合到 block 级别
- Causal mask支持因果注意力
"""
# Online softmax (m_i, l_i)
m_new = tl.maximum(m_i, m_local)
alpha = tl.math.exp2(m_i - m_new)
l_i = l_i * alpha + l_local
m_i = m_new
```
### 3.4 utils.py 分析
**关键函数**
```python
def find_blocks_chunked(
input_tensor, # [batch, heads, chunk_q, block_k]
current_index,
threshold, # 0-1
num_to_choose,
decoding,
mode,
causal
):
"""
基于阈值选择重要 blocks
返回:
boolean mask: [batch, heads, chunk_q, block_k]
"""
# 1. 计算阈值分数
score_threshold = input_tensor.max() * threshold
# 2. 生成布尔掩码
masks = (input_tensor >= score_threshold)
# 3. 应用因果约束
if causal:
# 只保留下三角区域
...
return masks
```
---
## 4. 集成设计决策
### 4.1 稀疏策略框架
nano-vllm 使用 `SparsePolicy` 抽象接口:
```python
class SparsePolicy(ABC):
"""稀疏注意力策略基类"""
@property
def supports_prefill(self) -> bool:
"""是否支持 prefill 阶段"""
...
@property
def supports_decode(self) -> bool:
"""是否支持 decode 阶段"""
...
@property
def requires_block_selection(self) -> bool:
"""是否需要 block selection用于 KV cache 加载)"""
...
@abstractmethod
def select_blocks(self, available_blocks, ctx) -> List[int]:
"""选择要加载的 KV blocks"""
...
@abstractmethod
def sparse_prefill_attention(self, q, k, v, layer_id) -> torch.Tensor:
"""计算稀疏 prefill 注意力"""
...
```
### 4.2 XAttention 设计决策
#### 决策 1Prefill-Only 策略
```python
class XAttentionPolicy(SparsePolicy):
supports_prefill = True
supports_decode = False # XAttention 仅用于 prefill
requires_block_selection = False # 不影响 KV cache 加载
```
**原因**
- XAttention 是 prefill 阶段的优化算法
- Decode 阶段使用其他策略(如 QUEST
- Block selection 不在 XAttention 范围内
#### 决策 2CPU Offload 模式简化
```python
def sparse_prefill_attention(self, q, k, v, layer_id):
# 使用 FlashAttention 直接计算
from flash_attn.flash_attn_interface import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=True,
)
return attn_output
```
**关键原因**
1. **Chunked Prefill 架构限制**
```
Offload 模式: run_layerwise_offload_prefill()
└─ 每次只处理一个 chunk (2048 tokens)
└─ 完整的 key_states 在 CPU不在当前调用栈
└─ 无法进行完整的 chunked estimation
```
2. **Estimation 需要完整上下文**
- XAttention 的 estimation 需要访问完整 key_states
- Offload 模式下 keys 分层存储在 CPU
- 传递所有 keys 会破坏 offload 的内存优势
3. **FlashAttention 原生支持 GQA**
- GQA (Grouped Query Attention): num_kv_heads < num_heads
- FlashAttention 自动处理 head 展开
- 避免手动实现的复杂性
#### 决策 3保留 Triton Kernels
虽然 CPU offload 模式使用 FlashAttention但仍保留 Triton kernels
```python
# nanovllm/kvcache/sparse/kernels.py
# 保留完整的 Triton 实现,供未来 GPU-only 模式使用
def softmax_fuse_block_sum(attn_weights_slice, ...):
"""Triton softmax + block sum wrapper"""
...
def flat_group_gemm_fuse_reshape(query_states, key_states, ...):
"""Triton GEMM + reshape wrapper"""
...
```
**原因**
- 未来可以支持 GPU-only 模式的完整 XAttention
- Triton kernels 已实现,无需删除
- 保持代码完整性
---
## 5. 实现细节
### 5.1 文件结构
```
nanovllm/kvcache/sparse/
├── __init__.py # 策略注册
├── policy.py # 基类定义
├── full_policy.py # Full attention 策略
├── quest.py # Quest 策略
├── minference.py # MInference 策略
├── xattn.py # XAttention 策略(新增)
├── utils.py # 工具函数(新增)
└── kernels.py # Triton kernels新增
```
### 5.2 utils.py 实现
```python
"""
Sparse attention utility functions.
Copied and adapted from COMPASS/compass/src/utils.py
"""
import torch
def find_blocks_chunked(
input_tensor,
current_index,
threshold,
num_to_choose,
decoding: bool,
mode: str = "both",
causal=True,
):
"""
Select blocks based on threshold.
Args:
input_tensor: [batch, heads, q_blocks, k_blocks] importance scores
current_index: Current chunk index
threshold: Block selection threshold (0-1)
num_to_choose: Number of blocks to choose (if None, use threshold)
decoding: Whether in decode mode
mode: Selection mode ("prefill", "decoding", "both")
causal: Apply causal mask
Returns:
boolean mask: [batch, heads, q_blocks, k_blocks]
"""
batch_size, head_num, chunk_q, block_k = input_tensor.shape
if num_to_choose is None:
# Threshold-based selection
score_threshold = input_tensor.max() * threshold
masks = (input_tensor >= score_threshold)
else:
# Top-k selection
topk_values, _ = torch.topk(
input_tensor.flatten(start_dim=2),
k=num_to_choose,
dim=-1
)
score_threshold = topk_values[..., -1:].unsqueeze(-1)
masks = (input_tensor >= score_threshold)
# Causal mask
if causal and chunk_q > 1:
for q_idx in range(chunk_q):
k_start = current_index + q_idx
masks[:, :, q_idx, :k_start] = False
return masks
```
### 5.3 kernels.py 实现
```python
"""
Triton kernels for XAttention sparse attention.
Copied and adapted from COMPASS/compass/src/kernels.py
Requirements:
- Triton >= 2.1.0
- CUDA compute capability SM 80+ (RTX 3090, A100, H100, etc.)
"""
import torch
import math
import triton
import triton.language as tl
@triton.jit
def softmax_fuse_block_sum_kernel_causal(
In, Out, scale,
input_stride_0, input_stride_1, input_stride_2,
output_stride_0, output_stride_1, output_stride_2,
real_q_len, k_len, chunk_start, chunk_end,
segment_size: tl.constexpr,
block_size: tl.constexpr,
):
"""
Causal softmax with block sum aggregation.
Online softmax algorithm:
m_i = max(m_i, m_new)
l_i = l_i * exp(m_i - m_new) + l_new
"""
block_id = tl.program_id(0)
head_id = tl.program_id(1)
batch_id = tl.program_id(2)
# ... (完整实现见源码)
@triton.jit
def flat_group_gemm_fuse_reshape_kernel(
Q, K, Out,
stride_qz, stride_qh, stride_qn,
stride_kz, stride_kh, stride_kn,
stride_oz, stride_oh, stride_on,
chunk_start, chunk_end,
H: tl.constexpr,
STRIDE: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
is_causal: tl.constexpr,
):
"""
Stride-based GEMM with reshape fusion.
"""
# ... (完整实现见源码)
def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size,
segment_size, chunk_start, chunk_end,
real_q_len, scale, is_causal=True):
"""Wrapper for Triton softmax-fuse-block-sum kernel."""
# ... (完整实现见源码)
def flat_group_gemm_fuse_reshape(query_states, key_states, stride,
chunk_start, chunk_end, is_causal=True):
"""Wrapper for Triton flat-group-gemm-fuse-reshape kernel."""
# ... (完整实现见源码)
```
### 5.4 xattn.py 实现
```python
"""
XAttention sparse attention policy for nano-vllm.
Implements the XAttention algorithm from COMPASS, using chunked estimation
and block sparse attention for efficient long-context inference.
Reference: COMPASS/compass/src/Xattention.py
"""
import math
from typing import List, Optional
import torch
import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.kernels import (
flat_group_gemm_fuse_reshape,
softmax_fuse_block_sum,
)
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
class XAttentionPolicy(SparsePolicy):
"""
XAttention sparse prefill policy using chunked estimation + block sparse attention.
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
"""
supports_prefill = True
supports_decode = False # XAttention is prefill-only
requires_block_selection = False # Only affects attention computation
def __init__(
self,
stride: int = 8,
threshold: float = 0.9,
chunk_size: Optional[int] = None,
use_triton: bool = True,
keep_sink: bool = False,
keep_recent: bool = False,
norm: float = 1.0,
):
"""
Initialize XAttention policy.
Args:
stride: Stride for reorganizing Q/K (default: 8)
threshold: Block selection threshold, 0-1 (default: 0.9)
chunk_size: Chunk size for estimation (auto if None)
use_triton: Use Triton kernels (requires SM 80+)
keep_sink: Always keep first block (sink tokens)
keep_recent: Always keep recent diagonal blocks
norm: Normalization factor for attention scores
"""
self.stride = stride
self.threshold = threshold
self.chunk_size = chunk_size
self.use_triton = use_triton
self.keep_sink = keep_sink
self.keep_recent = keep_recent
self.norm = norm
# Check Triton availability
if self.use_triton:
try:
import triton
props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.major < 8:
self.use_triton = False
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
except ImportError:
self.use_triton = False
print("XAttention: Triton not available. Falling back to PyTorch.")
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select blocks for decode phase.
XAttention is prefill-only, so this method is only used as a fallback.
Returns all available blocks by default.
"""
# XAttention is prefill-only, but we need to implement this abstract method
# Since requires_block_selection=False, this won't be called for loading
return available_blocks
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute XAttention sparse attention for prefill.
For CPU offload mode, uses FlashAttention directly with native GQA support.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
seq_len = q.shape[0]
num_heads = q.shape[1]
head_dim = q.shape[2]
num_kv_heads = k.shape[1]
# Use FlashAttention directly for CPU offload mode
# FlashAttention supports GQA natively
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=True,
)
return attn_output
except Exception as e:
# Fallback: PyTorch SDPA (supports GQA natively)
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
is_causal=True,
scale=1.0 / math.sqrt(head_dim)
)
return attn_output
def reset(self) -> None:
"""Reset policy state (no state to reset for XAttention)."""
pass
def __repr__(self) -> str:
return (f"XAttentionPolicy("
f"stride={self.stride}, "
f"threshold={self.threshold}, "
f"use_triton={self.use_triton})")
```
### 5.5 框架集成
**config.py - 添加配置参数**
```python
class SparsePolicyType(Enum):
"""Sparse attention policy types."""
FULL = auto()
QUEST = auto()
MINFERENCE = auto()
XATTN = auto() # 新增
@dataclass
class Config:
# ... 其他配置
# XAttention configuration
xattn_stride: int = 8
xattn_threshold: float = 0.9
xattn_chunk_size: int = 16384
xattn_use_triton: bool = True
xattn_keep_sink: bool = False
xattn_keep_recent: bool = False
xattn_norm: float = 1.0
```
**__init__.py - 注册策略**
```python
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
if policy_type == SparsePolicyType.XATTN:
return XAttentionPolicy(
stride=kwargs.get("stride", 8),
threshold=kwargs.get("threshold", 0.9),
chunk_size=kwargs.get("chunk_size", 16384),
use_triton=kwargs.get("use_triton", True),
keep_sink=kwargs.get("keep_sink", False),
keep_recent=kwargs.get("keep_recent", False),
norm=kwargs.get("norm", 1.0),
)
# ... 其他策略
```
**model_runner.py - 使用策略**
```python
# 在 SparsePolicy 初始化时自动选择
if self.config.sparse_policy == SparsePolicyType.XATTN:
self.sparse_prefill_policy = XAttentionPolicy(...)
```
---
## 6. 问题与解决方案
### 6.1 问题 1: Abstract Method Not Implemented
**错误**
```python
TypeError: Can't instantiate abstract class XAttentionPolicy
with abstract method select_blocks
```
**原因**
- `SparsePolicy` 是抽象基类,要求子类实现 `select_blocks()`
- XAttention 是 prefill-only 策略,不需要 block selection
**解决**
```python
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
"""
Select blocks for decode phase.
XAttention is prefill-only, so this method is only used as a fallback.
Returns all available blocks by default.
"""
# Since requires_block_selection=False, this won't be called for loading
return available_blocks
```
### 6.2 问题 2: CUDA OOM During Estimation
**错误**
```
CUDA out of memory. Tried to allocate 1013.92 GiB
```
**原因**
- `_xattn_estimate()` 使用 `q_len` 计算 `k_block_num`
- 但在 chunked prefill 中,`q_len` 是当前 chunk 大小2048
- 而不是完整上下文长度32768
- 导致 padding 计算错误
**原始代码问题**
```python
batch_size, num_heads, k_len, head_dim = key_states.shape
batch_size, num_heads, q_len, head_dim = query_states.shape
# 错误:使用 q_len 计算 k_block_num
k_block_num = (k_len + k_num_to_pad) // block_size # 应该用完整 k_len
```
**解决**
简化实现,直接使用 FlashAttention
```python
def sparse_prefill_attention(self, q, k, v, layer_id):
# 使用 FlashAttention 直接计算
# 不进行 chunked estimation与 offload 架构不兼容)
from flash_attn.flash_attn_interface import flash_attn_varlen_func
...
```
### 6.3 问题 3: GQA Head Count Mismatch
**错误**
```
ValueError: Number of heads in key/value must divide number of heads in query
```
**原因**
- Llama-3.1-8B 使用 GQAnum_heads=32, num_kv_heads=8
- 原始 XAttention 代码手动展开 KV heads
```python
# 错误方式
if num_kv_heads != num_heads:
key_states = key_states.repeat_interleave(num_heads // num_kv_heads, dim=1)
```
**解决**
依赖 FlashAttention 的原生 GQA 支持:
```python
# FlashAttention 自动处理 GQA无需手动展开
attn_output = flash_attn_varlen_func(
q, k, v, # k, v 可以有更少的 heads
...
)
```
### 6.4 Bug Fix: kernels.py Line 106
**原始代码**
```python
for iter in range(num_iters_before_causal + 1, num_iters):
X = torch.zeros([segment_size // block_size], dtype=torch.float32) # 错误
```
**修复**
```python
for iter in range(num_iters_before_causal + 1, num_iters):
X = tl.zeros([segment_size // block_size], dtype=torch.float32) # 正确
```
**原因**
- Triton JIT kernel 中必须使用 `tl.zeros` 而不是 `torch.zeros`
---
## 7. 测试验证
### 7.1 测试环境
- **模型**: Llama-3.1-8B-Instruct
- **GPU**: RTX 3090 (24GB)
- **数据集**: RULER 32k benchmark
- **模式**: CPU offload enabled
### 7.2 测试命令
```bash
# NIAH 任务测试
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--data-dir tests/data/ruler_32k \
--enable-offload \
--sparse-policy XATTN \
--num-samples 3 \
--datasets niah_single_1,niah_multikey_1,niah_multiquery,niah_multivalue \
--max-model-len 32896
# QA/Recall 任务测试(并行运行)
CUDA_VISIBLE_DEVICES=5 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--data-dir tests/data/ruler_32k \
--enable-offload \
--sparse-policy XATTN \
--num-samples 3 \
--datasets qa_1,qa_2,vt,cwe,fwe \
--max-model-len 32896
```
### 7.3 测试结果
#### GPU 4 - NIAH 任务
| 任务 | 通过/总数 | 准确率 | 平均分 |
|------|----------|--------|--------|
| niah_single_1 | 3/3 | 100.0% | 1.000 |
| niah_multikey_1 | 3/3 | 100.0% | 1.000 |
| niah_multiquery | 3/3 | 100.0% | 1.000 |
| niah_multivalue | 3/3 | 100.0% | 1.000 |
| **NIAH 总计** | **12/12** | **100.0%** | **1.000** |
#### GPU 5 - QA/Recall 任务
| 任务 | 通过/总数 | 准确率 | 平均分 |
|------|----------|--------|--------|
| qa_1 | 2/3 | 66.7% | 0.667 |
| qa_2 | 1/3 | 33.3% | 0.333 |
| vt | 3/3 | 100.0% | 0.867 |
| cwe | 2/3 | 66.7% | 0.467 |
| fwe | 3/3 | 100.0% | 0.889 |
| **QA/Recall 总计** | **11/15** | **73.3%** | **0.644** |
#### 总体结果
- **总计**: 23/27 样本通过 (85.2% 准确率)
- **耗时**: GPU 4 (74.9s), GPU 5 (425.1s)
- **结论**: XAttention 集成成功test_ruler.py 全部通过 ✅
### 7.4 内存使用
```
OffloadEngine initialized: GPU=650.0MB, CPU=4224.0MB
Ring buffer GPU cache: 522.0 MB (4 buffers × 33408 tokens)
CPU cache: 4224.0 MB (32 layers × 33 blocks)
```
---
## 8. 使用指南
### 8.1 基本用法
```python
from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
llm = LLM(
model_path="/path/to/model",
enable_cpu_offload=True,
sparse_policy=SparsePolicyType.XATTN,
xattn_threshold=0.9,
xattn_stride=8,
)
sampling_params = SamplingParams(temperature=0.1, max_tokens=128)
outputs = llm.generate(["Your prompt here"], sampling_params)
```
### 8.2 命令行测试
```bash
# RULER benchmark
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--enable-offload \
--sparse-policy XATTN \
--max-model-len 32896
# 单个样本测试
python tests/test_needle.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sparse-policy XATTN
```
### 8.3 配置参数
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `sparse_policy` | `FULL` | 稀疏策略类型 (FULL, QUEST, MINFERENCE, XATTN) |
| `xattn_threshold` | 0.9 | Block 选择阈值 (0-1) |
| `xattn_stride` | 8 | Q/K 重组步长 |
| `xattn_chunk_size` | 16384 | Estimation chunk 大小 |
| `xattn_use_triton` | True | 是否使用 Triton kernels |
### 8.4 与其他策略对比
| 策略 | 阶段 | 用途 | 优势 |
|------|------|------|------|
| FULL | prefill + decode | 基线 | 准确率最高 |
| QUEST | decode only | Top-K block selection | 适合 decode 优化 |
| MINFERENCE | prefill | Vertical + Slash pattern | GPU-only 高效 |
| XATTN | prefill only | Chunked estimation + block sparse | 长上下文 prefill |
---
## 附录
### A. 相关文档
- [`sparse_attention_guide.md`](sparse_attention_guide.md) - 稀疏注意力方法概述
- [`sparse_offload_integration.md`](sparse_offload_integration.md) - 稀疏策略与 offload 集成
- [`block_sparse_attention_lib.md`](block_sparse_attention_lib.md) - Block-Sparse-Attention 库参考
### B. Git 历史
- `ac1ccbc` - feat: add XAttention sparse policy integration
- `57f4e9c` - docs: reorganize documentation files
### C. 待办事项
- [ ] GPU-only 模式下的完整 XAttention 实现(使用 Triton kernels
- [ ] 性能基准测试(与 FULL、MINFERENCE 对比)
- [ ] 自适应 threshold 调整
- [ ] 更多上下文长度测试64k, 128k
---
**作者**: Zijie Tian
**日期**: 2026-01-14
**版本**: 1.0

View File

@@ -1,288 +0,0 @@
# Findings: nanovllm 多请求状态污染分析
## 重要说明
**nanovllm offload 模式不支持 batch**,只能单个 request 顺序执行。问题出在**请求切换**(前一个 request 完成后,开始下一个 request时状态清理不完整。
---
## 1. 代码架构发现
### 1.1 请求生命周期 (顺序执行)
**关键**: offload 模式下,每次只处理**一个 request**,不是 batch。
```
LLMEngine.generate() [llm_engine.py:114-151]
├── Observer.complete_reset() # 重置性能统计
├── for prompt in prompts:
│ └── add_request(prompt, sp) # 添加到 scheduler 队列
├── while not is_finished():
│ ├── scheduler.schedule() # 获取下一个序列 (offload 模式: 1个)
│ ├── model_runner.call("run", seqs, is_prefill) # 执行单个请求
│ └── scheduler.postprocess(seqs, token_ids)
│ └── if seq.is_finished:
│ └── kvcache_manager.deallocate(seq) # 释放资源 ← 问题点
│ └── [开始处理下一个请求] # ← 状态切换
└── return outputs
```
**请求切换流程**:
```
Request A (prefill) → Request A (decode × N) → Request A 完成
deallocate(A) ← 状态清理不完整!
Request B (prefill) → Request B 读取到 A 的残留状态 → 错误输出
```
### 1.2 OffloadEngine 状态清单
**位置**: `nanovllm/kvcache/offload_engine.py:40-145`
| 成员变量 | 类型 | Shape | 生命周期 |
|----------|------|-------|----------|
| `layer_k_cache` | GPU Tensor | [num_buffers, max_seq_len, kv_heads, head_dim] | 整个引擎 |
| `layer_v_cache` | GPU Tensor | [num_buffers, max_seq_len, kv_heads, head_dim] | 整个引擎 |
| `decode_k_buffer` | GPU Tensor | [num_layers, block_size, kv_heads, head_dim] | 整个引擎 |
| `decode_v_buffer` | GPU Tensor | [num_layers, block_size, kv_heads, head_dim] | 整个引擎 |
| `k_cache_cpu` | CPU Tensor (pinned) | [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] | 整个引擎 |
| `v_cache_cpu` | CPU Tensor (pinned) | [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] | 整个引擎 |
| `compute_stream` | CUDA Stream | - | 整个引擎 |
| `prefill_offload_streams` | List[CUDA Stream] | num_layers | 整个引擎 |
| `prefill_offload_events` | List[CUDA Event] | num_layers | 整个引擎 |
| `layer_load_streams` | List[CUDA Stream] | num_buffers | 整个引擎 |
| `buffer_load_events` | List[CUDA Event] | num_buffers | 整个引擎 |
| `buffer_compute_done_events` | List[CUDA Event] | num_buffers | 整个引擎 |
**关键发现**:
- **没有 reset() 方法**
- **没有任何清理逻辑**
- 所有 tensor 在初始化时 `torch.zeros()` 后永不清零
### 1.3 HybridKVCacheManager 状态清单
**位置**: `nanovllm/kvcache/hybrid_manager.py`
| 成员变量 | 作用 | 清理方式 |
|----------|------|----------|
| `logical_blocks` | 逻辑块列表 | `block.reset()` in deallocate |
| `free_logical_ids` | 空闲逻辑块队列 | deallocate 归还 |
| `free_cpu_blocks` | 空闲 CPU 块队列 | deallocate 归还 |
| `cpu_block_to_logical` | CPU 块→逻辑块映射 | deallocate 删除 |
| `prefilled_blocks` | 已 prefill 的块集合 | deallocate 中 discard |
| `_decode_start_pos` | 序列→decode起始位置 | `clear_decode_tracking()` |
| `_prefill_len` | 序列→prefill长度 | `clear_decode_tracking()` |
**关键发现**:
- `deallocate()` 没有调用 `clear_decode_tracking()`
- `_decode_start_pos``_prefill_len` 使用 `id(seq)` 作为 key
- Python 对象 ID 可能在不同请求间重用
---
## 2. 请求切换机制分析
### 2.1 offload 模式的单 request 限制
代码中明确限制:
```python
# model_runner.py:757, 880
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
```
### 2.2 请求切换时序
```
时间 →
┌─────────────────────────────────────────────────────────────────┐
│ Request A: [prefill] → [decode] → [decode] → ... → [完成] │
└─────────────────────────────────────────────────────────────────┘
deallocate(seq_A)
- blocks 释放 ✓
- tracking 字典未清理 ✗
┌─────────────────────────────────────────────────────────────────┐
│ Request B: [prefill] → [decode] → ... │
│ ↑ │
│ 如果 id(seq_B) == id(seq_A),读到 A 的残留状态! │
└─────────────────────────────────────────────────────────────────┘
```
### 2.3 Python 对象 ID 重用
Python 的内存管理会重用已释放对象的内存地址,导致:
```python
seq_A = Sequence(...) # id(seq_A) = 0x7f1234567890
del seq_A # 对象被释放,但字典中 key 保留
seq_B = Sequence(...) # id(seq_B) 可能 = 0x7f1234567890相同地址
# _decode_start_pos[id(seq_B)] 返回 seq_A 的旧值!
```
---
## 3. 状态污染机制分析
### 3.1 decode buffer 污染路径
**污染写入** (`run_layerwise_offload_decode:1010-1013`):
```python
# 每次 decode step将当前 token 的 KV 存入 decode buffer
offload_engine.decode_k_buffer[layer_id, pos_in_block].copy_(ring_k[context_len])
offload_engine.decode_v_buffer[layer_id, pos_in_block].copy_(ring_v[context_len])
```
**污染读取** (`run_layerwise_offload_decode:969-976`):
```python
# 如果有之前的 decode tokens从 decode buffer 读取
if num_prev_decode_tokens > 0:
k_decode_prev, v_decode_prev = offload_engine.get_decode_kv(
layer_id, decode_start_pos, pos_in_block
)
ring_k[total_prefill_tokens:total_prefill_tokens + num_prev_decode_tokens].copy_(k_decode_prev)
```
**问题场景**:
1. 请求 A 的 decode 阶段在 `decode_k_buffer[layer, 0:N]` 写入 KV
2. 请求 A 完成buffer 数据保留
3. 请求 B 开始,如果其 `decode_start_pos` 被错误计算为非零
4. 请求 B 会读取请求 A 的旧数据
### 3.2 decode_start_pos 计算逻辑
**位置**: `hybrid_manager.py:485-505`
```python
def get_decode_start_pos(self, seq: Sequence) -> int:
seq_id = id(seq) # Python 对象 ID
if seq_id not in self._decode_start_pos:
# 第一次调用 - 计算起始位置
prefill_len = len(seq) - 1 # 当前长度减去新 token
self._decode_start_pos[seq_id] = prefill_len % self._block_size
return self._decode_start_pos[seq_id]
```
**问题**:
- 如果新请求的 `id(seq)` 恰好等于旧请求的 `id(seq)`Python 内存重用)
- `_decode_start_pos` 中可能存在旧的值
- 会返回错误的 decode 起始位置
### 3.3 clear_decode_tracking 未被调用
**位置**: `hybrid_manager.py:538-549`
```python
def clear_decode_tracking(self, seq: Sequence) -> None:
seq_id = id(seq)
self._decode_start_pos.pop(seq_id, None)
self._prefill_len.pop(seq_id, None)
```
**问题**:
- 这个方法在 `deallocate()` 中**没有被调用**
- 查看 `deallocate()` (218-244 行),没有 `clear_decode_tracking()` 调用
- 这导致旧请求的 tracking 数据残留
---
## 3. 失败模式分析
### 3.1 观察到的失败模式
从测试结果:
| Sample | Expected | Output | Status |
|--------|----------|--------|--------|
| 0 | 8930103 | `: 8930103.` | PASS (第一个请求) |
| 1 | 4194548 | `: 419 multiplication of 4548.` | **FAIL** |
| 2 | 8231838 | `:ное 8231838.` | PASS |
Sample 1 的输出 "419 multiplication of 4548" 显示数字被"拆分"了。
**可能原因**:
1. 在某个 decode stepattention 计算使用了错误的 KV
2. 模型"看到"了旧请求的部分 context
3. 导致生成逻辑出错
### 3.2 为什么第一个请求总是成功?
1. 第一个请求时,所有 buffer 都是零初始化
2. `decode_start_pos` 字典为空,正确计算
3. 没有残留数据干扰
### 3.3 为什么后续请求可能成功?
某些请求可能成功因为:
1. `id(seq)` 没有与之前的请求冲突
2. `pos_in_block` 不重叠,没读到旧数据
3. 或者旧数据恰好对结果影响不大
---
## 4. 修复方向
### 4.1 必须修复: deallocate 时清理状态
```python
# hybrid_manager.py: deallocate()
def deallocate(self, seq: Sequence) -> None:
# ... 现有逻辑 ...
# 添加: 清理 decode tracking
self.clear_decode_tracking(seq)
# 添加: 通知 offload engine 清理
if self.offload_engine is not None:
self.offload_engine.on_sequence_finished()
```
### 4.2 必须修复: OffloadEngine 添加清理方法
```python
# offload_engine.py
def on_sequence_finished(self):
"""请求完成时的清理"""
# 清零 decode buffer
self.decode_k_buffer.zero_()
self.decode_v_buffer.zero_()
```
### 4.3 可选: 更激进的清理
```python
def reset_all(self):
"""完全重置状态"""
self.decode_k_buffer.zero_()
self.decode_v_buffer.zero_()
self.layer_k_cache.zero_()
self.layer_v_cache.zero_()
# 重置 CUDA events
for event in self.buffer_compute_done_events:
event.record()
```
---
## 5. 待验证假设
| 假设 | 验证方法 | 优先级 |
|------|----------|--------|
| decode_buffer 残留导致污染 | 在第二个请求开始时检查 buffer 是否为零 | 高 |
| _decode_start_pos 字典残留 | 打印 deallocate 前后的字典内容 | 高 |
| id(seq) 重用导致错误 | 打印每个请求的 seq id | 中 |
| ring buffer 残留 | 检查每次 decode 前 ring buffer 内容 | 低 |
---
## 6. 参考代码位置
| 功能 | 文件 | 行号 |
|------|------|------|
| OffloadEngine 初始化 | offload_engine.py | 40-145 |
| deallocate | hybrid_manager.py | 218-244 |
| clear_decode_tracking | hybrid_manager.py | 538-549 |
| get_decode_start_pos | hybrid_manager.py | 485-505 |
| run_layerwise_offload_decode | model_runner.py | 867-1057 |
| decode buffer 写入 | model_runner.py | 1010-1013 |
| decode buffer 读取 | model_runner.py | 969-976 |

View File

@@ -10,6 +10,7 @@ class SparsePolicyType(Enum):
FULL = auto() # No sparse attention (load all blocks) FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only) QUEST = auto() # Query-aware Top-K block selection (decode only)
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only) MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
XATTN = auto() # XAttention chunked estimation + block-sparse attention
@dataclass @dataclass
@@ -53,6 +54,16 @@ class Config:
minference_num_sink_tokens: int = 30 # Sink tokens to always keep minference_num_sink_tokens: int = 30 # Sink tokens to always keep
minference_num_recent_diags: int = 100 # Recent diagonals to always keep minference_num_recent_diags: int = 100 # Recent diagonals to always keep
# XAttention configuration (used when sparse_policy == XATTN)
xattn_stride: int = 8 # Stride for reorganizing Q/K
xattn_threshold: float = 0.9 # Block selection threshold (0-1)
xattn_chunk_size: int = 16384 # Chunk size for estimation (auto if None)
xattn_use_triton: bool = True # Use Triton kernels (requires SM 80+)
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
xattn_norm: float = 1.0 # Normalization factor for attention scores
xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation)
def __post_init__(self): def __post_init__(self):
assert os.path.isdir(self.model) assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0 assert self.kvcache_block_size % 256 == 0

View File

@@ -57,8 +57,8 @@ class ModelRunner:
load_model(self.model, config.model) load_model(self.model, config.model)
self.sampler = GreedySampler() self.sampler = GreedySampler()
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache) # Initialize attention_policy before warmup (will be configured in allocate_kv_cache)
self.sparse_prefill_policy = None self.attention_policy = None
#> Disable warmup for debugging #> Disable warmup for debugging
self.warmup_model() self.warmup_model()
@@ -178,23 +178,35 @@ class ModelRunner:
# Create KV cache manager using factory # Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config) self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Create sparse prefill policy for GPU-only path # Create attention policy (always, including FULL)
# This is separate from CPU offload sparse policy (which uses select_blocks) # In layerwise offload mode, all attention goes through the policy
self.sparse_prefill_policy = None from nanovllm.kvcache.sparse import create_attention_policy
if not config.enable_cpu_offload and config.sparse_policy != SparsePolicyType.FULL:
from nanovllm.kvcache.sparse import create_sparse_policy # Get policy-specific parameters based on type
policy = create_sparse_policy( if config.sparse_policy == SparsePolicyType.XATTN:
config.sparse_policy, policy_kwargs = {
vertical_size=config.minference_vertical_size, "stride": config.xattn_stride,
slash_size=config.minference_slash_size, "threshold": config.xattn_threshold,
adaptive_budget=config.minference_adaptive_budget, "chunk_size": config.xattn_chunk_size,
num_sink_tokens=config.minference_num_sink_tokens, "use_triton": config.xattn_use_triton,
num_recent_diags=config.minference_num_recent_diags, "keep_sink": config.xattn_keep_sink,
) "keep_recent": config.xattn_keep_recent,
# Only use if policy supports sparse prefill "norm": config.xattn_norm,
if policy.supports_prefill: "use_bsa": config.xattn_use_bsa,
self.sparse_prefill_policy = policy }
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}") elif config.sparse_policy == SparsePolicyType.MINFERENCE:
policy_kwargs = {
"vertical_size": config.minference_vertical_size,
"slash_size": config.minference_slash_size,
"adaptive_budget": config.minference_adaptive_budget,
"num_sink_tokens": config.minference_num_sink_tokens,
"num_recent_diags": config.minference_num_recent_diags,
}
else: # FULL or QUEST
policy_kwargs = {}
self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# Allocate cache through manager # Allocate cache through manager
self.kvcache_manager.allocate_cache( self.kvcache_manager.allocate_cache(
@@ -380,7 +392,7 @@ class ModelRunner:
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
slot_mapping, None, block_tables, slot_mapping, None, block_tables,
sparse_prefill_policy=self.sparse_prefill_policy) attention_policy=self.attention_policy)
return input_ids, positions return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]): def prepare_decode(self, seqs: list[Sequence]):
@@ -577,20 +589,10 @@ class ModelRunner:
# RoPE # RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k) q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention (uses k, v directly - before store!) # Compute attention using policy (uses k, v directly - before store!)
if self.sparse_prefill_policy is not None: attn_output = self.attention_policy.compute_prefill(
attn_output = self.sparse_prefill_policy.sparse_prefill_attention( q, k, v, layer_id,
q, k, v, layer_id
)
else:
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale, softmax_scale=layer.self_attn.attn.scale,
causal=True,
) )
# O projection # O projection
@@ -786,15 +788,56 @@ class ModelRunner:
for layer_id in range(num_layers): for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id] layer = self.model.model.layers[layer_id]
# 2a. Input LayerNorm # 2a. Input LayerNorm (chunked for long sequences)
# LayerNorm creates float32 temporaries: seq_len * hidden_size * 4 bytes
# For 64k: 65536 * 4096 * 4 = ~1 GB per operation
# Using chunk_size=4096 reduces peak to ~125 MB
layernorm_chunk_size = 128
if total_tokens > layernorm_chunk_size:
if residual is None:
# Chunked input_layernorm
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
ln_chunks = []
res_chunks = []
for chunk in hs_chunks:
ln, res = layer.input_layernorm(chunk), chunk
ln_chunks.append(ln)
res_chunks.append(res)
hidden_ln = torch.cat(ln_chunks, dim=0)
residual = torch.cat(res_chunks, dim=0)
else:
# Chunked input_layernorm with residual
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
res_chunks_in = residual.split(layernorm_chunk_size, dim=0)
ln_chunks = []
res_chunks_out = []
for hs_chunk, res_chunk in zip(hs_chunks, res_chunks_in):
ln, res = layer.input_layernorm(hs_chunk, res_chunk)
ln_chunks.append(ln)
res_chunks_out.append(res)
hidden_ln = torch.cat(ln_chunks, dim=0)
residual = torch.cat(res_chunks_out, dim=0)
else:
if residual is None: if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else: else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual) hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# 2b. Self-attention (full sequence) # 2b. Self-attention (full sequence)
# QKV projection # Chunked QKV projection to reduce activation memory for long sequences
# QKV activation = seq_len * (q_size + 2*kv_size) * 2 bytes
# For 64k: 65536 * (4096 + 2*1024) * 2 = ~805 MB
# Using chunk_size=2048 reduces peak to ~25 MB
qkv_chunk_size = 128
if total_tokens > qkv_chunk_size:
chunks = hidden_ln.split(qkv_chunk_size, dim=0)
qkv_chunks = []
for chunk in chunks:
qkv_chunks.append(layer.self_attn.qkv_proj(chunk))
qkv = torch.cat(qkv_chunks, dim=0)
else:
qkv = layer.self_attn.qkv_proj(hidden_ln) qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k, v = qkv.split([ q, k, v = qkv.split([
layer.self_attn.q_size, layer.self_attn.q_size,
layer.self_attn.kv_size, layer.self_attn.kv_size,
@@ -816,30 +859,49 @@ class ModelRunner:
# RoPE # RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k) q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention # Compute attention using policy
if self.sparse_prefill_policy is not None: attn_output = self.attention_policy.compute_prefill(
# MInference or other sparse prefill policy q, k, v, layer_id,
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale, softmax_scale=layer.self_attn.attn.scale,
causal=True,
) )
# O projection # O projection
attn_output = attn_output.view(total_tokens, -1) attn_output = attn_output.view(total_tokens, -1)
hidden_states = layer.self_attn.o_proj(attn_output) hidden_states = layer.self_attn.o_proj(attn_output)
# 2c. Post-attention LayerNorm + MLP # 2c. Post-attention LayerNorm (chunked for long sequences)
layernorm_chunk_size = 128
if total_tokens > layernorm_chunk_size:
# Chunked post_attention_layernorm
hs_chunks = hidden_states.split(layernorm_chunk_size, dim=0)
res_chunks_in = residual.split(layernorm_chunk_size, dim=0)
ln_chunks = []
res_chunks_out = []
for hs_chunk, res_chunk in zip(hs_chunks, res_chunks_in):
ln, res = layer.post_attention_layernorm(hs_chunk, res_chunk)
ln_chunks.append(ln)
res_chunks_out.append(res)
hidden_states = torch.cat(ln_chunks, dim=0)
residual = torch.cat(res_chunks_out, dim=0)
else:
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual) hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
# Chunked MLP processing to reduce activation memory for long sequences
# MLP activation = seq_len * intermediate_size * 2 bytes
# For 64k: 65536 * 14336 * 2 = ~1.75 GB (down_proj input)
# Using chunk_size=2048 reduces peak to ~55 MB
mlp_chunk_size = 128
if total_tokens > mlp_chunk_size:
chunks = hidden_states.split(mlp_chunk_size, dim=0)
outputs = []
for i, chunk in enumerate(chunks):
outputs.append(layer.mlp(chunk))
del chunk
torch.cuda.empty_cache() # Clean after every chunk
hidden_states = torch.cat(outputs, dim=0)
del outputs
torch.cuda.empty_cache()
else:
hidden_states = layer.mlp(hidden_states) hidden_states = layer.mlp(hidden_states)
# 2d. Offload KV to CPU (encapsulated with sparse policy hooks) # 2d. Offload KV to CPU (encapsulated with sparse policy hooks)

View File

@@ -1,48 +1,56 @@
""" """
Sparse Attention Policy module. Attention Policy module for layerwise offload mode.
Provides pluggable policies for selecting which KV blocks to load Provides pluggable policies for attention computation:
during chunked attention with CPU offload. - FullAttentionPolicy: Standard FlashAttention (no sparsity)
- XAttentionPolicy: Sparse prefill using XAttention algorithm
- MInferencePolicy: MInference sparse attention
- QuestPolicy: Quest block selection (for chunked offload)
Usage: Usage:
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
# Create policy using factory function # Create policy using factory function
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8) policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
# Use policy for attention
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
# Or create custom policy # Or create custom policy
class MyPolicy(SparsePolicy): class MyPolicy(AttentionPolicy):
supports_prefill = True supports_prefill = True
supports_decode = True supports_decode = True
def select_blocks(self, available_blocks, ctx): def compute_prefill(self, q, k, v, layer_id, softmax_scale):
return available_blocks[:5] # Just first 5 blocks # Custom attention computation
...
""" """
from nanovllm.config import SparsePolicyType from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.minference import MInferencePolicy from nanovllm.kvcache.sparse.minference import MInferencePolicy
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy: def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
""" """
Create a sparse policy instance from an enum type. Create an attention policy instance from an enum type.
The returned policy is not yet initialized. Call policy.initialize() All attention (including full attention) goes through a policy in layerwise
or let the framework call it during KV cache allocation. offload mode. The policy is responsible for computing prefill/decode attention.
Args: Args:
policy_type: SparsePolicyType enum value policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
**kwargs: Policy-specific configuration options **kwargs: Policy-specific configuration options
Returns: Returns:
SparsePolicy instance (not initialized) AttentionPolicy instance
Example: Example:
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4) policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
policy.initialize(num_layers=28, num_kv_heads=8, ...) attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
""" """
if policy_type == SparsePolicyType.FULL: if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy() return FullAttentionPolicy()
@@ -65,18 +73,41 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
num_recent_diags=kwargs.get("num_recent_diags", 100), num_recent_diags=kwargs.get("num_recent_diags", 100),
) )
elif policy_type == SparsePolicyType.XATTN:
return XAttentionPolicy(
stride=kwargs.get("stride", 8),
threshold=kwargs.get("threshold", 0.9),
chunk_size=kwargs.get("chunk_size", 16384),
use_triton=kwargs.get("use_triton", True),
keep_sink=kwargs.get("keep_sink", False),
keep_recent=kwargs.get("keep_recent", False),
norm=kwargs.get("norm", 1.0),
use_bsa=kwargs.get("use_bsa", True),
)
else: else:
raise ValueError(f"Unknown policy type: {policy_type}") raise ValueError(f"Unknown policy type: {policy_type}")
# Backward compatibility alias
create_sparse_policy = create_attention_policy
__all__ = [ __all__ = [
# New interface
"AttentionPolicy",
"create_attention_policy",
# Backward compatibility
"SparsePolicy", "SparsePolicy",
"create_sparse_policy",
# Common types
"PolicyContext", "PolicyContext",
"SparsePolicyType", "SparsePolicyType",
# Policy implementations
"FullAttentionPolicy", "FullAttentionPolicy",
"QuestPolicy", "QuestPolicy",
"QuestConfig", "QuestConfig",
"BlockMetadataManager", "BlockMetadataManager",
"MInferencePolicy", "MInferencePolicy",
"create_sparse_policy", "XAttentionPolicy",
] ]

View File

@@ -1,20 +1,21 @@
""" """
Full attention policy - loads all blocks (no sparsity). Full attention policy - standard FlashAttention without sparsity.
This serves as a baseline and default policy when sparse This serves as a baseline and default policy when sparse
attention is not needed. attention is not needed.
""" """
from typing import List from typing import Optional
from .policy import SparsePolicy, PolicyContext import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(SparsePolicy): class FullAttentionPolicy(AttentionPolicy):
""" """
Full attention policy that loads all available blocks. Full attention policy using FlashAttention (no sparsity).
This is the default behavior with no sparsity - all previous This is the default behavior with standard causal attention.
KV cache blocks are loaded for each query chunk. All tokens attend to all previous tokens.
Use this as: Use this as:
- A baseline for comparing sparse policies - A baseline for comparing sparse policies
@@ -25,15 +26,55 @@ class FullAttentionPolicy(SparsePolicy):
# Full attention supports both prefill and decode # Full attention supports both prefill and decode
supports_prefill = True supports_prefill = True
supports_decode = True supports_decode = True
requires_block_selection = False # Load all blocks, no selective loading
def select_blocks( def estimate(
self, self,
available_blocks: List[int], q: torch.Tensor,
ctx: PolicyContext, k: torch.Tensor,
) -> List[int]: layer_id: int,
"""Return all blocks - no sparsity.""" ) -> Optional[torch.Tensor]:
return available_blocks """
Full attention - no sparse mask needed.
Returns None to indicate full attention should be used.
"""
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self) -> str: def __repr__(self) -> str:
return "FullAttentionPolicy()" return "FullAttentionPolicy()"

View File

@@ -0,0 +1,320 @@
"""
Triton kernels for XAttention sparse attention.
Copied and adapted from COMPASS/compass/src/kernels.py
for XAttention integration in nano-vllm.
Requirements:
- Triton >= 2.1.0
- CUDA compute capability SM 80+ (RTX 3090, A100, H100, etc.)
"""
import torch
import math
import triton
import triton.language as tl
@triton.jit
def softmax_fuse_block_sum_kernel_causal(
In,
Out,
scale,
input_stride_0,
input_stride_1,
input_stride_2,
output_stride_0,
output_stride_1,
output_stride_2,
real_q_len,
k_len,
chunk_start,
chunk_end,
segment_size: tl.constexpr,
block_size: tl.constexpr,
):
block_id = tl.program_id(0)
head_id = tl.program_id(1)
batch_id = tl.program_id(2)
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
offs_k = tl.arange(0, segment_size)
num_iters = k_len // segment_size
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
for iter in range(0, num_iters_before_causal):
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
m_local = tl.max(X, 1)
m_new = tl.maximum(m_i, m_local)
alpha = tl.math.exp2(m_i - m_new)
X = X - m_new[:, None]
l_local = tl.sum(tl.math.exp2(X), 1)
l_i = l_i * alpha + l_local
m_i = m_new
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
X = tl.where(mask, X, -1.0e6)
m_local = tl.max(X, 1)
m_new = tl.maximum(m_i, m_local)
alpha = tl.math.exp2(m_i - m_new)
X = X - m_new[:, None]
l_local = tl.sum(tl.math.exp2(X), 1)
l_i = l_i * alpha + l_local
m_i = m_new
l_i_inv = 1.0 / l_i
sum_mask = offs_q[:, None] < real_q_len
for iter in range(0, num_iters_before_causal):
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
X = tl.where(sum_mask, X, 0)
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
X = tl.sum(X, 2)
X = tl.sum(X, 0)
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
X = tl.where(mask, X, -1.0e6)
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
X = tl.where(sum_mask, X, 0)
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
X = tl.sum(X, 2)
X = tl.sum(X, 0)
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
for iter in range(num_iters_before_causal + 1, num_iters):
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
@triton.jit
def softmax_fuse_block_sum_kernel_non_causal(
In,
Out,
scale,
input_stride_0,
input_stride_1,
input_stride_2,
output_stride_0,
output_stride_1,
output_stride_2,
real_q_len,
k_len,
chunk_start,
chunk_end,
segment_size: tl.constexpr,
block_size: tl.constexpr,
):
block_id = tl.program_id(0)
head_id = tl.program_id(1)
batch_id = tl.program_id(2)
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
offs_k = tl.arange(0, segment_size)
num_iters = k_len // segment_size
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
for iter in range(0, num_iters):
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
m_local = tl.max(X, 1)
m_new = tl.maximum(m_i, m_local)
alpha = tl.math.exp2(m_i - m_new)
X = X - m_new[:, None]
l_local = tl.sum(tl.math.exp2(X), 1)
l_i = l_i * alpha + l_local
m_i = m_new
l_i_inv = 1.0 / l_i
sum_mask = offs_q[:, None] < real_q_len
for iter in range(0, num_iters):
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
X = tl.where(sum_mask, X, 0)
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
X = tl.sum(X, 2)
X = tl.sum(X, 0)
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
@triton.jit
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out,
stride_qz, stride_qh, stride_qn,
stride_kz, stride_kh, stride_kn,
stride_oz, stride_oh, stride_on,
chunk_start, chunk_end,
H: tl.constexpr,
STRIDE: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
is_causal: tl.constexpr,
):
block_m = tl.program_id(0).to(tl.int64)
block_n = tl.program_id(1).to(tl.int64)
batch_id = tl.program_id(2).to(tl.int64) // H
head_id = tl.program_id(2).to(tl.int64) % H
if is_causal:
if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N:
return
Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn
K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn
Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1)
K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None]
o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for iter in range(STRIDE):
q = tl.load(Q_ptrs - iter * stride_qn)
k = tl.load(K_ptrs + iter * stride_kn)
o += tl.dot(q, k)
O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N
O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :]
tl.store(O_ptrs, o.to(Out.type.element_ty))
def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True):
"""Wrapper for Triton softmax-fuse-block-sum kernel."""
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
assert q_len % reshaped_block_size == 0
assert k_len % segment_size == 0
assert segment_size % reshaped_block_size == 0
assert attn_weights_slice.stride(-1) == 1
output = torch.empty(
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
dtype=attn_weights_slice.dtype,
device=attn_weights_slice.device
)
grid = (q_len // reshaped_block_size, num_heads, batch_size)
if is_causal:
softmax_fuse_block_sum_kernel_causal[grid](
attn_weights_slice,
output,
scale,
attn_weights_slice.stride(0),
attn_weights_slice.stride(1),
attn_weights_slice.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
real_q_len,
k_len,
chunk_start,
chunk_end,
segment_size,
reshaped_block_size,
)
else:
softmax_fuse_block_sum_kernel_non_causal[grid](
attn_weights_slice,
output,
scale,
attn_weights_slice.stride(0),
attn_weights_slice.stride(1),
attn_weights_slice.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
real_q_len,
k_len,
chunk_start,
chunk_end,
segment_size,
reshaped_block_size,
)
return output
def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True):
"""Wrapper for Triton flat-group-gemm-fuse-reshape kernel."""
batch_size, num_heads, q_len, head_dim = query_states.shape
kv_len = key_states.shape[2]
assert key_states.shape[0] == batch_size
assert key_states.shape[1] == num_heads
assert key_states.shape[3] == head_dim
output = torch.empty(
(batch_size, num_heads, q_len // stride, kv_len // stride),
dtype=query_states.dtype,
device=query_states.device
)
# Adjust block size based on GPU shared memory
props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.total_memory < 30 * 1024**3: # Less than 30GB (e.g., RTX 3090 24GB)
BLOCK_M = 64
BLOCK_N = 64
else:
BLOCK_M = 128
BLOCK_N = 128
assert q_len % (stride * BLOCK_M) == 0
assert kv_len % (stride * BLOCK_N) == 0
grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads)
flat_group_gemm_fuse_reshape_kernel[grid](
query_states,
key_states,
output,
query_states.stride(0),
query_states.stride(1),
query_states.stride(2),
key_states.stride(0),
key_states.stride(1),
key_states.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
chunk_start,
chunk_end,
num_heads,
stride,
head_dim,
BLOCK_M,
BLOCK_N,
is_causal,
)
return output

View File

@@ -10,10 +10,10 @@ from typing import List, Tuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext
class MInferencePolicy(SparsePolicy): class MInferencePolicy(AttentionPolicy):
""" """
MInference sparse prefill policy using vertical + slash pattern. MInference sparse prefill policy using vertical + slash pattern.
@@ -347,6 +347,33 @@ class MInferencePolicy(SparsePolicy):
return o return o
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute MInference sparse prefill attention.
This is the new unified interface for attention policies.
Delegates to sparse_prefill_attention (ignores softmax_scale as MInference
computes it internally from head_dim).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (unused, computed internally)
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
return self.sparse_prefill_attention(q, k, v, layer_id)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"MInferencePolicy(" return (f"MInferencePolicy("
f"adaptive_budget={self.adaptive_budget}, " f"adaptive_budget={self.adaptive_budget}, "

View File

@@ -1,13 +1,18 @@
""" """
Base class for sparse attention policies. Base class for attention policies in layerwise offload mode.
Sparse attention policies determine which KV cache blocks to load AttentionPolicy defines the interface for all attention computation,
from CPU for each query chunk during chunked attention computation. including full attention and sparse attention methods like XAttention.
Key methods:
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
- compute_prefill(): Compute prefill attention
- compute_decode(): Compute decode attention (default implementation provided)
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Any from typing import List, Optional, Tuple
import torch import torch
# Import SparsePolicyType from config to avoid circular imports # Import SparsePolicyType from config to avoid circular imports
@@ -17,10 +22,10 @@ from nanovllm.config import SparsePolicyType
@dataclass @dataclass
class PolicyContext: class PolicyContext:
""" """
Context passed to sparse policy for block selection. Context passed to attention policy for block selection.
This dataclass contains all information needed by a sparse policy This dataclass contains all information needed by an attention policy
to decide which blocks to load for the current query chunk. for sparse estimation and attention computation.
""" """
query_chunk_idx: int query_chunk_idx: int
@@ -49,40 +54,41 @@ class PolicyContext:
"""Total KV sequence length so far (for reference).""" """Total KV sequence length so far (for reference)."""
class SparsePolicy(ABC): class AttentionPolicy(ABC):
""" """
Abstract base class for sparse attention policies. Base class for attention policies in layerwise offload mode.
Subclass this and implement select_blocks() to create custom All attention computation goes through a policy, including both
sparse attention patterns. The policy receives context about full attention and sparse attention methods.
the current query chunk and returns which KV blocks to load.
The policy interface is designed for layerwise offload where:
- The entire KV cache for a layer is on GPU during computation
- No need for block loading from CPU during attention
- estimate() returns a sparse mask (or None for full attention)
- compute_prefill()/compute_decode() perform the actual attention
Attributes: Attributes:
supports_prefill: Whether this policy can be used for prefill phase. supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase. supports_decode: Whether this policy can be used for decode phase.
Example: Example:
class MySparsePolicy(SparsePolicy): class MyPolicy(AttentionPolicy):
supports_prefill = False # decode-only policy supports_prefill = True
supports_decode = True supports_decode = True
def select_blocks(self, available_blocks, ctx): def estimate(self, q, k, layer_id):
# Load first block and last 2 blocks # Return sparse mask or None
if len(available_blocks) <= 3: return None
return available_blocks
return [available_blocks[0]] + available_blocks[-2:] def compute_prefill(self, q, k, v, layer_id, softmax_scale):
# Compute attention
return flash_attn_varlen_func(q, k, v, ...)
""" """
# Compatibility flags - override in subclasses # Compatibility flags - override in subclasses
supports_prefill: bool = True supports_prefill: bool = True
supports_decode: bool = True supports_decode: bool = True
# Whether this policy requires selective block loading during decode
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
# Example: MInference=False (only affects attention), Quest=True (affects load)
requires_block_selection: bool = False
def initialize( def initialize(
self, self,
num_layers: int, num_layers: int,
@@ -96,7 +102,7 @@ class SparsePolicy(ABC):
Initialize policy resources. Initialize policy resources.
Called by the framework after KV cache is allocated. Override this Called by the framework after KV cache is allocated. Override this
to create metadata structures (e.g., BlockMetadataManager for Quest). to create metadata structures or pre-allocate buffers.
Default implementation does nothing. Default implementation does nothing.
Args: Args:
@@ -109,76 +115,98 @@ class SparsePolicy(ABC):
""" """
pass pass
@abstractmethod def estimate(
def select_blocks(
self, self,
available_blocks: List[int], q: torch.Tensor,
ctx: PolicyContext, k: torch.Tensor,
) -> List[int]: layer_id: int,
) -> Optional[torch.Tensor]:
""" """
Select which KV blocks to load for the current query chunk. Estimate sparse attention mask.
This is the core method that defines the sparse attention pattern. For sparse policies (e.g., XAttention), computes block-level importance
The returned blocks will be loaded from CPU to GPU for attention and returns a boolean mask indicating which blocks to attend.
computation against the current query chunk. For full attention policy, returns None.
This corresponds to xattn_estimate() in COMPASS.
Args: Args:
available_blocks: List of CPU block IDs that contain KV cache q: Query tensor [seq_len, num_heads, head_dim]
from previous chunks. These are ordered by k: Key tensor [seq_len, num_kv_heads, head_dim]
their position in the sequence. layer_id: Transformer layer index
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
Returns: Returns:
List of block IDs to load (must be a subset of available_blocks). sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
The order may affect performance (sequential access is faster). or None for full attention
Returning [] means no previous blocks will be loaded.
""" """
pass return None
def on_prefill_offload( @abstractmethod
def compute_prefill(
self, self,
cpu_block_id: int, q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int, layer_id: int,
k_cache: torch.Tensor, softmax_scale: float,
num_valid_tokens: int, ) -> torch.Tensor:
) -> None:
""" """
Hook called when a block is offloaded during prefill phase. Compute prefill attention.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU. The entire KV cache for this layer is on GPU. Compute attention
Override this to collect metadata about blocks (e.g., min/max keys between Q and K/V, optionally using sparse mask from estimate().
for Quest-style selection). Default implementation does nothing.
Args: Args:
cpu_block_id: The CPU block ID that will be written q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
num_valid_tokens: Number of valid tokens in this block
Returns:
Attention output [seq_len, num_heads, head_dim]
""" """
pass pass
def on_decode_offload( def compute_decode(
self, self,
cpu_block_id: int, q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int, layer_id: int,
k_cache: torch.Tensor, softmax_scale: float,
num_valid_tokens: int, ) -> torch.Tensor:
) -> None:
""" """
Hook called when a block is offloaded during decode phase. Compute decode attention.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU. KV is provided from ring buffer, containing prefill tokens + decoded tokens.
Override this to update metadata about blocks. Default implementation Default implementation uses FlashAttention.
does nothing.
Args: Args:
cpu_block_id: The CPU block ID that will be written q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) softmax_scale: Softmax scaling factor
num_valid_tokens: Number of valid tokens in this block
Returns:
Attention output [1, num_heads, head_dim]
""" """
pass from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None: def reset(self) -> None:
""" """
@@ -189,32 +217,9 @@ class SparsePolicy(ABC):
""" """
pass pass
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute sparse attention for prefill phase.
This method is called when supports_prefill=True and the policy
is used for GPU-only sparse prefill (no CPU offload).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
"Set supports_prefill=False or implement this method."
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}()" return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy

View File

@@ -11,7 +11,7 @@ import logging
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from .policy import SparsePolicy, PolicyContext from .policy import AttentionPolicy, PolicyContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -137,7 +137,7 @@ class QuestConfig:
"""Always include this many recent blocks (last N blocks), in addition to Top-K.""" """Always include this many recent blocks (last N blocks), in addition to Top-K."""
class QuestPolicy(SparsePolicy): class QuestPolicy(AttentionPolicy):
""" """
Quest-style Top-K block selection using min/max key bounds. Quest-style Top-K block selection using min/max key bounds.
@@ -317,6 +317,25 @@ class QuestPolicy(SparsePolicy):
if self.metadata is not None: if self.metadata is not None:
self.metadata.reset() self.metadata.reset()
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Quest does not support prefill - raises error.
Quest is a decode-only policy for selective block loading.
For prefill, use FullAttentionPolicy or XAttentionPolicy.
"""
raise NotImplementedError(
"QuestPolicy does not support prefill. "
"Use FullAttentionPolicy or XAttentionPolicy for prefill."
)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"QuestPolicy(topk={self.config.topk_blocks}, " f"QuestPolicy(topk={self.config.topk_blocks}, "

View File

@@ -0,0 +1,156 @@
"""
Utility functions for sparse attention policies.
Copied from COMPASS/compass/src/utils.py for XAttention integration.
"""
import torch
def find_blocks_chunked(
input_tensor, current_index, threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True
):
"""
Finds and selects relevant blocks of attention for transformer-based models based on a
threshold or a predefined number of blocks.
Parameters:
- input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num).
- current_index (int): The current index in the sequence processing.
- threshold (float or None): A threshold value used to determine the minimum attention weight sum.
- num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval.
- decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode.
- mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'.
- causal (bool): If True, applies causal masking to prevent future information leakage.
Returns:
- torch.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num),
indicating which blocks should be attended to.
"""
assert threshold is None or num_to_choose is None
batch_size, head_num, chunk_num, block_num = input_tensor.shape
if mode == "prefill" and decoding:
return torch.ones_like(input_tensor, dtype=torch.bool)
if mode == "decode" and not decoding:
mask = torch.ones_like(input_tensor, dtype=torch.bool)
if causal:
mask[:, :, :, current_index : current_index + chunk_num] = torch.tril(
torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device)
)
mask[:, :, current_index + chunk_num :, :] = 0
return torch.cat(
[
torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1],
torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :],
],
dim=-1,
)
else:
return mask
input_tensor = input_tensor.to(float)
if threshold is not None:
total_sum = input_tensor.sum(dim=-1, keepdim=True)
if isinstance(threshold, torch.Tensor):
threshold = threshold.to(float)
required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(
-1
).expand((batch_size, head_num, chunk_num, 1)).to(input_tensor.device)
else:
required_sum = total_sum * threshold
if causal:
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
mask[:, :, :, 0] = 1
mask[:, :, :, current_index : current_index + chunk_num] = (
torch.eye(chunk_num, device=mask.device)
.unsqueeze(0)
.unsqueeze(0)
.expand(1, head_num, chunk_num, chunk_num)
)
other_values = input_tensor.masked_fill(mask, 0)
sorted_values, _ = torch.sort(
other_values, dim=-1, descending=True
)
sorted_values = sorted_values.to(input_tensor.device)
sorted_values = torch.cat(
[
torch.zeros(
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
),
torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True),
sorted_values[:, :, :, :-2],
],
dim=-1,
)
_, index = torch.sort(
torch.where(mask, 100000 * (1 + input_tensor), input_tensor),
dim=-1,
descending=True
)
cumulative_sum_without_self = torch.cat(
[
torch.zeros(
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
),
sorted_values[:, :, :, 0:-1],
],
dim=-1,
).cumsum(dim=-1)
index_mask = cumulative_sum_without_self < required_sum
index = torch.where(index_mask, index, 0)
mask = mask.view(batch_size, head_num * chunk_num, block_num)
index = index.view(batch_size, head_num * chunk_num, block_num)
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
mask = mask.view(batch_size, head_num, chunk_num, block_num)
else:
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
sorted_values, index = torch.sort(
input_tensor, dim=-1, descending=True
)
sorted_values = sorted_values.to(input_tensor.device)
cumulative_sum_without_self = torch.cat(
[
torch.zeros(
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
),
sorted_values[:, :, :, 0:-1],
],
dim=-1,
).cumsum(dim=-1)
index_mask = cumulative_sum_without_self < required_sum
index = torch.where(index_mask, index, 0)
mask = mask.view(batch_size, head_num * chunk_num, block_num)
index = index.view(batch_size, head_num * chunk_num, block_num)
mask[
:,
torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1),
index,
] = True
mask = mask.view(batch_size, head_num, chunk_num, block_num)
else:
raise NotImplementedError("block num chunk prefill not implemented")
try:
if causal:
assert (~mask[:, :, :, current_index + chunk_num :]).all()
except:
mask[:, :, :, current_index + chunk_num :] = False
if causal:
if decoding:
assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all()
else:
lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device)
lambda_mask[:, :, :, 0] = 1
lambda_mask[:, :, :, current_index:current_index+chunk_num] = torch.eye(
chunk_num, device=lambda_mask.device
).unsqueeze(0).unsqueeze(0).expand(1, head_num, chunk_num, chunk_num)
assert(torch.where(lambda_mask, mask, True).all())
return mask

View File

@@ -0,0 +1,310 @@
"""
XAttention sparse attention policy for nano-vllm.
Implements the XAttention algorithm from COMPASS, using chunked estimation
and block sparse attention for efficient long-context inference.
Architecture:
XAttention = Estimate (Triton) + Compute (BSA)
- Estimate: xattn_estimate() computes block-level importance scores
- Compute: block_sparse_attn_func() executes sparse attention
Reference: COMPASS/compass/src/Xattention.py
"""
import math
from typing import Optional
import torch
import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import AttentionPolicy
# BSA block size is fixed at 128 (hardcoded in block_sparse_attn)
BSA_BLOCK_SIZE = 128
class XAttentionPolicy(AttentionPolicy):
"""
XAttention sparse prefill policy using chunked estimation + block sparse attention.
This policy estimates sparse attention patterns by:
1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn)
2. Block-wise softmax with importance scores
3. Block selection based on threshold
4. Block sparse attention computation using MIT-HAN-LAB BSA library
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
to compute the sparse attention mask.
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
"""
supports_prefill = True
supports_decode = True # Uses default FlashAttention for decode
def __init__(
self,
stride: int = 8,
threshold: float = 0.9,
block_size: int = 128,
chunk_size: int = 16384,
use_triton: bool = True,
keep_sink: bool = False,
keep_recent: bool = False,
norm: float = 1.0,
use_bsa: bool = True,
):
"""
Initialize XAttention policy.
Args:
stride: Stride for reorganizing Q/K (default: 8)
threshold: Block selection threshold, 0-1 (default: 0.9)
block_size: Block size for sparse attention (default: 128, must match BSA)
chunk_size: Chunk size for estimation (default: 16384)
use_triton: Use Triton kernels (requires SM 80+)
keep_sink: Always keep first block (sink tokens)
keep_recent: Always keep recent diagonal blocks
norm: Normalization factor for attention scores
use_bsa: Use Block Sparse Attention library (default: True)
"""
self.stride = stride
self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size
self.use_triton = use_triton
self.keep_sink = keep_sink
self.keep_recent = keep_recent
self.norm = norm
self.use_bsa = use_bsa
# BSA requires block_size = 128
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
self.block_size = BSA_BLOCK_SIZE
# Check Triton availability
if self.use_triton:
try:
import triton
props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.major < 8:
self.use_triton = False
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
except ImportError:
self.use_triton = False
print("XAttention: Triton not available. Falling back to PyTorch.")
# Check BSA availability
if self.use_bsa:
try:
from block_sparse_attn import block_sparse_attn_func
except ImportError:
self.use_bsa = False
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Estimate sparse attention mask using XAttention algorithm.
Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level
importance scores and generate a sparse boolean mask.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None if estimation fails (fallback to full attention)
"""
try:
from nanovllm.ops.xattn import xattn_estimate
seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[1]
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
# Handle GQA: expand k to match q heads for estimation
if num_kv_heads != num_heads:
# GQA: expand k by repeating
repeat_factor = num_heads // num_kv_heads
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
# Call xattn_estimate
attn_sums, sparse_mask = xattn_estimate(
q_bhsd, k_bhsd,
block_size=self.block_size,
stride=self.stride,
norm=self.norm,
threshold=self.threshold,
chunk_size=self.chunk_size,
use_triton=self.use_triton,
causal=True,
keep_sink=self.keep_sink,
keep_recent=self.keep_recent,
)
return sparse_mask
except Exception as e:
# If estimation fails, return None to use full attention
print(f"XAttention estimate failed: {e}, falling back to full attention")
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute XAttention sparse prefill attention.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None or BSA unavailable, use full FlashAttention
3. Otherwise, use block_sparse_attn_func with mask
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
# If BSA is disabled, use full attention directly (skip estimation)
if not self.use_bsa:
return self._full_attention(q, k, v, softmax_scale)
# Step 1: Estimate sparse mask
sparse_mask = self.estimate(q, k, layer_id)
# Step 2: Compute attention
if sparse_mask is None:
# Estimation failed, fallback to full FlashAttention
return self._full_attention(q, k, v, softmax_scale)
# Use block sparse attention with mask
return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
def _block_sparse_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sparse_mask: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute block sparse attention using MIT-HAN-LAB BSA library.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from block_sparse_attn import block_sparse_attn_func
seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[1]
# Handle GQA: expand K/V to match Q heads
if num_kv_heads != num_heads:
repeat_factor = num_heads // num_kv_heads
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
# Cumulative sequence lengths (batch=1)
cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
# Head mask type: 1 for all heads using block sparse
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
# Trim sparse_mask to actual block counts
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
# Call BSA
attn_output = block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
None, # streaming_info (left_mask)
block_mask,
seq_len, seq_len,
p_dropout=0.0,
deterministic=True,
softmax_scale=softmax_scale,
is_causal=True,
)
return attn_output
def _full_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def reset(self) -> None:
"""Reset policy state (no state to reset for XAttention)."""
pass
def __repr__(self) -> str:
return (f"XAttentionPolicy("
f"stride={self.stride}, "
f"threshold={self.threshold}, "
f"block_size={self.block_size}, "
f"use_triton={self.use_triton}, "
f"use_bsa={self.use_bsa})")

View File

@@ -98,10 +98,10 @@ class Attention(nn.Module):
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables) softmax_scale=self.scale, causal=True, block_table=context.block_tables)
elif context.sparse_prefill_policy is not None: elif context.attention_policy is not None:
# Sparse prefill (GPU-only) - delegate to policy # Attention via policy (GPU-only) - delegate to policy
o = context.sparse_prefill_policy.sparse_prefill_attention( o = context.attention_policy.compute_prefill(
q, k, v, self.layer_id q, k, v, self.layer_id, softmax_scale=self.scale
) )
else: else:
o = flash_attn_varlen_func(q, k, v, o = flash_attn_varlen_func(q, k, v,

View File

@@ -27,13 +27,13 @@ class RMSNorm(nn.Module):
x = x.to(orig_dtype).mul_(self.weight) x = x.to(orig_dtype).mul_(self.weight)
return x return x
@torch.compile
def add_rms_forward( def add_rms_forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch # Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
# Note: @torch.compile removed due to OOM with 64k sequences (memory fragmentation)
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.float().add_(residual.float()) x = x.float().add_(residual.float())
residual = x.to(orig_dtype) residual = x.to(orig_dtype)

38
nanovllm/ops/__init__.py Normal file
View File

@@ -0,0 +1,38 @@
"""
Operators module for nano-vLLM.
This module contains low-level attention operators and kernels.
"""
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse,
merge_attention_outputs,
chunked_attention_varlen,
ChunkedPrefillState,
)
from nanovllm.ops.xattn import (
xattn_estimate,
xattn_estimate_chunked,
flat_group_gemm_fuse_reshape,
softmax_fuse_block_sum,
find_blocks_chunked,
create_causal_mask,
compute_sparsity,
)
__all__ = [
# chunked_attention
"flash_attn_with_lse",
"merge_attention_outputs",
"chunked_attention_varlen",
"ChunkedPrefillState",
# xattn
"xattn_estimate",
"xattn_estimate_chunked",
"flat_group_gemm_fuse_reshape",
"softmax_fuse_block_sum",
"find_blocks_chunked",
"create_causal_mask",
"compute_sparsity",
]

View File

@@ -0,0 +1,624 @@
"""
Chunked attention implementation for CPU KV cache offloading.
This module implements flash attention with LSE (log-sum-exp) output,
enabling proper online softmax merging for chunked prefill.
Key functions:
- flash_attn_with_lse: Flash attention that returns output and LSE
- merge_attention_outputs: Merge outputs from multiple KV chunks
- chunked_prefill_attention: High-level interface for chunked attention
"""
import math
import torch
import triton
import triton.language as tl
from typing import Tuple, List, Optional
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel_with_lse(
Q,
K,
V,
Out,
Lse,
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Flash attention forward kernel with LSE output.
Implements standard Flash Attention online softmax algorithm:
- m_i: running max of attention scores
- l_i: running sum of exp(scores - m_i)
- acc_o: running sum of softmax(scores) @ V (unnormalized)
Final output: acc_o / l_i
Final LSE: m_i + log(l_i)
"""
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Pointers
q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)
# Initialize running statistics
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
# Load Q (once per block)
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
)
# Loop over K, V blocks
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# Load K
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# Compute QK^T * scale
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= softmax_scale
# Apply masks
if not EVEN_N:
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# Online softmax: compute block max
m_ij = tl.max(qk, 1) # [BLOCK_M]
# New running max
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
# Rescale factor for previous accumulator
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
# Compute P = exp(qk - m_new)
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
# Sum of current block
l_ij = tl.sum(p, 1) # [BLOCK_M]
# Update running sum: l_new = l_i * alpha + l_ij
l_new = l_i * alpha + l_ij
# Rescale previous output and add new contribution
acc_o = acc_o * alpha[:, None]
# Load V
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# acc_o += P @ V
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# Update running statistics
m_i = m_new
l_i = l_new
# Final normalization: output = acc_o / l_i
acc_o = acc_o / l_i[:, None]
# Compute LSE = m_i + log(l_i)
lse_i = m_i + tl.log(l_i)
# Store LSE
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
if EVEN_M:
tl.store(lse_ptrs, lse_i)
else:
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
# Store output
out_ptrs = (
Out
+ off_b * stride_ob
+ off_h * stride_oh
+ (offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def flash_attn_with_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Flash attention forward pass that returns both output and LSE.
Uses flash_attn library which natively supports GQA without memory overhead.
Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
causal: Whether to apply causal masking
Returns:
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
"""
from flash_attn.flash_attn_interface import flash_attn_func
batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
# Use flash_attn_func which natively supports GQA (no memory overhead)
# It returns (output, softmax_lse) when return_attn_probs=True is not set
# We need to use the internal function to get LSE
out, lse, _ = flash_attn_func(
q, k, v,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
)
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
# Trim to actual seqlen_q
lse = lse[:, :, :seqlen_q]
return out, lse
@triton.jit
def _merge_lse_kernel(
lse1_ptr, lse2_ptr, lse_out_ptr,
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@triton.jit
def _merge_output_kernel(
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
pid_head = tl.program_id(2)
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
sum_exp = exp1 + exp2
# Process headdim in chunks
for d_offset in range(0, headdim, BLOCK_SIZE):
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
mask = d_idx < headdim
# Compute output index: [batch, seqlen_q, nheads, headdim]
base_idx = (pid_batch * seqlen_q * nheads * headdim +
pid_seq * nheads * headdim +
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
def merge_attention_outputs(
o1: torch.Tensor,
lse1: torch.Tensor,
o2: torch.Tensor,
lse2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge two attention outputs using online softmax (Triton fused kernel).
This implements the online softmax merging formula:
- m_new = max(lse1, lse2)
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
Args:
o1: First output [batch, seqlen_q, nheads, headdim]
lse1: First LSE [batch, nheads, seqlen_q]
o2: Second output [batch, seqlen_q, nheads, headdim]
lse2: Second LSE [batch, nheads, seqlen_q]
Returns:
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
lse_merged: Merged LSE [batch, nheads, seqlen_q]
"""
batch, seqlen_q, nheads, headdim = o1.shape
# Allocate output tensors
o_merged = torch.empty_like(o1)
lse_merged = torch.empty_like(lse1)
# Launch LSE merge kernel
num_lse_elements = batch * nheads * seqlen_q
BLOCK_SIZE_LSE = 256
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
_merge_lse_kernel[grid_lse](
lse1, lse2, lse_merged,
num_lse_elements,
BLOCK_SIZE=BLOCK_SIZE_LSE,
)
# Launch output merge kernel
BLOCK_SIZE = 128
grid_output = (batch, seqlen_q, nheads)
_merge_output_kernel[grid_output](
o1, o2, lse1, lse2, o_merged,
batch, seqlen_q, nheads, headdim,
BLOCK_SIZE=BLOCK_SIZE,
)
return o_merged, lse_merged
def chunked_attention_varlen(
q: torch.Tensor,
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k_list: List[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k_list: List[int],
softmax_scale: Optional[float] = None,
causal_mask_per_chunk: Optional[List[bool]] = None,
) -> torch.Tensor:
"""
Compute attention with KV split across multiple chunks.
This is the core function for chunked prefill. It computes attention
against each KV chunk and merges results using online softmax.
For causal attention with chunked KV:
- First chunk (current tokens): Apply causal mask
- Previous chunks: No causal mask (all previous tokens are valid context)
Args:
q: Query tensor [total_q_tokens, nheads, headdim]
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
max_seqlen_q: Maximum query sequence length
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
softmax_scale: Scaling factor
causal_mask_per_chunk: Whether to apply causal mask for each chunk
Returns:
out: Output tensor [total_q_tokens, nheads, headdim]
"""
if len(kv_chunks) == 0:
raise ValueError("Need at least one KV chunk")
nheads = q.shape[1]
headdim = q.shape[2]
batch = cu_seqlens_q.shape[0] - 1
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
if causal_mask_per_chunk is None:
# Default: causal for last chunk only
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
# Initialize accumulated output and LSE
accumulated_o = None
accumulated_lse = None
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
is_causal = causal_mask_per_chunk[chunk_idx]
# Reshape Q for batch processing
# For varlen, we need to handle each sequence separately
# For simplicity, assume single sequence (batch=1) for now
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
# Compute attention for this chunk
chunk_o, chunk_lse = flash_attn_with_lse(
q_batched,
k_chunk,
v_chunk,
softmax_scale=softmax_scale,
causal=is_causal,
)
# Merge with accumulated
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse,
)
# Remove batch dimension
return accumulated_o.squeeze(0)
class ChunkedPrefillState:
"""
State for tracking chunked prefill progress.
This class maintains the accumulated attention output and LSE
across multiple prefill chunks.
"""
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
self.num_layers = num_layers
self.dtype = dtype
self.device = device
# Per-layer accumulated outputs
# Each entry: (accumulated_output, accumulated_lse) or None
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
None for _ in range(num_layers)
]
# Track which chunks have been processed
self.processed_chunks: int = 0
def update_layer(
self,
layer_id: int,
chunk_output: torch.Tensor,
chunk_lse: torch.Tensor,
):
"""Update accumulated state for a layer with a new chunk's output."""
if self.layer_states[layer_id] is None:
self.layer_states[layer_id] = (chunk_output, chunk_lse)
else:
acc_o, acc_lse = self.layer_states[layer_id]
merged_o, merged_lse = merge_attention_outputs(
acc_o, acc_lse,
chunk_output, chunk_lse,
)
self.layer_states[layer_id] = (merged_o, merged_lse)
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
"""Get the final accumulated output for a layer."""
if self.layer_states[layer_id] is None:
return None
return self.layer_states[layer_id][0]
def clear(self):
"""Clear all accumulated state."""
self.layer_states = [None for _ in range(self.num_layers)]
self.processed_chunks = 0
# Test function
def _test_chunked_attention():
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
from flash_attn.flash_attn_interface import flash_attn_func
torch.manual_seed(42)
print("=" * 70)
print("Test: Chunked attention vs flash_attn_func (non-causal)")
print("=" * 70)
print("Splitting K,V into chunks, computing attention per chunk, then merging")
print()
for dtype in [torch.float16, torch.bfloat16]:
for num_chunks in [64, 128, 256]:
for batch, seqlen, nheads, headdim in [
(1, 1024, 32, 128),
(1, 2048, 32, 128),
(1, 4096, 32, 128),
(1, 8192, 32, 128),
]:
# Generate random Q, K, V
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
# Reference: full attention (non-causal)
out_ref = flash_attn_func(q, k, v, causal=False)
# Chunked attention: split K, V into chunks
chunk_size = seqlen // num_chunks
accumulated_o = None
accumulated_lse = None
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
k_chunk = k[:, start:end, :, :]
v_chunk = v[:, start:end, :, :]
# Q attends to this K,V chunk (non-causal)
chunk_o, chunk_lse = flash_attn_with_lse(
q, k_chunk, v_chunk, causal=False
)
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
# Merge with previous chunks
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse
)
# Compare
out_diff = (out_ref - accumulated_o).abs()
out_max_diff = out_diff.max().item()
out_mean_diff = out_diff.mean().item()
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
print(
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
)
print()
print("=" * 70)
print("Test completed!")
if __name__ == "__main__":
_test_chunked_attention()

1167
nanovllm/ops/xattn.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -14,9 +14,9 @@ class Context:
context_lens: torch.Tensor | None = None context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None block_tables: torch.Tensor | None = None
# Sparse prefill attention support (GPU-only path) # Attention policy support (GPU-only path)
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention # When set, uses policy.compute_prefill() instead of FlashAttention
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True attention_policy: Any = None # AttentionPolicy instance
_CONTEXT = Context() _CONTEXT = Context()
@@ -35,7 +35,7 @@ def set_context(
slot_mapping=None, slot_mapping=None,
context_lens=None, context_lens=None,
block_tables=None, block_tables=None,
sparse_prefill_policy=None, attention_policy=None,
): ):
global _CONTEXT global _CONTEXT
_CONTEXT = Context( _CONTEXT = Context(
@@ -47,7 +47,7 @@ def set_context(
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
sparse_prefill_policy=sparse_prefill_policy, attention_policy=attention_policy,
) )

378
notes.md
View File

@@ -1,324 +1,130 @@
# Notes: Sparsity Integration into Layerwise Offload # Notes: SparsePolicy Refactoring Research
## Current Architecture Analysis ## Sources
### GPU-Only Path vs Offload Path ### Source 1: tzj/minference branch - policy.py
- 路径: `nanovllm/kvcache/sparse/policy.py`
- 关键设计:
- `PolicyContext` 数据类包含 query_chunk_idx, num_query_chunks, layer_id, query, is_prefill 等
- `select_blocks()` 需要 offload_engine 参数
- `compute_chunked_prefill()``compute_chunked_decode()` 是完整的 attention 流程
- `on_prefill_offload()` / `on_decode_offload()` hooks 用于收集元数据
| Aspect | GPU-Only | Layerwise Offload | ### Source 2: tzj/minference branch - full_policy.py
|--------|----------|-------------------| - 路径: `nanovllm/kvcache/sparse/full_policy.py`
| KV Storage | GPU blocks (paged) | CPU pinned + GPU ring buffer | - 关键实现:
| Prefill | All layers → then attention | Per-layer: attention → offload | - `compute_chunked_prefill()` 内部使用 ring buffer pipeline 加载 blocks
| Decode | FlashAttn with block table | Ring buffer H2D → FlashAttn | - 使用 `flash_attn_with_lse``merge_attention_outputs` 合并多个 chunk 的 attention
| Sparse Support | MInference via `attention.py` | Not integrated | - `compute_chunked_decode()` 处理 prefilled blocks + decode buffer
### MInference Flow (GPU-Only) ### Source 3: tzj/layer-offload branch - model_runner.py
- 路径: `nanovllm/engine/model_runner.py`
- 关键设计:
- `run_layerwise_offload_prefill()` 逐层处理,每层计算完整 attention
- `sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)` 简单接口
- FULL policy 通过 `if sparse_prefill_policy is None` 走 else 分支
``` ### Source 4: tzj/layer-offload branch - xattn.py
attention.py:101-105: - 路径: `nanovllm/kvcache/sparse/xattn.py`
if context.sparse_prefill_policy is not None: - 关键实现:
o = context.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id) - `sparse_prefill_attention()` 直接使用 FlashAttention因为 chunked prefill 架构限制)
- 保留 Triton kernels 供未来 GPU-only 模式
minference.py:sparse_prefill_attention(): ## Synthesized Findings
1. estimate_pattern(q, k, layer_id) -> vertical_indices, slash_indices
2. _triton_mixed_sparse_attention(q, k, v, indices)
3. return output
```
### Quest Flow (GPU Block Mode) ### 架构差异总结
``` | 方面 | Chunked Offload | Layerwise Offload |
hybrid_manager.py (if using CPU offload with Quest): |------|-----------------|-------------------|
select_blocks(available_blocks, ctx) -> selected block IDs | **Prefill 流程** | chunk-by-chunk跨层 | layer-by-layer完整序列 |
-> load selected blocks to GPU | **KV 存储** | 每 chunk 立即 offload | 每层计算后 offload |
-> standard FlashAttn with loaded blocks | **Attention 计算** | 分多次计算+合并 | 一次完整计算 |
``` | **Block 加载** | 需要从 CPU 加载历史 | 不需要,已在 GPU |
| **Policy 责任** | 完整 attention 流程 | 仅 attention kernel 选择 |
### Layerwise Offload Prefill Flow ### Layerwise Offload 的简化点
``` 1. **不需要 block selection**: 整层 KV 都在 GPU无需选择
model_runner.py:run_layerwise_offload_prefill(): 2. **不需要 offload_engine 参数**: Policy 不负责加载 KV
for layer_id in range(num_layers): 3. **不需要 merge_attention_outputs**: 一次计算完整 attention
# QKV projection 4. **不需要 offload hooks**: offload 在 model_runner 统一处理
q, k, v = qkv_proj(hidden_ln)
# RoPE ### 设计建议
q, k = rotary_emb(positions, q, k)
# FULL attention (no sparsity!) 1. **保持接口简单**: 只需要 `compute_prefill_attention()``compute_decode_attention()`
attn_output = flash_attn_varlen_func(q, k, v, ...) 2. **FULL 也实现方法**: 不再通过 `is None` 判断,所有 policy 统一调用
3. **移除不必要的参数**: 不需要 offload_engine, kvcache_manager, seq 等
4. **统一命名**: 使用 `compute_*_attention` 而不是 `sparse_prefill_attention`
# MLP ## Code Examples
hidden_states = mlp(attn_out + residual)
# Sync offload ALL k, v to CPU ### 当前调用方式 (model_runner.py:876-891)
for block_id in cpu_block_ids:
k_cache_cpu[layer_id, block_id].copy_(k[start:end])
v_cache_cpu[layer_id, block_id].copy_(v[start:end])
```
### Layerwise Offload Decode Flow
```
model_runner.py:run_layerwise_offload_decode():
# Preload first N layers to ring buffer
for i in range(num_buffers):
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
for layer_id in range(num_layers):
current_buffer = layer_id % num_buffers
# Wait for buffer load
offload_engine.wait_buffer_load(current_buffer)
# Get prefilled KV from ring buffer (ALL blocks loaded)
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
# QKV for new token
q, k_new, v_new = qkv_proj(hidden_ln)
# Concat and full attention
k_full = torch.cat([k_prefill, k_decode_prev, k_new])
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...)
# Start loading next layer
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
```
## Integration Points
### 1. Prefill Sparse Integration Point
**Location:** `model_runner.py:535-543`
**Current:**
```python ```python
attn_output = flash_attn_varlen_func( # Sparse or Full attention
q, k, v, if self.sparse_prefill_policy is not None:
cu_seqlens_q=cu_seqlens, # MInference or other sparse prefill policy
cu_seqlens_k=cu_seqlens, attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
```
**After Integration:**
```python
if self.sparse_policy and self.sparse_policy.supports_offload_prefill:
attn_output, k_sparse, v_sparse = self.sparse_policy.offload_prefill_attention(
q, k, v, layer_id q, k, v, layer_id
) )
k_to_offload = k_sparse if k_sparse is not None else k
v_to_offload = v_sparse if v_sparse is not None else v
else: else:
attn_output = flash_attn_varlen_func(q, k, v, ...) # Full attention using FlashAttention
k_to_offload, v_to_offload = k, v attn_output = flash_attn_varlen_func(
``` q, k, v, ...
### 2. Decode Sparse Integration Point
**Location:** `model_runner.py:636-637` and `model_runner.py:704-706`
**Current (preload):**
```python
for i in range(num_preload):
offload_engine.load_layer_kv_to_buffer(
i, i, cpu_block_table, valid_tokens_per_block
) )
``` ```
**After Integration:** ### 建议的新调用方式
```python ```python
for i in range(num_preload): # 所有 policy 统一调用
layer_to_load = i attn_output = self.attention_policy.compute_prefill_attention(
if self.sparse_policy and self.sparse_policy.supports_offload_decode: q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
# Prepare q for this layer (need to compute ahead)
# OR: use previous layer's pattern as estimate
selected_blocks = self.sparse_policy.select_offload_blocks(
None, # q not available yet at preload
layer_to_load,
cpu_block_table,
valid_tokens_per_block
)
else:
selected_blocks = cpu_block_table
offload_engine.load_sparse_layer_kv_to_buffer(
i, layer_to_load, selected_blocks, valid_tokens_per_block
) )
``` ```
**Challenge:** Q is not available during preload phase! ## Questions Resolved
**Solutions:** - Q: 是否需要 PolicyContext?
1. Skip sparse preload, only sparse for non-preloaded layers - A: 可以简化,因为 layerwise 模式下不需要 chunk 信息
2. Use previous decode step's pattern as estimate
3. Add preload hook to sparse policy
### 3. Offload Engine Extension - Q: decode 阶段如何处理?
- A: **Decode 不需要 policy**!当前 `run_layerwise_offload_decode()` 使用标准 `layer(positions, hidden_states, residual)` 调用,走 Attention.forward() 路径
**New Method in OffloadEngine:** - Q: 为什么 decode 不需要 sparse?
- A: 因为 decode 每次只有 1 个 token没有稀疏化的意义。KV 从 ring buffer 加载后直接用 flash_attn_with_kvcache
```python ## Key Insight
def load_sparse_layer_kv_to_buffer(
self,
buffer_idx: int,
layer_id: int,
selected_cpu_block_ids: List[int],
original_valid_tokens: List[int],
) -> int:
"""
Load only selected blocks from CPU to buffer.
Returns: **Layerwise Offload 的 Policy 设计应该只关注 Prefill**
Total tokens loaded (may be less than full sequence)
"""
stream = self.layer_load_streams[buffer_idx]
with torch.cuda.stream(stream): ```
stream.wait_event(self.buffer_compute_done_events[buffer_idx]) Prefill: 需要 Policy
- 整个序列一次计算 attention
- 可以使用 sparse attention 方法(如 MInference 的 vertical+slash pattern
- Policy 接收 q, k, v, layer_id, softmax_scale
# Build mapping: original block -> selected position Decode: 不需要 Policy
offset = 0 - 每次只有 1 个 token query
for i, cpu_block_id in enumerate(selected_cpu_block_ids): - KV 从 ring buffer 加载
# Find original index to get valid tokens - 使用标准 flash_attn_with_kvcache
valid_tokens = original_valid_tokens[i] # Need mapping
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
non_blocking=True
)
# ... v_cache same
offset += valid_tokens
self.buffer_load_events[buffer_idx].record(stream)
return offset # Caller needs to know actual loaded tokens
``` ```
## Metadata Flow for Quest ## Interface Comparison Summary
### During Prefill Offload | 方面 | tzj/minference | tzj/layer-offload (新设计) |
|------|----------------|---------------------------|
| 类名 | SparsePolicy | AttentionPolicy |
| Prefill 方法 | compute_chunked_prefill() | compute_attention() |
| Decode 方法 | compute_chunked_decode() | 不需要(用标准路径) |
| 需要 offload_engine | 是 | 否 |
| 需要 kvcache_manager | 是 | 否 |
| 需要 seq | 是 | 否 |
| 支持 FULL | 是 | 是 |
**Current:** No metadata collection in offload path ## Migration Path
**Required:** Call `on_prefill_offload()` for each block 1. 保留 `SparsePolicy` 作为 `AttentionPolicy` 的别名
2. 保留 `PolicyContext` 供未来扩展
```python 3. 保留 `select_blocks()` 方法签名(虽然不使用)
# In run_layerwise_offload_prefill() 4. 移除 `requires_block_selection` 属性(不需要)
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
# BEFORE offload: update Quest metadata
if self.sparse_policy and hasattr(self.sparse_policy, 'on_prefill_offload'):
self.sparse_policy.on_prefill_offload(
cpu_block_id, layer_id, k[start:end], actual_size
)
# Offload
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
```
### Quest Metadata Shape
```python
# BlockMetadataManager
key_min: [num_blocks, num_layers, num_kv_heads, head_dim] # Min key per block per layer
key_max: [num_blocks, num_layers, num_kv_heads, head_dim] # Max key per block per layer
```
**Memory:** 2 * num_blocks * num_layers * kv_heads * head_dim * 2 bytes
- Example: 1000 blocks * 28 layers * 4 heads * 128 dim * 2 * 2 = ~57 MB
## Performance Considerations
### MInference Prefill Overhead
| Operation | Time (64K seq) |
|-----------|----------------|
| Pattern estimation (last-64) | ~5ms |
| Triton sparse attention | ~80ms |
| Full FlashAttention | ~100ms |
| **Net Speedup** | ~15-20% |
### Quest Decode Overhead
| Operation | Time |
|-----------|------|
| Block scoring (GPU metadata) | ~0.1ms |
| Top-K selection | ~0.05ms |
| Sparse H2D load (8 blocks) | ~2ms |
| Full H2D load (100 blocks) | ~20ms |
| **Net Speedup** | ~10x H2D |
### Memory Trade-offs
| Mode | GPU Memory | CPU Memory | H2D Bandwidth |
|------|------------|------------|---------------|
| Full offload | Ring buffer | Full KV | High |
| Sparse offload | Ring buffer | Full KV | Low (subset) |
| Aggressive sparse | Ring buffer | Sparse KV | Very low |
## Edge Cases
### 1. Short Sequences (< sparse threshold)
```python
if total_tokens < sparse_threshold:
# Fall back to full attention
use_sparse = False
```
### 2. First Decode Step (no previous Q)
Quest can't score blocks without Q. Options:
- Use average embedding as proxy
- Load all blocks for first step
- Use prefill pattern as estimate
### 3. Variable Sequence Lengths in Batch
Layerwise offload currently only supports batch_size=1:
```python
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
```
Sparse integration should maintain this constraint.
### 4. Ring Buffer vs Sparse Load Mismatch
Ring buffer assumes fixed `total_prefill_tokens`:
```python
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, total_prefill_tokens)
```
Sparse load has variable token count. Need:
```python
# Track actual loaded tokens per buffer
loaded_tokens[buffer_idx] = sparse_load_count
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, loaded_tokens[buffer_idx])
```
## Testing Strategy
### Unit Tests
1. `test_sparse_policy_interface.py` - Verify new interface methods
2. `test_minference_offload.py` - MInference in offload mode
3. `test_quest_offload.py` - Quest block selection in offload mode
### Integration Tests
1. `test_offload_sparse_e2e.py` - Full prefill+decode with sparsity
2. `test_accuracy_comparison.py` - Compare outputs: full vs sparse
### Benchmarks
1. `bench_offload_sparse.py` - Compare:
- Full offload (baseline)
- MInference prefill + Quest decode
- Aggressive sparse offload

View File

@@ -1,155 +0,0 @@
# Progress Log: nanovllm 多请求状态污染问题
## Session: 2026-01-12
### 资源分配
| 资源 | 分配 |
|------|------|
| **GPU** | **1** (严格限制,不可更改) |
### 任务目标
研究 nanovllm CPU offload 模式下多请求之间状态影响导致准确率下降的问题。
---
### 10:00 - 启动分析
**完成**:
- [x] 读取 `docs/offload_accuracy_issue.md` 了解问题背景
- [x] 激活 Serena MCP 项目
- [x] 获取关键组件符号概览
**关键文件已分析**:
- `nanovllm/kvcache/offload_engine.py` - OffloadEngine 类
- `nanovllm/kvcache/hybrid_manager.py` - HybridKVCacheManager 类
- `nanovllm/engine/model_runner.py` - ModelRunner 类
- `nanovllm/engine/llm_engine.py` - LLMEngine 类
- `nanovllm/engine/scheduler.py` - Scheduler 类
---
### 10:15 - 深入代码分析
**分析的方法**:
| 方法 | 文件 | 发现 |
|------|------|------|
| `OffloadEngine.__init__` | offload_engine.py:40-145 | 初始化所有 buffer无 reset 方法 |
| `deallocate` | hybrid_manager.py:218-244 | 只清理逻辑块,不清理 OffloadEngine |
| `clear_decode_tracking` | hybrid_manager.py:538-549 | 清理 tracking 字典,但未被调用 |
| `run_layerwise_offload_decode` | model_runner.py:867-1057 | 包含 decode buffer 读写逻辑 |
| `generate` | llm_engine.py:114-151 | 请求循环逻辑 |
| `postprocess` | scheduler.py:93-99 | 调用 deallocate |
**关键发现 #1**: OffloadEngine 没有 reset() 方法
**关键发现 #2**: deallocate() 没有调用 clear_decode_tracking()
**关键发现 #3**: decode_buffer 在请求间不清理,可能导致状态污染
---
### 10:30 - 根因定位
**确认的问题**:
1. **decode buffer 残留**
- 位置: `offload_engine.decode_k_buffer`, `decode_v_buffer`
- 写入: `model_runner.py:1010-1013`
- 读取: `model_runner.py:969-976`
- 问题: 旧请求的 KV 数据可能被新请求读取
2. **tracking 字典未清理**
- 位置: `hybrid_manager._decode_start_pos`, `_prefill_len`
- 问题: 使用 `id(seq)` 作为 key可能重用
3. **缺失的清理调用**
- `clear_decode_tracking()``deallocate()` 中未被调用
---
### 10:45 - 创建规划文件
**创建的文件**:
- [x] `task_plan.md` - 完整的任务规划和阶段
- [x] `findings.md` - 详细的代码分析发现
- [x] `progress.md` - 本文件
---
### 11:00 - Sequential Thinking 深入分析
**使用 sequential thinking 验证分析结果**:
- 确认 deallocate() 确实没有调用 clear_decode_tracking()
- 分析 _decode_start_pos 和 _prefill_len 字典的生命周期
- 确定 id(seq) 重用是问题的触发条件
---
### 11:15 - 完成规划文件
**更新的文件**:
- [x] `task_plan.md` - 添加完整的 debug 方案和实施计划
- [x] `findings.md` - 详细的代码分析和修复方向
- [x] `progress.md` - 更新到当前进度
---
## 下一步 (待用户确认)
**执行顺序**:
1. **实施修复** - 修改 `deallocate()` 添加 `clear_decode_tracking(seq)`
2. **快速验证** - 20 样本连续执行(一次调用,不重启框架)→ 目标 20/20
3. **完整验证** - 100 样本 → 目标 100/100 (最终验收)
4. **防御性修复** (可选) - 添加 `OffloadEngine.on_sequence_finished()`
**核心修改** (一行代码):
```python
# hybrid_manager.py:deallocate() 末尾添加
self.clear_decode_tracking(seq)
```
**验收标准**:
| 测试 | 样本数 | 通过要求 |
|------|--------|----------|
| 快速验证 | 20 | 20/20 (100%) |
| 完整验证 | 100 | 100/100 (100%) |
---
## 错误记录
| 时间 | 错误 | 解决方案 |
|------|------|----------|
| 10:05 | Serena MCP 未激活 | 调用 activate_project |
---
## 文件修改记录
| 文件 | 操作 | 状态 |
|------|------|------|
| task_plan.md | 创建+更新 | 完成 |
| findings.md | 创建 | 完成 |
| progress.md | 创建+更新 | 完成 |
---
## 分析结论
**重要澄清**: nanovllm offload 模式**不支持 batch**,只能单个 request 顺序执行。问题出在**请求切换**时状态清理不完整。
**根本原因已确认**: `deallocate()` 没有调用 `clear_decode_tracking()`,导致 `_decode_start_pos``_prefill_len` 字典残留,当 Python 对象 ID 重用时,新请求会错误地使用旧请求的配置。
**修复方案已设计**: 在 `deallocate()` 末尾添加 `self.clear_decode_tracking(seq)` 调用。
---
## 关键理解
问题不是 "batch 处理",而是:
```
Request A 完成 → deallocate(A) [状态未完全清理] → Request B 开始 → B 读到 A 的残留状态
```

View File

@@ -1,359 +1,549 @@
# Task Plan: nanovllm CPU Offload 多请求状态污染问题 # Task Plan: Refactor SparsePolicy for Layerwise Offload
## 问题概述 ## Goal
重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。
**重要说明**: nanovllm offload 模式目前**不支持 batch**,只能单个 request 顺序执行。问题出在**请求切换**时的状态清理。 ## Background
| 模式 | 测试方式 | 准确率 | ### 两种 Offload 架构对比
|------|----------|--------|
| CPU Offload | 独立进程 (每请求一个进程) | **100%** |
| CPU Offload | 同进程顺序多请求 | 66% |
| Non-Offload | 同进程顺序多请求 | 100% |
**结论**: 单请求推理正确,问题在于**请求切换**时状态清理不完整。 | 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) |
|------|----------------------------------|---------------------------------------|
| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) |
| KV 位置 | 历史 chunks 在 CPU需要加载 | 整层 KV 都在 GPU |
| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` |
| 需要 offload_engine | 是(加载 blocks | 否KV 已在 GPU |
| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask |
### tzj/minference 的 Policy 接口
```python
class SparsePolicy(ABC):
supports_prefill: bool
supports_decode: bool
@abstractmethod
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
@abstractmethod
def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor
@abstractmethod
def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor
```
### 当前 branch 的 Policy 接口(重构前)
```python
class SparsePolicy(ABC):
supports_prefill: bool
supports_decode: bool
@abstractmethod
def select_blocks(self, available_blocks, ctx) -> List[int]
def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor
```
## Phases
- [x] Phase 1: 分析差异并设计新接口
- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过
- [ ] Phase 2: 重构 AttentionPolicy 基类
- [ ] Phase 3: 重构 FullAttentionPolicy
- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
- [ ] Phase 5: 更新 model_runner 调用方式
- [ ] Phase 6: 测试验证
--- ---
## Phase 1: 代码分析 (complete) ## Phase 0: 创建 nanovllm.ops 模块
### 1.1 识别状态管理组件 ### 目标
从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。
**已分析的关键组件**: ### 步骤
| 组件 | 文件 | 状态数据 |
|------|------|----------|
| `OffloadEngine` | `nanovllm/kvcache/offload_engine.py` | ring buffer, decode buffer, CUDA events |
| `HybridKVCacheManager` | `nanovllm/kvcache/hybrid_manager.py` | logical blocks, prefilled_blocks, _decode_start_pos, _prefill_len |
| `LLMEngine` | `nanovllm/engine/llm_engine.py` | generate() 循环,请求生命周期 |
| `Scheduler` | `nanovllm/engine/scheduler.py` | postprocess() 调用 deallocate() |
### 1.2 请求生命周期分析
1. **创建目录结构**
``` ```
generate() nanovllm/ops/
→ 多个请求添加到 scheduler ├── __init__.py
→ while not finished: ├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels
→ schedule() 获取下一批 seqs └── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用)
→ model_runner.run() 执行推理
→ postprocess() 处理完成的请求
→ 如果完成: kvcache_manager.deallocate(seq)
``` ```
--- 2. **从 tzj/minference 提取文件**
## Phase 2: 根本原因分析 (complete)
### 2.1 核心问题: OffloadEngine 缺少 reset() 方法
**关键发现**: `OffloadEngine` 没有任何重置/清理方法!
当请求完成时,`HybridKVCacheManager.deallocate()` 被调用,但它只清理:
- 逻辑块状态 (`block.reset()`)
- 物理块引用 (`free_cpu_blocks`, `cpu_block_to_logical`)
- prefilled_blocks 集合
- _decode_start_pos / _prefill_len 字典
**未被清理的状态** (存在于 OffloadEngine):
| 状态 | Shape | 问题 |
|------|-------|------|
| `layer_k_cache` | [num_buffers, max_seq_len, kv_heads, head_dim] | 包含旧请求的 KV |
| `layer_v_cache` | [num_buffers, max_seq_len, kv_heads, head_dim] | 包含旧请求的 KV |
| `decode_k_buffer` | [num_layers, block_size, kv_heads, head_dim] | 包含旧请求的 decode KV |
| `decode_v_buffer` | [num_layers, block_size, kv_heads, head_dim] | 包含旧请求的 decode KV |
### 2.2 具体污染场景
`run_layerwise_offload_decode()` (model_runner.py:867-1057):
```python
# 第 969-976 行: 读取之前的 decode KV
if num_prev_decode_tokens > 0:
k_decode_prev, v_decode_prev = offload_engine.get_decode_kv(
layer_id, decode_start_pos, pos_in_block
)
ring_k[...].copy_(k_decode_prev) # 可能读取旧请求的数据!
```
**场景**:
1. 请求 A (32K tokens) 完成decode_buffer 保留其 KV 数据
2. 请求 B 开始,其 `decode_start_pos` 可能非零(如果继承了旧状态)
3. 请求 B 在第一个 decode step 时错误地读取了请求 A 的 decode buffer 数据
### 2.3 潜在问题点
1. **decode_start_pos 计算错误**:
- `get_decode_start_pos()` 使用 `id(seq)` 作为 key
- Python 对象 ID 可能在请求之间重用
- 如果新 seq 对象的 ID 与旧 seq 相同,可能错误继承旧的 start_pos
2. **decode buffer 残留数据**:
- 如果 `pos_in_block` 在新请求中与旧请求重叠
- `get_decode_kv()` 会返回旧请求的数据
3. **ring buffer 残留数据**:
- 虽然每次 decode 会从 CPU 加载,但 decode buffer 的数据会被复制过来
- 如果 decode buffer 有残留,会污染 ring buffer
---
## Phase 3: Debug 方案设计 (complete)
### 3.1 确认的根本原因
通过代码分析,确认了两个根本原因:
**根本原因 1 (主要)**: `deallocate()` 不调用 `clear_decode_tracking()`
- 位置: `hybrid_manager.py:218-244`
- 影响: `_decode_start_pos``_prefill_len` 字典残留
- 后果: 如果 `id(seq)` 重用,返回错误的 decode 配置
**根本原因 2 (次要)**: decode_buffer 不清理
- 位置: `offload_engine.py`
- 影响: `decode_k_buffer/v_buffer` 保留旧 KV
- 后果: 可能被根本原因 1 触发读取
### 3.2 Debug 方案 A: 验证字典残留 (推荐先做)
**目标**: 验证 `_decode_start_pos` 字典是否有残留
**诊断代码** (添加到 `hybrid_manager.py`):
```python
# 在 get_decode_start_pos() 开头添加
def get_decode_start_pos(self, seq: Sequence) -> int:
seq_id = id(seq)
# DEBUG: 检查是否命中旧值
if seq_id in self._decode_start_pos:
logger.warning(f"[DEBUG] get_decode_start_pos: CACHE HIT! seq_id={seq_id}, "
f"cached_value={self._decode_start_pos[seq_id]}, "
f"expected={(len(seq) - 1) % self._block_size}")
# ... 原有逻辑
```
**诊断代码** (添加到 `deallocate()` 末尾):
```python
def deallocate(self, seq: Sequence) -> None:
# ... 现有逻辑 ...
# DEBUG: 打印未清理的状态
seq_id = id(seq)
if seq_id in self._decode_start_pos:
logger.warning(f"[DEBUG] deallocate: _decode_start_pos NOT CLEARED! "
f"seq_id={seq_id}, value={self._decode_start_pos[seq_id]}")
```
### 3.3 Debug 方案 B: 最小复现测试
**文件**: `tests/test_multi_request_offload_debug.py`
```python
"""最小复现批量模式失败"""
import os
import sys
sys.path.insert(0, os.getcwd())
from nanovllm import LLM
from nanovllm.sampling import SamplingParams
# 使用 RULER NIAH 的两个样本
PROMPTS = [
# Sample 0 (通常成功)
"...", # 从 niah_single_1_32k.jsonl 加载
# Sample 1 (通常失败)
"...",
]
EXPECTED = ["8930103", "4194548"]
def main():
llm = LLM(
"~/models/Llama-3.1-8B-Instruct",
max_model_len=33792,
max_num_batched_tokens=33792,
enable_cpu_offload=True,
num_gpu_blocks=4,
kvcache_block_size=1024,
enforce_eager=True,
)
params = SamplingParams(temperature=0.1, max_tokens=50)
# 连续处理两个请求
for i, (prompt, expected) in enumerate(zip(PROMPTS, EXPECTED)):
print(f"\n{'='*60}")
print(f"Sample {i}: Expected = {expected}")
# 打印关键状态
kvm = llm.model_runner.kvcache_manager
print(f" _decode_start_pos 字典大小: {len(kvm._decode_start_pos)}")
print(f" _prefill_len 字典大小: {len(kvm._prefill_len)}")
outputs = llm.generate([prompt], params, use_tqdm=False)
output_text = outputs[0]["text"]
passed = expected in output_text
print(f" Output: {output_text[:100]}...")
print(f" Status: {'PASS' if passed else 'FAIL'}")
if __name__ == "__main__":
main()
```
### 3.4 Debug 方案 C: 快速修复验证
**目标**: 验证修复 `deallocate()` 是否解决问题
**修改** (`hybrid_manager.py:218-244`):
```python
def deallocate(self, seq: Sequence) -> None:
"""Release all blocks for a sequence."""
for logical_id in reversed(seq.block_table):
# ... 现有逻辑 ...
seq.num_cached_tokens = 0
seq.block_table.clear()
# === 新增: 清理 decode tracking ===
self.clear_decode_tracking(seq)
```
**验证命令**:
```bash ```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \ git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
--model ~/models/Llama-3.1-8B-Instruct \ git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
--enable-offload \ git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
--sample-indices 0,1,2,3,4 \
--verbose
``` ```
### 3.5 Debug 方案 D: 添加 OffloadEngine 清理 (防御性) 3. **Cherry-pick 测试文件**
```bash
**目标**: 进一步隔离请求状态 git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
**添加方法** (`offload_engine.py`):
```python
def on_sequence_finished(self):
"""清理请求完成后的状态"""
# 清零 decode buffer (防止残留数据被读取)
self.decode_k_buffer.zero_()
self.decode_v_buffer.zero_()
logger.debug("OffloadEngine: decode buffer cleared")
``` ```
**调用点** (`hybrid_manager.py:deallocate` 末尾): 4. **运行测试验证**
```python ```bash
# 清理 OffloadEngine 状态 CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
if self.offload_engine is not None: python tests/test_xattn_estimate_chunked.py
self.offload_engine.on_sequence_finished()
``` ```
--- ### nanovllm/ops 模块内容
## Phase 4: 实施计划 (pending) | 文件 | 核心函数 | 用途 |
### 推荐执行顺序
1. **Step 4.1**: 实施修复
- 修改 `hybrid_manager.py:deallocate()` 添加 `clear_decode_tracking(seq)`
2. **Step 4.2**: 快速验证 (20 样本连续执行)
- **一次调用** `test_ruler_niah.py`,连续执行 20 个样本
- **不重启框架**,验证请求切换是否正确
- 目标: 20/20 全部通过
3. **Step 4.3**: 完整验证 (100 样本)
- 运行 100 个样本的 RULER NIAH 测试
- 目标: 100/100 全部通过 (准确率从 66% → 100%)
4. **Step 4.4**: 防御性修复 (可选)
- 添加 `OffloadEngine.on_sequence_finished()` 方法
- 清零 decode buffer 作为额外保险
### 具体修改
**文件 1**: `nanovllm/kvcache/hybrid_manager.py`
位置: `deallocate()` 方法末尾 (第 244 行后)
```python
def deallocate(self, seq: Sequence) -> None:
"""Release all blocks for a sequence."""
for logical_id in reversed(seq.block_table):
# ... 现有逻辑 (218-242 行) ...
seq.num_cached_tokens = 0
seq.block_table.clear()
# ============ 新增: 清理 decode tracking ============
self.clear_decode_tracking(seq)
```
**文件 2** (可选): `nanovllm/kvcache/offload_engine.py`
位置: 在类末尾添加新方法
```python
def on_sequence_finished(self):
"""清理请求完成后的状态 (防御性清理)"""
self.decode_k_buffer.zero_()
self.decode_v_buffer.zero_()
```
---
## 关键文件清单
| 文件 | 相关行号 | 说明 |
|------|----------|------| |------|----------|------|
| `nanovllm/kvcache/hybrid_manager.py` | 218-244 | `deallocate()` - **需要修改** | | `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation |
| `nanovllm/kvcache/hybrid_manager.py` | 538-549 | `clear_decode_tracking()` - 已存在 | | `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 |
| `nanovllm/kvcache/hybrid_manager.py` | 485-505 | `get_decode_start_pos()` - 问题读取点 | | `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM |
| `nanovllm/kvcache/hybrid_manager.py` | 519-537 | `get_prefill_len()` - 问题读取点 | | `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum |
| `nanovllm/kvcache/offload_engine.py` | 40-145 | `__init__` - 状态初始化 | | `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold |
| `nanovllm/kvcache/offload_engine.py` | (新增) | `on_sequence_finished()` - 可选防御 | | `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output |
| `nanovllm/engine/model_runner.py` | 867-1057 | `run_layerwise_offload_decode()` | | `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks |
| `nanovllm/engine/model_runner.py` | 969-976 | decode buffer 读取 (污染点) |
--- ### 与 Policy 的关系
## 验证命令 ```
XAttentionPolicy.estimate()
**指定 GPU: 1** (严格限制,不可更改) └── 调用 nanovllm.ops.xattn.xattn_estimate()
├── flat_group_gemm_fuse_reshape() (Triton)
```bash ├── softmax_fuse_block_sum() (Triton)
# 快速验证 (20 样本连续执行,不重启框架) └── find_blocks_chunked()
# 目标: 20/20 通过
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sample-indices 0-19 \
--verbose
# 完整验证 (100 样本)
# 目标: 100/100 通过 (最终验收)
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--quiet
``` ```
**验收标准**:
| 测试 | 样本数 | 通过要求 | 说明 |
|------|--------|----------|------|
| 快速验证 | 20 | 20/20 (100%) | 一次调用,连续执行,验证请求切换 |
| 完整验证 | 100 | 100/100 (100%) | 最终验收 |
--- ---
## 当前状态 ## Key Questions
- [x] Phase 1: 代码分析 1. **`select_blocks` 改为什么?**
- [x] Phase 2: 根本原因分析 - 改名为 `estimate()`:用于计算 sparse mask
- [x] Phase 3: Debug 方案设计 - 对于 XAttention对应 COMPASS 的 `xattn_estimate()` 函数
- [x] Phase 4: 实施计划 ✅ 100/100 PASSED - FullAttentionPolicy 的 `estimate()` 返回 None表示 full attention
### 验证结果 2. **Policy 接口应该如何设计?**
- Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)`
- Decode: `compute_decode(q, k, v, layer_id, softmax_scale)`
- Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask
| 测试 | 结果 | 日期 | 3. **FULL policy 如何处理?**
|------|------|------| - FULL 也实现 `compute_prefill/decode`,使用 FlashAttention
| 20 样本快速验证 | ✅ 20/20 (100%) | 2026-01-13 | - `estimate()` 返回 None表示不进行稀疏化
| 100 样本完整验证 | ✅ 100/100 (100%) | 2026-01-13 |
## Proposed New Interface
```python
from abc import ABC, abstractmethod
from typing import Optional
import torch
class AttentionPolicy(ABC):
"""Layerwise Offload 模式下的 Attention Policy
所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
支持 prefill 和 decode 两个阶段。
"""
supports_prefill: bool = True
supports_decode: bool = True
def estimate(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
layer_id: int,
) -> Optional[torch.Tensor]:
"""
估算 sparse attention mask。
对于 sparse policy如 XAttention计算哪些 blocks 需要 attend。
对于 full policy返回 None 表示使用完整 attention。
对应 COMPASS 的 xattn_estimate() 函数。
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
"""
return None # 默认为 full attention
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
计算 prefill attention。
整层 KV 都在 GPU 上,一次计算完整 attention。
可以先调用 estimate() 获取 sparse mask然后应用 block sparse attention。
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
pass
def compute_decode(
self,
q: torch.Tensor, # [1, num_heads, head_dim]
k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
计算 decode attention。
KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。
Args:
q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns:
Attention output [1, num_heads, head_dim]
"""
# 默认实现:使用 FlashAttention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
"""Reset policy state between sequences."""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# 保留旧名称作为别名
SparsePolicy = AttentionPolicy
```
## Implementation Plan
### Phase 2: 重构 policy.py
```python
# nanovllm/kvcache/sparse/policy.py
from abc import ABC, abstractmethod
from typing import Optional
import torch
class AttentionPolicy(ABC):
"""Base class for attention policies in layerwise offload mode."""
supports_prefill: bool = True
supports_decode: bool = True
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Estimate sparse attention mask.
For sparse policies (e.g., XAttention), computes block-level importance.
For full policy, returns None.
Corresponds to xattn_estimate() in COMPASS.
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] or None
"""
return None
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""Compute prefill attention."""
pass
def compute_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""Compute decode attention (default: FlashAttention)."""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy
```
### Phase 3: 重构 FullAttentionPolicy
```python
# nanovllm/kvcache/sparse/full_policy.py
import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(AttentionPolicy):
"""Full attention using FlashAttention (no sparsity)."""
supports_prefill = True
supports_decode = True
def estimate(self, q, k, layer_id):
"""Full attention - no sparse mask needed."""
return None
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self):
return "FullAttentionPolicy()"
```
### Phase 4: 重构 XAttentionPolicy
```python
# nanovllm/kvcache/sparse/xattn.py
import torch
from typing import Optional
from .policy import AttentionPolicy
class XAttentionPolicy(AttentionPolicy):
"""
XAttention sparse prefill policy.
Uses chunked estimation to compute sparse attention mask,
then applies block sparse attention.
"""
supports_prefill = True
supports_decode = True
def __init__(
self,
stride: int = 8,
threshold: float = 0.9,
block_size: int = 128,
chunk_size: int = 16384,
use_triton: bool = True,
):
self.stride = stride
self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size
self.use_triton = use_triton
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
XAttention estimation (xattn_estimate).
Uses chunked GEMM + softmax to estimate block-level importance,
then selects important blocks based on threshold.
对应 COMPASS 的 xattn_estimate() 函数:
1. Pad inputs to chunk_size multiples
2. Reshape with stride
3. Compute QK^T in chunks (Triton)
4. Block-wise softmax + aggregation
5. Threshold-based selection
Args:
q: [seq_len, num_heads, head_dim]
k: [seq_len, num_kv_heads, head_dim]
layer_id: transformer layer index
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
or None (fallback to full attention)
"""
# TODO: 实现真正的 xattn_estimate
# 当前返回 None 使用 full attention
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute XAttention sparse prefill.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None, use full attention
3. Otherwise, apply block sparse attention with mask
"""
# Step 1: Estimate sparse mask
sparse_mask = self.estimate(q, k, layer_id)
# Step 2: Compute attention
if sparse_mask is None:
# Fallback to full attention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
else:
# Apply block sparse attention with mask
# 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
raise NotImplementedError("Block sparse attention not yet implemented")
def __repr__(self):
return (f"XAttentionPolicy("
f"stride={self.stride}, "
f"threshold={self.threshold}, "
f"block_size={self.block_size})")
```
### Phase 5: 更新 model_runner.py
```python
# model_runner.py - allocate_kv_cache()
# 改为总是创建 policy包括 FULL
from nanovllm.kvcache.sparse import create_attention_policy
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()
# 旧代码:
if self.sparse_prefill_policy is not None:
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
else:
attn_output = flash_attn_varlen_func(...)
# 新代码:
attn_output = self.attention_policy.compute_prefill(
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
)
```
## Method Mapping
| 旧方法 | 新方法 | 说明 |
|--------|--------|------|
| `select_blocks()` | `estimate()` | 计算 sparse mask对应 xattn_estimate |
| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention |
| (无) | `compute_decode()` | Decode attention默认实现 |
| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 |
## Files to Modify
| File | Changes |
|------|---------|
| `nanovllm/kvcache/sparse/policy.py` | 新接口estimate, compute_prefill, compute_decode |
| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None |
| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() |
| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 |
| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() |
| `nanovllm/config.py` | 可选:重命名配置项 |
## Decisions Made
1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格
2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs
3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()`
4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention
5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现
## Errors Encountered
- (无)
## Status
**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2

View File

@@ -32,11 +32,14 @@ def run_needle_test(
enable_cpu_offload: bool = False, enable_cpu_offload: bool = False,
enable_quest: bool = False, enable_quest: bool = False,
enable_minference: bool = False, enable_minference: bool = False,
enable_xattn: bool = False,
sparse_topk: int = 8, sparse_topk: int = 8,
sparse_threshold: int = 4, sparse_threshold: int = 4,
minference_budget: float = 0.3, minference_budget: float = 0.3,
minference_vertical: int = 1000, minference_vertical: int = 1000,
minference_slash: int = 6096, minference_slash: int = 6096,
xattn_threshold: float = 0.9,
xattn_use_bsa: bool = True,
gpu_utilization: float = 0.9, gpu_utilization: float = 0.9,
enforce_eager: bool = True, enforce_eager: bool = True,
verbose: bool = True, verbose: bool = True,
@@ -56,11 +59,14 @@ def run_needle_test(
enable_cpu_offload: Enable CPU offload mode enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K) enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_minference: Enable MInference sparse prefill (GPU-only) enable_minference: Enable MInference sparse prefill (GPU-only)
enable_xattn: Enable XAttention sparse prefill with BSA
sparse_topk: Top-K blocks for Quest sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold sparse_threshold: Apply sparse only when blocks > threshold
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode) minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
minference_vertical: Fixed vertical_size (only used when budget=None) minference_vertical: Fixed vertical_size (only used when budget=None)
minference_slash: Fixed slash_size (only used when budget=None) minference_slash: Fixed slash_size (only used when budget=None)
xattn_threshold: XAttention block selection threshold (0-1)
xattn_use_bsa: Use Block Sparse Attention library
gpu_utilization: GPU memory utilization fraction gpu_utilization: GPU memory utilization fraction
verbose: Print detailed output verbose: Print detailed output
@@ -68,7 +74,9 @@ def run_needle_test(
True if test passed, False otherwise True if test passed, False otherwise
""" """
# Determine sparse policy # Determine sparse policy
if enable_minference: if enable_xattn:
sparse_policy = SparsePolicyType.XATTN
elif enable_minference:
sparse_policy = SparsePolicyType.MINFERENCE sparse_policy = SparsePolicyType.MINFERENCE
elif enable_quest: elif enable_quest:
sparse_policy = SparsePolicyType.QUEST sparse_policy = SparsePolicyType.QUEST
@@ -94,6 +102,8 @@ def run_needle_test(
print(f" MInference: adaptive (budget={minference_budget})") print(f" MInference: adaptive (budget={minference_budget})")
else: else:
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})") print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
if enable_xattn:
print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}")
print(f"{'='*60}\n") print(f"{'='*60}\n")
# 1. Initialize LLM # 1. Initialize LLM
@@ -111,7 +121,7 @@ def run_needle_test(
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
# Set sparse policy (can be used with or without offload) # Set sparse policy (can be used with or without offload)
if enable_minference or enable_quest: if enable_minference or enable_quest or enable_xattn:
llm_kwargs["sparse_policy"] = sparse_policy llm_kwargs["sparse_policy"] = sparse_policy
# MInference params (works with both GPU-only and offload mode) # MInference params (works with both GPU-only and offload mode)
@@ -120,6 +130,11 @@ def run_needle_test(
llm_kwargs["minference_vertical_size"] = minference_vertical llm_kwargs["minference_vertical_size"] = minference_vertical
llm_kwargs["minference_slash_size"] = minference_slash llm_kwargs["minference_slash_size"] = minference_slash
# XAttention params
if enable_xattn:
llm_kwargs["xattn_threshold"] = xattn_threshold
llm_kwargs["xattn_use_bsa"] = xattn_use_bsa
llm = LLM(model_path, **llm_kwargs) llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt # 2. Generate needle prompt
@@ -224,6 +239,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)" help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
) )
parser.add_argument(
"--enable-xattn",
action="store_true",
help="Enable XAttention sparse prefill with Block Sparse Attention"
)
parser.add_argument( parser.add_argument(
"--sparse-topk", "--sparse-topk",
type=int, type=int,
@@ -254,6 +274,17 @@ if __name__ == "__main__":
default=6096, default=6096,
help="Fixed slash_size (only used when budget=0)" help="Fixed slash_size (only used when budget=0)"
) )
parser.add_argument(
"--xattn-threshold",
type=float,
default=0.9,
help="XAttention block selection threshold (0-1, higher=more blocks)"
)
parser.add_argument(
"--xattn-no-bsa",
action="store_true",
help="Disable Block Sparse Attention (use FlashAttention fallback)"
)
parser.add_argument( parser.add_argument(
"--gpu-utilization", "--gpu-utilization",
type=float, type=float,
@@ -291,11 +322,14 @@ if __name__ == "__main__":
enable_cpu_offload=args.enable_offload, enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest, enable_quest=args.enable_quest,
enable_minference=args.enable_minference, enable_minference=args.enable_minference,
enable_xattn=args.enable_xattn,
sparse_topk=args.sparse_topk, sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold, sparse_threshold=args.sparse_threshold,
minference_budget=minference_budget, minference_budget=minference_budget,
minference_vertical=args.minference_vertical, minference_vertical=args.minference_vertical,
minference_slash=args.minference_slash, minference_slash=args.minference_slash,
xattn_threshold=args.xattn_threshold,
xattn_use_bsa=not args.xattn_no_bsa,
gpu_utilization=args.gpu_utilization, gpu_utilization=args.gpu_utilization,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
verbose=True, verbose=True,

View File

@@ -38,11 +38,11 @@ from nanovllm import LLM, SamplingParams
# Constants # Constants
# ============================================================ # ============================================================
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_32k" DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k"
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct") DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
# Note: max_model_len must be > max_input_len to leave room for output tokens # Note: max_model_len must be > max_input_len to leave room for output tokens
# 32k benchmark has inputs up to 32760 tokens, so we need 32768 + 128 = 32896 # 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664
DEFAULT_MAX_MODEL_LEN = 32896 DEFAULT_MAX_MODEL_LEN = 65664
DEFAULT_MAX_NEW_TOKENS = 128 # Larger for multi-value tasks DEFAULT_MAX_NEW_TOKENS = 128 # Larger for multi-value tasks
# Task categories for evaluation # Task categories for evaluation
@@ -222,9 +222,11 @@ def run_ruler_benchmark(
enable_cpu_offload: bool = False, enable_cpu_offload: bool = False,
num_gpu_blocks: int = 4, num_gpu_blocks: int = 4,
block_size: int = 1024, block_size: int = 1024,
num_kv_buffers: int = 4,
gpu_utilization: float = 0.9, gpu_utilization: float = 0.9,
enforce_eager: bool = True, enforce_eager: bool = True,
verbose: bool = True, verbose: bool = True,
sparse_policy: Optional[str] = None,
) -> Dict: ) -> Dict:
""" """
Run RULER benchmark on multiple tasks. Run RULER benchmark on multiple tasks.
@@ -235,6 +237,7 @@ def run_ruler_benchmark(
datasets: List of task names to test (None = all) datasets: List of task names to test (None = all)
num_samples: Number of samples per task (None = all) num_samples: Number of samples per task (None = all)
...other LLM config params... ...other LLM config params...
sparse_policy: Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)
Returns: Returns:
Dict with overall results and per-task results Dict with overall results and per-task results
@@ -270,6 +273,11 @@ def run_ruler_benchmark(
} }
if enable_cpu_offload: if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["num_kv_buffers"] = num_kv_buffers
if sparse_policy:
from nanovllm.config import SparsePolicyType
sparse_policy_type = SparsePolicyType[sparse_policy]
llm_kwargs["sparse_policy"] = sparse_policy_type
llm = LLM(model_path, **llm_kwargs) llm = LLM(model_path, **llm_kwargs)
@@ -356,12 +364,16 @@ if __name__ == "__main__":
help="Number of GPU blocks for CPU offload (default: 4)") help="Number of GPU blocks for CPU offload (default: 4)")
parser.add_argument("--block-size", type=int, default=1024, parser.add_argument("--block-size", type=int, default=1024,
help="KV cache block size (default: 1024)") help="KV cache block size (default: 1024)")
parser.add_argument("--num-kv-buffers", type=int, default=4,
help="Number of KV buffers for ring buffer (default: 4)")
parser.add_argument("--gpu-utilization", type=float, default=0.9, parser.add_argument("--gpu-utilization", type=float, default=0.9,
help="GPU memory utilization (default: 0.9)") help="GPU memory utilization (default: 0.9)")
parser.add_argument("--use-cuda-graph", action="store_true", parser.add_argument("--use-cuda-graph", action="store_true",
help="Enable CUDA graph") help="Enable CUDA graph")
parser.add_argument("--quiet", "-q", action="store_true", parser.add_argument("--quiet", "-q", action="store_true",
help="Quiet mode") help="Quiet mode")
parser.add_argument("--sparse-policy", type=str, default="",
help="Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)")
args = parser.parse_args() args = parser.parse_args()
@@ -369,6 +381,9 @@ if __name__ == "__main__":
datasets = args.datasets.split(",") if args.datasets else None datasets = args.datasets.split(",") if args.datasets else None
num_samples = args.num_samples if args.num_samples > 0 else None num_samples = args.num_samples if args.num_samples > 0 else None
# Parse sparse policy
sparse_policy_str = args.sparse_policy.upper() if args.sparse_policy else None
results = run_ruler_benchmark( results = run_ruler_benchmark(
model_path=os.path.expanduser(args.model), model_path=os.path.expanduser(args.model),
data_dir=Path(args.data_dir), data_dir=Path(args.data_dir),
@@ -379,9 +394,11 @@ if __name__ == "__main__":
enable_cpu_offload=args.enable_offload, enable_cpu_offload=args.enable_offload,
num_gpu_blocks=args.num_gpu_blocks, num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size, block_size=args.block_size,
num_kv_buffers=args.num_kv_buffers,
gpu_utilization=args.gpu_utilization, gpu_utilization=args.gpu_utilization,
enforce_eager=not args.use_cuda_graph, enforce_eager=not args.use_cuda_graph,
verbose=not args.quiet, verbose=not args.quiet,
sparse_policy=sparse_policy_str,
) )
# Exit code # Exit code

View File

@@ -0,0 +1,244 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask
as standard estimation. This ensures the chunked version can be used in
chunked prefill scenarios without accuracy loss.
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
"""
import sys
import traceback
import torch
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
# Configuration for xattn_estimate_chunked consistency test.
# Key requirements for 100% match:
# 1. Use matching chunk_size for both standard and chunked versions
# 2. Use same random seed for reproducibility
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
# floating point precision in cumulative sum calculations.
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096 # External chunking size
# Test sequence lengths
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
# ============================================================
# Utility Functions
# ============================================================
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
return xattn_estimate_chunked(
query, key,
q_start_pos=0,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
# q_start_pos=0 means Q starts at position 0 in the full sequence
# K is [0, q_chunk_end) for causal attention
k_end = q_chunk_end
k_chunk = key[:, :, :k_end, :]
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
"""Test a single sequence length."""
print(f"\nTesting seq_len={seq_len}")
print("=" * 60)
# Generate random Q/K
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
causal=True,
)
density_std = mask_std.float().mean().item()
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
density_chunked = mask_chunked.float().mean().item()
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
print("XAttention Chunked vs Standard Test")
print("=" * 60)
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
print(f"External chunk_size={CHUNK_SIZE}")
print()
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available!")
sys.exit(1)
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print("✓ xattn_estimate imported")
print("✓ xattn_estimate_chunked imported")
# Run tests
all_passed = True
results = []
for seq_len in TEST_SEQ_LENS:
passed = test_single_seq_len(seq_len)
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
results.append((seq_len, chunks, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, chunks, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
print("=" * 60)
if all_passed:
print("ALL TESTS PASSED!")
sys.exit(0)
else:
print("SOME TESTS FAILED!")
sys.exit(1)