Compare commits
28 Commits
4484a1482c
...
tzj/minfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
52b12a89e3 | ||
|
|
d35dd76e09 | ||
|
|
2b61c5ab57 | ||
|
|
a709551072 | ||
|
|
11a867f6fb | ||
|
|
af4da454ba | ||
|
|
ef37d4f1a8 | ||
|
|
c8a5ef04c0 | ||
|
|
1c36d53570 | ||
|
|
54fd302fa8 | ||
|
|
1eb7521994 | ||
|
|
51bd678335 | ||
|
|
1ea5afd886 | ||
|
|
829b311c02 | ||
|
|
dd0472aea8 | ||
|
|
a1c68a733e | ||
|
|
dc51972777 | ||
|
|
232fcf043e | ||
|
|
aeed6ccdfb | ||
|
|
6c55c4d2a3 | ||
|
|
6e34efd58a | ||
|
|
5acd5558d6 | ||
|
|
193ef55d18 | ||
|
|
f173a3f7f5 | ||
|
|
8035e4db3d | ||
|
|
8ab53e7331 | ||
|
|
2e96d1d97d | ||
|
|
f6ac4ccdde |
90
.claude/rules/test-ruler.md
Normal file
90
.claude/rules/test-ruler.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# test_ruler.py 使用规则
|
||||||
|
|
||||||
|
## 强制规则
|
||||||
|
|
||||||
|
**执行 `test_ruler.py` 前必须查阅文档**,禁止运行 `--help` 或猜测参数。
|
||||||
|
|
||||||
|
| 禁止 | 原因 |
|
||||||
|
|------|------|
|
||||||
|
| `python tests/test_ruler.py --help` | 浪费交互,文档已有完整说明 |
|
||||||
|
| 猜测参数格式 | 容易出错,降低效率 |
|
||||||
|
|
||||||
|
## 必读文档
|
||||||
|
|
||||||
|
**[`docs/test_ruler_usage_guide.md`](../docs/test_ruler_usage_guide.md)** - 包含:
|
||||||
|
- 完整参数说明
|
||||||
|
- 已验证的命令示例
|
||||||
|
- GPU 模式选择指南
|
||||||
|
- max-model-len 设置指南
|
||||||
|
|
||||||
|
## 快速参考
|
||||||
|
|
||||||
|
### 标准命令格式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/<MODEL> \
|
||||||
|
--data-dir tests/data/ruler_<CTX> \
|
||||||
|
--datasets <TASK> \
|
||||||
|
--num-samples <N> \
|
||||||
|
--max-model-len <LEN> \
|
||||||
|
--enable-offload \
|
||||||
|
[--sparse-policy XATTN_BSA] \
|
||||||
|
[--sparse-threshold 0.9]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 常用参数速查
|
||||||
|
|
||||||
|
| 参数 | 用途 | 示例 |
|
||||||
|
|------|------|------|
|
||||||
|
| `--datasets` | 指定任务 | `niah_single_1,qa_1` |
|
||||||
|
| `--num-samples` | 样本数 | `1`, `10`, `0`(全部) |
|
||||||
|
| `--sample-indices` | 指定索引 | `0,5,10` |
|
||||||
|
| `--enable-offload` | CPU offload | RTX 3090 必须 |
|
||||||
|
| `--sparse-policy` | 稀疏策略 | `XATTN_BSA` |
|
||||||
|
| `--json-output` | JSON 输出 | 脚本使用 |
|
||||||
|
| `--quiet` | 安静模式 | 减少输出 |
|
||||||
|
|
||||||
|
### max-model-len 速查
|
||||||
|
|
||||||
|
| 数据目录 | max-model-len |
|
||||||
|
|---------|---------------|
|
||||||
|
| ruler_32k | 40960 |
|
||||||
|
| ruler_64k | 72000 |
|
||||||
|
| ruler_128k | 135000 |
|
||||||
|
|
||||||
|
### 常用命令模板
|
||||||
|
|
||||||
|
**32K Offload + XAttn**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
**64K Offload + XAttn**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
## 执行前检查清单
|
||||||
|
|
||||||
|
- [ ] 用户指定了 GPU?否则询问
|
||||||
|
- [ ] RTX 3090/4090?必须 `--enable-offload`
|
||||||
|
- [ ] data-dir 与 max-model-len 匹配?
|
||||||
|
- [ ] 需要 density 统计?添加 `--sparse-policy XATTN_BSA`
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -240,3 +240,4 @@ findings_*.md
|
|||||||
progress_*.md
|
progress_*.md
|
||||||
notes.md
|
notes.md
|
||||||
Snipaste*
|
Snipaste*
|
||||||
|
.ralph-tui/session-meta.json
|
||||||
|
|||||||
12
.ralph-tui/config.toml
Normal file
12
.ralph-tui/config.toml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# Ralph TUI Configuration
|
||||||
|
# Generated by setup wizard
|
||||||
|
# See: ralph-tui config help
|
||||||
|
|
||||||
|
configVersion = "2.1"
|
||||||
|
tracker = "json"
|
||||||
|
agent = "claude"
|
||||||
|
maxIterations = 30
|
||||||
|
autoCommit = false
|
||||||
|
|
||||||
|
[trackerOptions]
|
||||||
|
[agentOptions]
|
||||||
20
CLAUDE.md
20
CLAUDE.md
@@ -16,8 +16,10 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||||
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
||||||
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
|
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
|
||||||
|
| [`docs/xattn_kv_chunking_kernels.md`](docs/xattn_kv_chunking_kernels.md) | XAttention KV Chunking: 三阶段 softmax、存储开销分析 (O(S) vs O(S²))、峰值显存优化 (8x)、Q/KV 独立分块 |
|
||||||
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
|
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
|
||||||
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
|
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
|
||||||
|
| [`docs/xattn_density_benchmark.md`](docs/xattn_density_benchmark.md) | 📊 XAttention Density Benchmark: 4K-32K context、stride 参数、per-layer density 分析 |
|
||||||
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
||||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||||
@@ -36,6 +38,16 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析,softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
|
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析,softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
|
||||||
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL),≤10B 推荐模型 |
|
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL),≤10B 推荐模型 |
|
||||||
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
|
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
|
||||||
|
| [`docs/xattn_density_alignment_analysis.md`](docs/xattn_density_alignment_analysis.md) | 📊 ANALYSIS: GPU-only vs Offload 模式 density 对齐分析,chunked softmax 边界效应,5-7% 差异根因 |
|
||||||
|
| [`docs/xattn_kv_chunking_density_test.md`](docs/xattn_kv_chunking_density_test.md) | 🧪 TEST: XAttention KV chunking density 验证,threshold=1.0 对齐,threshold<1.0 差异 10-13% |
|
||||||
|
| [`docs/gpuonly_density_alignment_test.md`](docs/gpuonly_density_alignment_test.md) | ✅ TEST: Density 对齐验证 (GPU-only + Offload, 4K-64K),xattn_estimate vs KV chunking 完全一致 |
|
||||||
|
| [`docs/xattn_memory_benchmark.md`](docs/xattn_memory_benchmark.md) | 📊 BENCH: XAttention 内存基准测试,Qwen3-0.6B 32K 在 24GB 显存可行 (gpu-util=0.28) |
|
||||||
|
| [`docs/xattn_offload_stream_sync_fix.md`](docs/xattn_offload_stream_sync_fix.md) | 🐛 FIX: XAttention Offload stream 同步 bug,Pass1/Pass2 K 数据不一致,compute_stream 包装 |
|
||||||
|
| [`docs/xattn_density_types.md`](docs/xattn_density_types.md) | 📊 Compute vs Comm density: BSA block (128) vs CPU block (4096) 粒度,聚合效应导致 comm=100% |
|
||||||
|
| [`docs/xattn_density_alignment_verification.md`](docs/xattn_density_alignment_verification.md) | ✅ VERIFIED: GPU-only vs Offload density 对齐验证 (32K 差异 0.37%, 64K 差异 0.09%) |
|
||||||
|
| [`docs/test_ruler_usage_guide.md`](docs/test_ruler_usage_guide.md) | 📖 GUIDE: test_ruler.py 使用指南,RULER benchmark 测试命令,已验证的命令示例 |
|
||||||
|
| [`docs/xattn_offload_profiling_32k.md`](docs/xattn_offload_profiling_32k.md) | 📊 PROFILE: XAttn vs Full 32K nsys 分析,estimate 占 41%,find_blocks 占 37%,compute 仅 21% |
|
||||||
|
| [`docs/changelog_2026-02-05.md`](docs/changelog_2026-02-05.md) | 📋 CHANGELOG: GQA buffer OOM 修复 (节省 16GB),tests 目录清理 (-4306 行) |
|
||||||
|
|
||||||
## Rules Index
|
## Rules Index
|
||||||
|
|
||||||
@@ -46,6 +58,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
| [`.claude/rules/sparse-policy.md`](.claude/rules/sparse-policy.md) | SparsePolicy implementation requirements |
|
| [`.claude/rules/sparse-policy.md`](.claude/rules/sparse-policy.md) | SparsePolicy implementation requirements |
|
||||||
| [`.claude/rules/planning-with-files.md`](.claude/rules/planning-with-files.md) | Planning file management for complex tasks |
|
| [`.claude/rules/planning-with-files.md`](.claude/rules/planning-with-files.md) | Planning file management for complex tasks |
|
||||||
| [`.claude/rules/gpu-monitor.md`](.claude/rules/gpu-monitor.md) | **GPU memory monitoring**: 必须使用 gpu-monitor agent,禁止手动 nvidia-smi 循环 |
|
| [`.claude/rules/gpu-monitor.md`](.claude/rules/gpu-monitor.md) | **GPU memory monitoring**: 必须使用 gpu-monitor agent,禁止手动 nvidia-smi 循环 |
|
||||||
|
| [`.claude/rules/test-ruler.md`](.claude/rules/test-ruler.md) | **test_ruler.py 规则**: 禁止 --help,必须查阅文档,含快速参考和命令模板 |
|
||||||
|
|
||||||
## GPU Mutex for Multi-Instance Debugging
|
## GPU Mutex for Multi-Instance Debugging
|
||||||
|
|
||||||
@@ -100,6 +113,13 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
|||||||
|
|
||||||
**Files**: `bench.py` (GPU), `bench_offload.py` (CPU offload), `bench_vllm.py` (comparison)
|
**Files**: `bench.py` (GPU), `bench_offload.py` (CPU offload), `bench_vllm.py` (comparison)
|
||||||
|
|
||||||
|
**GPU-only 测试模型选择**:
|
||||||
|
|
||||||
|
| GPU | 显存 | GPU-only 测试模型 |
|
||||||
|
|-----|------|------------------|
|
||||||
|
| RTX 3090 | 24GB | **Qwen3-0.6B** (必须,7B+ 模型会 OOM) |
|
||||||
|
| A100 | 40GB+ | Qwen3-0.6B / 4B / 7B 均可 |
|
||||||
|
|
||||||
**Offload Mode Constraint**: When using `enable_cpu_offload=True`, only test with context length ≥ 32K. Shorter contexts don't exercise the chunked offload pipeline properly.
|
**Offload Mode Constraint**: When using `enable_cpu_offload=True`, only test with context length ≥ 32K. Shorter contexts don't exercise the chunked offload pipeline properly.
|
||||||
|
|
||||||
**Common Issues**:
|
**Common Issues**:
|
||||||
|
|||||||
94
docs/changelog_2026-02-05.md
Normal file
94
docs/changelog_2026-02-05.md
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# Changelog 2026-02-05
|
||||||
|
|
||||||
|
## Bug Fixes
|
||||||
|
|
||||||
|
### XAttention Offload GQA Buffer OOM Fix
|
||||||
|
|
||||||
|
**Issue**: `docs/issue_xattn_offload_gqa_buffer_oom.md`
|
||||||
|
|
||||||
|
**Problem**: 在 XAttention BSA + CPU Offload 模式下,`alloc_policy_metadata()` 分配了只有 GPU-only 模式才需要的 GQA expansion buffers (`_k_expanded`, `_v_expanded`),导致 24GB GPU (RTX 3090) 上 OOM。
|
||||||
|
|
||||||
|
**Root Cause**:
|
||||||
|
- GQA buffer 大小: `2 × num_heads × max_seq_len × head_dim × dtype_size`
|
||||||
|
- 对于 1M max_seq_len: 2 × 32 × 1048576 × 128 × 2 = **16 GB**
|
||||||
|
- Offload 模式的 `compute_chunked_prefill()` 不需要这些 buffer
|
||||||
|
|
||||||
|
**Fix** (commit `11a867f`):
|
||||||
|
1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数
|
||||||
|
2. `nanovllm/kvcache/sparse/xattn_bsa.py`: offload 模式跳过 GQA buffer 分配
|
||||||
|
3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数
|
||||||
|
|
||||||
|
**Memory Savings**:
|
||||||
|
| max_model_len | 修复前 | 修复后 |
|
||||||
|
|---------------|--------|--------|
|
||||||
|
| 72K | +1.1 GB | 0 GB |
|
||||||
|
| 1M | +16 GB | 0 GB |
|
||||||
|
|
||||||
|
**Verification**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
- 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers`
|
||||||
|
- 测试结果: 100% 准确率
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Code Cleanup
|
||||||
|
|
||||||
|
### Tests Directory Cleanup
|
||||||
|
|
||||||
|
**Commits**: `a709551`, `2b61c5a`, `d35dd76`
|
||||||
|
|
||||||
|
删除了 16 个冗余/过时的测试文件,保留核心测试:
|
||||||
|
|
||||||
|
**保留的文件** (4 个):
|
||||||
|
| 文件 | 用途 |
|
||||||
|
|------|------|
|
||||||
|
| `test_ruler.py` | 核心 RULER benchmark (13 tasks, 100 samples) |
|
||||||
|
| `test_xattn_estimate_alignment.py` | XAttn kernel 一致性验证 |
|
||||||
|
| `utils.py` | 共享工具函数 |
|
||||||
|
| `__init__.py` | 包标记 |
|
||||||
|
|
||||||
|
**删除的文件** (16 个, -4306 行):
|
||||||
|
|
||||||
|
| 类别 | 文件 | 删除原因 |
|
||||||
|
|------|------|----------|
|
||||||
|
| XAttn 测试 | `test_xattn_bsa.py` | 功能被 test_ruler 覆盖 |
|
||||||
|
| | `test_xattn_chunked.py` | 与 estimate_chunked 重复 |
|
||||||
|
| | `test_xattn_estimate_chunked.py` | chunked prefill 验证 |
|
||||||
|
| | `test_xattn_kernels.py` | Triton kernel 单元测试 |
|
||||||
|
| | `test_xattn_kv_chunking_batch.py` | batch 验证 |
|
||||||
|
| Needle 测试 | `test_needle.py` | 被 test_ruler NIAH 任务覆盖 |
|
||||||
|
| | `test_needle_ref.py` | HF 参考实现 |
|
||||||
|
| CUDA Graph | `test_chunk_attention_graph.py` | 被 graph_reuse 取代 |
|
||||||
|
| | `test_chunk_attention_graph_reuse.py` | 实验性功能 |
|
||||||
|
| | `test_cudagraph_memory.py` | 内存分析工具 |
|
||||||
|
| 其他 | `test_gpuonly_density_alignment.py` | GPU-only 密度测试 |
|
||||||
|
| | `test_hierarchical_estimate.py` | 分层估计测试 |
|
||||||
|
| | `test_quest_policy.py` | Quest 策略测试 |
|
||||||
|
| | `test_sequential.py` | 状态隔离测试 |
|
||||||
|
| | `bench_estimate_block_size.py` | 性能 benchmark |
|
||||||
|
| | `modeling_qwen3.py` | Qwen3 参考模型 |
|
||||||
|
|
||||||
|
**Note**: 所有删除的文件可从 git 历史恢复:
|
||||||
|
```bash
|
||||||
|
git checkout <commit-hash>^ -- tests/<filename>
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
| 类型 | 数量 | 影响 |
|
||||||
|
|------|------|------|
|
||||||
|
| Bug Fix | 1 | 节省 16GB 显存 (1M seq) |
|
||||||
|
| 文件删除 | 16 | -4306 行代码 |
|
||||||
|
| 新增文档 | 1 | 本文件 |
|
||||||
246
docs/gpuonly_density_alignment_test.md
Normal file
246
docs/gpuonly_density_alignment_test.md
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
# Density Alignment Test Results
|
||||||
|
|
||||||
|
验证 GPU-only 和 Offload 模式下三阶段 KV chunking 流程的正确性。
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
### GPU-only 模式
|
||||||
|
- **模型**: Qwen3-0.6B (28 layers, 16 heads, 8 KV heads, head_dim=128)
|
||||||
|
- **Threshold**: 0.9
|
||||||
|
- **Block Size**: 128 tokens (BSA block)
|
||||||
|
- **Stride**: 8
|
||||||
|
- **Chunk Size**: 16384 tokens
|
||||||
|
|
||||||
|
### Offload 模式
|
||||||
|
- **模型**: Llama-3.1-8B-Instruct (32 layers, 32 heads, 8 KV heads, head_dim=128)
|
||||||
|
- **Threshold**: 0.9
|
||||||
|
- **Block Size**: 128 tokens (BSA block)
|
||||||
|
- **Stride**: 4
|
||||||
|
- **Chunk Size**: 4096 tokens
|
||||||
|
|
||||||
|
## 三阶段 KV Chunking 对齐测试 (2026-02-02)
|
||||||
|
|
||||||
|
### 测试目的
|
||||||
|
|
||||||
|
验证 `xattn_estimate` 高层 API 与手动实现的三阶段 KV chunking 流程是否完全一致。
|
||||||
|
|
||||||
|
### 三阶段流程
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Stage 1: softmax_compute_partial_stats │
|
||||||
|
│ └── 每个 KV chunk 独立计算 partial stats (m_i, l_i) │
|
||||||
|
│ │
|
||||||
|
│ Stage 2: merge_softmax_stats │
|
||||||
|
│ └── Host 端合并所有 chunks: (m_global, l_global) │
|
||||||
|
│ │
|
||||||
|
│ Stage 3: softmax_normalize_and_block_sum │
|
||||||
|
│ └── 使用全局 stats 归一化并计算 block sums │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试结果
|
||||||
|
|
||||||
|
#### CHUNK_SIZE = 16384 (默认)
|
||||||
|
|
||||||
|
| Context | Tokens | Q Chunks | KV Chunks | Density | Mask 差异 | attn_sums 差异 | 结果 |
|
||||||
|
|---------|--------|----------|-----------|---------|-----------|----------------|------|
|
||||||
|
| 4K | 3,692 | 1 | 1 | 63.84% | 0 | 0.0 | ✅ |
|
||||||
|
| 8K | 7,892 | 1 | 1 | 64.98% | 0 | 0.0 | ✅ |
|
||||||
|
| 16K | 15,689 | 1 | 1 | 61.63% | 0 | 0.0 | ✅ |
|
||||||
|
| 32K | 32,485 | 2 | 2 | 50.21% | 0 | 0.0 | ✅ |
|
||||||
|
| **64K** | **64,891** | **4** | **4** | **37.00%** | **0** | **0.0** | ✅ |
|
||||||
|
|
||||||
|
#### CHUNK_SIZE = 4096 (更多 chunks)
|
||||||
|
|
||||||
|
| Context | Tokens | Q Chunks | KV Chunks | Density | xattn_estimate vs KV chunking | 结果 |
|
||||||
|
|---------|--------|----------|-----------|---------|-------------------------------|------|
|
||||||
|
| 4K | 3,692 | 1 | 1 | 63.84% | 0.000000 | ✅ |
|
||||||
|
| 8K | 7,892 | 2 | 2 | 63.02% | 0.000000 | ✅ |
|
||||||
|
| 16K | 15,689 | 4 | 4 | 60.08% | 0.000000 | ✅ |
|
||||||
|
| 32K | 32,485 | 8 | 8 | 49.84% | 0.000000 | ✅ |
|
||||||
|
| **64K** | **64,891** | **16** | **16** | **36.91%** | **0.000000** | ✅ |
|
||||||
|
|
||||||
|
### 64K 详细验证 (CHUNK_SIZE=4096)
|
||||||
|
|
||||||
|
64K 序列使用 chunk_size=4096 时产生 16×16 的 chunk 矩阵:
|
||||||
|
|
||||||
|
```
|
||||||
|
seq_len: 64891, q_chunk_num: 16, kv_chunk_num: 16
|
||||||
|
|
||||||
|
Q chunk 0: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
|
||||||
|
Q chunk 1: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
|
||||||
|
...
|
||||||
|
Q chunk 15: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
|
||||||
|
```
|
||||||
|
|
||||||
|
每个 Q chunk 需要合并 16 个 KV chunks 的 softmax stats,充分验证了 `merge_softmax_stats` 在大规模 chunk 合并场景下的正确性。
|
||||||
|
|
||||||
|
### 验证指标
|
||||||
|
|
||||||
|
| 指标 | 预期 | 所有长度实际结果 |
|
||||||
|
|------|------|------------------|
|
||||||
|
| attn_sums max diff | 0 | 0.000000e+00 |
|
||||||
|
| attn_sums mean diff | 0 | 0.000000e+00 |
|
||||||
|
| mask exact match | True | True |
|
||||||
|
| density diff | 0% | 0.000000% |
|
||||||
|
|
||||||
|
### 结论
|
||||||
|
|
||||||
|
✅ **三阶段 KV chunking 与一次性处理完全等价,无任何精度损失。**
|
||||||
|
|
||||||
|
- 当 seq_len < CHUNK_SIZE (16384):单 chunk 处理
|
||||||
|
- 当 seq_len >= CHUNK_SIZE:多 chunk 分段处理后合并,结果与一次性处理完全一致
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Offload 模式测试 (2026-02-02)
|
||||||
|
|
||||||
|
使用 Offload 模式保存的真实 KV cache 数据进行测试。
|
||||||
|
|
||||||
|
### 测试结果
|
||||||
|
|
||||||
|
| 文件 | Tokens | Layer | Saved Density | Computed Density | Q/KV Chunks | 结果 |
|
||||||
|
|------|--------|-------|---------------|------------------|-------------|------|
|
||||||
|
| `qkv_3688.pt` | 3.7K | 3 | 38.34% | 38.34% | 1/1 | ✅ PASSED |
|
||||||
|
| `qkv_7888.pt` | 7.9K | 3 | 29.06% | 27.56% | 2/2 | ✅ PASSED |
|
||||||
|
| `qkv_15685.pt` | 15.7K | 3 | 19.77% | 18.60% | 4/4 | ✅ PASSED |
|
||||||
|
| `qkv_32485.pt` | 32.5K | 5 | 15.71% | 15.62% | 8/8 | ✅ PASSED |
|
||||||
|
| `qkv_64891.pt` | 64.9K | 3 | 11.09% | 11.09% | 16/16 | ✅ PASSED |
|
||||||
|
|
||||||
|
### Layer 5 GPU-only 测试 (threshold=0.9)
|
||||||
|
|
||||||
|
| 指标 | 结果 |
|
||||||
|
|------|------|
|
||||||
|
| Q/K shape | `[1, 16, 21001, 128]` (21K tokens) |
|
||||||
|
| Density | 6.24% |
|
||||||
|
| xattn_estimate vs KV chunking | 完全一致 (0.0000%) |
|
||||||
|
| mask 差异 | 0 / 435600 blocks |
|
||||||
|
| attn_sums 差异 | max=0.0, mean=0.0 |
|
||||||
|
|
||||||
|
### 观察
|
||||||
|
|
||||||
|
1. **Density 随 context 增长而降低**: 3.7K (38%) → 64.9K (11%)
|
||||||
|
2. **xattn_estimate API 与三阶段 KV chunking 完全一致**: 所有长度差异均为 0.0000%
|
||||||
|
3. **Saved density vs Computed density 略有差异**: 这是因为 saved density 可能在不同 chunk 下记录,累积计算方式略有不同
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录:xattn_bsa vs xattn_estimate 对齐
|
||||||
|
|
||||||
|
| Context | Tokens | Layer 0 Density | Compute Density | Min Layer | 验证结果 |
|
||||||
|
|---------|--------|-----------------|-----------------|-----------|----------|
|
||||||
|
| 4k | 3,692 | 63.8% | 52.9% | Layer 3 (31.3%) | ✅ PASSED |
|
||||||
|
| 8k | 7,892 | 65.0% | 52.5% | Layer 5 (27.3%) | ✅ PASSED |
|
||||||
|
| 16k | 15,689 | 61.6% | 47.8% | Layer 5 (23.5%) | ✅ PASSED |
|
||||||
|
| 32k | 32,485 | 50.2% | 40.1% | Layer 5 (18.5%) | ✅ PASSED |
|
||||||
|
| 64k | 64,891 | 37.0% | 29.6% | Layer 5 (12.4%) | ✅ PASSED |
|
||||||
|
|
||||||
|
## Density 计算公式
|
||||||
|
|
||||||
|
### Total (分母)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Causal mask: Q block i 只能看到 K block 0 到 i
|
||||||
|
causal_mask[i, j] = (j <= i + q_offset_blocks)
|
||||||
|
|
||||||
|
# Total = causal 区域内的 block 数 × batch × heads
|
||||||
|
total = causal_mask.sum() × batch × heads
|
||||||
|
= (n × (n+1) / 2) × 1 × 32 # n = valid_q_blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
### Selected (分子)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 causal 区域内,被选中 (mask=True) 的 block 数量
|
||||||
|
selected = (mask & causal_mask).sum()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Density
|
||||||
|
|
||||||
|
```python
|
||||||
|
density = selected / total
|
||||||
|
```
|
||||||
|
|
||||||
|
## 观察
|
||||||
|
|
||||||
|
1. **Density 随 context 增长而降低**: 4k (63.8%) → 64k (37.0%),这是因为长序列中 attention 更加分散
|
||||||
|
|
||||||
|
2. **Layer 5 通常是最稀疏的层**: 在所有长度测试中,Layer 5 的 density 最低
|
||||||
|
|
||||||
|
3. **Layer 0 density 最高**: 第一层的 attention pattern 最密集,可能与 sink token 效应有关
|
||||||
|
|
||||||
|
4. **Threshold=0.9 对应 ~50% density**: 在 32k context 下,threshold=0.9 意味着选择覆盖 90% attention 的 blocks,实际 density 约 50%
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### Step 1: 启用 debug 保存
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/xattn_bsa.py
|
||||||
|
_DEBUG_SAVE_MASK = True # 改为 True
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: 运行 GPU-only 推理
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: 运行 KV chunking 对齐验证
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 使用 GPU-only 保存的数据
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_alignment.py --gpuonly
|
||||||
|
|
||||||
|
# 使用 Offload 模式保存的数据 (默认)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_alignment.py
|
||||||
|
|
||||||
|
# 指定自定义数据文件
|
||||||
|
python tests/test_xattn_estimate_alignment.py --data-file /path/to/data.pt
|
||||||
|
|
||||||
|
# 批量测试所有 Offload 数据
|
||||||
|
for f in results/kvcache/qkv_*.pt; do
|
||||||
|
echo "Testing: $(basename $f)"
|
||||||
|
python tests/test_xattn_estimate_alignment.py --data-file "$f"
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
### 批量测试所有长度
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for ctx in 4k 8k 16k 32k 64k; do
|
||||||
|
case $ctx in
|
||||||
|
4k) max_len=5000 ;;
|
||||||
|
8k) max_len=9000 ;;
|
||||||
|
16k) max_len=17000 ;;
|
||||||
|
32k) max_len=34000 ;;
|
||||||
|
64k) max_len=65664 ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo "Testing $ctx..."
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--data-dir tests/data/ruler_$ctx \
|
||||||
|
--max-model-len $max_len \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--num-samples 1 --quiet
|
||||||
|
|
||||||
|
python tests/test_xattn_estimate_alignment.py --gpuonly
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy 实现
|
||||||
|
- `nanovllm/ops/xattn.py`: xattn_estimate 函数及三阶段 KV chunking kernels
|
||||||
|
- `tests/test_xattn_estimate_alignment.py`: KV chunking 对齐验证脚本
|
||||||
209
docs/issue_xattn_offload_gqa_buffer_oom.md
Normal file
209
docs/issue_xattn_offload_gqa_buffer_oom.md
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
# Issue: XAttention Offload Mode GQA Buffer OOM
|
||||||
|
|
||||||
|
## 问题描述
|
||||||
|
|
||||||
|
在使用 XAttention BSA (Block Sparse Attention) + CPU Offload 模式运行 GLM-4-9B 等大模型时,出现 CUDA OOM 错误。
|
||||||
|
|
||||||
|
### 错误信息
|
||||||
|
|
||||||
|
```
|
||||||
|
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB.
|
||||||
|
GPU 0 has a total capacity of 23.57 GiB of which 4.19 GiB is free.
|
||||||
|
```
|
||||||
|
|
||||||
|
### 复现环境
|
||||||
|
|
||||||
|
| 项目 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| 模型 | GLM-4-9B-Chat-1M |
|
||||||
|
| GPU | RTX 3090 (24GB) |
|
||||||
|
| Context Length | 32K |
|
||||||
|
| sparse_policy | XATTN_BSA |
|
||||||
|
| enable_cpu_offload | true |
|
||||||
|
| max_model_len | 1048576 (1M) |
|
||||||
|
|
||||||
|
### 错误位置
|
||||||
|
|
||||||
|
```
|
||||||
|
File "nanovllm/kvcache/sparse/xattn_bsa.py", line 246, in alloc_policy_metadata
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 问题分析
|
||||||
|
|
||||||
|
### 内存分配分析
|
||||||
|
|
||||||
|
`alloc_policy_metadata()` 在 KV cache 初始化时分配以下 buffer:
|
||||||
|
|
||||||
|
| Buffer | 用途 | 大小 (GLM-4, 1M seq) |
|
||||||
|
|--------|------|----------------------|
|
||||||
|
| `_prefill_mask_buffer` | BSA mask | ~32 MB |
|
||||||
|
| `_m_partial_buffer` | KV chunking m stats | ~32 MB |
|
||||||
|
| `_l_partial_buffer` | KV chunking l stats | ~32 MB |
|
||||||
|
| `_block_sums_buffer` | Block sums | ~64 MB |
|
||||||
|
| **`_k_expanded`** | GQA K 扩展 | **~8 GB** |
|
||||||
|
| **`_v_expanded`** | GQA V 扩展 | **~8 GB** |
|
||||||
|
|
||||||
|
### GQA Buffer 计算
|
||||||
|
|
||||||
|
```python
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
= (1, 32, 1048576, 128)
|
||||||
|
|
||||||
|
size = 1 × 32 × 1048576 × 128 × 2 bytes (fp16)
|
||||||
|
= 8,589,934,592 bytes
|
||||||
|
= 8 GB per buffer
|
||||||
|
```
|
||||||
|
|
||||||
|
### 根本原因
|
||||||
|
|
||||||
|
1. **设计意图冲突**:`_k_expanded` 和 `_v_expanded` 的文档注释明确说是 "for GPU-only mode"
|
||||||
|
2. **条件检查不完整**:代码只检查了 `num_heads == num_kv_heads` 来跳过分配,没有检查 offload 模式
|
||||||
|
3. **Offload 模式不需要这些 buffer**:`compute_chunked_prefill()` 使用不同的计算路径,不依赖预分配的 GQA buffer
|
||||||
|
|
||||||
|
### 相关代码
|
||||||
|
|
||||||
|
```python
|
||||||
|
# xattn_bsa.py:238-247
|
||||||
|
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
|
||||||
|
if num_heads == num_kv_heads:
|
||||||
|
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
||||||
|
return # <-- 只检查了 GQA,没检查 offload 模式
|
||||||
|
|
||||||
|
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device) # <-- OOM here
|
||||||
|
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 解决思路
|
||||||
|
|
||||||
|
### 方案 1: 在 Offload 模式下跳过 GQA Buffer 分配 (推荐)
|
||||||
|
|
||||||
|
在 `alloc_policy_metadata()` 中添加 offload 模式检查:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def alloc_policy_metadata(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
enable_cpu_offload: bool = False, # <-- 新增参数
|
||||||
|
) -> None:
|
||||||
|
# ... 分配 mask buffer 和 KV chunking buffers (offload 模式需要)
|
||||||
|
|
||||||
|
# Skip GQA buffers in offload mode
|
||||||
|
# Chunked prefill uses compute_chunked_prefill() which doesn't need these
|
||||||
|
if enable_cpu_offload:
|
||||||
|
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers")
|
||||||
|
return
|
||||||
|
|
||||||
|
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
|
||||||
|
if num_heads == num_kv_heads:
|
||||||
|
logger.info(f"[XAttn] No GQA expansion needed")
|
||||||
|
return
|
||||||
|
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
```
|
||||||
|
|
||||||
|
**需要修改的文件**:
|
||||||
|
1. `nanovllm/kvcache/sparse/xattn_bsa.py` - `alloc_policy_metadata()` 方法
|
||||||
|
2. `nanovllm/engine/model_runner.py` - 调用 `alloc_policy_metadata()` 时传入 `enable_cpu_offload`
|
||||||
|
|
||||||
|
### 方案 2: 延迟分配 (Lazy Allocation)
|
||||||
|
|
||||||
|
只在 `compute_prefill()` 首次调用时分配 GQA buffer,offload 模式走 `compute_chunked_prefill()` 不会触发分配。
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compute_prefill(self, ...):
|
||||||
|
# Lazy allocation on first use
|
||||||
|
if self._k_expanded is None and num_heads != num_kv_heads:
|
||||||
|
self._allocate_gqa_buffers(...)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方案 3: 基于 chunk_size 限制 buffer 大小
|
||||||
|
|
||||||
|
不预分配 max_seq_len 大小,而是只分配 chunk_size 大小:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 原来: max_seq_len (1M tokens) -> 8 GB
|
||||||
|
# 修改后: chunk_size (16K tokens) -> ~130 MB
|
||||||
|
buffer_len = self.chunk_size if enable_cpu_offload else max_seq_len
|
||||||
|
shape = (1, num_heads, buffer_len, head_dim)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 验证方法
|
||||||
|
|
||||||
|
修复后运行以下命令验证:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /home/zijie/Code/COMPASS
|
||||||
|
GPULIST=0 ./scripts/run_ruler.sh glm4-9b-xattn-nanovllm synthetic xattn --task niah_single_1
|
||||||
|
```
|
||||||
|
|
||||||
|
预期结果:
|
||||||
|
- 不再出现 8GB allocation 的 OOM 错误
|
||||||
|
- 模型正常加载并完成推理
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- `docs/xattn_bsa_policy_design.md` - XAttention BSA Policy 设计文档
|
||||||
|
- `docs/gpu_only_xattn_guide.md` - GPU-Only XAttention 指南
|
||||||
|
|
||||||
|
## 优先级
|
||||||
|
|
||||||
|
**High** - 阻塞 9B+ 模型在 24GB 显存 GPU 上使用 XAttention + Offload 模式
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 修复状态
|
||||||
|
|
||||||
|
**✅ 已修复** (2026-02-05)
|
||||||
|
|
||||||
|
### 修复内容
|
||||||
|
|
||||||
|
采用方案 1,在 offload 模式下跳过 GQA buffer 分配:
|
||||||
|
|
||||||
|
1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数
|
||||||
|
2. `nanovllm/kvcache/sparse/xattn_bsa.py`: 实现 offload 模式检查,跳过 GQA buffer
|
||||||
|
3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数
|
||||||
|
|
||||||
|
### 验证结果
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 64K offload 测试
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
- ✅ 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers`
|
||||||
|
- ✅ 测试通过: 100% 准确率
|
||||||
|
- ✅ 内存节省: ~16 GB (for 1M max_seq_len)
|
||||||
|
|
||||||
|
### 内存对比
|
||||||
|
|
||||||
|
| 配置 | 修复前 | 修复后 |
|
||||||
|
|------|--------|--------|
|
||||||
|
| max_model_len=72K | +1.1 GB | 0 GB |
|
||||||
|
| max_model_len=1M | +16 GB | 0 GB |
|
||||||
338
docs/test_ruler_usage_guide.md
Normal file
338
docs/test_ruler_usage_guide.md
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
# test_ruler.py 使用指南
|
||||||
|
|
||||||
|
RULER benchmark 综合测试工具,用于评估 LLM 长上下文能力。
|
||||||
|
|
||||||
|
**测试日期**: 2026-02-05
|
||||||
|
**测试 GPU**: RTX 3090 (GPU 4)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 支持的任务
|
||||||
|
|
||||||
|
| 类别 | 任务 |
|
||||||
|
|------|------|
|
||||||
|
| NIAH (Needle-In-A-Haystack) | `niah_single_1/2/3`, `niah_multikey_1/2/3`, `niah_multiquery`, `niah_multivalue` |
|
||||||
|
| QA (Question Answering) | `qa_1`, `qa_2` |
|
||||||
|
| Recall | `cwe`, `fwe`, `vt` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 基本命令格式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU_ID> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py [OPTIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
### 必要参数
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--model` | `~/models/Llama-3.1-8B-Instruct` | 模型路径 |
|
||||||
|
| `--data-dir` | `tests/data/ruler_64k` | 数据目录 |
|
||||||
|
| `--max-model-len` | 65664 | 最大上下文长度 |
|
||||||
|
|
||||||
|
### 数据选择
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--datasets` | 全部 | 逗号分隔的数据集名 |
|
||||||
|
| `--num-samples` | 0 (全部) | 每个数据集测试样本数 |
|
||||||
|
| `--sample-indices` | - | 指定样本索引 (如 `0,5,10`) |
|
||||||
|
|
||||||
|
### Offload 配置
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--enable-offload` | False | 启用 CPU offload 模式 |
|
||||||
|
| `--num-gpu-blocks` | 4 | GPU 上的 KV cache blocks 数量 |
|
||||||
|
| `--block-size` | 4096 | KV cache block 大小 (tokens) |
|
||||||
|
| `--num-kv-buffers` | 4 | Ring buffer 数量 |
|
||||||
|
| `--gpu-utilization` | 0.9 | GPU 显存利用率 |
|
||||||
|
|
||||||
|
### Sparse Attention 配置
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--sparse-policy` | - | 稀疏策略: `FULL`, `QUEST`, `XATTN_BSA` |
|
||||||
|
| `--sparse-threshold` | 0.9 | XAttn cumulative attention 阈值 |
|
||||||
|
| `--sparse-samples` | 128 | XAttn 每 chunk 采样数 |
|
||||||
|
| `--sparse-stride` | 8 | XAttn Q/K 下采样步长 |
|
||||||
|
|
||||||
|
### 输出控制
|
||||||
|
|
||||||
|
| 参数 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `--quiet` / `-q` | 安静模式 |
|
||||||
|
| `--json-output` | JSON 格式输出 |
|
||||||
|
| `--fresh-llm` | 每个样本重新初始化 LLM |
|
||||||
|
|
||||||
|
### 其他
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--dtype` | auto | 模型数据类型 (`bfloat16`, `float16`) |
|
||||||
|
| `--use-cuda-graph` | False | 启用 CUDA Graph |
|
||||||
|
| `--max-new-tokens` | 16 | 最大生成 token 数 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 已验证的命令示例
|
||||||
|
|
||||||
|
以下命令均在 RTX 3090 (24GB) 上测试通过。
|
||||||
|
|
||||||
|
### 1. 基础 Offload 测试 (32K)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 100% 准确率, 耗时 ~16s
|
||||||
|
|
||||||
|
### 2. Offload + XAttention BSA (32K)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 100% 准确率, compute density ~50%, 耗时 ~19s
|
||||||
|
|
||||||
|
### 3. Offload + XAttention BSA (64K)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 100% 准确率, compute density ~37%, 耗时 ~52s
|
||||||
|
|
||||||
|
### 4. 多数据集多样本测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1,qa_1 \
|
||||||
|
--num-samples 2 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 4/4 (100%), 耗时 ~71s
|
||||||
|
|
||||||
|
### 5. 指定样本索引测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--sample-indices 0,5,10 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. JSON 输出模式 (用于脚本)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--json-output
|
||||||
|
```
|
||||||
|
|
||||||
|
**输出格式**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"total_correct": 1,
|
||||||
|
"total_samples": 1,
|
||||||
|
"overall_accuracy": 1.0,
|
||||||
|
"avg_score": 1.0,
|
||||||
|
"time": 30.44,
|
||||||
|
"tasks": {"niah_single_1": {"correct": 1, "total": 1, "accuracy": 1.0}},
|
||||||
|
"failed_samples": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 安静模式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--quiet
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. 调整 GPU blocks 数量
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--num-gpu-blocks 8 \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9. GLM-4 模型测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/GLM-4-9B-Chat-1M \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--dtype bfloat16
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 100% 准确率, 耗时 ~17s
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
tests/data/
|
||||||
|
├── ruler_4k/ # 4K context
|
||||||
|
├── ruler_8k/ # 8K context
|
||||||
|
├── ruler_16k/ # 16K context
|
||||||
|
├── ruler_32k/ # 32K context (推荐测试)
|
||||||
|
├── ruler_64k/ # 64K context
|
||||||
|
├── ruler_128k/ # 128K context
|
||||||
|
├── ruler_256k/ # 256K context
|
||||||
|
├── ruler_512k/ # 512K context
|
||||||
|
├── ruler_768k/ # 768K context
|
||||||
|
└── ruler_1m/ # 1M context
|
||||||
|
```
|
||||||
|
|
||||||
|
每个目录包含 13 个任务子目录,每个任务有 `validation.jsonl` 文件。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## GPU 与模式选择
|
||||||
|
|
||||||
|
| GPU 显存 | 推荐模式 | 说明 |
|
||||||
|
|---------|---------|------|
|
||||||
|
| 24GB (3090/4090) | `--enable-offload` | 必须使用 offload |
|
||||||
|
| 40GB+ (A100) | 两种模式均可 | 可测试 GPU-only |
|
||||||
|
|
||||||
|
**RTX 3090 限制**: 由于显存限制,必须使用 `--enable-offload` 参数。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## max-model-len 设置指南
|
||||||
|
|
||||||
|
| 数据目录 | 推荐 max-model-len | 说明 |
|
||||||
|
|---------|-------------------|------|
|
||||||
|
| ruler_4k | 5000 | 留出 output 空间 |
|
||||||
|
| ruler_8k | 9000 | |
|
||||||
|
| ruler_16k | 17000 | |
|
||||||
|
| ruler_32k | 40960 | |
|
||||||
|
| ruler_64k | 72000 | |
|
||||||
|
| ruler_128k | 135000 | |
|
||||||
|
|
||||||
|
**公式**: `max_model_len >= max_input_len + max_new_tokens`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## DensityObserver 输出
|
||||||
|
|
||||||
|
使用 `--sparse-policy XATTN_BSA` 时自动启用,输出示例:
|
||||||
|
|
||||||
|
```
|
||||||
|
============================================================
|
||||||
|
Density Statistics (XAttention BSA)
|
||||||
|
============================================================
|
||||||
|
[DensityObserver] Mode: offload
|
||||||
|
Compute density: 0.3691 (min: 0.3691 @ layer 0)
|
||||||
|
Comm density: 1.0000 (CPU block granularity)
|
||||||
|
Savings ratio: 0.0% H2D transfer reduction
|
||||||
|
Num layers: 1
|
||||||
|
Layer 0 density: 0.369052
|
||||||
|
```
|
||||||
|
|
||||||
|
| 指标 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| Compute density | BSA block (128 tokens) 粒度的计算密度 |
|
||||||
|
| Comm density | CPU block (4096 tokens) 粒度的通信密度 |
|
||||||
|
| Savings ratio | H2D 传输减少比例 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### 1. OOM 错误
|
||||||
|
|
||||||
|
**原因**: 显存不足
|
||||||
|
**解决**:
|
||||||
|
- 使用 `--enable-offload`
|
||||||
|
- 降低 `--gpu-utilization`
|
||||||
|
- 减少 `--num-gpu-blocks`
|
||||||
|
|
||||||
|
### 2. 模型加载失败
|
||||||
|
|
||||||
|
**原因**: 模型配置不兼容
|
||||||
|
**解决**:
|
||||||
|
- 检查 `--dtype` 参数 (GLM 模型需要 `--dtype bfloat16`)
|
||||||
|
- 确认模型路径正确
|
||||||
|
|
||||||
|
### 3. 准确率异常
|
||||||
|
|
||||||
|
**原因**: 状态泄漏
|
||||||
|
**解决**: 使用 `--fresh-llm` 参数为每个样本重新初始化 LLM
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`docs/xattn_density_types.md`](xattn_density_types.md) - Compute vs Comm density 解释
|
||||||
|
- [`docs/xattn_density_alignment_verification.md`](xattn_density_alignment_verification.md) - GPU-only vs Offload 对齐验证
|
||||||
|
- [`docs/ruler_benchmark_results_32k.md`](ruler_benchmark_results_32k.md) - RULER 32K 基准测试结果
|
||||||
142
docs/xattn_density_alignment_verification.md
Normal file
142
docs/xattn_density_alignment_verification.md
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# XAttention Density Alignment Verification
|
||||||
|
|
||||||
|
验证 GPU-only 和 Offload 模式的 density 对齐情况。
|
||||||
|
|
||||||
|
**测试日期**: 2026-02-05
|
||||||
|
**测试模型**: Llama-3.1-8B-Instruct
|
||||||
|
**测试任务**: RULER niah_single_1
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
| 参数 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| sparse_policy | XATTN_BSA |
|
||||||
|
| threshold | 0.9 |
|
||||||
|
| chunk_size | 4096 (已对齐) |
|
||||||
|
| stride | 8 |
|
||||||
|
| BSA block_size | 128 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### 32K Context
|
||||||
|
|
||||||
|
| 模式 | Layer 0 Density | Overall Density | 准确率 |
|
||||||
|
|------|-----------------|-----------------|--------|
|
||||||
|
| GPU-only | 0.502079 | 0.4012 | 100% |
|
||||||
|
| Offload | 0.498421 | 0.4984 | 100% |
|
||||||
|
| **差异** | **0.37%** | - | - |
|
||||||
|
|
||||||
|
### 64K Context
|
||||||
|
|
||||||
|
| 模式 | Layer 0 Density | Overall Density | 准确率 |
|
||||||
|
|------|-----------------|-----------------|--------|
|
||||||
|
| GPU-only | 0.369972 | 0.2963 | 100% |
|
||||||
|
| Offload | 0.369052 | 0.3691 | 100% |
|
||||||
|
| **差异** | **0.09%** | - | - |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 关键修复
|
||||||
|
|
||||||
|
### Commit 829b311 - chunk_size 对齐 + Stream 同步修复
|
||||||
|
|
||||||
|
**问题**: 之前 GPU-only 和 Offload 模式的 density 差异达 10-13%
|
||||||
|
|
||||||
|
**根因**:
|
||||||
|
1. GPU-only 使用 `chunk_size=16384`,Offload 使用 `chunk_size=4096`
|
||||||
|
2. Stream 同步 bug 导致 Pass 1/2 K 数据不一致
|
||||||
|
|
||||||
|
**修复**:
|
||||||
|
1. 将 `XAttentionBSAPolicy.chunk_size` 默认值从 16384 改为 4096
|
||||||
|
2. 所有 compute kernels 包装在 `compute_stream` context 中
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
### GPU-only 模式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### Offload 模式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 详细日志
|
||||||
|
|
||||||
|
### 32K Offload 模式 Per-Chunk Density
|
||||||
|
|
||||||
|
```
|
||||||
|
Layer0 chunk: q_len=4096, k_len=4096, density=0.6234
|
||||||
|
Layer0 chunk: q_len=4096, k_len=8192, density=0.6239
|
||||||
|
Layer0 chunk: q_len=4096, k_len=12288, density=0.6026
|
||||||
|
Layer0 chunk: q_len=4096, k_len=16384, density=0.5695
|
||||||
|
Layer0 chunk: q_len=4096, k_len=20480, density=0.5285
|
||||||
|
Layer0 chunk: q_len=4096, k_len=24576, density=0.4891
|
||||||
|
Layer0 chunk: q_len=4096, k_len=28672, density=0.4514
|
||||||
|
Layer0 chunk: q_len=3813, k_len=32485, density=0.4208
|
||||||
|
```
|
||||||
|
|
||||||
|
### 64K Offload 模式 Per-Chunk Density
|
||||||
|
|
||||||
|
```
|
||||||
|
Layer0 chunk: q_len=4096, k_len=4096, density=0.6234
|
||||||
|
Layer0 chunk: q_len=4096, k_len=8192, density=0.6239
|
||||||
|
Layer0 chunk: q_len=4096, k_len=12288, density=0.6026
|
||||||
|
Layer0 chunk: q_len=4096, k_len=16384, density=0.5681
|
||||||
|
Layer0 chunk: q_len=4096, k_len=20480, density=0.5255
|
||||||
|
Layer0 chunk: q_len=4096, k_len=24576, density=0.4859
|
||||||
|
Layer0 chunk: q_len=4096, k_len=28672, density=0.4485
|
||||||
|
Layer0 chunk: q_len=4096, k_len=32768, density=0.4161
|
||||||
|
Layer0 chunk: q_len=4096, k_len=36864, density=0.3892
|
||||||
|
Layer0 chunk: q_len=4096, k_len=40960, density=0.3658
|
||||||
|
Layer0 chunk: q_len=4096, k_len=45056, density=0.3464
|
||||||
|
Layer0 chunk: q_len=4096, k_len=49152, density=0.3303
|
||||||
|
Layer0 chunk: q_len=4096, k_len=53248, density=0.3170
|
||||||
|
Layer0 chunk: q_len=4096, k_len=57344, density=0.3068
|
||||||
|
Layer0 chunk: q_len=4096, k_len=61440, density=0.2988
|
||||||
|
Layer0 chunk: q_len=3451, k_len=64891, density=0.2947
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
1. **Density 对齐成功**: 差异从 10-13% 降到 <0.5%
|
||||||
|
2. **准确率一致**: 两种模式都达到 100% 准确率
|
||||||
|
3. **Density 随 context 增长下降**: 符合预期,更长的 context 稀疏性更高
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`docs/xattn_offload_stream_sync_fix.md`](xattn_offload_stream_sync_fix.md) - Stream 同步修复详情
|
||||||
|
- [`docs/xattn_density_types.md`](xattn_density_types.md) - Compute vs Comm density
|
||||||
|
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md) - 早期对齐测试
|
||||||
195
docs/xattn_density_benchmark.md
Normal file
195
docs/xattn_density_benchmark.md
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
# XAttention Density Benchmark
|
||||||
|
|
||||||
|
GPU-only 模式下 XAttention Block Sparse Attention 的 density 测试结果。
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
| 参数 | 值 | 说明 |
|
||||||
|
|------|-----|------|
|
||||||
|
| Model | Llama-3.1-8B-Instruct | 32 layers, 32 heads, 8 KV heads |
|
||||||
|
| Block Size | 128 tokens | BSA kernel 固定要求 |
|
||||||
|
| Threshold | 0.9 / 0.95 | 累积注意力阈值 |
|
||||||
|
| Stride | 4 / 8 / 16 | Q/K 下采样步长 |
|
||||||
|
| Dataset | RULER niah_single_1 | Sample 0 |
|
||||||
|
| Mode | GPU-only | 无 CPU offload |
|
||||||
|
|
||||||
|
## Density 定义
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Density = selected_blocks / total_causal_blocks
|
||||||
|
# 在 causal attention 下,只计算下三角区域的 blocks
|
||||||
|
# Overall density = 所有层的平均值
|
||||||
|
|
||||||
|
def compute_density(mask, causal=True):
|
||||||
|
"""
|
||||||
|
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
|
||||||
|
"""
|
||||||
|
if causal:
|
||||||
|
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks))
|
||||||
|
total = causal_mask.sum() * batch * heads
|
||||||
|
selected = (mask & causal_mask).sum()
|
||||||
|
return selected / total
|
||||||
|
```
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### threshold=0.9
|
||||||
|
|
||||||
|
#### Overall Density (平均)
|
||||||
|
|
||||||
|
| Context | stride=4 | stride=8 | stride=16 |
|
||||||
|
|---------|----------|----------|-----------|
|
||||||
|
| **4K** | 0.5220 (52.2%) | 0.5292 (52.9%) | 0.5430 (54.3%) |
|
||||||
|
| **8K** | 0.5152 (51.5%) | 0.5252 (52.5%) | 0.5396 (54.0%) |
|
||||||
|
| **16K** | 0.4682 (46.8%) | 0.4775 (47.8%) | 0.4888 (48.9%) |
|
||||||
|
| **32K** | 0.3700 (37.0%) | 0.4012 (40.1%) | 0.4196 (42.0%) |
|
||||||
|
|
||||||
|
#### Min Density (per layer)
|
||||||
|
|
||||||
|
| Context | stride=4 | stride=8 | stride=16 |
|
||||||
|
|---------|----------|----------|-----------|
|
||||||
|
| **4K** | 0.2805 (Layer 3) | 0.3132 (Layer 3) | 0.3376 (Layer 5) |
|
||||||
|
| **8K** | 0.2886 (Layer 5) | 0.2725 (Layer 5) | 0.2995 (Layer 5) |
|
||||||
|
| **16K** | 0.2247 (Layer 5) | 0.2349 (Layer 5) | 0.2451 (Layer 5) |
|
||||||
|
| **32K** | 0.1799 (Layer 5) | 0.1846 (Layer 5) | 0.1964 (Layer 5) |
|
||||||
|
|
||||||
|
### threshold=0.95
|
||||||
|
|
||||||
|
#### Overall Density (平均)
|
||||||
|
|
||||||
|
| Context | stride=4 | stride=8 | stride=16 |
|
||||||
|
|---------|----------|----------|-----------|
|
||||||
|
| **4K** | 0.6561 (65.6%) | 0.6699 (67.0%) | 0.6815 (68.2%) |
|
||||||
|
| **8K** | 0.6462 (64.6%) | 0.6584 (65.8%) | 0.6732 (67.3%) |
|
||||||
|
| **16K** | 0.6004 (60.0%) | 0.6114 (61.1%) | 0.6193 (61.9%) |
|
||||||
|
| **32K** | 0.4894 (48.9%) | 0.5203 (52.0%) | 0.5385 (53.9%) |
|
||||||
|
|
||||||
|
#### Min Density (per layer)
|
||||||
|
|
||||||
|
| Context | stride=4 | stride=8 | stride=16 |
|
||||||
|
|---------|----------|----------|-----------|
|
||||||
|
| **4K** | 0.3972 (Layer 3) | 0.4348 (Layer 5) | 0.4517 (Layer 4) |
|
||||||
|
| **8K** | 0.4004 (Layer 5) | 0.3906 (Layer 5) | 0.4239 (Layer 5) |
|
||||||
|
| **16K** | 0.3331 (Layer 5) | 0.3453 (Layer 5) | 0.3589 (Layer 5) |
|
||||||
|
| **32K** | 0.2656 (Layer 5) | 0.2784 (Layer 5) | 0.2917 (Layer 5) |
|
||||||
|
|
||||||
|
### threshold 对比 (stride=8)
|
||||||
|
|
||||||
|
| Context | threshold=0.9 | threshold=0.95 | 差异 |
|
||||||
|
|---------|---------------|----------------|------|
|
||||||
|
| **4K** | 0.5292 (52.9%) | 0.6699 (67.0%) | -14.1% |
|
||||||
|
| **8K** | 0.5252 (52.5%) | 0.6584 (65.8%) | -13.3% |
|
||||||
|
| **16K** | 0.4775 (47.8%) | 0.6114 (61.1%) | -13.4% |
|
||||||
|
| **32K** | 0.4012 (40.1%) | 0.5203 (52.0%) | -11.9% |
|
||||||
|
|
||||||
|
## 关键发现
|
||||||
|
|
||||||
|
### 1. Context Length 影响最大
|
||||||
|
|
||||||
|
Density 随 context length 显著下降(threshold=0.9, stride=8):
|
||||||
|
- 4K: 52.9% density
|
||||||
|
- 8K: 52.5% density
|
||||||
|
- 16K: 47.8% density
|
||||||
|
- 32K: 40.1% density
|
||||||
|
|
||||||
|
**结论**: 长序列有更多稀疏化机会,XAttention 的优势在长序列上更明显。
|
||||||
|
|
||||||
|
### 2. Threshold 影响显著
|
||||||
|
|
||||||
|
threshold=0.9 比 0.95 的 density 低约 12-14%:
|
||||||
|
- 0.9 更激进,选择更少的 blocks
|
||||||
|
- 0.95 更保守,保留更多 blocks
|
||||||
|
- 两者准确性都不受影响(RULER NIAH 全部 PASS)
|
||||||
|
|
||||||
|
### 3. Stride 影响较小
|
||||||
|
|
||||||
|
同一 context 下,不同 stride 的 density 差异约 2-5%:
|
||||||
|
- stride 越大 → density 略高(采样越粗糙,选择更保守)
|
||||||
|
- stride=4 最激进,stride=16 最保守
|
||||||
|
|
||||||
|
### 4. Min Density 集中在中间层
|
||||||
|
|
||||||
|
- 大多数情况下 min density 出现在 Layer 5
|
||||||
|
- 中间层的稀疏性最高,首尾层相对密集
|
||||||
|
- 这符合 Transformer 注意力模式的一般规律
|
||||||
|
|
||||||
|
### 5. 最佳稀疏化配置
|
||||||
|
|
||||||
|
32K + stride=4 + threshold=0.9 达到最低 density:
|
||||||
|
- Overall: **37.0%** (节省 63% 计算)
|
||||||
|
- Min: **18.0%** (Layer 5)
|
||||||
|
|
||||||
|
### 6. 准确性稳定
|
||||||
|
|
||||||
|
所有配置下 RULER NIAH 测试都 PASS (score=1.0),说明:
|
||||||
|
- threshold=0.9 和 0.95 都足够保守,不损失准确性
|
||||||
|
- 不同 stride 不影响最终结果
|
||||||
|
|
||||||
|
## 推荐配置
|
||||||
|
|
||||||
|
| 场景 | threshold | stride | 说明 |
|
||||||
|
|------|-----------|--------|------|
|
||||||
|
| 精度优先 | 0.95 | 8 | 保守配置,density ~52-67% |
|
||||||
|
| 平衡 | 0.9 | 8 | 默认配置,density ~40-53% |
|
||||||
|
| 性能优先 | 0.9 | 4 | 激进配置,density ~37-52% |
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 基本测试
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--sample-indices 0 \
|
||||||
|
--max-model-len 33792 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9 \
|
||||||
|
--sparse-stride 8 \
|
||||||
|
--gpu-utilization 0.85
|
||||||
|
|
||||||
|
# 参数说明
|
||||||
|
# --sparse-policy XATTN_BSA 启用 XAttention Block Sparse Attention
|
||||||
|
# --sparse-threshold 0.9 累积注意力阈值 (0.9-0.99)
|
||||||
|
# --sparse-stride 8 Q/K 下采样步长 (4/8/16)
|
||||||
|
```
|
||||||
|
|
||||||
|
## DensityObserver 使用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.utils.density_observer import DensityObserver
|
||||||
|
|
||||||
|
# 启用并重置
|
||||||
|
DensityObserver.enable()
|
||||||
|
DensityObserver.complete_reset()
|
||||||
|
|
||||||
|
# ... 运行 inference (compute_prefill 自动记录) ...
|
||||||
|
|
||||||
|
# 获取结果
|
||||||
|
summary = DensityObserver.get_summary()
|
||||||
|
# {
|
||||||
|
# "mode": "gpu_only",
|
||||||
|
# "overall_density": 0.40, # 所有层的平均值
|
||||||
|
# "per_layer_density": {0: 0.55, 1: 0.45, ...},
|
||||||
|
# "num_layers": 32
|
||||||
|
# }
|
||||||
|
|
||||||
|
# 获取最低 density
|
||||||
|
min_layer, min_density = DensityObserver.get_min_density()
|
||||||
|
|
||||||
|
# 打印摘要
|
||||||
|
DensityObserver.print_summary()
|
||||||
|
# [DensityObserver] Mode: gpu_only
|
||||||
|
# Overall density: 0.4012
|
||||||
|
# Min density: 0.1846 (layer 5)
|
||||||
|
# Num layers: 32
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
| 文件 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `nanovllm/kvcache/sparse/xattn_bsa.py` | XAttention BSA Policy 实现 |
|
||||||
|
| `nanovllm/utils/density_observer.py` | Density 统计 Observer |
|
||||||
|
| `nanovllm/ops/xattn.py` | xattn_estimate 核心算法 |
|
||||||
|
| `tests/test_ruler.py` | RULER benchmark 测试脚本 |
|
||||||
152
docs/xattn_density_types.md
Normal file
152
docs/xattn_density_types.md
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
# XAttention Density Types: Compute vs Communication
|
||||||
|
|
||||||
|
XAttention BSA 统计两种不同粒度的 density,它们反映不同的优化效果。
|
||||||
|
|
||||||
|
## 两种 Density 的定义
|
||||||
|
|
||||||
|
### 1. Compute Density(计算密度)
|
||||||
|
|
||||||
|
**粒度**: BSA block (128 tokens)
|
||||||
|
|
||||||
|
**公式**:
|
||||||
|
```
|
||||||
|
compute_density = selected_bsa_blocks / total_causal_bsa_blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
**含义**: 实际需要计算 attention 的 blocks 占 causal 区域的比例。
|
||||||
|
|
||||||
|
**影响**: 决定 attention 计算量的减少。
|
||||||
|
|
||||||
|
### 2. Communication Density(通信密度)
|
||||||
|
|
||||||
|
**粒度**: CPU block (4096 tokens = 32 BSA blocks)
|
||||||
|
|
||||||
|
**公式**:
|
||||||
|
```
|
||||||
|
comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
**含义**: 需要从 CPU 传输到 GPU 的 blocks 占总 blocks 的比例。
|
||||||
|
|
||||||
|
**影响**: 决定 H2D 传输量的减少。
|
||||||
|
|
||||||
|
## 为什么 Comm Density 通常高于 Compute Density
|
||||||
|
|
||||||
|
### 聚合效应
|
||||||
|
|
||||||
|
由于 CPU block 粒度是 BSA block 的 32 倍,CPU block 选择使用 `any()` 聚合:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# BSA mask: [B, H, Q_bsa, K_bsa]
|
||||||
|
# Reshape to CPU block level
|
||||||
|
mask_per_cpu = mask.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
|
||||||
|
# Any BSA block selected -> whole CPU block needed
|
||||||
|
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
只要 CPU block 中**任意一个**:
|
||||||
|
- Head 选择了该 block,或
|
||||||
|
- Q position 选择了该 block,或
|
||||||
|
- BSA sub-block 被选中
|
||||||
|
|
||||||
|
则整个 CPU block 都需要传输。
|
||||||
|
|
||||||
|
### 示例
|
||||||
|
|
||||||
|
| 场景 | Compute Density | Comm Density | 说明 |
|
||||||
|
|------|-----------------|--------------|------|
|
||||||
|
| 64K context, threshold=0.9 | 37% | 100% | 稀疏 blocks 均匀分布在所有 CPU blocks |
|
||||||
|
| 32K context, threshold=0.9 | 50% | 100% | 同上 |
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Offload 模式测试
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### 输出示例
|
||||||
|
|
||||||
|
```
|
||||||
|
[DensityObserver] Mode: offload
|
||||||
|
Compute density: 0.3691 (min: 0.3691 @ layer 0)
|
||||||
|
Comm density: 1.0000 (CPU block granularity)
|
||||||
|
Savings ratio: 0.0% H2D transfer reduction
|
||||||
|
Num layers: 1
|
||||||
|
Layer 0 density: 0.369052
|
||||||
|
```
|
||||||
|
|
||||||
|
## 关键发现
|
||||||
|
|
||||||
|
### 当前 XAttention 的通信优化局限
|
||||||
|
|
||||||
|
1. **Compute density 有效降低**: ~37% @ 64K context(计算量减少 63%)
|
||||||
|
2. **Comm density 没有降低**: 100%(通信量没有减少)
|
||||||
|
|
||||||
|
### 原因分析
|
||||||
|
|
||||||
|
Attention pattern 的特点:
|
||||||
|
- 不同 heads 关注不同位置
|
||||||
|
- 不同 Q positions 关注不同 K positions
|
||||||
|
- 稀疏选择分布在整个 sequence 上
|
||||||
|
|
||||||
|
这导致虽然每个 (head, Q, K) 组合只选择少量 blocks,但聚合后覆盖了所有 CPU blocks。
|
||||||
|
|
||||||
|
### 潜在优化方向
|
||||||
|
|
||||||
|
1. **Per-head block selection**: 每个 head 独立选择 CPU blocks
|
||||||
|
2. **Block clustering**: 将相关 blocks 聚合到同一 CPU block
|
||||||
|
3. **Dynamic block size**: 根据 attention pattern 动态调整 CPU block 大小
|
||||||
|
|
||||||
|
## DensityObserver API
|
||||||
|
|
||||||
|
### 启用和重置
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.utils.density_observer import DensityObserver
|
||||||
|
|
||||||
|
DensityObserver.enable()
|
||||||
|
DensityObserver.complete_reset()
|
||||||
|
DensityObserver.set_mode("offload") # or "gpu_only"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 记录
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Compute density (GPU-only 模式自动记录)
|
||||||
|
DensityObserver.record(layer_id, mask, causal=True)
|
||||||
|
|
||||||
|
# Comm density (Offload 模式在 select_blocks 中记录)
|
||||||
|
DensityObserver.record_comm_density(layer_id, selected_cpu_blocks, total_cpu_blocks)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 获取结果
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 总体 density
|
||||||
|
overall_compute = DensityObserver.get_overall_density()
|
||||||
|
overall_comm = DensityObserver.get_overall_comm_density()
|
||||||
|
|
||||||
|
# Per-layer density
|
||||||
|
per_layer_compute = DensityObserver.get_per_layer_density()
|
||||||
|
per_layer_comm = DensityObserver.get_per_layer_comm_density()
|
||||||
|
|
||||||
|
# 打印摘要
|
||||||
|
DensityObserver.print_summary()
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/utils/density_observer.py`: DensityObserver 实现
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy(select_blocks 中记录 comm density)
|
||||||
|
- `tests/test_ruler.py`: RULER benchmark 测试脚本
|
||||||
122
docs/xattn_kv_chunking_density_test.md
Normal file
122
docs/xattn_kv_chunking_density_test.md
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# XAttention KV Chunking Density 验证测试
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
验证 XAttention Triton kernel 是否只能沿 Q 轴分 chunk,不能沿 KV 轴分 chunk。
|
||||||
|
|
||||||
|
**假设**:`softmax_fuse_block_sum` 需要完整的 K 来计算正确的归一化分母,分 chunk 后的 attention 分布与完整序列不同。
|
||||||
|
|
||||||
|
## 测试方法
|
||||||
|
|
||||||
|
1. **GPU-only 模式**:一次性对完整序列调用 `xattn_estimate`,记录 Layer 0 的 density
|
||||||
|
2. **Offload DEBUG 模式**:分 chunk 调用 `xattn_estimate`,累积 selected/total counts,计算最终 density
|
||||||
|
3. 使用相同的 `_debug_k_full` buffer 收集完整 K cache,确保输入数据一致
|
||||||
|
|
||||||
|
### 关键代码逻辑
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Offload DEBUG: 每个 chunk 累积 selected/total
|
||||||
|
for each chunk:
|
||||||
|
K_full = _debug_k_full[:, :, :total_k_len, :] # 累积的 K
|
||||||
|
_, mask_chunk = xattn_estimate(Q_chunk, K_full, threshold=threshold, causal=True)
|
||||||
|
|
||||||
|
# 裁剪到有效区域,计算正确的 causal mask (考虑 Q 偏移量)
|
||||||
|
q_offset_blocks = k_blocks - q_blocks
|
||||||
|
causal_mask = indices <= (q_indices + q_offset_blocks)
|
||||||
|
|
||||||
|
selected += (mask_valid & causal_mask).sum()
|
||||||
|
total += causal_mask.sum()
|
||||||
|
|
||||||
|
density = selected / total
|
||||||
|
```
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### 64K 序列 (niah_single_1, 序列长度 64891)
|
||||||
|
|
||||||
|
| threshold | GPU-only selected | Offload selected | GPU-only density | Offload density | 差异 (selected) |
|
||||||
|
|-----------|------------------|------------------|------------------|-----------------|-----------------|
|
||||||
|
| **0.90** | 1,524,617 | 1,330,506 | **0.3700** | **0.3229** | 194,111 (12.7%) |
|
||||||
|
| **0.95** | 1,955,015 | 1,747,585 | **0.4744** | **0.4241** | 207,430 (10.6%) |
|
||||||
|
| **1.00** | 4,118,719 | 4,118,896 | **0.9995** | **0.9995** | -177 (~0%) |
|
||||||
|
|
||||||
|
- **total**: 4,120,896 (两种模式一致)
|
||||||
|
|
||||||
|
### 32K 序列 (niah_single_1, 序列长度 32485)
|
||||||
|
|
||||||
|
| threshold | GPU-only selected | Offload selected | GPU-only density | Offload density | 差异 (selected) |
|
||||||
|
|-----------|------------------|------------------|------------------|-----------------|-----------------|
|
||||||
|
| **0.90** | 520,314 | 466,937 | **0.5021** | **0.4506** | 53,377 (10.3%) |
|
||||||
|
| **0.95** | 647,765 | 602,953 | **0.6251** | **0.5818** | 44,812 (6.9%) |
|
||||||
|
| **1.00** | 1,036,295 | 1,036,264 | **0.9999** | **0.9999** | 31 (~0%) |
|
||||||
|
|
||||||
|
- **total**: 1,036,320 (两种模式一致)
|
||||||
|
|
||||||
|
### 汇总对比
|
||||||
|
|
||||||
|
| 序列长度 | threshold | GPU-only density | Offload density | density 差异 |
|
||||||
|
|---------|-----------|------------------|-----------------|--------------|
|
||||||
|
| 32K | 0.90 | 0.5021 | 0.4506 | 5.2% |
|
||||||
|
| 64K | 0.90 | 0.3700 | 0.3229 | 4.7% |
|
||||||
|
| 32K | 0.95 | 0.6251 | 0.5818 | 4.3% |
|
||||||
|
| 64K | 0.95 | 0.4744 | 0.4241 | 5.0% |
|
||||||
|
| 32K | 1.00 | 0.9999 | 0.9999 | ~0% |
|
||||||
|
| 64K | 1.00 | 0.9995 | 0.9995 | ~0% |
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
### 1. Softmax 归一化本身是正确的
|
||||||
|
|
||||||
|
当 `threshold=1.0`(选择所有 blocks)时,GPU-only 和 Offload 模式的 density 几乎完全对齐(差异 < 0.01%)。
|
||||||
|
|
||||||
|
这说明:
|
||||||
|
- `_debug_k_full` 正确收集了完整的 K cache
|
||||||
|
- 分 chunk 调用 `xattn_estimate` 时,softmax 归一化在正确的 K 序列上计算
|
||||||
|
- causal mask 的 Q 偏移量处理正确
|
||||||
|
|
||||||
|
### 2. 问题在于 threshold 的应用方式
|
||||||
|
|
||||||
|
当 `threshold < 1.0` 时,差异显著(10-13%):
|
||||||
|
|
||||||
|
- **GPU-only**:对完整序列一次性应用 threshold,选择 cumulative attention >= threshold 的 blocks
|
||||||
|
- **Offload**:每个 chunk 独立应用 threshold,累积 selected counts
|
||||||
|
|
||||||
|
每个 chunk 独立应用 threshold 会导致:
|
||||||
|
- 某些在 GPU-only 中被选中的 blocks,在分 chunk 时因 attention 分布不同而未被选中
|
||||||
|
- 累积的 selected 比一次性计算的要少
|
||||||
|
|
||||||
|
### 3. XAttention Triton kernel 的 KV chunking 限制
|
||||||
|
|
||||||
|
**验证结论**:XAttention 的 `xattn_estimate` 可以正确处理 KV chunking(softmax 归一化正确),但 **threshold-based block selection 不能简单累积**。
|
||||||
|
|
||||||
|
如果要在 Offload 模式下获得与 GPU-only 一致的 block selection:
|
||||||
|
1. 需要先累积所有 chunks 的 attention scores
|
||||||
|
2. 最后一次性应用 threshold 选择 blocks
|
||||||
|
|
||||||
|
或者接受 10-13% 的 density 差异,这对实际推理准确性的影响需要进一步评估。
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GPU-only 模式
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||||
|
--sparse-policy xattn_bsa --sparse-threshold 0.9
|
||||||
|
|
||||||
|
# Offload 模式 (64K)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||||
|
--sparse-policy xattn_bsa --sparse-threshold 0.9 --enable-offload
|
||||||
|
|
||||||
|
# Offload 模式 (32K)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||||
|
--sparse-policy xattn_bsa --sparse-threshold 0.9 --enable-offload \
|
||||||
|
--data-dir /home/zijie/Code/nano-vllm/tests/data/ruler_32k --max-model-len 34000
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: DEBUG 代码位置
|
||||||
|
- `nanovllm/ops/xattn.py`: `xattn_estimate` 实现
|
||||||
|
- `nanovllm/utils/density_observer.py`: DensityObserver 实现
|
||||||
400
docs/xattn_kv_chunking_kernels.md
Normal file
400
docs/xattn_kv_chunking_kernels.md
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
# XAttention KV Chunking Kernels
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本文档描述了支持 KV 维度分 chunk 的 softmax kernels 实现。这些 kernels 允许在 CPU offload 场景下,沿 KV 维度分块计算 sparse attention estimation,而不需要在 GPU 上保存完整的 raw attention scores。
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
原始的 `softmax_fuse_block_sum` kernel 需要完整的 K 序列来计算正确的 softmax 归一化分母:
|
||||||
|
|
||||||
|
```
|
||||||
|
softmax(x_i) = exp(x_i) / Σ_j exp(x_j)
|
||||||
|
```
|
||||||
|
|
||||||
|
如果只有部分 K (KV chunk),分母 `Σ_j exp(x_j)` 不完整,导致归一化错误。
|
||||||
|
|
||||||
|
## 解决方案:三阶段计算
|
||||||
|
|
||||||
|
通过将 softmax 计算拆分为三个阶段,实现正确的 KV chunking:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ 三阶段 Pipeline │
|
||||||
|
├─────────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||||
|
│ │ KV Chunk 0 │ │ KV Chunk 1 │ │ KV Chunk N │ │
|
||||||
|
│ │ attn_scores │ │ attn_scores │ │ attn_scores │ │
|
||||||
|
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ 阶段 1: softmax_compute_partial_stats │ │
|
||||||
|
│ │ 计算每个 chunk 的 (m_partial, l_partial) │ │
|
||||||
|
│ └─────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ (m_0, l_0) (m_1, l_1) (m_N, l_N) │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ └────────────────┬┴─────────────────┘ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ 阶段 2: merge_softmax_stats │ │
|
||||||
|
│ │ Host 端合并 → (m_global, l_global) │ │
|
||||||
|
│ └─────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌────────────────┼────────────────┐ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ 阶段 3: softmax_normalize_and_block_sum │ │
|
||||||
|
│ │ 使用全局 stats 归一化并计算 block sums │ │
|
||||||
|
│ └─────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ block_sums_0 block_sums_1 block_sums_N │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ └────────────────┴────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ torch.cat → final mask │
|
||||||
|
│ │
|
||||||
|
└─────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 1: `softmax_compute_partial_stats`
|
||||||
|
|
||||||
|
计算每个 KV chunk 的 partial statistics:
|
||||||
|
- `m_partial`: 该 chunk 内的最大值 (per query row)
|
||||||
|
- `l_partial`: 该 chunk 内的 partial sum = Σ exp(x - m_partial)
|
||||||
|
|
||||||
|
```python
|
||||||
|
m_partial, l_partial = softmax_compute_partial_stats(
|
||||||
|
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset, # KV chunk 在完整序列中的偏移
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
# 输出: m_partial, l_partial 形状为 [batch, heads, q_len]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 2: `merge_softmax_stats`
|
||||||
|
|
||||||
|
Host 端合并所有 KV chunks 的 statistics:
|
||||||
|
|
||||||
|
```python
|
||||||
|
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
合并公式 (Online Softmax):
|
||||||
|
```
|
||||||
|
m_new = max(m_global, m_chunk)
|
||||||
|
l_new = l_global * exp(m_global - m_new) + l_chunk * exp(m_chunk - m_new)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 3: `softmax_normalize_and_block_sum`
|
||||||
|
|
||||||
|
使用全局 statistics 归一化并计算 block sums:
|
||||||
|
|
||||||
|
```python
|
||||||
|
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
|
||||||
|
m_global, # [batch, heads, q_len]
|
||||||
|
l_global, # [batch, heads, q_len]
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
real_q_len=real_q_len,
|
||||||
|
scale=scale,
|
||||||
|
kv_offset=kv_offset,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
# 输出: [batch, heads, q_blocks, k_chunk_blocks]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 数学等价性证明
|
||||||
|
|
||||||
|
原始 softmax 计算 (完整 K):
|
||||||
|
```
|
||||||
|
softmax(x_i) = exp(x_i - m) / Σ_j exp(x_j - m)
|
||||||
|
```
|
||||||
|
|
||||||
|
分 KV chunk 计算:
|
||||||
|
```
|
||||||
|
Chunk 0: m_0 = max(x[0:N/2]), l_0 = Σ exp(x[0:N/2] - m_0)
|
||||||
|
Chunk 1: m_1 = max(x[N/2:N]), l_1 = Σ exp(x[N/2:N] - m_1)
|
||||||
|
|
||||||
|
合并:
|
||||||
|
m_global = max(m_0, m_1)
|
||||||
|
l_global = l_0 * exp(m_0 - m_global) + l_1 * exp(m_1 - m_global)
|
||||||
|
= Σ exp(x[0:N] - m_global) # 等于全局 sum
|
||||||
|
|
||||||
|
归一化:
|
||||||
|
softmax(x_i) = exp(x_i - m_global) / l_global # 正确!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Causal Mask 处理
|
||||||
|
|
||||||
|
两个 kernel 都正确处理了 causal attention:
|
||||||
|
|
||||||
|
1. **`softmax_partial_stats_kernel`**: 通过 `kv_offset` 参数确定当前 KV chunk 在完整序列中的位置,正确计算 causal boundary
|
||||||
|
|
||||||
|
2. **`softmax_normalize_block_sum_kernel`**: 同样使用 `kv_offset`,对 causal boundary 之后的位置输出 0
|
||||||
|
|
||||||
|
## 存储开销分析
|
||||||
|
|
||||||
|
### 符号定义
|
||||||
|
|
||||||
|
| 符号 | 含义 | 典型值 |
|
||||||
|
|------|------|--------|
|
||||||
|
| S | seq_len | 64K |
|
||||||
|
| B | batch_size | 1 |
|
||||||
|
| H | num_heads | 32 |
|
||||||
|
| D | head_dim | 128 |
|
||||||
|
| T | stride | 4-8 |
|
||||||
|
| C | chunk_size | 16K |
|
||||||
|
| n | num_kv_chunks = ceil(S/C) | 4 |
|
||||||
|
|
||||||
|
### 原始方式 (无 KV chunking)
|
||||||
|
|
||||||
|
**attn_weights 峰值内存**:
|
||||||
|
```
|
||||||
|
[B, H, S/T, S/T] × 4 bytes = B × H × (S/T)² × 4
|
||||||
|
|
||||||
|
例: S=64K, T=4, B=1, H=32
|
||||||
|
= 1 × 32 × 16384² × 4 = 32 GB
|
||||||
|
```
|
||||||
|
|
||||||
|
### KV Chunking 方式的额外存储
|
||||||
|
|
||||||
|
#### 1. Partial Stats (每个 KV chunk)
|
||||||
|
|
||||||
|
```
|
||||||
|
m_partial: [B, H, C/T] × 4 bytes
|
||||||
|
l_partial: [B, H, C/T] × 4 bytes
|
||||||
|
|
||||||
|
单个 chunk = 2 × B × H × (C/T) × 4
|
||||||
|
= 2 × 1 × 32 × 4096 × 4 = 1 MB
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Global Stats
|
||||||
|
|
||||||
|
```
|
||||||
|
m_global: [B, H, S/T] × 4 bytes
|
||||||
|
l_global: [B, H, S/T] × 4 bytes
|
||||||
|
|
||||||
|
= 2 × B × H × (S/T) × 4
|
||||||
|
= 2 × 1 × 32 × 16384 × 4 = 4 MB
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. 总额外开销
|
||||||
|
|
||||||
|
```
|
||||||
|
total_extra = n × partial_stats + global_stats
|
||||||
|
= 4 × 1MB + 4MB = 8 MB
|
||||||
|
```
|
||||||
|
|
||||||
|
### 存储开销随 seqlen 变化
|
||||||
|
|
||||||
|
| seqlen | num_chunks | 原始 attn_weights | 额外 stats | 比例 |
|
||||||
|
|--------|------------|-------------------|------------|------|
|
||||||
|
| 16K | 1 | 2 GB | 2 MB | 0.1% |
|
||||||
|
| 32K | 2 | 8 GB | 4 MB | 0.05% |
|
||||||
|
| 64K | 4 | 32 GB | 8 MB | 0.025% |
|
||||||
|
| 128K | 8 | 128 GB | 16 MB | 0.012% |
|
||||||
|
|
||||||
|
### 复杂度分析
|
||||||
|
|
||||||
|
| 存储组件 | 复杂度 | 说明 |
|
||||||
|
|----------|--------|------|
|
||||||
|
| 原始 attn_weights | O(S²) | 二次增长 |
|
||||||
|
| Partial/Global stats | O(S) | 线性增长 |
|
||||||
|
| **相对开销** | O(1/S) | **随 seqlen 递减** |
|
||||||
|
|
||||||
|
### 峰值显存优化
|
||||||
|
|
||||||
|
KV chunking 的主要收益是**峰值显存**从 O(S²) 降到 O(S×C):
|
||||||
|
|
||||||
|
```
|
||||||
|
原始: O(B × H × (S/T)²) # 完整 attn_weights
|
||||||
|
KV chunking: O(B × H × (S/T) × (C/T)) # 一次只处理一个 chunk
|
||||||
|
```
|
||||||
|
|
||||||
|
以 S=128K, C=16K 为例:
|
||||||
|
- 原始峰值: ~128 GB
|
||||||
|
- KV chunking 峰值: ~16 GB (降低 **8 倍**)
|
||||||
|
|
||||||
|
## 支持不同 Q/KV Chunk Size
|
||||||
|
|
||||||
|
三阶段 pipeline 支持 Q 和 KV 使用不同的 chunk size:
|
||||||
|
|
||||||
|
```python
|
||||||
|
q_chunk_size = 8192 # Q 分块大小
|
||||||
|
kv_chunk_size = 16384 # KV 分块大小
|
||||||
|
|
||||||
|
for q_chunk_idx in range(q_chunk_num):
|
||||||
|
Q_chunk = Q[:, :, q_start:q_end, :] # [B, H, q_chunk_size, D]
|
||||||
|
|
||||||
|
for kv_chunk_idx in range(kv_chunk_num):
|
||||||
|
K_chunk = K[:, :, kv_start:kv_end, :] # [B, H, kv_chunk_size, D]
|
||||||
|
# ... 三阶段处理
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试验证结果
|
||||||
|
|
||||||
|
| Config | seq_len | Q chunks | KV chunks | density | 对齐 |
|
||||||
|
|--------|---------|----------|-----------|---------|------|
|
||||||
|
| Q=16K, KV=16K | 64891 | 4 | 4 | 0.1117 | ✓ 100% |
|
||||||
|
| Q=8K, KV=16K | 64891 | 8 | 4 | 0.1112 | ✓ 100% |
|
||||||
|
| Q=16K, KV=8K | 64891 | 4 | 8 | 0.1117 | ✓ 100% |
|
||||||
|
| Q=8K, KV=8K | 64891 | 8 | 8 | 0.1112 | ✓ 100% |
|
||||||
|
| Q=4K, KV=16K | 64891 | 16 | 4 | 0.1109 | ✓ 100% |
|
||||||
|
|
||||||
|
## API 参考
|
||||||
|
|
||||||
|
### `softmax_compute_partial_stats`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def softmax_compute_partial_stats(
|
||||||
|
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
|
||||||
|
reshaped_block_size: int,
|
||||||
|
segment_size: int,
|
||||||
|
scale: float,
|
||||||
|
chunk_start: int = 0, # Q chunk 起始位置 (reshaped space)
|
||||||
|
kv_offset: int = 0, # KV chunk 偏移 (reshaped space)
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""返回 (m, l) partial stats"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### `merge_softmax_stats`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def merge_softmax_stats(
|
||||||
|
m_chunks: list, # List of [batch, heads, q_len] tensors
|
||||||
|
l_chunks: list, # List of [batch, heads, q_len] tensors
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""返回 (m_global, l_global)"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### `softmax_normalize_and_block_sum`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
|
||||||
|
m_global: torch.Tensor, # [batch, heads, q_len]
|
||||||
|
l_global: torch.Tensor, # [batch, heads, q_len]
|
||||||
|
reshaped_block_size: int,
|
||||||
|
segment_size: int,
|
||||||
|
chunk_start: int,
|
||||||
|
real_q_len: int,
|
||||||
|
scale: float,
|
||||||
|
kv_offset: int = 0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""返回 block sums [batch, heads, q_blocks, k_chunk_blocks]"""
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.ops.xattn import (
|
||||||
|
flat_group_gemm_fuse_reshape,
|
||||||
|
softmax_compute_partial_stats,
|
||||||
|
softmax_normalize_and_block_sum,
|
||||||
|
merge_softmax_stats,
|
||||||
|
find_blocks_chunked,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对每个 Q chunk
|
||||||
|
for q_chunk_idx in range(q_chunk_num):
|
||||||
|
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||||
|
|
||||||
|
# 阶段 1: 每个 KV chunk 计算 partial stats
|
||||||
|
m_chunks, l_chunks = [], []
|
||||||
|
attn_weights_chunks = []
|
||||||
|
|
||||||
|
for kv_chunk_idx in range(kv_chunk_num):
|
||||||
|
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||||
|
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
|
||||||
|
|
||||||
|
# 计算 raw scores
|
||||||
|
attn_weights = flat_group_gemm_fuse_reshape(
|
||||||
|
Q_chunk, K_chunk, STRIDE,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False, # K 不完整
|
||||||
|
)
|
||||||
|
attn_weights_chunks.append(attn_weights)
|
||||||
|
|
||||||
|
# 计算 partial stats
|
||||||
|
m, l = softmax_compute_partial_stats(
|
||||||
|
attn_weights, block_size, segment_size, scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
m_chunks.append(m)
|
||||||
|
l_chunks.append(l)
|
||||||
|
|
||||||
|
# 阶段 2: 合并 stats
|
||||||
|
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||||
|
|
||||||
|
# 阶段 3: 归一化并计算 block sums
|
||||||
|
block_sums_list = []
|
||||||
|
for kv_chunk_idx, attn_weights in enumerate(attn_weights_chunks):
|
||||||
|
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
|
||||||
|
block_sums = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights, m_global, l_global,
|
||||||
|
block_size, segment_size, chunk_start, real_q_len, scale,
|
||||||
|
kv_offset=kv_offset,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
block_sums_list.append(block_sums)
|
||||||
|
|
||||||
|
# 拼接并选择 blocks
|
||||||
|
attn_sum = torch.cat(block_sums_list, dim=-1)
|
||||||
|
mask = find_blocks_chunked(attn_sum, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能对比
|
||||||
|
|
||||||
|
| 方面 | 原始实现 | KV Chunking 实现 |
|
||||||
|
|------|---------|-----------------|
|
||||||
|
| Kernel 数量 | 1 | 2 (stats + normalize) |
|
||||||
|
| Raw scores 读取次数 | 2 | 2 |
|
||||||
|
| 额外内存 | 0 | O(B × H × S/T × 2) for (m, l) |
|
||||||
|
| Host 计算 | 无 | merge stats (轻量) |
|
||||||
|
| **峰值显存** | O(S²) | **O(S × C)** |
|
||||||
|
|
||||||
|
## 验证测试
|
||||||
|
|
||||||
|
### 批量测试结果
|
||||||
|
|
||||||
|
测试脚本 `tests/test_xattn_kv_chunking_batch.py` 验证了不同 seqlen 下的一致性:
|
||||||
|
|
||||||
|
```
|
||||||
|
| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |
|
||||||
|
|---------|--------|-----------|-----------|-------------|------------|----------|-----------|--------|
|
||||||
|
| 3688 | 4 | 0.90 | 1 | 0.383405 | 0.383405 | 0.000000 | 0.0000% | PASS |
|
||||||
|
| 7888 | 4 | 0.90 | 1 | 0.290611 | 0.290611 | 0.000000 | 0.0000% | PASS |
|
||||||
|
| 15685 | 4 | 0.90 | 1 | 0.197724 | 0.197724 | 0.000000 | 0.0000% | PASS |
|
||||||
|
| 32485 | 4 | 0.90 | 2 | 0.159023 | 0.159023 | 0.000000 | 0.0000% | PASS |
|
||||||
|
| 64891 | 4 | 0.90 | 4 | 0.111656 | 0.111656 | 0.000000 | 0.0000% | PASS |
|
||||||
|
```
|
||||||
|
|
||||||
|
### 关键结论
|
||||||
|
|
||||||
|
1. **数学等价性**: density_diff = 0.000000 对于所有测试
|
||||||
|
2. **Mask 完全对齐**: mask_diff = 0.0000% 对于所有测试
|
||||||
|
3. **支持任意 Q/KV chunk size 组合**
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/ops/xattn.py`: Kernel 实现
|
||||||
|
- `tests/test_xattn_estimate_alignment.py`: 单文件验证测试
|
||||||
|
- `tests/test_xattn_kv_chunking_batch.py`: 批量验证测试
|
||||||
|
- `docs/xattn_kernels_guide.md`: 原始 kernel 文档
|
||||||
154
docs/xattn_memory_benchmark.md
Normal file
154
docs/xattn_memory_benchmark.md
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# XAttention Memory Benchmark
|
||||||
|
|
||||||
|
GPU-only 模式下 XAttention 的内存使用分析。
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
### 硬件
|
||||||
|
- **GPU**: NVIDIA A100 80GB (用于基准测试)
|
||||||
|
- **目标**: 验证在 RTX 3090/4090 (24GB) 上的可行性
|
||||||
|
|
||||||
|
### 模型
|
||||||
|
- **Model**: Qwen3-0.6B (28 layers, 16 heads, 8 KV heads, head_dim=128)
|
||||||
|
- **Context Length**: 32K (max_model_len=40960)
|
||||||
|
|
||||||
|
### XAttention 配置
|
||||||
|
- **Sparse Policy**: XATTN_BSA
|
||||||
|
- **Threshold**: 0.9
|
||||||
|
- **Block Size**: 128 tokens (BSA block)
|
||||||
|
- **Stride**: 8
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 内存使用分析
|
||||||
|
|
||||||
|
### 基准测试 (gpu-utilization=0.9)
|
||||||
|
|
||||||
|
| 指标 | 数值 |
|
||||||
|
|------|------|
|
||||||
|
| KV Cache | 157 blocks × 448 MB = 70.3 GB |
|
||||||
|
| **峰值内存** | **73,949 MiB (72.2 GB)** |
|
||||||
|
| GPU 利用率 | 90.2% |
|
||||||
|
|
||||||
|
### 24GB 显存可行性测试
|
||||||
|
|
||||||
|
| gpu-utilization | KV Cache Blocks | KV Cache Size | 峰值内存 | 测试结果 |
|
||||||
|
|-----------------|-----------------|---------------|----------|----------|
|
||||||
|
| 0.25 | 39 blocks | 17.5 GB | **20.6 GB** | ✅ 5/5 PASSED |
|
||||||
|
| 0.28 | 44 blocks | 19.7 GB | **22.8 GB** | ✅ 5/5 PASSED |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 24GB 显存推荐配置
|
||||||
|
|
||||||
|
适用于 **RTX 3090 / RTX 4090 (24GB)**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \
|
||||||
|
--model ~/models/Qwen3-0.6B \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 5 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9 \
|
||||||
|
--gpu-utilization 0.28
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置说明
|
||||||
|
|
||||||
|
| 参数 | 值 | 说明 |
|
||||||
|
|------|-----|------|
|
||||||
|
| `--gpu-utilization` | 0.28 | 限制 GPU 内存使用到 ~23GB |
|
||||||
|
| `--max-model-len` | 40960 | 支持 32K+ context |
|
||||||
|
| `--sparse-policy` | XATTN_BSA | 启用 XAttention 稀疏注意力 |
|
||||||
|
| `--sparse-threshold` | 0.9 | 选择覆盖 90% attention 的 blocks |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 内存分解
|
||||||
|
|
||||||
|
### Qwen3-0.6B @ 32K Context
|
||||||
|
|
||||||
|
| 组件 | 计算公式 | 大小 |
|
||||||
|
|------|----------|------|
|
||||||
|
| 模型权重 | 0.6B × 2 bytes | ~1.2 GB |
|
||||||
|
| KV Cache (per-token) | 2 × 28 layers × 8 kv_heads × 128 head_dim × 2 bytes | 112 KB |
|
||||||
|
| KV Cache (per-block) | 112 KB × 4096 tokens | 448 MB |
|
||||||
|
| KV Cache (44 blocks) | 448 MB × 44 | 19.7 GB |
|
||||||
|
| XAttention Buffers | GQA + mask + KV chunking | ~0.3 GB |
|
||||||
|
| 中间激活 | 运行时分配 | ~1.5 GB |
|
||||||
|
| **总计** | | **~22.8 GB** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 性能指标
|
||||||
|
|
||||||
|
### RULER niah_single_1 (5 samples)
|
||||||
|
|
||||||
|
| 指标 | gpu-util=0.25 | gpu-util=0.28 | gpu-util=0.9 |
|
||||||
|
|------|---------------|---------------|--------------|
|
||||||
|
| 准确率 | 100% (5/5) | 100% (5/5) | 100% (5/5) |
|
||||||
|
| 耗时 | 11.4s | 11.5s | 11.6s |
|
||||||
|
| Compute Density | 24.77% | 24.77% | 24.77% |
|
||||||
|
| Min Layer Density | 4.29% (Layer 5) | 4.29% (Layer 5) | 4.29% (Layer 5) |
|
||||||
|
|
||||||
|
**结论**: 降低 gpu-utilization 不影响准确率和性能,只影响可支持的最大序列长度。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 不同模型的估算
|
||||||
|
|
||||||
|
### KV Cache 公式
|
||||||
|
|
||||||
|
```
|
||||||
|
KV Cache per-token = 2 × num_layers × num_kv_heads × head_dim × dtype_size
|
||||||
|
KV Cache per-block = per-token × block_size
|
||||||
|
```
|
||||||
|
|
||||||
|
### 常见模型估算 (32K context, block_size=4096)
|
||||||
|
|
||||||
|
| 模型 | Layers | KV Heads | Head Dim | Per-Token | 32K Tokens | 24GB 可行? |
|
||||||
|
|------|--------|----------|----------|-----------|------------|------------|
|
||||||
|
| Qwen3-0.6B | 28 | 8 | 128 | 112 KB | 3.5 GB | ✅ 是 |
|
||||||
|
| Qwen3-4B | 36 | 8 | 128 | 144 KB | 4.5 GB | ✅ 是 |
|
||||||
|
| Llama-3.1-8B | 32 | 8 | 128 | 128 KB | 4.0 GB | ⚠️ 需要 offload |
|
||||||
|
| Qwen2.5-7B | 28 | 4 | 128 | 56 KB | 1.8 GB | ✅ 是 |
|
||||||
|
|
||||||
|
注: 8B 模型权重约 16GB,加上 KV cache 超过 24GB,需要 CPU offload。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 使用建议
|
||||||
|
|
||||||
|
### RTX 3090/4090 (24GB)
|
||||||
|
|
||||||
|
1. **小模型 (≤4B)**:可直接使用 GPU-only + XAttention
|
||||||
|
```bash
|
||||||
|
--gpu-utilization 0.28 --sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **大模型 (7B-8B)**:需要 CPU offload
|
||||||
|
```bash
|
||||||
|
--enable-offload --num-gpu-blocks 4 --num-cpu-blocks 32
|
||||||
|
```
|
||||||
|
|
||||||
|
### A100 (40GB/80GB)
|
||||||
|
|
||||||
|
1. **所有模型**:可使用 GPU-only 模式
|
||||||
|
```bash
|
||||||
|
--gpu-utilization 0.9 --sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `tests/test_ruler.py`: RULER 测试脚本
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy 实现
|
||||||
|
- `docs/gpuonly_density_alignment_test.md`: Density 对齐验证
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Date**: 2026-02-02
|
||||||
|
**Author**: Zijie Tian
|
||||||
184
docs/xattn_offload_profiling_32k.md
Normal file
184
docs/xattn_offload_profiling_32k.md
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
# XAttention Offload Profiling - 32K Context
|
||||||
|
|
||||||
|
Nsys profiling 分析 XAttention vs Full Attention 在 Offload 模式下的性能。
|
||||||
|
|
||||||
|
**测试日期**: 2026-02-05
|
||||||
|
**测试模型**: Llama-3.1-8B-Instruct
|
||||||
|
**Context**: 32K tokens
|
||||||
|
**GPU**: A100-80GB (GPU 0)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
| 参数 | Full | XAttention |
|
||||||
|
|------|------|------------|
|
||||||
|
| Policy | FULL | XATTN_BSA |
|
||||||
|
| Block size | 4096 | 4096 |
|
||||||
|
| GPU blocks | 4 | 4 |
|
||||||
|
| Threshold | - | 0.95 |
|
||||||
|
| Density | 100% | ~50% |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## XAttention 各阶段时间统计
|
||||||
|
|
||||||
|
### NVTX Markers Summary
|
||||||
|
|
||||||
|
| 阶段 | 总时间(ms) | 调用次数 | 平均时间(ms) | 说明 |
|
||||||
|
|------|------------|----------|--------------|------|
|
||||||
|
| xattn_find_blocks | 1155.1 | 256 | 4.51 | 块选择 (threshold-based) |
|
||||||
|
| xattn_estimate_pass1 | 588.3 | 256 | 2.30 | 第一轮: partial stats |
|
||||||
|
| xattn_compute_historical | 512.0 | 224 | 2.29 | 历史 KV attention |
|
||||||
|
| xattn_estimate_pass2 | 501.6 | 256 | 1.96 | 第二轮: block sums |
|
||||||
|
| xattn_estimate_merge | 197.9 | 256 | 0.77 | 合并 softmax stats |
|
||||||
|
| xattn_compute_merge | 93.8 | 256 | 0.37 | 计算结果合并 |
|
||||||
|
| xattn_compute_current | 59.2 | 256 | 0.23 | 当前 chunk attention |
|
||||||
|
|
||||||
|
### 时间分配
|
||||||
|
|
||||||
|
```
|
||||||
|
Total XAttention overhead: 3108 ms
|
||||||
|
|
||||||
|
Estimate 阶段: 1288 ms (41.4%)
|
||||||
|
- pass1: 588 ms
|
||||||
|
- pass2: 502 ms
|
||||||
|
- merge: 198 ms
|
||||||
|
|
||||||
|
Find blocks: 1155 ms (37.2%)
|
||||||
|
|
||||||
|
Compute 阶段: 665 ms (21.4%)
|
||||||
|
- historical: 512 ms
|
||||||
|
- merge: 94 ms
|
||||||
|
- current: 59 ms
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Chunk7 (最后一个 chunk) 对比
|
||||||
|
|
||||||
|
### Per-Layer 时间
|
||||||
|
|
||||||
|
| Policy | Layer 0 | Layer 1 | ... | Layer 31 | Avg |
|
||||||
|
|--------|---------|---------|-----|----------|-----|
|
||||||
|
| Full | 36.5 ms | 33.6 ms | ... | 32.7 ms | ~35 ms |
|
||||||
|
| XAttn | 39.7 ms | 39.3 ms | ... | 38.5 ms | ~38 ms |
|
||||||
|
|
||||||
|
### 分析
|
||||||
|
|
||||||
|
Chunk7 是序列的最后 ~4K tokens (3813 tokens),此时:
|
||||||
|
- K 长度: 32485 tokens
|
||||||
|
- Density: 42.08%
|
||||||
|
|
||||||
|
**结论**: XAttention 在 Chunk7 比 Full 慢约 8%,原因:
|
||||||
|
1. Estimate 开销无法被稀疏计算收益抵消
|
||||||
|
2. 42% density 仍然较高,稀疏收益有限
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Full Attention Chunk7 详细数据
|
||||||
|
|
||||||
|
```
|
||||||
|
Layer Time(ms)
|
||||||
|
L0 36.5
|
||||||
|
L1 44.3
|
||||||
|
L2 43.7
|
||||||
|
L3 38.7
|
||||||
|
L4 34.2
|
||||||
|
L5 45.2
|
||||||
|
...
|
||||||
|
L31 32.7
|
||||||
|
Avg ~35
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## XAttention Chunk7 详细数据
|
||||||
|
|
||||||
|
```
|
||||||
|
Layer Time(ms)
|
||||||
|
L0 39.7
|
||||||
|
L1 39.3
|
||||||
|
L2 37.1
|
||||||
|
L3 39.1
|
||||||
|
L4 38.7
|
||||||
|
L5 39.4
|
||||||
|
...
|
||||||
|
L31 38.5
|
||||||
|
Avg ~38
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 性能瓶颈分析
|
||||||
|
|
||||||
|
### 1. xattn_find_blocks 开销过高
|
||||||
|
|
||||||
|
- 平均 4.51 ms per call
|
||||||
|
- 占总时间 37.2%
|
||||||
|
- 原因: threshold-based 块选择涉及排序和累积求和
|
||||||
|
|
||||||
|
### 2. 两轮 estimate 开销
|
||||||
|
|
||||||
|
- Pass1 + Pass2 共 1090 ms
|
||||||
|
- 需要遍历所有 KV chunks 两次
|
||||||
|
- 可优化方向: 单轮 estimate
|
||||||
|
|
||||||
|
### 3. Compute 阶段相对高效
|
||||||
|
|
||||||
|
- 只占 21.4%
|
||||||
|
- 说明 BSA 稀疏计算本身效率不错
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 优化建议
|
||||||
|
|
||||||
|
### 短期
|
||||||
|
|
||||||
|
1. **减少 find_blocks 开销**
|
||||||
|
- 使用 top-k 而不是 threshold
|
||||||
|
- 预分配 mask buffer 避免动态分配
|
||||||
|
|
||||||
|
2. **合并 estimate 两轮**
|
||||||
|
- 在单轮中同时计算 stats 和 block sums
|
||||||
|
|
||||||
|
### 中期
|
||||||
|
|
||||||
|
1. **estimate 阶段使用更小的 block_size**
|
||||||
|
- 当前 block_size=4096 对 estimate 不友好
|
||||||
|
- 参考 `docs/estimate_block_size_performance.md`
|
||||||
|
|
||||||
|
2. **Pipeline estimate 和 H2D**
|
||||||
|
- 将 estimate 与下一个 chunk 的 H2D 重叠
|
||||||
|
|
||||||
|
### 长期
|
||||||
|
|
||||||
|
1. **预测式块选择**
|
||||||
|
- 基于历史 pattern 预测下一个 chunk 的重要 blocks
|
||||||
|
- 减少 estimate 开销
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `results/nsys/full_offload_32k_blk4096_20260205_023257.nsys-rep`
|
||||||
|
- `results/nsys/xattn_offload_32k_blk4096_20260205_023435.nsys-rep`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 命令
|
||||||
|
|
||||||
|
### Profile Full
|
||||||
|
```bash
|
||||||
|
bash scripts/profile_offload.sh --policy full --ctx-len 32k --gpu 0 --model ~/models/Llama-3.1-8B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### Profile XAttention
|
||||||
|
```bash
|
||||||
|
bash scripts/profile_offload.sh --policy xattn --ctx-len 32k --gpu 0 --model ~/models/Llama-3.1-8B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 分析 NVTX
|
||||||
|
```bash
|
||||||
|
nsys stats --report nvtx_pushpop_sum <file>.nsys-rep
|
||||||
|
```
|
||||||
307
docs/xattn_offload_stream_sync_fix.md
Normal file
307
docs/xattn_offload_stream_sync_fix.md
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
# XAttention Offload Stream Synchronization Fix
|
||||||
|
|
||||||
|
修复 XAttention BSA Policy 在 Offload 模式下的 CUDA stream 同步 bug。
|
||||||
|
|
||||||
|
**修复日期**: 2026-02-05
|
||||||
|
**Commit**: `829b311`
|
||||||
|
**影响文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`, `nanovllm/kvcache/offload_engine.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 问题描述
|
||||||
|
|
||||||
|
### 症状
|
||||||
|
|
||||||
|
在 Offload 模式下运行 RULER benchmark 时,XAttention BSA 的 `select_blocks` 方法中 Pass 1 和 Pass 2 从**同一个 CPU block** 加载的 K 数据不一致:
|
||||||
|
|
||||||
|
```
|
||||||
|
Pass 1: K_chunk sum = 745472.00 (正确)
|
||||||
|
Pass 2: K_chunk sum = 0.00 (错误,数据未加载完成)
|
||||||
|
```
|
||||||
|
|
||||||
|
这导致 attention 计算结果错误,RULER 准确率下降。
|
||||||
|
|
||||||
|
### 复现条件
|
||||||
|
|
||||||
|
- 模式: Offload (`--enable-offload`)
|
||||||
|
- Context: ≥ 32K tokens
|
||||||
|
- 稀疏策略: `--sparse-policy XATTN_BSA`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 根因分析
|
||||||
|
|
||||||
|
### Stream 配置回顾
|
||||||
|
|
||||||
|
nano-vllm 的 CPU offload 使用多个 CUDA streams 实现 pipeline:
|
||||||
|
|
||||||
|
| Stream | 用途 |
|
||||||
|
|--------|------|
|
||||||
|
| `slot_transfer_streams[i]` | H2D 传输 (CPU → GPU slot) |
|
||||||
|
| `compute_stream` | Attention 计算 |
|
||||||
|
| `prefill_offload_streams[i]` | D2H 传输 (GPU → CPU cache) |
|
||||||
|
|
||||||
|
### 同步机制
|
||||||
|
|
||||||
|
`wait_slot_layer(slot)` 使用 event 机制同步:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def wait_slot_layer(self, slot_idx: int):
|
||||||
|
"""Make compute_stream wait for H2D transfer completion."""
|
||||||
|
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bug 根因
|
||||||
|
|
||||||
|
在 `select_blocks` 方法中:
|
||||||
|
|
||||||
|
1. H2D 传输在 `slot_transfer_streams` 上执行
|
||||||
|
2. `wait_slot_layer` 让 `compute_stream` 等待传输完成
|
||||||
|
3. **但是** 后续的 compute kernels 在**默认 stream** 上执行,而不是 `compute_stream`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Bug 代码
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot) # compute_stream 等待
|
||||||
|
|
||||||
|
# 这些 kernel 在默认 stream 上运行,没有等待 H2D 完成!
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
# ... 后续计算 ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 时序图
|
||||||
|
|
||||||
|
```
|
||||||
|
slot_transfer_stream: [====H2D====]
|
||||||
|
compute_stream: |wait|
|
||||||
|
default_stream: [kernel1][kernel2] ← 没有等待!
|
||||||
|
↑
|
||||||
|
数据未就绪
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 修复方案
|
||||||
|
|
||||||
|
### 核心修改
|
||||||
|
|
||||||
|
将所有 estimate 阶段的 compute kernels 包装在 `with torch.cuda.stream(compute_stream):` 中:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 修复后代码
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot) # compute_stream 等待
|
||||||
|
|
||||||
|
# 所有计算在 compute_stream 上执行
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
# ... 后续计算 ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 修复位置
|
||||||
|
|
||||||
|
`select_blocks` 方法中共 6 处需要修复:
|
||||||
|
|
||||||
|
| 位置 | 阶段 | 修复内容 |
|
||||||
|
|------|------|----------|
|
||||||
|
| Pass 1 历史 blocks | `xattn_estimate_pass1` | 历史 KV chunk 处理 |
|
||||||
|
| Pass 1 当前 chunk | `xattn_estimate_pass1` | 当前 GPU 上的 K 处理 |
|
||||||
|
| Step 2 合并 | `merge_softmax_stats` | softmax stats 合并 |
|
||||||
|
| Pass 2 历史 blocks | `xattn_estimate_pass2` | 带全局 stats 的 block_sum |
|
||||||
|
| Pass 2 当前 chunk | `xattn_estimate_pass2` | 当前 chunk 的 block_sum |
|
||||||
|
| Step 4 block 选择 | `find_blocks_chunked` | 最终 block 选择 |
|
||||||
|
|
||||||
|
### 时序图(修复后)
|
||||||
|
|
||||||
|
```
|
||||||
|
slot_transfer_stream: [====H2D====]
|
||||||
|
compute_stream: |wait|[kernel1][kernel2]
|
||||||
|
↑
|
||||||
|
数据已就绪
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 代码变更详情
|
||||||
|
|
||||||
|
### 1. Pass 1 历史 blocks 处理
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Before (bug)
|
||||||
|
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot) # 默认 stream
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
# ... compute ...
|
||||||
|
|
||||||
|
# After (fixed)
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream): # 显式指定 stream
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
# ... compute ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 移除 STRONG SYNC
|
||||||
|
|
||||||
|
`offload_engine.py` 中移除了不必要的强同步:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Removed from load_to_slot_layer() and load_k_only_to_slot_layer()
|
||||||
|
# STRONG SYNC: Synchronize all prefill offload streams before H2D
|
||||||
|
# for offload_stream in self.prefill_offload_streams:
|
||||||
|
# offload_stream.synchronize()
|
||||||
|
```
|
||||||
|
|
||||||
|
这些同步现在由 event 机制正确处理,不再需要阻塞式同步。
|
||||||
|
|
||||||
|
### 3. 其他清理
|
||||||
|
|
||||||
|
- 移除 DEBUG print 语句
|
||||||
|
- 移除 `torch.save()` debug 代码
|
||||||
|
- 合并多个 fallback 条件
|
||||||
|
- 将 `chunk_size` 默认值从 16384 改为 4096(匹配 offload Q chunk size)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试验证
|
||||||
|
|
||||||
|
### 测试命令
|
||||||
|
|
||||||
|
**GPU 0 - Offload 模式测试**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 10 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
**GPU 1 - GPU-only 模式测试**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Qwen3-0.6B \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 10 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试结果
|
||||||
|
|
||||||
|
| 模式 | 模型 | Context | Samples | Pass Rate | Density |
|
||||||
|
|------|------|---------|---------|-----------|---------|
|
||||||
|
| Offload | Llama-3.1-8B | 32K | 10/10 | **100%** | 9.53% |
|
||||||
|
| GPU-only | Qwen3-0.6B | 32K | 10/10 | **100%** | 9.84% |
|
||||||
|
|
||||||
|
### Density 对齐验证
|
||||||
|
|
||||||
|
| 模式 | Layer 0 Density | 差异 |
|
||||||
|
|------|-----------------|------|
|
||||||
|
| GPU-only | 9.84% | - |
|
||||||
|
| Offload | 9.53% | ~3% |
|
||||||
|
|
||||||
|
~3% 的差异是预期的,因为两种模式的 KV 累积模式不同:
|
||||||
|
- GPU-only: 一次性处理所有 KV
|
||||||
|
- Offload: 分 chunk 处理,每个 chunk 独立计算 softmax stats 后合并
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 技术细节
|
||||||
|
|
||||||
|
### 三阶段 KV Chunking 流程
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Stage 1: softmax_compute_partial_stats │
|
||||||
|
│ └── 每个 KV chunk 独立计算 partial stats (m_i, l_i) │
|
||||||
|
│ │
|
||||||
|
│ Stage 2: merge_softmax_stats │
|
||||||
|
│ └── Host 端合并所有 chunks: (m_global, l_global) │
|
||||||
|
│ │
|
||||||
|
│ Stage 3: softmax_normalize_and_block_sum │
|
||||||
|
│ └── 使用全局 stats 归一化并计算 block sums │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stream 配置要求
|
||||||
|
|
||||||
|
| 操作类型 | Stream | 原因 |
|
||||||
|
|----------|--------|------|
|
||||||
|
| H2D 传输 | `slot_transfer_streams` | 异步传输,不阻塞计算 |
|
||||||
|
| D2H 传输 | `prefill_offload_streams` | 异步 offload,不阻塞计算 |
|
||||||
|
| Estimate kernels | `compute_stream` | 与 attention 计算共享,确保同步 |
|
||||||
|
| Attention kernels | `compute_stream` | 主计算流 |
|
||||||
|
|
||||||
|
### Event 同步机制
|
||||||
|
|
||||||
|
```python
|
||||||
|
# H2D 传输完成后记录 event
|
||||||
|
self.ring_slot_ready[slot_idx].record(slot_transfer_stream)
|
||||||
|
|
||||||
|
# 计算前等待 H2D 完成
|
||||||
|
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
|
||||||
|
|
||||||
|
# 计算完成后记录 event(用于下一轮 H2D)
|
||||||
|
self.ring_slot_compute_done[slot_idx].record(compute_stream)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`docs/architecture_guide.md`](architecture_guide.md): Stream 配置和 ring buffer 架构
|
||||||
|
- [`docs/xattn_kv_chunking_kernels.md`](xattn_kv_chunking_kernels.md): 三阶段 softmax kernels
|
||||||
|
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md): Density 对齐测试
|
||||||
|
- [`docs/xattn_bsa_policy_design.md`](xattn_bsa_policy_design.md): XAttention BSA Policy 设计
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 经验总结
|
||||||
|
|
||||||
|
### 1. Stream 同步的隐蔽性
|
||||||
|
|
||||||
|
CUDA stream 同步 bug 很难发现:
|
||||||
|
- 数据可能"大部分时间"正确(取决于时序)
|
||||||
|
- 错误表现为随机/间歇性的结果偏差
|
||||||
|
- 需要精确的 debug logging 才能定位
|
||||||
|
|
||||||
|
### 2. Event vs Synchronize
|
||||||
|
|
||||||
|
| 方法 | 优点 | 缺点 |
|
||||||
|
|------|------|------|
|
||||||
|
| `stream.wait_event()` | 非阻塞,保持 pipeline | 只同步指定 stream |
|
||||||
|
| `stream.synchronize()` | 保证完成 | 阻塞整个 stream,破坏 pipeline |
|
||||||
|
|
||||||
|
**最佳实践**: 使用 event 进行精确同步,避免 synchronize 阻塞。
|
||||||
|
|
||||||
|
### 3. 调试方法
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 打印 tensor sum 验证数据一致性
|
||||||
|
print(f"K_chunk sum = {K_chunk.sum().item()}")
|
||||||
|
|
||||||
|
# 保存中间结果进行离线比较
|
||||||
|
torch.save({'K': K_chunk, 'layer': layer_id}, f'/tmp/debug_{pass}_{chunk}.pt')
|
||||||
|
```
|
||||||
@@ -227,18 +227,19 @@ class ModelRunner:
|
|||||||
device=torch.device("cuda"),
|
device=torch.device("cuda"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# GPU-only mode: pre-allocate policy metadata buffers
|
# Pre-allocate policy metadata buffers
|
||||||
# This avoids dynamic GPU memory allocation during forward pass
|
# - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
|
||||||
if not config.enable_cpu_offload:
|
# - GPU-only mode: additionally allocate GQA expansion buffers
|
||||||
num_heads = hf_config.num_attention_heads // self.world_size
|
num_heads = hf_config.num_attention_heads // self.world_size
|
||||||
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
max_seq_len=config.max_model_len,
|
max_seq_len=config.max_model_len,
|
||||||
dtype=hf_config.torch_dtype,
|
dtype=hf_config.torch_dtype,
|
||||||
device=torch.device("cuda"),
|
device=torch.device("cuda"),
|
||||||
)
|
enable_cpu_offload=config.enable_cpu_offload,
|
||||||
|
)
|
||||||
|
|
||||||
# Log policy info (handle both enum and None cases)
|
# Log policy info (handle both enum and None cases)
|
||||||
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Return all blocks - no sparsity."""
|
"""Return all blocks - no sparsity."""
|
||||||
# Update statistics (only for layer 0 to avoid overcounting)
|
# Update statistics (only for layer 0 to avoid overcounting)
|
||||||
|
|||||||
@@ -116,13 +116,15 @@ class SparsePolicy(ABC):
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
enable_cpu_offload: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Pre-allocate GPU buffers for policy computation.
|
Pre-allocate GPU buffers for policy computation.
|
||||||
|
|
||||||
Called by the framework after KV cache allocation, but ONLY for GPU-only
|
Called by the framework after KV cache allocation. Implementations should
|
||||||
mode (not CPU offload mode). Override this to pre-allocate buffers that
|
use enable_cpu_offload to decide which buffers to allocate:
|
||||||
would otherwise be dynamically allocated during forward pass.
|
- Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
|
||||||
|
- GPU-only mode: additionally allocate GQA expansion buffers
|
||||||
|
|
||||||
This is separate from initialize() which is used for CPU offload metadata.
|
This is separate from initialize() which is used for CPU offload metadata.
|
||||||
|
|
||||||
@@ -133,6 +135,7 @@ class SparsePolicy(ABC):
|
|||||||
max_seq_len: Maximum sequence length (for buffer sizing)
|
max_seq_len: Maximum sequence length (for buffer sizing)
|
||||||
dtype: Data type (typically float16/bfloat16)
|
dtype: Data type (typically float16/bfloat16)
|
||||||
device: Target device (cuda)
|
device: Target device (cuda)
|
||||||
|
enable_cpu_offload: Whether CPU offload is enabled
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -142,6 +145,8 @@ class SparsePolicy(ABC):
|
|||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Select which KV blocks to load for the current query chunk.
|
Select which KV blocks to load for the current query chunk.
|
||||||
@@ -158,6 +163,8 @@ class SparsePolicy(ABC):
|
|||||||
to load KV to make selection decisions).
|
to load KV to make selection decisions).
|
||||||
ctx: PolicyContext with information about the current query
|
ctx: PolicyContext with information about the current query
|
||||||
chunk, layer, phase (prefill/decode), etc.
|
chunk, layer, phase (prefill/decode), etc.
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of block IDs to load (must be a subset of available_blocks).
|
List of block IDs to load (must be a subset of available_blocks).
|
||||||
|
|||||||
@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
|
|||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Select Top-K blocks based on query-key similarity bounds.
|
Select Top-K blocks based on query-key similarity bounds.
|
||||||
|
|
||||||
If query is not available (some prefill scenarios), falls back
|
If query is not available (some prefill scenarios), falls back
|
||||||
to loading all blocks.
|
to loading all blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_blocks: List of CPU block IDs
|
||||||
|
offload_engine: OffloadEngine for loading KV (unused in Quest)
|
||||||
|
ctx: PolicyContext with metadata
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected block IDs
|
||||||
"""
|
"""
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
|
|||||||
if n <= self.config.threshold_blocks:
|
if n <= self.config.threshold_blocks:
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
if ctx.query is None:
|
if q is None:
|
||||||
# No query available - cannot compute scores
|
# No query available - cannot compute scores
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Metadata is already on GPU, same device as query
|
# Metadata is already on GPU, same device as query
|
||||||
device = ctx.query.device
|
device = q.device
|
||||||
|
|
||||||
# Compute upper bound scores
|
# Compute upper bound scores
|
||||||
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
# query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
|
||||||
q = ctx.query
|
|
||||||
if q.dim() == 4:
|
if q.dim() == 4:
|
||||||
# Prefill: use mean over sequence length
|
# Prefill: use mean over sequence length
|
||||||
q = q.mean(dim=1) # [1, num_heads, head_dim]
|
q = q.mean(dim=1) # [1, num_heads, head_dim]
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import torch.cuda.nvtx as nvtx
|
|||||||
from typing import List, Tuple, TYPE_CHECKING
|
from typing import List, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
|
from nanovllm.utils.density_observer import DensityObserver
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||||
@@ -25,6 +26,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global storage for mask debugging
|
||||||
|
_DEBUG_SAVE_MASK = False # Set to True to save masks for comparison
|
||||||
|
_DEBUG_MASK_STORAGE = {}
|
||||||
|
|
||||||
# Check BSA availability
|
# Check BSA availability
|
||||||
try:
|
try:
|
||||||
from block_sparse_attn import block_sparse_attn_func
|
from block_sparse_attn import block_sparse_attn_func
|
||||||
@@ -91,7 +96,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
self,
|
self,
|
||||||
threshold: float = 0.95, # High threshold for accuracy testing
|
threshold: float = 0.95, # High threshold for accuracy testing
|
||||||
stride: int = 8,
|
stride: int = 8,
|
||||||
chunk_size: int = 16384,
|
chunk_size: int = 4096, # Match offload Q chunk size for density alignment
|
||||||
block_size: int = 128,
|
block_size: int = 128,
|
||||||
samples_per_chunk: int = 128,
|
samples_per_chunk: int = 128,
|
||||||
use_triton: bool = True,
|
use_triton: bool = True,
|
||||||
@@ -134,6 +139,34 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
self._v_expanded: torch.Tensor | None = None
|
self._v_expanded: torch.Tensor | None = None
|
||||||
self._max_seq_len: int = 0
|
self._max_seq_len: int = 0
|
||||||
|
|
||||||
|
# Pre-allocated mask buffer for chunked prefill (offload mode)
|
||||||
|
# Stores BSA-level mask from select_blocks for use in compute_chunked_prefill
|
||||||
|
# Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
|
||||||
|
self._prefill_mask_buffer: torch.Tensor | None = None
|
||||||
|
self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer
|
||||||
|
self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer
|
||||||
|
|
||||||
|
# Selected block indices for mask extraction in compute_chunked_prefill
|
||||||
|
# Stores the indices of selected CPU blocks in available_blocks
|
||||||
|
self._selected_cpu_indices: List[int] = []
|
||||||
|
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Pre-allocated buffers for 3-stage KV chunking (offload mode)
|
||||||
|
# =====================================================================
|
||||||
|
# Partial softmax stats: m (max) and l (exp sum) for each KV chunk
|
||||||
|
# Shape: [max_kv_chunks, batch, heads, q_reshaped_len]
|
||||||
|
self._m_partial_buffer: torch.Tensor | None = None
|
||||||
|
self._l_partial_buffer: torch.Tensor | None = None
|
||||||
|
|
||||||
|
# Block sums buffer: normalized attention sums for all K blocks
|
||||||
|
# Shape: [batch, heads, max_q_bsa_blocks, max_k_bsa_blocks]
|
||||||
|
self._block_sums_buffer: torch.Tensor | None = None
|
||||||
|
|
||||||
|
# Configuration for KV chunking
|
||||||
|
self._max_kv_chunks: int = 0
|
||||||
|
self._cpu_block_size: int = 0 # Tokens per CPU block (set at runtime)
|
||||||
|
|
||||||
def alloc_policy_metadata(
|
def alloc_policy_metadata(
|
||||||
self,
|
self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@@ -142,6 +175,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
enable_cpu_offload: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Pre-allocate GQA expansion buffers for GPU-only mode.
|
Pre-allocate GQA expansion buffers for GPU-only mode.
|
||||||
@@ -161,6 +195,54 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
dtype: Data type
|
dtype: Data type
|
||||||
device: Target device
|
device: Target device
|
||||||
"""
|
"""
|
||||||
|
# Pre-allocate mask buffer for chunked prefill (offload mode)
|
||||||
|
# mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
|
||||||
|
# This is needed regardless of GQA
|
||||||
|
max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE
|
||||||
|
max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE
|
||||||
|
mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
|
||||||
|
self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device)
|
||||||
|
mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024)
|
||||||
|
logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB")
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Pre-allocate buffers for 3-stage KV chunking (offload mode)
|
||||||
|
# =====================================================================
|
||||||
|
# Calculate max KV chunks: historical blocks + current chunk
|
||||||
|
# Use cpu_block_size as KV chunk granularity (will be set at runtime)
|
||||||
|
# For now, estimate based on chunk_size (actual cpu_block_size may differ)
|
||||||
|
estimated_cpu_block_size = 4096 # Default, will be overwritten
|
||||||
|
max_kv_chunks = (max_seq_len // estimated_cpu_block_size) + 1 # +1 for current chunk
|
||||||
|
|
||||||
|
# Q reshaped length for one chunk
|
||||||
|
q_reshaped_len = self.chunk_size // self.stride
|
||||||
|
kv_chunk_reshaped_len = estimated_cpu_block_size // self.stride
|
||||||
|
|
||||||
|
# Partial stats buffers: [max_kv_chunks, batch=1, heads, q_reshaped_len]
|
||||||
|
m_partial_shape = (max_kv_chunks, 1, num_heads, q_reshaped_len)
|
||||||
|
self._m_partial_buffer = torch.empty(m_partial_shape, dtype=torch.float32, device=device)
|
||||||
|
self._l_partial_buffer = torch.empty(m_partial_shape, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# Block sums buffer: [batch=1, heads, max_q_bsa_blocks, max_k_bsa_blocks]
|
||||||
|
block_sums_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
|
||||||
|
self._block_sums_buffer = torch.empty(block_sums_shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self._max_kv_chunks = max_kv_chunks
|
||||||
|
|
||||||
|
# Memory calculation
|
||||||
|
m_l_memory_mb = 2 * max_kv_chunks * num_heads * q_reshaped_len * 4 / (1024 * 1024)
|
||||||
|
block_sums_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f"[XAttn] Pre-allocated KV chunking buffers: "
|
||||||
|
f"m/l shape={m_partial_shape} ({m_l_memory_mb:.1f} MB), "
|
||||||
|
f"block_sums shape={block_sums_shape} ({block_sums_memory_mb:.1f} MB)")
|
||||||
|
|
||||||
|
# Skip GQA buffers in offload mode
|
||||||
|
# Chunked prefill uses compute_chunked_prefill() which handles GQA inline
|
||||||
|
if enable_cpu_offload:
|
||||||
|
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers (saves ~16GB for 1M seq)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
|
||||||
# Only allocate if GQA (num_heads != num_kv_heads)
|
# Only allocate if GQA (num_heads != num_kv_heads)
|
||||||
if num_heads == num_kv_heads:
|
if num_heads == num_kv_heads:
|
||||||
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
||||||
@@ -215,9 +297,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
Attention output [total_q, num_heads, head_dim]
|
Attention output [total_q, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
# When block_tables is provided (paged KV cache / prefix cache),
|
# Fallback to flash attention when:
|
||||||
# fallback to flash_attn as XAttention expects contiguous K, V
|
# 1. block_tables provided (paged KV cache / prefix cache) - XAttention expects contiguous K, V
|
||||||
if block_tables is not None:
|
# 2. BSA kernel not available
|
||||||
|
# 3. xattn_estimate not available
|
||||||
|
if block_tables is not None or not BSA_AVAILABLE or not XATTN_AVAILABLE:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
return flash_attn_varlen_func(
|
return flash_attn_varlen_func(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
@@ -230,34 +314,12 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
block_table=block_tables,
|
block_table=block_tables,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not BSA_AVAILABLE:
|
|
||||||
# Fallback to flash attention if BSA not available
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not XATTN_AVAILABLE:
|
|
||||||
# Fallback to flash attention if xattn not available
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
from nanovllm.ops.xattn import xattn_estimate
|
from nanovllm.ops.xattn import xattn_estimate
|
||||||
|
|
||||||
|
# Set DensityObserver mode on first layer
|
||||||
|
if layer_id == 0:
|
||||||
|
DensityObserver.set_mode("gpu_only")
|
||||||
|
|
||||||
# Get dimensions
|
# Get dimensions
|
||||||
total_q, num_heads, head_dim = q.shape
|
total_q, num_heads, head_dim = q.shape
|
||||||
total_kv, num_kv_heads, _ = k.shape
|
total_kv, num_kv_heads, _ = k.shape
|
||||||
@@ -311,15 +373,46 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
|
|
||||||
# Estimate block importance and get sparse mask
|
# Estimate block importance and get sparse mask
|
||||||
with nvtx.range("xattn_estimate"):
|
with nvtx.range("xattn_estimate"):
|
||||||
_, mask = xattn_estimate(
|
attn_sums, mask = xattn_estimate(
|
||||||
Q, K_exp,
|
Q, K_exp,
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
block_size=self.BSA_BLOCK_SIZE,
|
block_size=self.BSA_BLOCK_SIZE,
|
||||||
|
stride=self.stride,
|
||||||
threshold=self.threshold,
|
threshold=self.threshold,
|
||||||
use_triton=self.use_triton,
|
use_triton=self.use_triton,
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Debug: Save Q, K, mask, attn_sums for external verification
|
||||||
|
if _DEBUG_SAVE_MASK and layer_id == 0:
|
||||||
|
import os
|
||||||
|
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
|
valid_k_blocks = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
|
mask_valid = mask[:, :, :valid_q_blocks, :valid_k_blocks]
|
||||||
|
attn_sums_valid = attn_sums[:, :, :valid_q_blocks, :valid_k_blocks]
|
||||||
|
save_dir = "/home/zijie/Code/nano-vllm/results/mask_alignment"
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
save_path = f"{save_dir}/gpuonly_layer{layer_id}.pt"
|
||||||
|
torch.save({
|
||||||
|
# Input tensors (GQA-expanded)
|
||||||
|
"Q": Q.clone().cpu(), # [1, num_heads, q_len, head_dim]
|
||||||
|
"K": K_exp.clone().cpu(), # [1, num_heads, k_len, head_dim]
|
||||||
|
# xattn_estimate parameters
|
||||||
|
"chunk_size": self.chunk_size,
|
||||||
|
"block_size": self.BSA_BLOCK_SIZE,
|
||||||
|
"stride": self.stride,
|
||||||
|
"threshold": self.threshold,
|
||||||
|
# Output for comparison
|
||||||
|
"mask": mask_valid.clone().cpu(),
|
||||||
|
"attn_sums": attn_sums_valid.clone().cpu(),
|
||||||
|
# Metadata
|
||||||
|
"q_len": q_len,
|
||||||
|
"k_len": k_len,
|
||||||
|
"valid_q_blocks": valid_q_blocks,
|
||||||
|
"valid_k_blocks": valid_k_blocks,
|
||||||
|
}, save_path)
|
||||||
|
logger.info(f"[DEBUG] Saved Q/K/mask to {save_path}, Q={Q.shape}, K={K_exp.shape}")
|
||||||
|
|
||||||
# Compute block counts
|
# Compute block counts
|
||||||
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
@@ -360,13 +453,16 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update statistics (layer 0 only to avoid overcounting)
|
# Record density for all layers via DensityObserver
|
||||||
if layer_id == 0:
|
if layer_id == 0:
|
||||||
selected_blocks = mask_trimmed.sum().item()
|
# DEBUG: 打印 GPU-only Layer 0 的 mask 详情
|
||||||
total_blocks = q_block_num * k_block_num * num_heads
|
q_bk = mask_trimmed.shape[2]
|
||||||
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
|
k_bk = mask_trimmed.shape[3]
|
||||||
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
|
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
|
||||||
f"k_blocks={k_block_num}, density={density:.1%}")
|
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
|
||||||
|
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
|
|
||||||
|
DensityObserver.record(layer_id, mask_trimmed, causal=True)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@@ -400,46 +496,72 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Compute attention scores for all available blocks using flat_group_gemm,
|
Select important blocks using 3-stage KV chunking algorithm.
|
||||||
then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks.
|
|
||||||
|
|
||||||
This method:
|
This method implements the same algorithm as tests/test_xattn_estimate_alignment.py:
|
||||||
1. Loads each K block from CPU
|
1. For each KV chunk: compute attention scores and partial softmax stats
|
||||||
2. Computes Q@K^T attention scores using XAttention stride reshape
|
2. Merge all partial stats to get global m and l
|
||||||
3. Applies softmax_fuse_block_sum to get block-level attention
|
3. For each KV chunk: normalize with global stats and compute block sums
|
||||||
4. Uses find_blocks_chunked to select blocks based on threshold
|
4. Use find_blocks_chunked to select important blocks
|
||||||
|
|
||||||
|
This approach:
|
||||||
|
- Uses O(S×C) peak memory instead of O(S²)
|
||||||
|
- Produces identical density to GPU-only xattn_estimate
|
||||||
|
- Supports ultra-long contexts
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
available_blocks: List of CPU block IDs
|
available_blocks: List of CPU block IDs (historical blocks only)
|
||||||
offload_engine: OffloadEngine for loading blocks
|
offload_engine: OffloadEngine for loading blocks
|
||||||
ctx: PolicyContext with query tensor and metadata
|
ctx: PolicyContext with metadata
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Selected block IDs based on attention threshold
|
Selected block IDs based on attention threshold
|
||||||
"""
|
"""
|
||||||
if not available_blocks or ctx.query is None:
|
if q is None:
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked
|
# CRITICAL: Wait for all previous prefill offloads to complete before loading from CPU
|
||||||
|
# This ensures that the K data we load from k_cache_cpu is actually valid.
|
||||||
|
# Without this sync, we may load stale/uninitialized data because the async offload
|
||||||
|
# from the previous chunk hasn't finished yet.
|
||||||
|
if available_blocks and offload_engine is not None:
|
||||||
|
offload_engine.wait_all_prefill_offloads()
|
||||||
|
|
||||||
|
from nanovllm.ops.xattn import (
|
||||||
|
flat_group_gemm_fuse_reshape,
|
||||||
|
softmax_compute_partial_stats,
|
||||||
|
softmax_normalize_and_block_sum,
|
||||||
|
merge_softmax_stats,
|
||||||
|
find_blocks_chunked,
|
||||||
|
)
|
||||||
import math
|
import math
|
||||||
|
|
||||||
layer_id = ctx.layer_id
|
layer_id = ctx.layer_id
|
||||||
q = ctx.query # [seq_len, num_heads, head_dim]
|
|
||||||
|
|
||||||
|
# Set DensityObserver mode on first layer
|
||||||
|
if layer_id == 0:
|
||||||
|
DensityObserver.set_mode("offload")
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# Step 0: Setup parameters
|
||||||
|
# ================================================================
|
||||||
# Convert Q to [batch, heads, seq_len, head_dim]
|
# Convert Q to [batch, heads, seq_len, head_dim]
|
||||||
# q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim]
|
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
|
||||||
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
|
|
||||||
|
|
||||||
num_heads = Q.shape[1]
|
num_heads = Q.shape[1]
|
||||||
head_dim = Q.shape[3]
|
head_dim = Q.shape[3]
|
||||||
q_len = Q.shape[2]
|
q_len = Q.shape[2]
|
||||||
|
|
||||||
# flat_group_gemm requires q_len to be divisible by stride * BLOCK_M (typically 8 * 128 = 1024)
|
# Alignment requirements
|
||||||
# Pad Q if necessary
|
|
||||||
BLOCK_M = 128 # Triton block size
|
BLOCK_M = 128 # Triton block size
|
||||||
alignment = self.stride * BLOCK_M
|
alignment = self.stride * BLOCK_M # 8 * 128 = 1024
|
||||||
|
|
||||||
if q_len < alignment:
|
if q_len < alignment:
|
||||||
# Q too short, skip estimation and return all blocks
|
# Q too short, skip estimation and return all blocks
|
||||||
logger.debug(f"[XAttn] select_blocks: q_len={q_len} < alignment={alignment}, skipping estimation")
|
logger.debug(f"[XAttn] select_blocks: q_len={q_len} < alignment={alignment}, skipping estimation")
|
||||||
@@ -447,150 +569,337 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
|
|
||||||
# Pad Q to alignment
|
# Pad Q to alignment
|
||||||
padded_q_len = ((q_len + alignment - 1) // alignment) * alignment
|
padded_q_len = ((q_len + alignment - 1) // alignment) * alignment
|
||||||
if padded_q_len != q_len:
|
q_pad_size = padded_q_len - q_len
|
||||||
pad_size = padded_q_len - q_len
|
if q_pad_size > 0:
|
||||||
Q = torch.nn.functional.pad(Q, (0, 0, 0, pad_size), value=0)
|
Q = torch.nn.functional.pad(Q, (0, 0, 0, q_pad_size), value=0)
|
||||||
|
|
||||||
|
# Get CPU block size from context
|
||||||
|
cpu_block_size = ctx.block_size # e.g., 4096 tokens per CPU block
|
||||||
|
self._cpu_block_size = cpu_block_size
|
||||||
|
|
||||||
|
# KV chunk parameters (use CPU block as KV chunk unit)
|
||||||
|
num_historical_blocks = len(available_blocks)
|
||||||
|
historical_k_len = num_historical_blocks * cpu_block_size
|
||||||
|
total_k_len = historical_k_len + q_len # Include current chunk
|
||||||
|
|
||||||
|
# Reshaped dimensions
|
||||||
|
reshaped_block_size = self.BSA_BLOCK_SIZE // self.stride # 128/8 = 16
|
||||||
q_reshaped_len = padded_q_len // self.stride
|
q_reshaped_len = padded_q_len // self.stride
|
||||||
|
kv_chunk_reshaped = cpu_block_size // self.stride
|
||||||
|
|
||||||
# Use a single slot for loading (synchronous mode for simplicity)
|
# BSA blocks per CPU block
|
||||||
slot = 0
|
bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # 4096/128 = 32
|
||||||
attn_scores_list = []
|
|
||||||
|
|
||||||
# Get block size from context
|
# Global K position parameters
|
||||||
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
|
# Q在全局K序列中的位置 (按照 test_xattn_estimate_alignment.py 的逻辑)
|
||||||
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
|
# 对于 chunked softmax,我们需要计算 Q 在整个序列中的 BSA block 偏移
|
||||||
|
# k_block_num = total BSA blocks (padded), q_block_num = Q's BSA blocks (padded)
|
||||||
|
padded_total_k_len = ((total_k_len + alignment - 1) // alignment) * alignment
|
||||||
|
k_block_num = padded_total_k_len // self.BSA_BLOCK_SIZE
|
||||||
|
q_block_num = padded_q_len // self.BSA_BLOCK_SIZE
|
||||||
|
chunk_start = (k_block_num - q_block_num) * reshaped_block_size # Q 在 reshaped 空间的起始
|
||||||
|
chunk_end = chunk_start + q_reshaped_len
|
||||||
|
|
||||||
with nvtx.range("xattn_estimate_gemm"):
|
# real_q_len: 用于 softmax 归一化的有效 Q 长度
|
||||||
for cpu_block_id in available_blocks:
|
k_reshaped_seq_len = padded_total_k_len // self.stride
|
||||||
# Load only K from CPU to GPU (V not needed for estimate)
|
k_reshaped_num_to_pad = (padded_total_k_len - total_k_len) // self.stride
|
||||||
# This saves 50% communication in the estimate phase
|
|
||||||
|
# Softmax scale
|
||||||
|
norm = 1.0
|
||||||
|
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm
|
||||||
|
segment_size = min(4096, reshaped_block_size)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# Step 1: First pass - compute partial stats for all KV chunks
|
||||||
|
# ================================================================
|
||||||
|
m_chunks = []
|
||||||
|
l_chunks = []
|
||||||
|
num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk
|
||||||
|
|
||||||
|
# Get compute_stream for all compute kernels (like attention computation)
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
with nvtx.range("xattn_estimate_pass1"):
|
||||||
|
slot = 0
|
||||||
|
|
||||||
|
# Process historical blocks (from CPU)
|
||||||
|
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||||
|
# Load K from CPU (on slot_transfer_stream)
|
||||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||||
|
# wait_slot_layer makes compute_stream wait for H2D transfer
|
||||||
offload_engine.wait_slot_layer(slot)
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
# Get K only: [1, block_size, num_kv_heads, head_dim]
|
# All compute kernels run on compute_stream (like attention computation)
|
||||||
k_block = offload_engine.get_k_for_slot(slot)
|
with torch.cuda.stream(compute_stream):
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim]
|
||||||
|
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
|
||||||
|
|
||||||
# Convert K to [batch, heads, k_len, head_dim]
|
# GQA expansion
|
||||||
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
|
num_kv_heads = K_chunk.shape[1]
|
||||||
K_chunk = k_block.transpose(1, 2)
|
if num_heads != num_kv_heads:
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
|
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||||
|
|
||||||
# Handle GQA: expand K heads to match Q heads
|
# KV offset in reshaped space
|
||||||
num_kv_heads = K_chunk.shape[1]
|
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
||||||
|
|
||||||
|
# Compute raw attention scores
|
||||||
|
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||||
|
Q, K_chunk, self.stride,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False, # K 不完整,不能在这里用 causal
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute partial stats (带 causal mask)
|
||||||
|
m_partial, l_partial = softmax_compute_partial_stats(
|
||||||
|
attn_weights_kv,
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset_reshaped,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
m_chunks.append(m_partial)
|
||||||
|
l_chunks.append(l_partial)
|
||||||
|
|
||||||
|
offload_engine.record_slot_compute_done(slot)
|
||||||
|
del attn_weights_kv
|
||||||
|
|
||||||
|
# Process current chunk K (already on GPU) on compute_stream
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
|
||||||
|
K_current = k.unsqueeze(0).transpose(1, 2)
|
||||||
|
|
||||||
|
# GQA expansion for current chunk
|
||||||
|
num_kv_heads = K_current.shape[1]
|
||||||
if num_heads != num_kv_heads:
|
if num_heads != num_kv_heads:
|
||||||
num_groups = num_heads // num_kv_heads
|
num_groups = num_heads // num_kv_heads
|
||||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
K_current = K_current.repeat_interleave(num_groups, dim=1)
|
||||||
|
|
||||||
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N)
|
# Pad current K to alignment
|
||||||
k_len = K_chunk.shape[2]
|
curr_k_len = K_current.shape[2]
|
||||||
BLOCK_N = 128
|
padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment
|
||||||
k_alignment = self.stride * BLOCK_N
|
if padded_curr_k_len != curr_k_len:
|
||||||
if k_len < k_alignment:
|
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0)
|
||||||
# K too short, pad it
|
|
||||||
pad_size = k_alignment - k_len
|
|
||||||
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
|
|
||||||
|
|
||||||
# Compute attention scores using flat_group_gemm_fuse_reshape
|
# KV offset for current chunk
|
||||||
# Output: [batch, heads, q_len/stride, k_len/stride]
|
kv_offset_current = num_historical_blocks * kv_chunk_reshaped
|
||||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
|
||||||
Q, K_chunk, self.stride,
|
# Compute attention scores for current chunk
|
||||||
chunk_start=0,
|
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
||||||
chunk_end=q_reshaped_len,
|
Q, K_current, self.stride,
|
||||||
is_causal=False
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False,
|
||||||
)
|
)
|
||||||
attn_scores_list.append(attn_chunk)
|
|
||||||
|
|
||||||
# Mark slot as done for reuse
|
# Compute partial stats for current chunk
|
||||||
offload_engine.record_slot_compute_done(slot)
|
m_partial_curr, l_partial_curr = softmax_compute_partial_stats(
|
||||||
|
attn_weights_curr,
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset_current,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
m_chunks.append(m_partial_curr)
|
||||||
|
l_chunks.append(l_partial_curr)
|
||||||
|
|
||||||
# Concatenate all attention scores along K dimension
|
del attn_weights_curr
|
||||||
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len]
|
|
||||||
# Result: [1, heads, q_reshaped_len, total_k_reshaped_len]
|
|
||||||
if not attn_scores_list:
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
# ================================================================
|
||||||
# Free intermediate list immediately
|
# Step 2: Merge all partial stats (on compute_stream)
|
||||||
del attn_scores_list
|
# ================================================================
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
with nvtx.range("xattn_estimate_merge"):
|
||||||
|
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||||
|
|
||||||
# Step 2: Apply softmax_fuse_block_sum with hierarchical aggregation
|
del m_chunks, l_chunks
|
||||||
# Use smaller estimate_block_size (1024) for 15x faster softmax kernel,
|
|
||||||
# then aggregate to CPU block level (4096).
|
|
||||||
#
|
|
||||||
# Hierarchical approach:
|
|
||||||
# 1. softmax_fuse_block_sum with estimate_block_size (1024) -> fine-grained scores
|
|
||||||
# 2. Aggregate: reshape + sum -> CPU block level scores
|
|
||||||
# 3. Select blocks based on score + threshold (NOT mask + voting)
|
|
||||||
cpu_block_size = block_size # e.g., 4096
|
|
||||||
estimate_bs = self.estimate_block_size # e.g., 1024 (15x faster)
|
|
||||||
ratio = cpu_block_size // estimate_bs # e.g., 4
|
|
||||||
|
|
||||||
# Use estimate_block_size for softmax kernel (optimized)
|
# ================================================================
|
||||||
reshaped_est_bs = estimate_bs // self.stride # e.g., 1024/8 = 128
|
# Step 3: Second pass - normalize and compute block sums
|
||||||
norm = 1.0 # Normalization factor
|
# ================================================================
|
||||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
|
attn_sum_per_kv = []
|
||||||
segment_size = min(4096, reshaped_est_bs)
|
|
||||||
|
|
||||||
with nvtx.range("xattn_estimate_softmax"):
|
with nvtx.range("xattn_estimate_pass2"):
|
||||||
block_sums_fine = softmax_fuse_block_sum(
|
slot = 0
|
||||||
attn_scores,
|
|
||||||
reshaped_est_bs, # Use optimized estimate block size (128 vs 512)
|
# Process historical blocks again
|
||||||
segment_size,
|
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||||
chunk_start=0,
|
# Load K from CPU (on slot_transfer_stream)
|
||||||
chunk_end=q_reshaped_len,
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||||
real_q_len=q_reshaped_len,
|
# wait_slot_layer makes compute_stream wait for H2D transfer
|
||||||
scale=scale,
|
offload_engine.wait_slot_layer(slot)
|
||||||
is_causal=False, # Historical blocks are all before current chunk
|
|
||||||
|
# All compute kernels run on compute_stream
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
|
||||||
|
num_kv_heads = K_chunk.shape[1]
|
||||||
|
if num_heads != num_kv_heads:
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
|
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||||
|
|
||||||
|
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
||||||
|
|
||||||
|
# Recompute attention scores (trade-off: compute vs memory)
|
||||||
|
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||||
|
Q, K_chunk, self.stride,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize with global stats and compute block sums
|
||||||
|
block_sum_kv = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_kv,
|
||||||
|
m_global,
|
||||||
|
l_global,
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||||
|
scale=scale,
|
||||||
|
kv_offset=kv_offset_reshaped,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
attn_sum_per_kv.append(block_sum_kv)
|
||||||
|
|
||||||
|
offload_engine.record_slot_compute_done(slot)
|
||||||
|
del attn_weights_kv
|
||||||
|
|
||||||
|
# Process current chunk on compute_stream
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Recompute attention scores for current chunk
|
||||||
|
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
||||||
|
Q, K_current, self.stride,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_sum_curr = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_curr,
|
||||||
|
m_global,
|
||||||
|
l_global,
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||||
|
scale=scale,
|
||||||
|
kv_offset=kv_offset_current,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
attn_sum_per_kv.append(block_sum_curr)
|
||||||
|
del attn_weights_curr, K_current
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# Step 4: Concatenate block sums and select blocks (on compute_stream)
|
||||||
|
# ================================================================
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||||
|
del attn_sum_per_kv, m_global, l_global
|
||||||
|
|
||||||
|
# Calculate q_block offset for find_blocks_chunked
|
||||||
|
# This is the number of BSA blocks before Q in the full sequence
|
||||||
|
num_blocks_per_chunk = q_reshaped_len // reshaped_block_size
|
||||||
|
current_index = k_block_num - q_block_num # Q starts at this BSA block index
|
||||||
|
|
||||||
|
with nvtx.range("xattn_find_blocks"):
|
||||||
|
mask = find_blocks_chunked(
|
||||||
|
attn_sum_concat,
|
||||||
|
current_index=current_index,
|
||||||
|
threshold=self.threshold,
|
||||||
|
num_to_choose=None,
|
||||||
|
decoding=False,
|
||||||
|
mode="prefill",
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply causal mask post-processing (same as xattn.py lines 1300-1306)
|
||||||
|
mask[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||||
|
torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0),
|
||||||
|
mask[:, :, -q_block_num:, -q_block_num:],
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
# block_sums_fine shape: [batch, heads, q_est_blocks, k_est_blocks]
|
|
||||||
# where k_est_blocks = len(available_blocks) * ratio
|
|
||||||
|
|
||||||
# Step 3: Aggregate to CPU block level (hierarchical sum)
|
# ================================================================
|
||||||
# This is mathematically equivalent to direct computation but much faster
|
# Step 5: Record density (only on layer 0)
|
||||||
batch_size_bs, num_heads_bs, q_est_blocks, k_est_blocks = block_sums_fine.shape
|
# ================================================================
|
||||||
num_cpu_blocks = len(available_blocks)
|
if layer_id == 0:
|
||||||
|
# Trim mask to valid region
|
||||||
|
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
|
valid_k_blocks = (total_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
|
mask_valid = mask[:, :, :valid_q_blocks, :valid_k_blocks]
|
||||||
|
attn_sums_valid = attn_sum_concat[:, :, :valid_q_blocks, :valid_k_blocks]
|
||||||
|
|
||||||
with nvtx.range("xattn_estimate_aggregate"):
|
# Compute causal mask for density calculation
|
||||||
# Reshape: [batch, heads, q_est, k_est] -> [batch, heads, q_est, num_cpu, ratio]
|
q_offset_blocks = valid_k_blocks - valid_q_blocks
|
||||||
block_sums_coarse = block_sums_fine.view(
|
indices = torch.arange(valid_k_blocks, device=mask.device).unsqueeze(0)
|
||||||
batch_size_bs, num_heads_bs, q_est_blocks, num_cpu_blocks, ratio
|
q_indices = torch.arange(valid_q_blocks, device=mask.device).unsqueeze(1)
|
||||||
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
|
causal_mask = indices <= (q_indices + q_offset_blocks)
|
||||||
|
|
||||||
# Sum over Q dimension to get total attention from Q chunk to each K block
|
chunk_total = causal_mask.sum().item() * mask_valid.shape[0] * mask_valid.shape[1]
|
||||||
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
|
chunk_selected = (mask_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
|
|
||||||
# Step 4: Select blocks using score + threshold (replaces mask + majority voting)
|
DensityObserver.record_counts(layer_id, chunk_selected, chunk_total)
|
||||||
# This is simpler and more direct than the original mask-based approach
|
logger.info(f"[XAttn Offload] Layer0 chunk: q_len={q_len}, k_len={total_k_len}, "
|
||||||
with nvtx.range("xattn_estimate_select"):
|
f"valid_q_blocks={valid_q_blocks}, valid_k_blocks={valid_k_blocks}, "
|
||||||
# Average scores across heads (GQA-aware: all heads contribute equally)
|
f"q_offset={q_offset_blocks}, selected={chunk_selected}, total={chunk_total}, "
|
||||||
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
|
f"density={chunk_selected/chunk_total:.4f}")
|
||||||
|
|
||||||
# Normalize to get attention distribution
|
# Debug: Save mask and attention sums for comparison
|
||||||
total_score = scores_per_block.sum()
|
if _DEBUG_SAVE_MASK:
|
||||||
if total_score > 0:
|
import os
|
||||||
score_ratio = scores_per_block / total_score
|
chunk_idx = ctx.query_chunk_idx if ctx else 0
|
||||||
else:
|
save_dir = "/home/zijie/Code/nano-vllm/results/mask_alignment"
|
||||||
# Edge case: all zeros, select all blocks
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
selected_block_ids = list(available_blocks)
|
save_path = f"{save_dir}/offload_layer{layer_id}_chunk{chunk_idx}.pt"
|
||||||
if layer_id == 0 and available_blocks:
|
torch.save({
|
||||||
self._stats_total_available_blocks += len(available_blocks)
|
"mask": mask_valid.clone().cpu(),
|
||||||
self._stats_total_selected_blocks += len(selected_block_ids)
|
"attn_sums": attn_sums_valid.clone().cpu(),
|
||||||
self._stats_num_chunks += 1
|
"q_len": q_len,
|
||||||
return selected_block_ids
|
"k_len": total_k_len,
|
||||||
|
"valid_q_blocks": valid_q_blocks,
|
||||||
|
"valid_k_blocks": valid_k_blocks,
|
||||||
|
"current_index": current_index,
|
||||||
|
"chunk_start": chunk_start,
|
||||||
|
}, save_path)
|
||||||
|
logger.info(f"[DEBUG] Saved mask to {save_path}")
|
||||||
|
|
||||||
# Sort by score (descending) and select until threshold is reached
|
del attn_sum_concat
|
||||||
sorted_indices = torch.argsort(score_ratio, descending=True)
|
|
||||||
cumsum = 0.0
|
|
||||||
selected_indices = set()
|
|
||||||
|
|
||||||
for idx in sorted_indices.tolist():
|
# ================================================================
|
||||||
selected_indices.add(idx)
|
# Step 6: Extract historical mask and aggregate to CPU blocks
|
||||||
cumsum += score_ratio[idx].item()
|
# ================================================================
|
||||||
if cumsum >= self.threshold:
|
B, H, Q_bsa, K_bsa_total = mask.shape
|
||||||
break
|
historical_k_bsa = num_historical_blocks * bsa_per_cpu
|
||||||
|
|
||||||
# Map indices back to block IDs
|
# Save mask to buffer for compute_chunked_prefill (if needed later)
|
||||||
selected_block_ids = [available_blocks[i] for i in sorted(selected_indices)]
|
if self._prefill_mask_buffer is not None and historical_k_bsa > 0:
|
||||||
|
self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa].copy_(
|
||||||
|
mask[:, :, :, :historical_k_bsa]
|
||||||
|
)
|
||||||
|
self._current_mask_q_bsa = Q_bsa
|
||||||
|
self._current_mask_k_bsa = historical_k_bsa
|
||||||
|
|
||||||
|
# Aggregate to CPU block level (union across heads, Q blocks, BSA blocks per CPU)
|
||||||
|
if num_historical_blocks == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
mask_historical = mask[:, :, :, :historical_k_bsa]
|
||||||
|
mask_per_cpu = mask_historical.view(B, H, Q_bsa, num_historical_blocks, bsa_per_cpu)
|
||||||
|
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu]
|
||||||
|
|
||||||
|
selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist()
|
||||||
|
if isinstance(selected_indices, int):
|
||||||
|
selected_indices = [selected_indices]
|
||||||
|
|
||||||
|
selected_block_ids = [available_blocks[i] for i in selected_indices]
|
||||||
|
|
||||||
# Always include first block (sink) and last block for safety
|
# Always include first block (sink) and last block for safety
|
||||||
if available_blocks and available_blocks[0] not in selected_block_ids:
|
if available_blocks and available_blocks[0] not in selected_block_ids:
|
||||||
@@ -604,14 +913,20 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
self._stats_total_selected_blocks += len(selected_block_ids)
|
self._stats_total_selected_blocks += len(selected_block_ids)
|
||||||
self._stats_num_chunks += 1
|
self._stats_num_chunks += 1
|
||||||
|
|
||||||
|
# Record communication density to DensityObserver
|
||||||
|
# Comm density = selected_cpu_blocks / available_cpu_blocks
|
||||||
|
# This is different from compute density (BSA block granularity)
|
||||||
|
DensityObserver.record_comm_density(
|
||||||
|
layer_id=layer_id,
|
||||||
|
selected_cpu_blocks=len(selected_block_ids),
|
||||||
|
total_cpu_blocks=len(available_blocks),
|
||||||
|
)
|
||||||
|
|
||||||
# Log per-chunk density
|
# Log per-chunk density
|
||||||
chunk_density = len(selected_block_ids) / len(available_blocks)
|
chunk_density = len(selected_block_ids) / len(available_blocks)
|
||||||
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
|
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
|
||||||
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
|
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
|
||||||
|
|
||||||
# Free intermediate tensors to prevent memory leak
|
|
||||||
del attn_scores, block_sums_fine, block_sums_coarse, cpu_block_scores, scores_per_block
|
|
||||||
|
|
||||||
return selected_block_ids
|
return selected_block_ids
|
||||||
|
|
||||||
def compute_chunked_prefill(
|
def compute_chunked_prefill(
|
||||||
@@ -636,6 +951,10 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
2. Compute attention to current chunk
|
2. Compute attention to current chunk
|
||||||
3. Merge all results
|
3. Merge all results
|
||||||
|
|
||||||
|
Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks().
|
||||||
|
Currently we use flash_attn_with_lse for computation (supports LSE merge).
|
||||||
|
TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
|
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
|
||||||
@@ -666,6 +985,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
# Use the pre-selected blocks directly
|
# Use the pre-selected blocks directly
|
||||||
cpu_block_table = selected_blocks
|
cpu_block_table = selected_blocks
|
||||||
|
|
||||||
|
# Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks)
|
||||||
|
# Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa
|
||||||
|
# Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu
|
||||||
|
# TODO: Use this mask with BSA kernel for per-head sparse attention optimization
|
||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
with nvtx.range("xattn_compute_historical"):
|
with nvtx.range("xattn_compute_historical"):
|
||||||
load_slots = list(range(offload_engine.num_ring_slots))
|
load_slots = list(range(offload_engine.num_ring_slots))
|
||||||
|
|||||||
@@ -221,20 +221,19 @@ class Attention(nn.Module):
|
|||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
|
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
|
||||||
selected_blocks = []
|
# Always call select_blocks even for first chunk (cpu_block_table may be empty)
|
||||||
if cpu_block_table:
|
num_chunks = current_chunk_idx + 1
|
||||||
num_chunks = current_chunk_idx + 1
|
policy_ctx = PolicyContext(
|
||||||
policy_ctx = PolicyContext(
|
query_chunk_idx=current_chunk_idx,
|
||||||
query_chunk_idx=current_chunk_idx,
|
num_query_chunks=num_chunks,
|
||||||
num_query_chunks=num_chunks,
|
layer_id=self.layer_id,
|
||||||
layer_id=self.layer_id,
|
query=q, # Pass query for sparse policies that need it
|
||||||
query=q, # Pass query for sparse policies that need it
|
is_prefill=True,
|
||||||
is_prefill=True,
|
block_size=kvcache_manager.block_size,
|
||||||
block_size=kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size if cpu_block_table else 0,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
)
|
||||||
)
|
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
|
||||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||||
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
|
||||||
|
|
||||||
# [DEBUG] Verify execution path
|
# [DEBUG] Verify execution path
|
||||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||||||
@@ -320,7 +319,7 @@ class Attention(nn.Module):
|
|||||||
block_size=kvcache_manager.block_size,
|
block_size=kvcache_manager.block_size,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
)
|
)
|
||||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
|
||||||
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||||
|
|
||||||
# [DEBUG] Verify execution path
|
# [DEBUG] Verify execution path
|
||||||
|
|||||||
@@ -218,6 +218,209 @@ def softmax_fuse_block_sum_kernel_non_causal(
|
|||||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# KV Chunking Support Kernels
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def softmax_partial_stats_kernel(
|
||||||
|
In,
|
||||||
|
M_out, # max per row
|
||||||
|
L_out, # sum per row (normalized by M_out)
|
||||||
|
scale,
|
||||||
|
input_stride_0,
|
||||||
|
input_stride_1,
|
||||||
|
input_stride_2,
|
||||||
|
stats_stride_0,
|
||||||
|
stats_stride_1,
|
||||||
|
k_len,
|
||||||
|
chunk_start, # Q start position (for causal)
|
||||||
|
kv_offset, # KV chunk offset (for causal)
|
||||||
|
segment_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
is_causal: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute partial softmax statistics for a KV chunk.
|
||||||
|
|
||||||
|
For each query row, computes:
|
||||||
|
- m: max value in this chunk
|
||||||
|
- l: sum of exp(x - m) in this chunk
|
||||||
|
|
||||||
|
These can be merged across chunks using online softmax formula.
|
||||||
|
|
||||||
|
Input shape: [batch, heads, q_len, k_chunk_len]
|
||||||
|
Output shapes: M[batch, heads, q_len], L[batch, heads, q_len]
|
||||||
|
"""
|
||||||
|
block_id = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
|
batch_id = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||||
|
offs_k = tl.arange(0, segment_size)
|
||||||
|
|
||||||
|
num_iters = k_len // segment_size
|
||||||
|
|
||||||
|
# For causal: compute boundary
|
||||||
|
if is_causal:
|
||||||
|
# causal boundary: Q position where this KV chunk starts to be valid
|
||||||
|
# Q[i] can attend K[j] if i >= j
|
||||||
|
# For KV chunk at kv_offset, Q[i] can attend if i >= kv_offset
|
||||||
|
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
|
||||||
|
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
|
||||||
|
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
|
||||||
|
else:
|
||||||
|
num_iters_before_causal = num_iters
|
||||||
|
|
||||||
|
# Online softmax state
|
||||||
|
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([block_size], dtype=tl.float32)
|
||||||
|
|
||||||
|
# Input pointer
|
||||||
|
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||||
|
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||||
|
|
||||||
|
# Compute max and sum (before causal boundary)
|
||||||
|
for iter in range(0, num_iters_before_causal):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
m_local = tl.max(X, 1)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
|
||||||
|
X = X - m_new[:, None]
|
||||||
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
# Handle causal boundary
|
||||||
|
if is_causal:
|
||||||
|
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||||
|
if iter < num_iters:
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
# causal mask: Q[i] >= K[j] + kv_offset
|
||||||
|
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
|
||||||
|
X = tl.where(mask, X, -1.0e6)
|
||||||
|
m_local = tl.max(X, 1)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
|
||||||
|
X = X - m_new[:, None]
|
||||||
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
# Output pointers
|
||||||
|
m_ptr = M_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||||
|
l_ptr = L_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||||
|
|
||||||
|
offs = tl.arange(0, block_size)
|
||||||
|
tl.store(m_ptr + offs, m_i.to(M_out.type.element_ty))
|
||||||
|
tl.store(l_ptr + offs, l_i.to(L_out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def softmax_normalize_block_sum_kernel(
|
||||||
|
In,
|
||||||
|
Out,
|
||||||
|
M_global, # global max per row
|
||||||
|
L_global, # global sum per row
|
||||||
|
scale,
|
||||||
|
input_stride_0,
|
||||||
|
input_stride_1,
|
||||||
|
input_stride_2,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
output_stride_2,
|
||||||
|
stats_stride_0,
|
||||||
|
stats_stride_1,
|
||||||
|
real_q_len,
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
kv_offset, # KV chunk offset (for causal)
|
||||||
|
segment_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
is_causal: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Normalize with global stats and compute block sums for a KV chunk.
|
||||||
|
|
||||||
|
Uses pre-computed global m and l to correctly normalize softmax
|
||||||
|
across all KV chunks.
|
||||||
|
|
||||||
|
Input shape: [batch, heads, q_len, k_chunk_len]
|
||||||
|
Output shape: [batch, heads, q_blocks, k_chunk_blocks]
|
||||||
|
"""
|
||||||
|
block_id = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
|
batch_id = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||||
|
offs_k = tl.arange(0, segment_size)
|
||||||
|
|
||||||
|
num_iters = k_len // segment_size
|
||||||
|
|
||||||
|
# For causal: compute boundary
|
||||||
|
if is_causal:
|
||||||
|
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
|
||||||
|
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
|
||||||
|
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
|
||||||
|
else:
|
||||||
|
num_iters_before_causal = num_iters
|
||||||
|
|
||||||
|
# Load global stats
|
||||||
|
m_ptr = M_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||||
|
l_ptr = L_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||||
|
|
||||||
|
offs = tl.arange(0, block_size)
|
||||||
|
m_global = tl.load(m_ptr + offs).to(tl.float32)
|
||||||
|
l_global = tl.load(l_ptr + offs).to(tl.float32)
|
||||||
|
# Handle l_global = 0 (when all positions are masked)
|
||||||
|
l_global_safe = tl.where(l_global > 0, l_global, 1.0)
|
||||||
|
l_global_inv = 1.0 / l_global_safe
|
||||||
|
|
||||||
|
# Input pointer
|
||||||
|
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||||
|
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||||
|
|
||||||
|
# Output pointer
|
||||||
|
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||||
|
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||||
|
|
||||||
|
sum_mask = offs_q[:, None] < real_q_len
|
||||||
|
|
||||||
|
# Normalize and compute block sums (before causal boundary)
|
||||||
|
for iter in range(0, num_iters_before_causal):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
|
||||||
|
X = tl.where(sum_mask, X, 0)
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
X = tl.sum(X, 2)
|
||||||
|
X = tl.sum(X, 0)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
# Handle causal boundary
|
||||||
|
if is_causal:
|
||||||
|
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||||
|
if iter < num_iters:
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
# causal mask: Q[i] >= K[j] + kv_offset
|
||||||
|
causal_mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
|
||||||
|
X = tl.where(causal_mask, X, -1.0e6)
|
||||||
|
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
|
||||||
|
X = tl.where(sum_mask, X, 0)
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
X = tl.sum(X, 2)
|
||||||
|
X = tl.sum(X, 0)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
# Zero out future blocks
|
||||||
|
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||||
|
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def flat_group_gemm_fuse_reshape_kernel(
|
def flat_group_gemm_fuse_reshape_kernel(
|
||||||
Q, K, Out,
|
Q, K, Out,
|
||||||
@@ -380,6 +583,194 @@ def softmax_fuse_block_sum(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_compute_partial_stats(
|
||||||
|
attn_weights_slice: torch.Tensor,
|
||||||
|
reshaped_block_size: int,
|
||||||
|
segment_size: int,
|
||||||
|
scale: float,
|
||||||
|
chunk_start: int = 0,
|
||||||
|
kv_offset: int = 0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute partial softmax statistics for a KV chunk.
|
||||||
|
|
||||||
|
This is the first step for KV-chunked softmax computation.
|
||||||
|
For each query row, computes:
|
||||||
|
- m: max value in this chunk
|
||||||
|
- l: sum of exp(x - m) in this chunk
|
||||||
|
|
||||||
|
These partial stats can be merged across KV chunks using
|
||||||
|
`merge_softmax_stats()`, then used with `softmax_normalize_and_block_sum()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
|
||||||
|
reshaped_block_size: Block size in reshaped space
|
||||||
|
segment_size: Processing segment size
|
||||||
|
scale: Softmax scale factor
|
||||||
|
chunk_start: Q chunk start position (in reshaped space)
|
||||||
|
kv_offset: KV chunk offset (in reshaped space, for causal masking)
|
||||||
|
is_causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (m, l) where:
|
||||||
|
- m: [batch, heads, q_len] max values per row
|
||||||
|
- l: [batch, heads, q_len] partial sums per row
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||||
|
|
||||||
|
assert q_len % reshaped_block_size == 0
|
||||||
|
assert k_len % segment_size == 0
|
||||||
|
assert attn_weights_slice.stride(-1) == 1
|
||||||
|
|
||||||
|
m_out = torch.empty(
|
||||||
|
(batch_size, num_heads, q_len),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=attn_weights_slice.device
|
||||||
|
)
|
||||||
|
l_out = torch.empty(
|
||||||
|
(batch_size, num_heads, q_len),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=attn_weights_slice.device
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||||
|
|
||||||
|
softmax_partial_stats_kernel[grid](
|
||||||
|
attn_weights_slice,
|
||||||
|
m_out,
|
||||||
|
l_out,
|
||||||
|
scale,
|
||||||
|
attn_weights_slice.stride(0),
|
||||||
|
attn_weights_slice.stride(1),
|
||||||
|
attn_weights_slice.stride(2),
|
||||||
|
m_out.stride(0),
|
||||||
|
m_out.stride(1),
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
kv_offset,
|
||||||
|
segment_size,
|
||||||
|
reshaped_block_size,
|
||||||
|
is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return m_out, l_out
|
||||||
|
|
||||||
|
|
||||||
|
def merge_softmax_stats(
|
||||||
|
m_chunks: list,
|
||||||
|
l_chunks: list,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Merge partial softmax statistics from multiple KV chunks.
|
||||||
|
|
||||||
|
Uses the online softmax merging formula:
|
||||||
|
m_new = max(m1, m2)
|
||||||
|
l_new = l1 * exp(m1 - m_new) + l2 * exp(m2 - m_new)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m_chunks: List of max tensors [batch, heads, q_len] from each chunk
|
||||||
|
l_chunks: List of sum tensors [batch, heads, q_len] from each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (m_global, l_global) with same shape as inputs
|
||||||
|
"""
|
||||||
|
assert len(m_chunks) == len(l_chunks)
|
||||||
|
assert len(m_chunks) > 0
|
||||||
|
|
||||||
|
# Use log2 scale to match kernel (exp2)
|
||||||
|
LOG2E = 1.4426950408889634
|
||||||
|
|
||||||
|
m_global = m_chunks[0].clone()
|
||||||
|
l_global = l_chunks[0].clone()
|
||||||
|
|
||||||
|
for i in range(1, len(m_chunks)):
|
||||||
|
m_chunk = m_chunks[i]
|
||||||
|
l_chunk = l_chunks[i]
|
||||||
|
|
||||||
|
m_new = torch.maximum(m_global, m_chunk)
|
||||||
|
# exp2(m - m_new) = 2^(m - m_new)
|
||||||
|
l_global = l_global * torch.pow(2.0, m_global - m_new) + l_chunk * torch.pow(2.0, m_chunk - m_new)
|
||||||
|
m_global = m_new
|
||||||
|
|
||||||
|
return m_global, l_global
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_slice: torch.Tensor,
|
||||||
|
m_global: torch.Tensor,
|
||||||
|
l_global: torch.Tensor,
|
||||||
|
reshaped_block_size: int,
|
||||||
|
segment_size: int,
|
||||||
|
chunk_start: int,
|
||||||
|
real_q_len: int,
|
||||||
|
scale: float,
|
||||||
|
kv_offset: int = 0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalize with global stats and compute block sums for a KV chunk.
|
||||||
|
|
||||||
|
This is the second step for KV-chunked softmax computation.
|
||||||
|
Uses pre-computed global m and l (from `merge_softmax_stats()`)
|
||||||
|
to correctly normalize softmax values and compute block sums.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
|
||||||
|
m_global: Global max values [batch, heads, q_len]
|
||||||
|
l_global: Global sum values [batch, heads, q_len]
|
||||||
|
reshaped_block_size: Block size in reshaped space
|
||||||
|
segment_size: Processing segment size
|
||||||
|
chunk_start: Start position for this chunk (for masking)
|
||||||
|
real_q_len: Actual Q length (before padding)
|
||||||
|
scale: Softmax scale factor
|
||||||
|
kv_offset: KV chunk offset (in reshaped space, for causal masking)
|
||||||
|
is_causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Block-level attention sums [batch, heads, q_blocks, k_chunk_blocks]
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||||
|
|
||||||
|
assert q_len % reshaped_block_size == 0
|
||||||
|
assert k_len % segment_size == 0
|
||||||
|
assert segment_size % reshaped_block_size == 0
|
||||||
|
assert attn_weights_slice.stride(-1) == 1
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
||||||
|
dtype=attn_weights_slice.dtype,
|
||||||
|
device=attn_weights_slice.device
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||||
|
|
||||||
|
softmax_normalize_block_sum_kernel[grid](
|
||||||
|
attn_weights_slice,
|
||||||
|
output,
|
||||||
|
m_global,
|
||||||
|
l_global,
|
||||||
|
scale,
|
||||||
|
attn_weights_slice.stride(0),
|
||||||
|
attn_weights_slice.stride(1),
|
||||||
|
attn_weights_slice.stride(2),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
output.stride(2),
|
||||||
|
m_global.stride(0),
|
||||||
|
m_global.stride(1),
|
||||||
|
real_q_len,
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
kv_offset,
|
||||||
|
segment_size,
|
||||||
|
reshaped_block_size,
|
||||||
|
is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def flat_group_gemm_fuse_reshape(
|
def flat_group_gemm_fuse_reshape(
|
||||||
query_states: torch.Tensor,
|
query_states: torch.Tensor,
|
||||||
key_states: torch.Tensor,
|
key_states: torch.Tensor,
|
||||||
|
|||||||
327
nanovllm/utils/density_observer.py
Normal file
327
nanovllm/utils/density_observer.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
"""
|
||||||
|
DensityObserver - Sparse Attention Density 统计 Observer。
|
||||||
|
|
||||||
|
统计两种 density:
|
||||||
|
1. Compute Density (计算密度): 基于 BSA block size (128)
|
||||||
|
- density = selected_bsa_blocks / total_causal_bsa_blocks
|
||||||
|
- GPU-only 和 Offload 模式应该一致
|
||||||
|
|
||||||
|
2. Communication Density (通信密度): 基于 CPU block size (如 4096)
|
||||||
|
- comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||||
|
- 仅用于 Offload 模式,由于粒度更粗,必然 >= compute density
|
||||||
|
|
||||||
|
统计位置:
|
||||||
|
- GPU-only: xattn_bsa.py compute_prefill() - 只记录 compute density
|
||||||
|
- Offload: xattn_bsa.py select_blocks() - 记录两种 density
|
||||||
|
|
||||||
|
对于 Offload 模式的 Density 计算:
|
||||||
|
- 不是简单的 avg 或 min
|
||||||
|
- 而是 sum(selected) / sum(total),正确处理不同 chunk 大小的权重
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from nanovllm.utils.observer import Observer
|
||||||
|
|
||||||
|
|
||||||
|
class DensityObserver(Observer):
|
||||||
|
"""
|
||||||
|
Sparse Attention Density Observer。
|
||||||
|
|
||||||
|
记录每层的 density,用于验证 GPU-only 和 Offload 模式的一致性。
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
DensityObserver.enable()
|
||||||
|
DensityObserver.complete_reset()
|
||||||
|
# ... run inference ...
|
||||||
|
DensityObserver.record(layer_id, mask, causal=True)
|
||||||
|
# 或者使用累积模式 (offload):
|
||||||
|
DensityObserver.record_counts(layer_id, selected, total)
|
||||||
|
# ...
|
||||||
|
DensityObserver.print_summary()
|
||||||
|
"""
|
||||||
|
|
||||||
|
_enabled: bool = False # 默认禁用
|
||||||
|
|
||||||
|
# 每层的 compute density 记录 (BSA block 粒度)
|
||||||
|
# key: layer_id, value: list of density values (每次 prefill chunk 一个)
|
||||||
|
_layer_densities: Dict[int, List[float]] = {}
|
||||||
|
|
||||||
|
# 每层的 communication density 记录 (CPU block 粒度,仅 offload 模式)
|
||||||
|
_layer_comm_densities: Dict[int, List[float]] = {}
|
||||||
|
|
||||||
|
# 累积模式: 记录 selected/total counts (用于 offload 模式)
|
||||||
|
# 这样可以在所有 chunks 完成后正确计算 density = sum(selected) / sum(total)
|
||||||
|
_layer_selected_counts: Dict[int, List[int]] = {}
|
||||||
|
_layer_total_counts: Dict[int, List[int]] = {}
|
||||||
|
|
||||||
|
# Mask shape 记录 (用于调试)
|
||||||
|
_last_q_blocks: int = 0
|
||||||
|
_last_k_blocks: int = 0
|
||||||
|
|
||||||
|
# 模式标记
|
||||||
|
_mode: str = "unknown" # "gpu_only" or "offload"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_mode(cls, mode: str) -> None:
|
||||||
|
"""设置当前模式 (gpu_only / offload)"""
|
||||||
|
cls._mode = mode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record(
|
||||||
|
cls,
|
||||||
|
layer_id: int,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
causal: bool = True,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
记录一层的 density (适用于 GPU-only 模式)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: 层 ID
|
||||||
|
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
|
||||||
|
causal: 是否考虑 causal mask (只计算下三角)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
density 值
|
||||||
|
"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
density = cls._compute_density(mask, causal)
|
||||||
|
|
||||||
|
# 记录
|
||||||
|
if layer_id not in cls._layer_densities:
|
||||||
|
cls._layer_densities[layer_id] = []
|
||||||
|
cls._layer_densities[layer_id].append(density)
|
||||||
|
|
||||||
|
# 记录 mask shape
|
||||||
|
cls._last_q_blocks = mask.shape[2]
|
||||||
|
cls._last_k_blocks = mask.shape[3]
|
||||||
|
|
||||||
|
return density
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record_counts(
|
||||||
|
cls,
|
||||||
|
layer_id: int,
|
||||||
|
selected_blocks: int,
|
||||||
|
total_blocks: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
记录一层的 selected/total block counts (适用于 offload 累积模式)。
|
||||||
|
|
||||||
|
使用累积计数而不是直接计算 density,这样在所有 chunks 处理完后可以正确计算:
|
||||||
|
overall_density = sum(selected) / sum(total)
|
||||||
|
|
||||||
|
这比 avg(density) 更准确,因为不同 chunk 的 Q 和 K 长度不同。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: 层 ID
|
||||||
|
selected_blocks: 这个 chunk 选中的 blocks 数量
|
||||||
|
total_blocks: 这个 chunk 的 total possible blocks 数量
|
||||||
|
"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 初始化列表
|
||||||
|
if layer_id not in cls._layer_selected_counts:
|
||||||
|
cls._layer_selected_counts[layer_id] = []
|
||||||
|
if layer_id not in cls._layer_total_counts:
|
||||||
|
cls._layer_total_counts[layer_id] = []
|
||||||
|
|
||||||
|
# 累积记录
|
||||||
|
cls._layer_selected_counts[layer_id].append(selected_blocks)
|
||||||
|
cls._layer_total_counts[layer_id].append(total_blocks)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record_comm_density(
|
||||||
|
cls,
|
||||||
|
layer_id: int,
|
||||||
|
selected_cpu_blocks: int,
|
||||||
|
total_cpu_blocks: int,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
记录一层的 communication density (CPU block 粒度)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: 层 ID
|
||||||
|
selected_cpu_blocks: 选中的 CPU blocks 数量
|
||||||
|
total_cpu_blocks: 总 CPU blocks 数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
communication density 值
|
||||||
|
"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if total_cpu_blocks == 0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||||
|
|
||||||
|
# 记录
|
||||||
|
if layer_id not in cls._layer_comm_densities:
|
||||||
|
cls._layer_comm_densities[layer_id] = []
|
||||||
|
cls._layer_comm_densities[layer_id].append(comm_density)
|
||||||
|
|
||||||
|
return comm_density
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float:
|
||||||
|
"""计算 mask 的 density"""
|
||||||
|
batch, heads, q_blocks, k_blocks = mask.shape
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
# 只计算下三角区域
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool)
|
||||||
|
)
|
||||||
|
total_blocks = causal_mask.sum().item() * batch * heads
|
||||||
|
selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
|
else:
|
||||||
|
total_blocks = mask.numel()
|
||||||
|
selected_blocks = mask.sum().item()
|
||||||
|
|
||||||
|
if total_blocks == 0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return selected_blocks / total_blocks
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def complete_reset(cls) -> None:
|
||||||
|
"""重置所有统计"""
|
||||||
|
cls._layer_densities = {}
|
||||||
|
cls._layer_comm_densities = {}
|
||||||
|
cls._layer_selected_counts = {}
|
||||||
|
cls._layer_total_counts = {}
|
||||||
|
cls._last_q_blocks = 0
|
||||||
|
cls._last_k_blocks = 0
|
||||||
|
cls._mode = "unknown"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_per_layer_density(cls) -> Dict[int, float]:
|
||||||
|
"""
|
||||||
|
获取每层的 density。
|
||||||
|
|
||||||
|
对于累积模式 (offload): density = sum(selected) / sum(total)
|
||||||
|
对于直接记录模式 (gpu_only): density = avg(density_values)
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# 优先使用累积模式 (offload)
|
||||||
|
if cls._layer_selected_counts:
|
||||||
|
for layer_id in cls._layer_selected_counts:
|
||||||
|
selected_list = cls._layer_selected_counts.get(layer_id, [])
|
||||||
|
total_list = cls._layer_total_counts.get(layer_id, [])
|
||||||
|
total_selected = sum(selected_list)
|
||||||
|
total_total = sum(total_list)
|
||||||
|
if total_total > 0:
|
||||||
|
result[layer_id] = total_selected / total_total
|
||||||
|
else:
|
||||||
|
# 直接记录模式 (gpu_only)
|
||||||
|
for layer_id, densities in cls._layer_densities.items():
|
||||||
|
if densities:
|
||||||
|
result[layer_id] = sum(densities) / len(densities)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_overall_density(cls) -> float:
|
||||||
|
"""
|
||||||
|
获取所有层的总体 compute density。
|
||||||
|
|
||||||
|
对于累积模式 (offload): density = sum(all_selected) / sum(all_total)
|
||||||
|
对于直接记录模式 (gpu_only): density = avg(all_density_values)
|
||||||
|
|
||||||
|
注意: 总体 density 不是简单的 avg(per_layer_density),
|
||||||
|
而是 sum(all_selected) / sum(all_total),这样可以正确处理权重。
|
||||||
|
"""
|
||||||
|
# 优先使用累积模式 (offload)
|
||||||
|
if cls._layer_selected_counts:
|
||||||
|
total_selected = 0
|
||||||
|
total_total = 0
|
||||||
|
for layer_id in cls._layer_selected_counts:
|
||||||
|
total_selected += sum(cls._layer_selected_counts[layer_id])
|
||||||
|
total_total += sum(cls._layer_total_counts.get(layer_id, []))
|
||||||
|
if total_total > 0:
|
||||||
|
return total_selected / total_total
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 直接记录模式 (gpu_only)
|
||||||
|
all_densities = []
|
||||||
|
for densities in cls._layer_densities.values():
|
||||||
|
all_densities.extend(densities)
|
||||||
|
if not all_densities:
|
||||||
|
return 0.0
|
||||||
|
return sum(all_densities) / len(all_densities)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_overall_comm_density(cls) -> float:
|
||||||
|
"""获取所有层的平均 communication density"""
|
||||||
|
all_densities = []
|
||||||
|
for densities in cls._layer_comm_densities.values():
|
||||||
|
all_densities.extend(densities)
|
||||||
|
if not all_densities:
|
||||||
|
return 0.0
|
||||||
|
return sum(all_densities) / len(all_densities)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_per_layer_comm_density(cls) -> Dict[int, float]:
|
||||||
|
"""
|
||||||
|
获取每层的 communication density (CPU block 粒度)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[layer_id, avg_comm_density]
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for layer_id, densities in cls._layer_comm_densities.items():
|
||||||
|
if densities:
|
||||||
|
result[layer_id] = sum(densities) / len(densities)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_summary(cls) -> dict:
|
||||||
|
"""返回统计摘要"""
|
||||||
|
per_layer = cls.get_per_layer_density()
|
||||||
|
per_layer_comm = cls.get_per_layer_comm_density()
|
||||||
|
return {
|
||||||
|
"mode": cls._mode,
|
||||||
|
"overall_compute_density": cls.get_overall_density(),
|
||||||
|
"overall_comm_density": cls.get_overall_comm_density(),
|
||||||
|
"per_layer_compute_density": per_layer,
|
||||||
|
"per_layer_comm_density": per_layer_comm,
|
||||||
|
"num_layers": len(per_layer),
|
||||||
|
"last_mask_shape": {
|
||||||
|
"q_blocks": cls._last_q_blocks,
|
||||||
|
"k_blocks": cls._last_k_blocks,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_density(cls) -> Tuple[int, float]:
|
||||||
|
"""获取最低 density 的层和值"""
|
||||||
|
per_layer = cls.get_per_layer_density()
|
||||||
|
if not per_layer:
|
||||||
|
return -1, 0.0
|
||||||
|
min_layer = min(per_layer, key=per_layer.get)
|
||||||
|
return min_layer, per_layer[min_layer]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def print_summary(cls) -> None:
|
||||||
|
"""打印人类可读的摘要"""
|
||||||
|
per_layer = cls.get_per_layer_density()
|
||||||
|
overall = cls.get_overall_density()
|
||||||
|
min_layer, min_density = cls.get_min_density()
|
||||||
|
overall_comm = cls.get_overall_comm_density()
|
||||||
|
|
||||||
|
print(f"[DensityObserver] Mode: {cls._mode}")
|
||||||
|
print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
|
||||||
|
if overall_comm > 0:
|
||||||
|
# Offload mode: show both densities with explanation
|
||||||
|
print(f" Comm density: {overall_comm:.4f} (CPU block granularity)")
|
||||||
|
print(f" Savings ratio: {1 - overall_comm:.1%} H2D transfer reduction")
|
||||||
|
print(f" Num layers: {len(per_layer)}")
|
||||||
|
# 输出 layer 0 的 density 用于对比
|
||||||
|
if 0 in per_layer:
|
||||||
|
print(f" Layer 0 density: {per_layer[0]:.6f}")
|
||||||
@@ -1,314 +0,0 @@
|
|||||||
"""
|
|
||||||
Benchmark: block_size impact on XAttention estimate phase performance.
|
|
||||||
|
|
||||||
This script tests how different block_size values affect the performance of:
|
|
||||||
1. flat_group_gemm_fuse_reshape (estimate GEMM)
|
|
||||||
2. softmax_fuse_block_sum (estimate softmax + block aggregation)
|
|
||||||
|
|
||||||
Key insight: The current select_blocks uses global kvcache_block_size for estimation,
|
|
||||||
which may not be optimal for the Triton kernels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import time
|
|
||||||
import math
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Test configurations
|
|
||||||
BLOCK_SIZES = [64, 128, 256, 512] # BSA optimal is 128
|
|
||||||
STRIDE = 8
|
|
||||||
NUM_WARMUP = 3
|
|
||||||
NUM_RUNS = 10
|
|
||||||
|
|
||||||
# Model dimensions (Llama-3.1-8B-Instruct)
|
|
||||||
NUM_HEADS = 32
|
|
||||||
NUM_KV_HEADS = 8
|
|
||||||
HEAD_DIM = 128
|
|
||||||
|
|
||||||
# Context lengths to test
|
|
||||||
CONTEXT_LENGTHS = [16384, 32768, 65536] # 16K, 32K, 64K
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Benchmark Functions
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def benchmark_flat_group_gemm(Q, K, stride, block_size, num_warmup=3, num_runs=10):
|
|
||||||
"""
|
|
||||||
Benchmark flat_group_gemm_fuse_reshape kernel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Q: [batch, heads, q_len, head_dim]
|
|
||||||
K: [batch, heads, k_len, head_dim]
|
|
||||||
stride: Stride for reshape
|
|
||||||
block_size: Block size (affects alignment requirements)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(avg_time_ms, output_tensor)
|
|
||||||
"""
|
|
||||||
q_len = Q.shape[2]
|
|
||||||
k_len = K.shape[2]
|
|
||||||
|
|
||||||
# Compute reshaped dimensions
|
|
||||||
reshaped_q_len = q_len // stride
|
|
||||||
reshaped_k_len = k_len // stride
|
|
||||||
reshaped_block_size = block_size // stride
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
_ = flat_group_gemm_fuse_reshape(
|
|
||||||
Q, K, stride,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=reshaped_q_len,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(num_runs):
|
|
||||||
output = flat_group_gemm_fuse_reshape(
|
|
||||||
Q, K, stride,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=reshaped_q_len,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
avg_time_ms = (end - start) / num_runs * 1000
|
|
||||||
return avg_time_ms, output
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_softmax_fuse_block_sum(attn_weights, reshaped_block_size, num_warmup=3, num_runs=10):
|
|
||||||
"""
|
|
||||||
Benchmark softmax_fuse_block_sum kernel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attn_weights: [batch, heads, q_len, k_len] attention weights
|
|
||||||
reshaped_block_size: Block size in reshaped space
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
avg_time_ms
|
|
||||||
"""
|
|
||||||
batch_size, num_heads, q_len, k_len = attn_weights.shape
|
|
||||||
head_dim = HEAD_DIM
|
|
||||||
stride = STRIDE
|
|
||||||
norm = 1.0
|
|
||||||
|
|
||||||
# segment_size must divide k_len and be >= reshaped_block_size
|
|
||||||
segment_size = min(4096, reshaped_block_size)
|
|
||||||
|
|
||||||
# Ensure k_len is divisible by segment_size
|
|
||||||
if k_len % segment_size != 0:
|
|
||||||
# Pad k_len
|
|
||||||
pad_size = segment_size - (k_len % segment_size)
|
|
||||||
attn_weights = torch.nn.functional.pad(attn_weights, (0, pad_size), value=0)
|
|
||||||
k_len = attn_weights.shape[3]
|
|
||||||
|
|
||||||
# Scale factor
|
|
||||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
_ = softmax_fuse_block_sum(
|
|
||||||
attn_weights,
|
|
||||||
reshaped_block_size,
|
|
||||||
segment_size,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=q_len,
|
|
||||||
real_q_len=q_len,
|
|
||||||
scale=scale,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(num_runs):
|
|
||||||
output = softmax_fuse_block_sum(
|
|
||||||
attn_weights,
|
|
||||||
reshaped_block_size,
|
|
||||||
segment_size,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=q_len,
|
|
||||||
real_q_len=q_len,
|
|
||||||
scale=scale,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
avg_time_ms = (end - start) / num_runs * 1000
|
|
||||||
return avg_time_ms
|
|
||||||
|
|
||||||
|
|
||||||
def run_estimate_benchmark(q_len, k_len, block_size, stride=STRIDE):
|
|
||||||
"""
|
|
||||||
Run full estimate benchmark for given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q_len: Query length
|
|
||||||
k_len: Key length (usually same as q_len for current chunk scenario)
|
|
||||||
block_size: Block size to test
|
|
||||||
stride: Stride for reshape
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict with timing results
|
|
||||||
"""
|
|
||||||
# Create random Q and K tensors
|
|
||||||
# Shape: [batch, heads, seq_len, head_dim]
|
|
||||||
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
|
|
||||||
K = torch.randn(1, NUM_HEADS, k_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
|
|
||||||
|
|
||||||
reshaped_block_size = block_size // stride
|
|
||||||
reshaped_q_len = q_len // stride
|
|
||||||
reshaped_k_len = k_len // stride
|
|
||||||
|
|
||||||
# Benchmark GEMM
|
|
||||||
gemm_time, attn_weights = benchmark_flat_group_gemm(
|
|
||||||
Q, K, stride, block_size,
|
|
||||||
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
|
|
||||||
)
|
|
||||||
|
|
||||||
# Benchmark softmax + block sum
|
|
||||||
softmax_time = benchmark_softmax_fuse_block_sum(
|
|
||||||
attn_weights, reshaped_block_size,
|
|
||||||
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
del Q, K, attn_weights
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"q_len": q_len,
|
|
||||||
"k_len": k_len,
|
|
||||||
"block_size": block_size,
|
|
||||||
"reshaped_block_size": reshaped_block_size,
|
|
||||||
"gemm_time_ms": gemm_time,
|
|
||||||
"softmax_time_ms": softmax_time,
|
|
||||||
"total_time_ms": gemm_time + softmax_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Benchmark
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser(description="Benchmark block_size impact on estimate phase")
|
|
||||||
parser.add_argument("--gpu", type=int, default=0, help="GPU to use")
|
|
||||||
parser.add_argument("--ctx-len", type=int, default=None,
|
|
||||||
help="Single context length to test (default: test multiple)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Set GPU
|
|
||||||
torch.cuda.set_device(args.gpu)
|
|
||||||
device_name = torch.cuda.get_device_name(args.gpu)
|
|
||||||
print(f"Using GPU {args.gpu}: {device_name}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Determine context lengths to test
|
|
||||||
if args.ctx_len:
|
|
||||||
context_lengths = [args.ctx_len]
|
|
||||||
else:
|
|
||||||
context_lengths = CONTEXT_LENGTHS
|
|
||||||
|
|
||||||
print("=" * 80)
|
|
||||||
print("Benchmark: block_size impact on XAttention estimate phase")
|
|
||||||
print("=" * 80)
|
|
||||||
print(f"Configuration:")
|
|
||||||
print(f" NUM_HEADS: {NUM_HEADS}")
|
|
||||||
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
|
|
||||||
print(f" HEAD_DIM: {HEAD_DIM}")
|
|
||||||
print(f" STRIDE: {STRIDE}")
|
|
||||||
print(f" BLOCK_SIZES: {BLOCK_SIZES}")
|
|
||||||
print(f" NUM_WARMUP: {NUM_WARMUP}")
|
|
||||||
print(f" NUM_RUNS: {NUM_RUNS}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
all_results = []
|
|
||||||
|
|
||||||
for ctx_len in context_lengths:
|
|
||||||
print(f"\n{'='*80}")
|
|
||||||
print(f"Context Length: {ctx_len // 1024}K ({ctx_len} tokens)")
|
|
||||||
print(f"{'='*80}")
|
|
||||||
|
|
||||||
# Pad to alignment
|
|
||||||
alignment = STRIDE * 128 # Triton BLOCK_M requirement
|
|
||||||
padded_len = ((ctx_len + alignment - 1) // alignment) * alignment
|
|
||||||
print(f"Padded to: {padded_len} tokens (alignment={alignment})")
|
|
||||||
print()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for block_size in BLOCK_SIZES:
|
|
||||||
print(f"Testing block_size={block_size} (reshaped={block_size // STRIDE})...", end=" ")
|
|
||||||
try:
|
|
||||||
result = run_estimate_benchmark(padded_len, padded_len, block_size)
|
|
||||||
results.append(result)
|
|
||||||
print(f"GEMM={result['gemm_time_ms']:.2f}ms, "
|
|
||||||
f"Softmax={result['softmax_time_ms']:.2f}ms, "
|
|
||||||
f"Total={result['total_time_ms']:.2f}ms")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if results:
|
|
||||||
all_results.extend(results)
|
|
||||||
|
|
||||||
# Print summary table for this context length
|
|
||||||
print(f"\n--- Summary for {ctx_len // 1024}K context ---")
|
|
||||||
print(f"{'block_size':>12} {'reshaped':>10} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12} {'Speedup':>10}")
|
|
||||||
print("-" * 74)
|
|
||||||
|
|
||||||
baseline_total = results[0]["total_time_ms"]
|
|
||||||
for r in results:
|
|
||||||
speedup = baseline_total / r["total_time_ms"]
|
|
||||||
print(f"{r['block_size']:>12} {r['reshaped_block_size']:>10} "
|
|
||||||
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
|
|
||||||
f"{r['total_time_ms']:>12.2f} {speedup:>9.2f}x")
|
|
||||||
|
|
||||||
# Final summary across all context lengths
|
|
||||||
if len(context_lengths) > 1:
|
|
||||||
print(f"\n{'='*80}")
|
|
||||||
print("OVERALL SUMMARY")
|
|
||||||
print(f"{'='*80}")
|
|
||||||
print(f"{'ctx_len':>10} {'block_size':>12} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12}")
|
|
||||||
print("-" * 64)
|
|
||||||
for r in all_results:
|
|
||||||
print(f"{r['q_len']//1024:>9}K {r['block_size']:>12} "
|
|
||||||
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
|
|
||||||
f"{r['total_time_ms']:>12.2f}")
|
|
||||||
|
|
||||||
# Find optimal block_size for softmax
|
|
||||||
print(f"\n{'='*80}")
|
|
||||||
print("ANALYSIS: Optimal block_size for softmax_fuse_block_sum")
|
|
||||||
print(f"{'='*80}")
|
|
||||||
|
|
||||||
for ctx_len in context_lengths:
|
|
||||||
ctx_results = [r for r in all_results if r["q_len"] == ((ctx_len + STRIDE * 128 - 1) // (STRIDE * 128)) * (STRIDE * 128)]
|
|
||||||
if ctx_results:
|
|
||||||
best = min(ctx_results, key=lambda x: x["softmax_time_ms"])
|
|
||||||
worst = max(ctx_results, key=lambda x: x["softmax_time_ms"])
|
|
||||||
improvement = worst["softmax_time_ms"] / best["softmax_time_ms"]
|
|
||||||
print(f"Context {ctx_len // 1024}K:")
|
|
||||||
print(f" Best: block_size={best['block_size']} ({best['softmax_time_ms']:.2f}ms)")
|
|
||||||
print(f" Worst: block_size={worst['block_size']} ({worst['softmax_time_ms']:.2f}ms)")
|
|
||||||
print(f" Potential improvement: {improvement:.2f}x")
|
|
||||||
|
|
||||||
print("\nbench_estimate_block_size: DONE")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,757 +0,0 @@
|
|||||||
"""
|
|
||||||
Custom Qwen3 implementation using only torch and transformers.
|
|
||||||
This file provides a clean reference implementation for understanding the model computation graph.
|
|
||||||
|
|
||||||
Computation Graph:
|
|
||||||
==================
|
|
||||||
|
|
||||||
Input: token_ids [batch, seq_len]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────┐
|
|
||||||
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
|
|
||||||
└─────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
hidden_states [batch, seq_len, hidden_size]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────────────────────────────────────────────────┐
|
|
||||||
│ Decoder Layer (x N) │
|
|
||||||
│ ┌───────────────────────────────────────────────────┐ │
|
|
||||||
│ │ Self Attention Block │ │
|
|
||||||
│ │ │ │
|
|
||||||
│ │ input_layernorm (RMSNorm) │ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ ▼ │ │
|
|
||||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
|
||||||
│ │ │ Qwen3Attention │ │ │
|
|
||||||
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
|
|
||||||
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
|
|
||||||
│ │ │ V = v_proj(x) → reshape │ │ │
|
|
||||||
│ │ │ │ │ │ │
|
|
||||||
│ │ │ ▼ │ │ │
|
|
||||||
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
|
|
||||||
│ │ │ │ │ │ │
|
|
||||||
│ │ │ ▼ │ │ │
|
|
||||||
│ │ │ attn_output = attention(Q, K, V) │ │ │
|
|
||||||
│ │ │ │ │ │ │
|
|
||||||
│ │ │ ▼ │ │ │
|
|
||||||
│ │ │ output = o_proj(attn_output) │ │ │
|
|
||||||
│ │ └─────────────────────────────────────────────┘ │ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ ▼ │ │
|
|
||||||
│ │ hidden_states = residual + attn_output │ │
|
|
||||||
│ └───────────────────────────────────────────────────┘ │
|
|
||||||
│ │ │
|
|
||||||
│ ▼ │
|
|
||||||
│ ┌───────────────────────────────────────────────────┐ │
|
|
||||||
│ │ MLP Block │ │
|
|
||||||
│ │ │ │
|
|
||||||
│ │ post_attention_layernorm (RMSNorm) │ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ ▼ │ │
|
|
||||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
|
||||||
│ │ │ Qwen3MLP │ │ │
|
|
||||||
│ │ │ gate = gate_proj(x) │ │ │
|
|
||||||
│ │ │ up = up_proj(x) │ │ │
|
|
||||||
│ │ │ output = down_proj(silu(gate) * up) │ │ │
|
|
||||||
│ │ └─────────────────────────────────────────────┘ │ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ ▼ │ │
|
|
||||||
│ │ hidden_states = residual + mlp_output │ │
|
|
||||||
│ └───────────────────────────────────────────────────┘ │
|
|
||||||
└─────────────────────────────────────────────────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────┐
|
|
||||||
│ norm │ final RMSNorm
|
|
||||||
└─────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────┐
|
|
||||||
│ lm_head │ [hidden_size, vocab_size]
|
|
||||||
└─────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
logits [batch, seq_len, vocab_size]
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional, Tuple, List
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3RMSNorm(nn.Module):
|
|
||||||
"""RMSNorm implementation."""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
input_dtype = x.dtype
|
|
||||||
x = x.float()
|
|
||||||
variance = x.pow(2).mean(-1, keepdim=True)
|
|
||||||
x = x * torch.rsqrt(variance + self.eps)
|
|
||||||
return self.weight * x.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3RotaryEmbedding(nn.Module):
|
|
||||||
"""Rotary Position Embedding (RoPE)."""
|
|
||||||
|
|
||||||
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.base = base
|
|
||||||
|
|
||||||
# Compute inverse frequencies
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
|
|
||||||
position_ids: Position indices [batch, seq_len]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
cos, sin: [batch, seq_len, head_dim]
|
|
||||||
"""
|
|
||||||
# inv_freq: [dim/2]
|
|
||||||
# position_ids: [batch, seq_len]
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
|
|
||||||
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
|
|
||||||
|
|
||||||
# freqs: [batch, dim/2, seq_len]
|
|
||||||
freqs = inv_freq_expanded @ position_ids_expanded
|
|
||||||
# freqs: [batch, seq_len, dim/2]
|
|
||||||
freqs = freqs.transpose(1, 2)
|
|
||||||
|
|
||||||
# Duplicate for full head_dim: [batch, seq_len, dim]
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
|
|
||||||
cos = emb.cos().to(x.dtype)
|
|
||||||
sin = emb.sin().to(x.dtype)
|
|
||||||
|
|
||||||
return cos, sin
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Rotate half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Apply rotary position embeddings to Q and K.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: [batch, num_heads, seq_len, head_dim]
|
|
||||||
k: [batch, num_kv_heads, seq_len, head_dim]
|
|
||||||
cos: [batch, seq_len, head_dim]
|
|
||||||
sin: [batch, seq_len, head_dim]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
q_embed, k_embed with same shapes as inputs
|
|
||||||
"""
|
|
||||||
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
|
|
||||||
cos = cos.unsqueeze(1)
|
|
||||||
sin = sin.unsqueeze(1)
|
|
||||||
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Attention(nn.Module):
|
|
||||||
"""
|
|
||||||
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
|
|
||||||
|
|
||||||
Data Flow:
|
|
||||||
---------
|
|
||||||
hidden_states [batch, seq_len, hidden_size]
|
|
||||||
│
|
|
||||||
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
|
|
||||||
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
|
|
||||||
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
apply_rotary_pos_emb(Q, K)
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position_embeddings: int = 32768,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
rms_norm_eps: float = 1e-6,
|
|
||||||
layer_idx: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_heads = num_attention_heads
|
|
||||||
self.num_kv_heads = num_key_value_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.num_kv_groups = num_attention_heads // num_key_value_heads
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
|
|
||||||
# Scaling factor
|
|
||||||
self.scaling = head_dim ** -0.5
|
|
||||||
|
|
||||||
# QKV projections
|
|
||||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
|
|
||||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
|
||||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
|
||||||
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
|
|
||||||
|
|
||||||
# QK normalization (Qwen3 specific)
|
|
||||||
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
|
||||||
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
|
||||||
|
|
||||||
# Rotary embeddings
|
|
||||||
self.rotary_emb = Qwen3RotaryEmbedding(
|
|
||||||
head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
base=rope_theta,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_qkv: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: [batch, seq_len, hidden_size]
|
|
||||||
position_ids: [batch, seq_len]
|
|
||||||
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
|
|
||||||
past_key_value: (k_cache, v_cache) from previous steps
|
|
||||||
use_cache: Whether to return updated cache
|
|
||||||
output_qkv: Whether to output Q, K, V tensors for debugging
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
output: [batch, seq_len, hidden_size]
|
|
||||||
past_key_value: Updated cache (if use_cache=True)
|
|
||||||
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
|
|
||||||
"""
|
|
||||||
batch_size, seq_len, _ = hidden_states.shape
|
|
||||||
|
|
||||||
# === QKV Projections ===
|
|
||||||
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
|
|
||||||
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
|
||||||
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
|
||||||
|
|
||||||
# Reshape to [batch, seq_len, num_heads, head_dim]
|
|
||||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
|
||||||
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
|
||||||
|
|
||||||
# === QK Normalization (Qwen3 specific) ===
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
# Transpose to [batch, num_heads, seq_len, head_dim]
|
|
||||||
q = q.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
|
|
||||||
# === Rotary Position Embeddings ===
|
|
||||||
cos, sin = self.rotary_emb(v, position_ids)
|
|
||||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
||||||
|
|
||||||
# === KV Cache Update ===
|
|
||||||
if past_key_value is not None:
|
|
||||||
k_cache, v_cache = past_key_value
|
|
||||||
k = torch.cat([k_cache, k], dim=2)
|
|
||||||
v = torch.cat([v_cache, v], dim=2)
|
|
||||||
|
|
||||||
new_past_key_value = (k, v) if use_cache else None
|
|
||||||
|
|
||||||
# === Grouped Query Attention (expand KV heads if needed) ===
|
|
||||||
if self.num_kv_groups > 1:
|
|
||||||
# Repeat KV for each query group
|
|
||||||
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
|
||||||
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
|
||||||
|
|
||||||
# === Attention Computation (using SDPA for memory efficiency) ===
|
|
||||||
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
|
|
||||||
# is_causal only works when q_len == kv_len (prefill), not during decode
|
|
||||||
q_len, kv_len = q.shape[2], k.shape[2]
|
|
||||||
is_causal = (q_len == kv_len) and (q_len > 1)
|
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
q, k, v,
|
|
||||||
attn_mask=None,
|
|
||||||
dropout_p=0.0,
|
|
||||||
is_causal=is_causal,
|
|
||||||
scale=self.scaling,
|
|
||||||
) # [batch, num_heads, seq_len, head_dim]
|
|
||||||
|
|
||||||
# === Output Projection ===
|
|
||||||
# Transpose back and reshape
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
|
|
||||||
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
|
|
||||||
output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
# Optional QKV output for debugging
|
|
||||||
qkv_dict = None
|
|
||||||
if output_qkv:
|
|
||||||
qkv_dict = {
|
|
||||||
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
|
|
||||||
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
|
|
||||||
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
|
|
||||||
}
|
|
||||||
|
|
||||||
return output, new_past_key_value, qkv_dict
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MLP(nn.Module):
|
|
||||||
"""
|
|
||||||
Qwen3 MLP with SwiGLU activation.
|
|
||||||
|
|
||||||
Data Flow:
|
|
||||||
---------
|
|
||||||
hidden_states [batch, seq_len, hidden_size]
|
|
||||||
│
|
|
||||||
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
|
|
||||||
│
|
|
||||||
└──► up_proj ──► up [batch, seq_len, intermediate_size]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
silu(gate) * up
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
down_proj ──► output [batch, seq_len, hidden_size]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
|
||||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
|
||||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
gate = self.gate_proj(x)
|
|
||||||
up = self.up_proj(x)
|
|
||||||
return self.down_proj(F.silu(gate) * up)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3DecoderLayer(nn.Module):
|
|
||||||
"""Single Qwen3 Decoder Layer."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position_embeddings: int = 32768,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
rms_norm_eps: float = 1e-6,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
mlp_bias: bool = False,
|
|
||||||
layer_idx: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
|
|
||||||
# Pre-attention LayerNorm
|
|
||||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
|
||||||
|
|
||||||
# Self-attention
|
|
||||||
self.self_attn = Qwen3Attention(
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
num_key_value_heads=num_key_value_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
rope_theta=rope_theta,
|
|
||||||
attention_bias=attention_bias,
|
|
||||||
rms_norm_eps=rms_norm_eps,
|
|
||||||
layer_idx=layer_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Post-attention LayerNorm
|
|
||||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
|
||||||
|
|
||||||
# MLP
|
|
||||||
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_qkv: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: [batch, seq_len, hidden_size]
|
|
||||||
position_ids: [batch, seq_len]
|
|
||||||
attention_mask: Causal attention mask
|
|
||||||
past_key_value: KV cache for this layer
|
|
||||||
use_cache: Whether to return updated cache
|
|
||||||
output_qkv: Whether to output Q, K, V for debugging
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
hidden_states: [batch, seq_len, hidden_size]
|
|
||||||
past_key_value: Updated cache
|
|
||||||
qkv_dict: QKV tensors (if output_qkv=True)
|
|
||||||
"""
|
|
||||||
# === Self Attention Block ===
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
attn_output, new_past_key_value, qkv_dict = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_qkv=output_qkv,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = residual + attn_output
|
|
||||||
|
|
||||||
# === MLP Block ===
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
return hidden_states, new_past_key_value, qkv_dict
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Model(nn.Module):
|
|
||||||
"""Qwen3 Transformer Model (without LM head)."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
num_hidden_layers: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position_embeddings: int = 32768,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
rms_norm_eps: float = 1e-6,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
mlp_bias: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
|
|
||||||
# Token embeddings
|
|
||||||
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
|
|
||||||
|
|
||||||
# Decoder layers
|
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
Qwen3DecoderLayer(
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
num_key_value_heads=num_key_value_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
rope_theta=rope_theta,
|
|
||||||
rms_norm_eps=rms_norm_eps,
|
|
||||||
attention_bias=attention_bias,
|
|
||||||
mlp_bias=mlp_bias,
|
|
||||||
layer_idx=i,
|
|
||||||
)
|
|
||||||
for i in range(num_hidden_layers)
|
|
||||||
])
|
|
||||||
|
|
||||||
# Final LayerNorm
|
|
||||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_qkv_layers: Optional[List[int]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_ids: [batch, seq_len]
|
|
||||||
position_ids: [batch, seq_len]
|
|
||||||
attention_mask: [batch, seq_len] or pre-computed 4D mask
|
|
||||||
past_key_values: List of (k, v) tuples for each layer
|
|
||||||
use_cache: Whether to return new cache
|
|
||||||
output_qkv_layers: List of layer indices to output QKV for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
hidden_states: [batch, seq_len, hidden_size]
|
|
||||||
new_past_key_values: Updated cache
|
|
||||||
qkv_outputs: {layer_idx: qkv_dict}
|
|
||||||
"""
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
|
|
||||||
# Embedding
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# Position IDs
|
|
||||||
if position_ids is None:
|
|
||||||
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
|
|
||||||
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
|
||||||
|
|
||||||
# Attention mask (create causal mask if not provided)
|
|
||||||
if attention_mask is None or attention_mask.dim() == 2:
|
|
||||||
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
|
|
||||||
causal_mask = torch.triu(
|
|
||||||
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
|
|
||||||
diagonal=kv_seq_len - seq_len + 1,
|
|
||||||
)
|
|
||||||
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
|
|
||||||
|
|
||||||
# Initialize cache list
|
|
||||||
new_past_key_values = [] if use_cache else None
|
|
||||||
qkv_outputs = {} if output_qkv_layers else None
|
|
||||||
|
|
||||||
# Decoder layers
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
past_kv = past_key_values[i] if past_key_values else None
|
|
||||||
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
|
|
||||||
|
|
||||||
hidden_states, new_kv, qkv_dict = layer(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_value=past_kv,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_qkv=output_qkv,
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
new_past_key_values.append(new_kv)
|
|
||||||
if qkv_dict is not None:
|
|
||||||
qkv_outputs[i] = qkv_dict
|
|
||||||
|
|
||||||
# Final norm
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states, new_past_key_values, qkv_outputs
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3ForCausalLM(nn.Module):
|
|
||||||
"""Qwen3 Model with Language Modeling head."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
num_hidden_layers: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position_embeddings: int = 32768,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
rms_norm_eps: float = 1e-6,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
mlp_bias: bool = False,
|
|
||||||
tie_word_embeddings: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
|
||||||
|
|
||||||
# Transformer model
|
|
||||||
self.model = Qwen3Model(
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
num_hidden_layers=num_hidden_layers,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
num_key_value_heads=num_key_value_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
rope_theta=rope_theta,
|
|
||||||
rms_norm_eps=rms_norm_eps,
|
|
||||||
attention_bias=attention_bias,
|
|
||||||
mlp_bias=mlp_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
# LM head
|
|
||||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_qkv_layers: Optional[List[int]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_ids: [batch, seq_len]
|
|
||||||
... (same as Qwen3Model)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logits: [batch, seq_len, vocab_size]
|
|
||||||
past_key_values: Updated KV cache
|
|
||||||
qkv_outputs: QKV tensors for specified layers
|
|
||||||
"""
|
|
||||||
hidden_states, new_past_key_values, qkv_outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_qkv_layers=output_qkv_layers,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
return logits, new_past_key_values, qkv_outputs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
|
|
||||||
"""
|
|
||||||
Load weights from a pretrained Qwen3 model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model directory containing config.json and model weights
|
|
||||||
dtype: Data type for model weights
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Initialized Qwen3ForCausalLM model
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
# Load config
|
|
||||||
config_path = os.path.join(model_path, "config.json")
|
|
||||||
with open(config_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
# Create model
|
|
||||||
model = cls(
|
|
||||||
vocab_size=config["vocab_size"],
|
|
||||||
hidden_size=config["hidden_size"],
|
|
||||||
intermediate_size=config["intermediate_size"],
|
|
||||||
num_hidden_layers=config["num_hidden_layers"],
|
|
||||||
num_attention_heads=config["num_attention_heads"],
|
|
||||||
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
|
|
||||||
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
|
|
||||||
max_position_embeddings=config.get("max_position_embeddings", 32768),
|
|
||||||
rope_theta=config.get("rope_theta", 10000.0),
|
|
||||||
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
|
|
||||||
attention_bias=config.get("attention_bias", False),
|
|
||||||
mlp_bias=config.get("mlp_bias", False),
|
|
||||||
tie_word_embeddings=config.get("tie_word_embeddings", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load weights
|
|
||||||
weight_files = sorted([
|
|
||||||
f for f in os.listdir(model_path)
|
|
||||||
if f.endswith(".safetensors")
|
|
||||||
])
|
|
||||||
|
|
||||||
state_dict = {}
|
|
||||||
for wf in weight_files:
|
|
||||||
state_dict.update(load_file(os.path.join(model_path, wf)))
|
|
||||||
|
|
||||||
# Load into model
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
# Tie lm_head weights to embed_tokens if configured
|
|
||||||
if model.tie_word_embeddings:
|
|
||||||
model.lm_head.weight = model.model.embed_tokens.weight
|
|
||||||
|
|
||||||
model = model.to(dtype)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
max_new_tokens: int = 32,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
do_sample: bool = True,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Simple autoregressive generation."""
|
|
||||||
device = input_ids.device
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
past_key_values = None
|
|
||||||
generated = input_ids.clone()
|
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
|
||||||
if past_key_values is None:
|
|
||||||
current_input = generated
|
|
||||||
else:
|
|
||||||
current_input = generated[:, -1:]
|
|
||||||
|
|
||||||
logits, past_key_values, _ = self(
|
|
||||||
input_ids=current_input,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_logits = logits[:, -1, :]
|
|
||||||
if temperature > 0 and do_sample:
|
|
||||||
next_token_logits = next_token_logits / temperature
|
|
||||||
probs = torch.softmax(next_token_logits, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
else:
|
|
||||||
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
generated = torch.cat([generated, next_token], dim=1)
|
|
||||||
|
|
||||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
|
||||||
break
|
|
||||||
|
|
||||||
return generated
|
|
||||||
|
|
||||||
|
|
||||||
def print_computation_graph():
|
|
||||||
"""Print the computation graph for reference."""
|
|
||||||
print(__doc__)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print_computation_graph()
|
|
||||||
@@ -1,151 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test: Pre-allocated chunk pair graphs for block sparse attention.
|
|
||||||
|
|
||||||
Each (Q_chunk, K_chunk) pair has its own captured CUDA graph.
|
|
||||||
Zero copy_() during replay - all data pre-filled.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChunkAttentionGraph:
|
|
||||||
"""Container for a captured chunk attention graph."""
|
|
||||||
graph: torch.cuda.CUDAGraph
|
|
||||||
static_q: torch.Tensor
|
|
||||||
static_k: torch.Tensor
|
|
||||||
static_v: torch.Tensor
|
|
||||||
static_output: torch.Tensor
|
|
||||||
static_lse: torch.Tensor
|
|
||||||
causal: bool
|
|
||||||
|
|
||||||
|
|
||||||
def capture_chunk_attention_graph(
|
|
||||||
chunk_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
scale: float,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
causal: bool = False,
|
|
||||||
) -> ChunkAttentionGraph:
|
|
||||||
"""Capture a CUDA graph for single chunk attention."""
|
|
||||||
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
static_q.normal_()
|
|
||||||
static_k.normal_()
|
|
||||||
static_v.normal_()
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
with torch.inference_mode():
|
|
||||||
for _ in range(3):
|
|
||||||
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Capture
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.inference_mode():
|
|
||||||
with torch.cuda.graph(graph):
|
|
||||||
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return ChunkAttentionGraph(
|
|
||||||
graph=graph,
|
|
||||||
static_q=static_q,
|
|
||||||
static_k=static_k,
|
|
||||||
static_v=static_v,
|
|
||||||
static_output=static_output,
|
|
||||||
static_lse=static_lse,
|
|
||||||
causal=causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
chunk_size = 64
|
|
||||||
num_chunks = 4
|
|
||||||
num_heads = 8
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 64
|
|
||||||
scale = 1.0 / (head_dim ** 0.5)
|
|
||||||
seq_len = chunk_size * num_chunks
|
|
||||||
|
|
||||||
print(f"Device: {torch.cuda.get_device_name()}")
|
|
||||||
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}")
|
|
||||||
print(f"Total graphs: {num_chunks * (num_chunks + 1) // 2}")
|
|
||||||
|
|
||||||
# Test data
|
|
||||||
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
# Reference
|
|
||||||
with torch.inference_mode():
|
|
||||||
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
|
|
||||||
|
|
||||||
# Capture all graphs
|
|
||||||
graphs: List[List[Optional[ChunkAttentionGraph]]] = [[None] * num_chunks for _ in range(num_chunks)]
|
|
||||||
for q_idx in range(num_chunks):
|
|
||||||
for k_idx in range(q_idx + 1):
|
|
||||||
graphs[q_idx][k_idx] = capture_chunk_attention_graph(
|
|
||||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype,
|
|
||||||
causal=(k_idx == q_idx)
|
|
||||||
)
|
|
||||||
print("All graphs captured")
|
|
||||||
|
|
||||||
# Pre-fill static tensors
|
|
||||||
for q_idx in range(num_chunks):
|
|
||||||
for k_idx in range(q_idx + 1):
|
|
||||||
g = graphs[q_idx][k_idx]
|
|
||||||
g.static_q.copy_(full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size])
|
|
||||||
g.static_k.copy_(full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
|
|
||||||
g.static_v.copy_(full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
|
|
||||||
print("Static tensors pre-filled")
|
|
||||||
|
|
||||||
# Replay and merge
|
|
||||||
chunked_output = torch.zeros_like(full_output)
|
|
||||||
for q_idx in range(num_chunks):
|
|
||||||
acc_out, acc_lse = None, None
|
|
||||||
for k_idx in range(q_idx + 1):
|
|
||||||
g = graphs[q_idx][k_idx]
|
|
||||||
g.graph.replay()
|
|
||||||
out, lse = g.static_output.clone(), g.static_lse.clone()
|
|
||||||
if acc_out is None:
|
|
||||||
acc_out, acc_lse = out, lse
|
|
||||||
else:
|
|
||||||
with torch.inference_mode():
|
|
||||||
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
|
|
||||||
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Compare
|
|
||||||
all_pass = True
|
|
||||||
for q_idx in range(num_chunks):
|
|
||||||
s, e = q_idx * chunk_size, (q_idx + 1) * chunk_size
|
|
||||||
diff = (full_output[:, s:e] - chunked_output[:, s:e]).abs().max().item()
|
|
||||||
status = "✅" if diff < 1e-2 else "❌"
|
|
||||||
print(f"Q[{q_idx}]: max_diff={diff:.2e} {status}")
|
|
||||||
if diff >= 1e-2:
|
|
||||||
all_pass = False
|
|
||||||
|
|
||||||
print("✅ PASSED" if all_pass else "❌ FAILED")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
|
|
||||||
|
|
||||||
Key insight: LLM layers have identical computation structure.
|
|
||||||
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ReusableChunkGraph:
|
|
||||||
"""A single graph that can be reused with copy_() updates."""
|
|
||||||
graph: torch.cuda.CUDAGraph
|
|
||||||
static_q: torch.Tensor
|
|
||||||
static_k: torch.Tensor
|
|
||||||
static_v: torch.Tensor
|
|
||||||
static_output: torch.Tensor
|
|
||||||
static_lse: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
def capture_reusable_graph(
|
|
||||||
chunk_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
scale: float,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
causal: bool,
|
|
||||||
) -> ReusableChunkGraph:
|
|
||||||
"""Capture ONE graph to be reused for all chunk pairs."""
|
|
||||||
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
static_q.normal_()
|
|
||||||
static_k.normal_()
|
|
||||||
static_v.normal_()
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
with torch.inference_mode():
|
|
||||||
for _ in range(3):
|
|
||||||
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Capture
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.inference_mode():
|
|
||||||
with torch.cuda.graph(graph):
|
|
||||||
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return ReusableChunkGraph(
|
|
||||||
graph=graph,
|
|
||||||
static_q=static_q,
|
|
||||||
static_k=static_k,
|
|
||||||
static_v=static_v,
|
|
||||||
static_output=static_output,
|
|
||||||
static_lse=static_lse,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
|
||||||
"""Replay graph after updating static tensors with copy_()."""
|
|
||||||
graph.static_q.copy_(q)
|
|
||||||
graph.static_k.copy_(k)
|
|
||||||
graph.static_v.copy_(v)
|
|
||||||
graph.graph.replay()
|
|
||||||
return graph.static_output.clone(), graph.static_lse.clone()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
chunk_size = 64
|
|
||||||
num_chunks = 4
|
|
||||||
num_layers = 3 # Simulate multiple layers
|
|
||||||
num_heads = 8
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 64
|
|
||||||
scale = 1.0 / (head_dim ** 0.5)
|
|
||||||
seq_len = chunk_size * num_chunks
|
|
||||||
|
|
||||||
print(f"Device: {torch.cuda.get_device_name()}")
|
|
||||||
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
|
|
||||||
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
|
|
||||||
|
|
||||||
# Capture only 2 graphs
|
|
||||||
graph_causal = capture_reusable_graph(
|
|
||||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
|
|
||||||
)
|
|
||||||
graph_non_causal = capture_reusable_graph(
|
|
||||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
|
|
||||||
)
|
|
||||||
print("2 graphs captured (causal + non-causal)")
|
|
||||||
|
|
||||||
all_pass = True
|
|
||||||
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
# Different Q/K/V for each layer (simulating different layer outputs)
|
|
||||||
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
# Reference: full causal attention
|
|
||||||
with torch.inference_mode():
|
|
||||||
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
|
|
||||||
|
|
||||||
# Chunked with graph reuse
|
|
||||||
chunked_output = torch.zeros_like(full_output)
|
|
||||||
|
|
||||||
for q_idx in range(num_chunks):
|
|
||||||
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
|
|
||||||
acc_out, acc_lse = None, None
|
|
||||||
|
|
||||||
for k_idx in range(q_idx + 1):
|
|
||||||
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
|
|
||||||
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
|
|
||||||
|
|
||||||
# Reuse graph with copy_()
|
|
||||||
graph = graph_causal if k_idx == q_idx else graph_non_causal
|
|
||||||
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
|
|
||||||
|
|
||||||
if acc_out is None:
|
|
||||||
acc_out, acc_lse = out, lse
|
|
||||||
else:
|
|
||||||
with torch.inference_mode():
|
|
||||||
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
|
|
||||||
|
|
||||||
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Compare
|
|
||||||
max_diff = (full_output - chunked_output).abs().max().item()
|
|
||||||
status = "✅" if max_diff < 1e-2 else "❌"
|
|
||||||
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
|
|
||||||
if max_diff >= 1e-2:
|
|
||||||
all_pass = False
|
|
||||||
|
|
||||||
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,357 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
CUDA Graph Memory Analysis Test
|
|
||||||
|
|
||||||
This script analyzes the memory overhead of CUDA Graph at each stage:
|
|
||||||
1. Model loading
|
|
||||||
2. StaticCache allocation
|
|
||||||
3. Warmup runs
|
|
||||||
4. Graph capture
|
|
||||||
5. Graph replay
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py
|
|
||||||
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B
|
|
||||||
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from transformers.cache_utils import StaticCache
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_mb():
|
|
||||||
"""Get current allocated memory in MB."""
|
|
||||||
return torch.cuda.memory_allocated() / 1024**2
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_gb():
|
|
||||||
"""Get current allocated memory in GB."""
|
|
||||||
return torch.cuda.memory_allocated() / 1024**3
|
|
||||||
|
|
||||||
|
|
||||||
def get_peak_memory_gb():
|
|
||||||
"""Get peak allocated memory in GB."""
|
|
||||||
return torch.cuda.max_memory_allocated() / 1024**3
|
|
||||||
|
|
||||||
|
|
||||||
def print_separator(title=None):
|
|
||||||
"""Print a separator line."""
|
|
||||||
if title:
|
|
||||||
print(f"\n{'=' * 70}")
|
|
||||||
print(f" {title}")
|
|
||||||
print(f"{'=' * 70}")
|
|
||||||
else:
|
|
||||||
print("-" * 70)
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1):
|
|
||||||
"""
|
|
||||||
Test memory usage at each stage of CUDA Graph setup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to the model
|
|
||||||
max_cache_len: Maximum cache length for StaticCache
|
|
||||||
batch_size: Batch size for inference
|
|
||||||
"""
|
|
||||||
print_separator("CUDA Graph Memory Analysis")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Max cache length: {max_cache_len}")
|
|
||||||
print(f"Batch size: {batch_size}")
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
# Stage 0: Initial
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
results["initial"] = get_memory_mb()
|
|
||||||
|
|
||||||
# Stage 1: Load model
|
|
||||||
print_separator("Stage 1: Model Loading")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="cuda",
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
results["after_model"] = get_memory_mb()
|
|
||||||
model_size = results["after_model"] - results["initial"]
|
|
||||||
print(f" Memory: {results['after_model']:.0f} MB")
|
|
||||||
print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)")
|
|
||||||
|
|
||||||
config = model.config
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
dtype = next(model.parameters()).dtype
|
|
||||||
|
|
||||||
# Stage 2: Allocate StaticCache
|
|
||||||
print_separator("Stage 2: StaticCache Allocation")
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
before = get_memory_mb()
|
|
||||||
|
|
||||||
static_cache = StaticCache(
|
|
||||||
config=config,
|
|
||||||
max_batch_size=batch_size,
|
|
||||||
max_cache_len=max_cache_len,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
results["after_cache"] = get_memory_mb()
|
|
||||||
cache_size = results["after_cache"] - before
|
|
||||||
print(f" Memory: {results['after_cache']:.0f} MB")
|
|
||||||
print(f" StaticCache size: {cache_size:.0f} MB")
|
|
||||||
|
|
||||||
# Calculate theoretical cache size
|
|
||||||
num_layers = config.num_hidden_layers
|
|
||||||
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
dtype_size = 2 # bfloat16
|
|
||||||
|
|
||||||
theoretical_cache = (
|
|
||||||
num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size
|
|
||||||
) / (1024**2)
|
|
||||||
print(f" Theoretical: {theoretical_cache:.0f} MB")
|
|
||||||
print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)")
|
|
||||||
|
|
||||||
# Stage 3: Prepare static tensors
|
|
||||||
print_separator("Stage 3: Static Tensor Allocation")
|
|
||||||
before = get_memory_mb()
|
|
||||||
|
|
||||||
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
|
|
||||||
static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
|
|
||||||
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
results["after_tensors"] = get_memory_mb()
|
|
||||||
tensor_size = results["after_tensors"] - before
|
|
||||||
print(f" Memory: {results['after_tensors']:.0f} MB")
|
|
||||||
print(f" Static tensors: {tensor_size:.2f} MB (negligible)")
|
|
||||||
|
|
||||||
# Stage 4: Warmup runs
|
|
||||||
print_separator("Stage 4: Warmup Runs (3 iterations)")
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
before = get_memory_mb()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
for i in range(3):
|
|
||||||
_ = model(
|
|
||||||
input_ids=static_input_ids,
|
|
||||||
position_ids=static_position_ids,
|
|
||||||
past_key_values=static_cache,
|
|
||||||
cache_position=static_cache_position,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
results["after_warmup"] = get_memory_mb()
|
|
||||||
results["warmup_peak"] = get_peak_memory_gb() * 1024
|
|
||||||
warmup_size = results["after_warmup"] - before
|
|
||||||
print(f" Memory: {results['after_warmup']:.0f} MB")
|
|
||||||
print(f" Peak: {results['warmup_peak']:.0f} MB")
|
|
||||||
print(f" Warmup overhead: {warmup_size:.0f} MB")
|
|
||||||
|
|
||||||
# Stage 5: CUDA Graph capture
|
|
||||||
print_separator("Stage 5: CUDA Graph Capture")
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
before = get_memory_mb()
|
|
||||||
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.inference_mode():
|
|
||||||
with torch.cuda.graph(graph):
|
|
||||||
outputs = model(
|
|
||||||
input_ids=static_input_ids,
|
|
||||||
position_ids=static_position_ids,
|
|
||||||
past_key_values=static_cache,
|
|
||||||
cache_position=static_cache_position,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
static_logits = outputs.logits
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
results["after_capture"] = get_memory_mb()
|
|
||||||
results["capture_peak"] = get_peak_memory_gb() * 1024
|
|
||||||
capture_size = results["after_capture"] - before
|
|
||||||
print(f" Memory: {results['after_capture']:.0f} MB")
|
|
||||||
print(f" Peak: {results['capture_peak']:.0f} MB")
|
|
||||||
print(f" Graph capture overhead: {capture_size:.0f} MB")
|
|
||||||
|
|
||||||
# Stage 6: Graph replay
|
|
||||||
print_separator("Stage 6: Graph Replay (10 iterations)")
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
before = get_memory_mb()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
for _ in range(10):
|
|
||||||
static_input_ids.fill_(1)
|
|
||||||
static_cache_position.fill_(0)
|
|
||||||
graph.replay()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
results["after_replay"] = get_memory_mb()
|
|
||||||
results["replay_peak"] = get_peak_memory_gb() * 1024
|
|
||||||
replay_change = results["after_replay"] - before
|
|
||||||
print(f" Memory: {results['after_replay']:.0f} MB")
|
|
||||||
print(f" Peak: {results['replay_peak']:.0f} MB")
|
|
||||||
print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)")
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print_separator("SUMMARY")
|
|
||||||
total_overhead = results["after_capture"] - results["after_model"]
|
|
||||||
|
|
||||||
print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}")
|
|
||||||
print("-" * 50)
|
|
||||||
print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}")
|
|
||||||
print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}")
|
|
||||||
print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}")
|
|
||||||
print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}")
|
|
||||||
print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}")
|
|
||||||
print("-" * 50)
|
|
||||||
print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}")
|
|
||||||
|
|
||||||
print_separator("KEY FINDINGS")
|
|
||||||
print(f" 1. Model size: {model_size/1024:.2f} GB")
|
|
||||||
print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)")
|
|
||||||
print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)")
|
|
||||||
print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)")
|
|
||||||
print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def test_cache_length_scaling(model_path: str, cache_lengths: list):
|
|
||||||
"""
|
|
||||||
Test how memory scales with different cache lengths.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to the model
|
|
||||||
cache_lengths: List of cache lengths to test
|
|
||||||
"""
|
|
||||||
print_separator("Cache Length Scaling Test")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Cache lengths: {cache_lengths}")
|
|
||||||
|
|
||||||
# Load model once
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="cuda",
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
config = model.config
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
dtype = next(model.parameters()).dtype
|
|
||||||
|
|
||||||
model_mem = get_memory_mb()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for cache_len in cache_lengths:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
|
|
||||||
# Create cache and capture graph
|
|
||||||
static_cache = StaticCache(
|
|
||||||
config=config,
|
|
||||||
max_batch_size=1,
|
|
||||||
max_cache_len=cache_len,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
|
|
||||||
static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
|
|
||||||
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
# Warmup
|
|
||||||
for _ in range(3):
|
|
||||||
_ = model(
|
|
||||||
input_ids=static_input_ids,
|
|
||||||
position_ids=static_position_ids,
|
|
||||||
past_key_values=static_cache,
|
|
||||||
cache_position=static_cache_position,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Capture
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(graph):
|
|
||||||
outputs = model(
|
|
||||||
input_ids=static_input_ids,
|
|
||||||
position_ids=static_position_ids,
|
|
||||||
past_key_values=static_cache,
|
|
||||||
cache_position=static_cache_position,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
total_mem = get_memory_mb()
|
|
||||||
overhead = total_mem - model_mem
|
|
||||||
results.append((cache_len, total_mem, overhead))
|
|
||||||
|
|
||||||
del static_cache, graph
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Print results
|
|
||||||
print()
|
|
||||||
print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}")
|
|
||||||
print("-" * 60)
|
|
||||||
for cache_len, total, overhead in results:
|
|
||||||
per_1k = overhead / (cache_len / 1000)
|
|
||||||
print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
type=str,
|
|
||||||
default="~/models/Qwen3-4B-Instruct-2507",
|
|
||||||
help="Model path",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-cache-len",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="Maximum cache length",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--test-scaling",
|
|
||||||
action="store_true",
|
|
||||||
help="Test cache length scaling",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
model_path = os.path.expanduser(args.model)
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA is not available!")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"Device: cuda:{torch.cuda.current_device()}")
|
|
||||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
|
||||||
|
|
||||||
if args.test_scaling:
|
|
||||||
cache_lengths = [256, 512, 1024, 2048, 4096]
|
|
||||||
test_cache_length_scaling(model_path, cache_lengths)
|
|
||||||
else:
|
|
||||||
test_memory_stages(model_path, args.max_cache_len, args.batch_size)
|
|
||||||
|
|
||||||
print("\ntest_cudagraph_memory: PASSED")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,442 +0,0 @@
|
|||||||
"""
|
|
||||||
Test: Hierarchical Block Sum Estimation for XAttention
|
|
||||||
|
|
||||||
Verify that hierarchical estimation (small estimate_block_size + aggregation)
|
|
||||||
produces equivalent results to direct estimation (large block_size), while
|
|
||||||
being significantly faster.
|
|
||||||
|
|
||||||
Key changes validated:
|
|
||||||
1. Hierarchical block sum: estimate_block_size=1024 → aggregate to cpu_block_size=4096
|
|
||||||
2. Selection strategy: score + threshold (NOT mask + majority voting)
|
|
||||||
|
|
||||||
This test uses pure torch + xattn kernels, independent of nanovllm framework.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Model dimensions (Llama-3.1-8B-Instruct style)
|
|
||||||
NUM_HEADS = 32
|
|
||||||
NUM_KV_HEADS = 8
|
|
||||||
HEAD_DIM = 128
|
|
||||||
STRIDE = 8
|
|
||||||
|
|
||||||
# Block sizes
|
|
||||||
CPU_BLOCK_SIZE = 4096 # External CPU block size (fixed, for overlap)
|
|
||||||
ESTIMATE_BLOCK_SIZE = 1024 # Internal estimate block size (optimized)
|
|
||||||
|
|
||||||
# Selection parameters
|
|
||||||
THRESHOLD = 0.95 # Cumulative attention threshold
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Hierarchical Estimation Implementation
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def compute_attention_scores(Q, K_blocks, stride):
|
|
||||||
"""
|
|
||||||
Compute attention scores for Q against multiple K blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Q: [1, num_heads, q_len, head_dim]
|
|
||||||
K_blocks: List of K tensors, each [1, num_heads, block_size, head_dim]
|
|
||||||
stride: Stride for reshape
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
attn_scores: [1, num_heads, q_reshaped, total_k_reshaped]
|
|
||||||
"""
|
|
||||||
q_len = Q.shape[2]
|
|
||||||
q_reshaped = q_len // stride
|
|
||||||
|
|
||||||
attn_chunks = []
|
|
||||||
for K_block in K_blocks:
|
|
||||||
# flat_group_gemm_fuse_reshape
|
|
||||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
|
||||||
Q, K_block, stride,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=q_reshaped,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
attn_chunks.append(attn_chunk)
|
|
||||||
|
|
||||||
# Concatenate along K dimension
|
|
||||||
attn_scores = torch.cat(attn_chunks, dim=-1)
|
|
||||||
return attn_scores
|
|
||||||
|
|
||||||
|
|
||||||
def hierarchical_block_sum(
|
|
||||||
attn_scores,
|
|
||||||
estimate_block_size,
|
|
||||||
cpu_block_size,
|
|
||||||
stride,
|
|
||||||
head_dim,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Compute hierarchical block sums: fine-grained → aggregated to CPU block level.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attn_scores: [batch, heads, q_reshaped, k_reshaped]
|
|
||||||
estimate_block_size: Small block size for efficient softmax (e.g., 1024)
|
|
||||||
cpu_block_size: External CPU block size (e.g., 4096)
|
|
||||||
stride: Stride used in reshape
|
|
||||||
head_dim: Head dimension for scale computation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
cpu_block_scores: [batch, heads, num_cpu_blocks] - attention score per CPU block
|
|
||||||
"""
|
|
||||||
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
|
|
||||||
|
|
||||||
# Compute reshaped block sizes
|
|
||||||
reshaped_est_bs = estimate_block_size // stride # 1024/8 = 128
|
|
||||||
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
|
|
||||||
|
|
||||||
# Scale factor
|
|
||||||
norm = 1.0
|
|
||||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
|
||||||
|
|
||||||
# Segment size for softmax kernel
|
|
||||||
segment_size = min(4096, reshaped_est_bs)
|
|
||||||
|
|
||||||
# Step 1: Fine-grained softmax + block sum
|
|
||||||
block_sums_fine = softmax_fuse_block_sum(
|
|
||||||
attn_scores,
|
|
||||||
reshaped_est_bs,
|
|
||||||
segment_size,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=q_reshaped,
|
|
||||||
real_q_len=q_reshaped,
|
|
||||||
scale=scale,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
# block_sums_fine: [batch, heads, q_est_blocks, k_est_blocks]
|
|
||||||
|
|
||||||
q_est_blocks = block_sums_fine.shape[2]
|
|
||||||
k_est_blocks = block_sums_fine.shape[3]
|
|
||||||
|
|
||||||
# Step 2: Aggregate to CPU block level
|
|
||||||
# ratio = cpu_block_size / estimate_block_size = 4
|
|
||||||
ratio = cpu_block_size // estimate_block_size
|
|
||||||
num_cpu_blocks = k_est_blocks // ratio
|
|
||||||
|
|
||||||
# Reshape and sum along K dimension
|
|
||||||
# [batch, heads, q_est, k_est] → [batch, heads, q_est, num_cpu, ratio]
|
|
||||||
block_sums_coarse = block_sums_fine.view(
|
|
||||||
batch_size, num_heads, q_est_blocks, num_cpu_blocks, ratio
|
|
||||||
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
|
|
||||||
|
|
||||||
# Step 3: Sum over Q dimension (total attention from Q chunk to each K block)
|
|
||||||
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
|
|
||||||
|
|
||||||
return cpu_block_scores, block_sums_fine
|
|
||||||
|
|
||||||
|
|
||||||
def direct_block_sum(
|
|
||||||
attn_scores,
|
|
||||||
cpu_block_size,
|
|
||||||
stride,
|
|
||||||
head_dim,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Compute block sums directly with CPU block size (baseline for comparison).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attn_scores: [batch, heads, q_reshaped, k_reshaped]
|
|
||||||
cpu_block_size: Block size (e.g., 4096)
|
|
||||||
stride: Stride used in reshape
|
|
||||||
head_dim: Head dimension for scale computation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
cpu_block_scores: [batch, heads, num_cpu_blocks]
|
|
||||||
"""
|
|
||||||
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
|
|
||||||
|
|
||||||
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
|
|
||||||
|
|
||||||
norm = 1.0
|
|
||||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
|
||||||
segment_size = min(4096, reshaped_cpu_bs)
|
|
||||||
|
|
||||||
block_sums = softmax_fuse_block_sum(
|
|
||||||
attn_scores,
|
|
||||||
reshaped_cpu_bs,
|
|
||||||
segment_size,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=q_reshaped,
|
|
||||||
real_q_len=q_reshaped,
|
|
||||||
scale=scale,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
# block_sums: [batch, heads, q_cpu_blocks, k_cpu_blocks]
|
|
||||||
|
|
||||||
# Sum over Q dimension
|
|
||||||
cpu_block_scores = block_sums.sum(dim=2) # [batch, heads, num_cpu_blocks]
|
|
||||||
|
|
||||||
return cpu_block_scores
|
|
||||||
|
|
||||||
|
|
||||||
def select_blocks_by_score(
|
|
||||||
cpu_block_scores,
|
|
||||||
threshold=0.95,
|
|
||||||
always_include_first=True,
|
|
||||||
always_include_last=True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Select CPU blocks based on score + threshold.
|
|
||||||
|
|
||||||
⚠️ IMPORTANT: This replaces the original mask + majority voting strategy.
|
|
||||||
This change should be documented in the final implementation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cpu_block_scores: [batch, heads, num_cpu_blocks]
|
|
||||||
threshold: Cumulative attention threshold (e.g., 0.95)
|
|
||||||
always_include_first: Always include first block (sink)
|
|
||||||
always_include_last: Always include last block (safety)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
selected_block_ids: List of selected block indices
|
|
||||||
density: Fraction of blocks selected
|
|
||||||
"""
|
|
||||||
# Average scores across heads
|
|
||||||
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
|
|
||||||
num_blocks = scores_per_block.shape[0]
|
|
||||||
|
|
||||||
# Normalize to get attention distribution
|
|
||||||
total_score = scores_per_block.sum()
|
|
||||||
score_ratio = scores_per_block / total_score
|
|
||||||
|
|
||||||
# Sort by score (descending)
|
|
||||||
sorted_indices = torch.argsort(score_ratio, descending=True)
|
|
||||||
|
|
||||||
# Select blocks until cumulative threshold is reached
|
|
||||||
cumsum = 0.0
|
|
||||||
selected = set()
|
|
||||||
|
|
||||||
for idx in sorted_indices.tolist():
|
|
||||||
selected.add(idx)
|
|
||||||
cumsum += score_ratio[idx].item()
|
|
||||||
if cumsum >= threshold:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Always include first and last blocks
|
|
||||||
if always_include_first:
|
|
||||||
selected.add(0)
|
|
||||||
if always_include_last:
|
|
||||||
selected.add(num_blocks - 1)
|
|
||||||
|
|
||||||
selected_block_ids = sorted(list(selected))
|
|
||||||
density = len(selected_block_ids) / num_blocks
|
|
||||||
|
|
||||||
return selected_block_ids, density
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test Cases
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def test_equivalence():
|
|
||||||
"""
|
|
||||||
Test that hierarchical estimation produces equivalent scores to direct estimation.
|
|
||||||
"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("Test 1: Hierarchical vs Direct - Equivalence")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Create random Q and multiple K blocks
|
|
||||||
q_len = CPU_BLOCK_SIZE # 4096
|
|
||||||
num_k_blocks = 4
|
|
||||||
|
|
||||||
# Q: [1, num_heads, q_len, head_dim]
|
|
||||||
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
|
||||||
|
|
||||||
# K blocks: each [1, num_heads, cpu_block_size, head_dim]
|
|
||||||
K_blocks = [
|
|
||||||
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
|
||||||
for _ in range(num_k_blocks)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Compute attention scores
|
|
||||||
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
|
|
||||||
print(f"attn_scores shape: {attn_scores.shape}")
|
|
||||||
|
|
||||||
# Method 1: Hierarchical (fast)
|
|
||||||
scores_hier, _ = hierarchical_block_sum(
|
|
||||||
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
|
|
||||||
)
|
|
||||||
print(f"scores_hier shape: {scores_hier.shape}")
|
|
||||||
|
|
||||||
# Method 2: Direct (slow)
|
|
||||||
scores_direct = direct_block_sum(
|
|
||||||
attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
|
|
||||||
)
|
|
||||||
print(f"scores_direct shape: {scores_direct.shape}")
|
|
||||||
|
|
||||||
# Compare
|
|
||||||
diff = (scores_hier - scores_direct).abs().max().item()
|
|
||||||
print(f"\nMax difference: {diff:.6f}")
|
|
||||||
|
|
||||||
# Per-block comparison
|
|
||||||
print("\nPer-block scores comparison:")
|
|
||||||
for i in range(num_k_blocks):
|
|
||||||
h_val = scores_hier[0, 0, i].item()
|
|
||||||
d_val = scores_direct[0, 0, i].item()
|
|
||||||
print(f" Block {i}: hierarchical={h_val:.4f}, direct={d_val:.4f}, diff={abs(h_val-d_val):.6f}")
|
|
||||||
|
|
||||||
passed = diff < 0.01
|
|
||||||
print(f"\nTest 1: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
return passed
|
|
||||||
|
|
||||||
|
|
||||||
def test_selection():
|
|
||||||
"""
|
|
||||||
Test the score + threshold selection strategy.
|
|
||||||
"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Test 2: Score + Threshold Selection")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Create Q and K blocks with varying importance
|
|
||||||
q_len = CPU_BLOCK_SIZE
|
|
||||||
num_k_blocks = 8
|
|
||||||
|
|
||||||
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
|
||||||
|
|
||||||
# Create K blocks - make some more important than others
|
|
||||||
K_blocks = []
|
|
||||||
for i in range(num_k_blocks):
|
|
||||||
# First and middle blocks are more important (higher values)
|
|
||||||
importance = 2.0 if i in [0, 3, 4] else 1.0
|
|
||||||
K = torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
|
||||||
K = K * importance
|
|
||||||
K_blocks.append(K)
|
|
||||||
|
|
||||||
# Compute scores
|
|
||||||
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
|
|
||||||
scores, _ = hierarchical_block_sum(
|
|
||||||
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print scores per block
|
|
||||||
print("\nCPU block scores (head 0):")
|
|
||||||
for i in range(num_k_blocks):
|
|
||||||
print(f" Block {i}: {scores[0, 0, i].item():.4f}")
|
|
||||||
|
|
||||||
# Select blocks with different thresholds
|
|
||||||
for thresh in [0.9, 0.95, 0.99]:
|
|
||||||
selected, density = select_blocks_by_score(scores, threshold=thresh)
|
|
||||||
print(f"\nThreshold {thresh}: selected {len(selected)}/{num_k_blocks} blocks ({density:.1%})")
|
|
||||||
print(f" Selected: {selected}")
|
|
||||||
|
|
||||||
print("\nTest 2: PASSED (visual inspection)")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def test_performance():
|
|
||||||
"""
|
|
||||||
Benchmark hierarchical vs direct estimation performance.
|
|
||||||
"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Test 3: Performance Benchmark")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
NUM_WARMUP = 3
|
|
||||||
NUM_RUNS = 10
|
|
||||||
|
|
||||||
# Larger test case
|
|
||||||
q_len = CPU_BLOCK_SIZE
|
|
||||||
num_k_blocks = 16 # 64K context
|
|
||||||
|
|
||||||
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
|
||||||
K_blocks = [
|
|
||||||
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
|
||||||
for _ in range(num_k_blocks)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Compute attention scores (shared)
|
|
||||||
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
|
|
||||||
print(f"attn_scores shape: {attn_scores.shape}")
|
|
||||||
print(f"Context: {num_k_blocks * CPU_BLOCK_SIZE // 1024}K tokens")
|
|
||||||
|
|
||||||
# Warmup and benchmark hierarchical
|
|
||||||
for _ in range(NUM_WARMUP):
|
|
||||||
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(NUM_RUNS):
|
|
||||||
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
hier_time = (time.perf_counter() - start) / NUM_RUNS * 1000
|
|
||||||
|
|
||||||
# Warmup and benchmark direct
|
|
||||||
for _ in range(NUM_WARMUP):
|
|
||||||
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(NUM_RUNS):
|
|
||||||
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
direct_time = (time.perf_counter() - start) / NUM_RUNS * 1000
|
|
||||||
|
|
||||||
speedup = direct_time / hier_time
|
|
||||||
|
|
||||||
print(f"\nResults:")
|
|
||||||
print(f" Hierarchical (bs=1024): {hier_time:.2f} ms")
|
|
||||||
print(f" Direct (bs=4096): {direct_time:.2f} ms")
|
|
||||||
print(f" Speedup: {speedup:.2f}x")
|
|
||||||
|
|
||||||
passed = speedup > 5.0 # Expect at least 5x speedup
|
|
||||||
print(f"\nTest 3: {'PASSED' if passed else 'FAILED'} (speedup > 5x expected)")
|
|
||||||
return passed
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=" * 60)
|
|
||||||
print("Hierarchical Block Sum Estimation Test")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"\nConfiguration:")
|
|
||||||
print(f" NUM_HEADS: {NUM_HEADS}")
|
|
||||||
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
|
|
||||||
print(f" HEAD_DIM: {HEAD_DIM}")
|
|
||||||
print(f" STRIDE: {STRIDE}")
|
|
||||||
print(f" CPU_BLOCK_SIZE: {CPU_BLOCK_SIZE}")
|
|
||||||
print(f" ESTIMATE_BLOCK_SIZE: {ESTIMATE_BLOCK_SIZE}")
|
|
||||||
print(f" THRESHOLD: {THRESHOLD}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
results.append(("Equivalence", test_equivalence()))
|
|
||||||
results.append(("Selection", test_selection()))
|
|
||||||
results.append(("Performance", test_performance()))
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("SUMMARY")
|
|
||||||
print("=" * 60)
|
|
||||||
for name, passed in results:
|
|
||||||
status = "PASSED" if passed else "FAILED"
|
|
||||||
print(f" {name}: {status}")
|
|
||||||
|
|
||||||
all_passed = all(p for _, p in results)
|
|
||||||
print("=" * 60)
|
|
||||||
if all_passed:
|
|
||||||
print("test_hierarchical_estimate: ALL PASSED")
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print("test_hierarchical_estimate: SOME FAILED")
|
|
||||||
sys.exit(1)
|
|
||||||
@@ -1,254 +0,0 @@
|
|||||||
"""
|
|
||||||
Needle-in-a-haystack test for LLM.
|
|
||||||
|
|
||||||
Tests: Long context retrieval capability with configurable sequence length.
|
|
||||||
|
|
||||||
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
|
|
||||||
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from nanovllm.config import SparsePolicyType
|
|
||||||
from utils import generate_needle_prompt, check_needle_answer
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def run_needle_test(
|
|
||||||
model_path: str,
|
|
||||||
max_model_len: int,
|
|
||||||
input_len: int,
|
|
||||||
num_gpu_blocks: int = 4,
|
|
||||||
block_size: int = 1024,
|
|
||||||
needle_position: float = 0.5,
|
|
||||||
needle_value: str = "7492",
|
|
||||||
max_new_tokens: int = 32,
|
|
||||||
enable_cpu_offload: bool = False,
|
|
||||||
enable_quest: bool = False,
|
|
||||||
enable_xattn_bsa: bool = False,
|
|
||||||
sparse_topk: int = 8,
|
|
||||||
sparse_threshold: int = 4,
|
|
||||||
sparse_samples: int = 128,
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Run a needle-in-haystack test.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model
|
|
||||||
max_model_len: Maximum model context length
|
|
||||||
input_len: Target input sequence length
|
|
||||||
num_gpu_blocks: Number of GPU blocks for offload
|
|
||||||
block_size: KV cache block size
|
|
||||||
needle_position: Where to place needle (0.0-1.0)
|
|
||||||
needle_value: The secret value to find
|
|
||||||
max_new_tokens: Maximum tokens to generate
|
|
||||||
enable_cpu_offload: Enable CPU offload mode
|
|
||||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
|
||||||
enable_xattn_bsa: Enable XAttention BSA sparse attention (prefill-only)
|
|
||||||
sparse_topk: Top-K blocks for Quest
|
|
||||||
sparse_threshold: Threshold for sparse selection (Quest/XAttention BSA)
|
|
||||||
sparse_samples: Samples per chunk for XAttention BSA estimation
|
|
||||||
verbose: Print detailed output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if test passed, False otherwise
|
|
||||||
"""
|
|
||||||
# Determine sparse policy
|
|
||||||
if enable_xattn_bsa:
|
|
||||||
sparse_policy = SparsePolicyType.XATTN_BSA
|
|
||||||
elif enable_quest:
|
|
||||||
sparse_policy = SparsePolicyType.QUEST
|
|
||||||
else:
|
|
||||||
sparse_policy = SparsePolicyType.FULL
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Needle-in-Haystack Test")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Max model len: {max_model_len}")
|
|
||||||
print(f"Input length: {input_len}")
|
|
||||||
print(f"Block size: {block_size}")
|
|
||||||
print(f"Needle position: {needle_position:.0%}")
|
|
||||||
print(f"Needle value: {needle_value}")
|
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
|
||||||
if enable_cpu_offload:
|
|
||||||
print(f"Sparse policy: {sparse_policy.name}")
|
|
||||||
if sparse_policy == SparsePolicyType.QUEST:
|
|
||||||
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
|
|
||||||
elif sparse_policy == SparsePolicyType.XATTN_BSA:
|
|
||||||
print(f" XAttention BSA: threshold={sparse_threshold}, samples={sparse_samples}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# 1. Initialize LLM
|
|
||||||
llm_kwargs = {
|
|
||||||
"enforce_eager": True,
|
|
||||||
"max_model_len": max_model_len,
|
|
||||||
"max_num_batched_tokens": max_model_len,
|
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
|
||||||
"kvcache_block_size": block_size,
|
|
||||||
}
|
|
||||||
if enable_cpu_offload:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
|
||||||
llm_kwargs["sparse_policy"] = sparse_policy
|
|
||||||
if sparse_policy == SparsePolicyType.QUEST:
|
|
||||||
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
|
||||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
|
||||||
elif sparse_policy == SparsePolicyType.XATTN_BSA:
|
|
||||||
llm_kwargs["sparse_threshold"] = float(sparse_threshold) / 10.0 # Convert to 0.0-1.0 range
|
|
||||||
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
|
||||||
|
|
||||||
# 2. Generate needle prompt
|
|
||||||
prompt, expected = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=needle_position,
|
|
||||||
needle_value=needle_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Generate output
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.6, # Moderate temperature
|
|
||||||
max_tokens=max_new_tokens,
|
|
||||||
)
|
|
||||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
|
|
||||||
|
|
||||||
# 4. Check result
|
|
||||||
output_text = outputs[0]["text"]
|
|
||||||
output_token_ids = outputs[0]["token_ids"]
|
|
||||||
passed = check_needle_answer(output_text, expected)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Result")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Expected: {expected}")
|
|
||||||
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
|
|
||||||
print(f"Output: {output_text[:200]}...")
|
|
||||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return passed
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# CLI Entry Point
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
type=str,
|
|
||||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
|
||||||
help="Path to model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-model-len",
|
|
||||||
type=int,
|
|
||||||
default=128 * 1024,
|
|
||||||
help="Maximum model context length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=8 * 1024,
|
|
||||||
help="Target input sequence length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-gpu-blocks",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Number of GPU blocks for CPU offload"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--block-size",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="KV cache block size"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-position",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-value",
|
|
||||||
type=str,
|
|
||||||
default="7492",
|
|
||||||
help="The secret value to hide"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-new-tokens",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Maximum tokens to generate"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-offload",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CPU offload (has known bug for long sequences)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-quest",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-xattn-bsa",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable XAttention BSA sparse attention (prefill-only)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sparse-topk",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Top-K blocks for Quest sparse attention"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sparse-threshold",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="Apply sparse only when blocks > threshold (Quest) or attention threshold 0-9 (XAttention BSA)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sparse-samples",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
help="Samples per chunk for XAttention BSA estimation"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
passed = run_needle_test(
|
|
||||||
model_path=args.model,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
input_len=args.input_len,
|
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
|
||||||
block_size=args.block_size,
|
|
||||||
needle_position=args.needle_position,
|
|
||||||
needle_value=args.needle_value,
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
enable_cpu_offload=args.enable_offload,
|
|
||||||
enable_quest=args.enable_quest,
|
|
||||||
enable_xattn_bsa=args.enable_xattn_bsa,
|
|
||||||
sparse_topk=args.sparse_topk,
|
|
||||||
sparse_threshold=args.sparse_threshold,
|
|
||||||
sparse_samples=args.sparse_samples,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_needle: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_needle: FAILED")
|
|
||||||
exit(1)
|
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
"""
|
|
||||||
Needle-in-a-haystack reference test using pure torch + transformers.
|
|
||||||
|
|
||||||
This is a reference implementation for comparison with nanovllm.
|
|
||||||
Uses standard HuggingFace inference (no custom KV cache, no offload).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from modeling_qwen3 import Qwen3ForCausalLM
|
|
||||||
from utils import generate_needle_prompt, check_needle_answer
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def run_needle_test(
|
|
||||||
model_path: str,
|
|
||||||
input_len: int,
|
|
||||||
needle_position: float = 0.5,
|
|
||||||
needle_value: str = "7492",
|
|
||||||
max_new_tokens: int = 32,
|
|
||||||
dtype: str = "auto",
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Run a needle-in-haystack test using standard transformers inference.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model
|
|
||||||
input_len: Target input sequence length
|
|
||||||
needle_position: Where to place needle (0.0-1.0)
|
|
||||||
needle_value: The secret value to find
|
|
||||||
max_new_tokens: Maximum tokens to generate
|
|
||||||
dtype: Model dtype ("auto", "float16", "bfloat16")
|
|
||||||
verbose: Print detailed output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if test passed, False otherwise
|
|
||||||
"""
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Needle-in-Haystack Reference Test (torch + transformers)")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Input length: {input_len}")
|
|
||||||
print(f"Needle position: {needle_position:.0%}")
|
|
||||||
print(f"Needle value: {needle_value}")
|
|
||||||
print(f"Dtype: {dtype}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# 1. Load tokenizer
|
|
||||||
print("[1/4] Loading tokenizer...")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
# 2. Generate needle prompt
|
|
||||||
print("[2/4] Generating needle prompt...")
|
|
||||||
prompt, expected = generate_needle_prompt(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=needle_position,
|
|
||||||
needle_value=needle_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Load model
|
|
||||||
print("[3/4] Loading model...")
|
|
||||||
torch_dtype = {
|
|
||||||
"auto": torch.float16, # default to float16 for custom model
|
|
||||||
"float16": torch.float16,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
}.get(dtype, torch.float16)
|
|
||||||
|
|
||||||
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
|
|
||||||
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# 4. Generate output
|
|
||||||
print("[4/4] Running inference...")
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
||||||
print(f" Input shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output_ids = model.generate(
|
|
||||||
input_ids,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
temperature=0.6,
|
|
||||||
do_sample=True,
|
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode only the new tokens
|
|
||||||
new_token_ids = output_ids[0, input_ids.shape[1]:]
|
|
||||||
output_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
|
|
||||||
|
|
||||||
# 5. Check result
|
|
||||||
passed = check_needle_answer(output_text, expected)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Result")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Expected: {expected}")
|
|
||||||
print(f"Output tokens ({len(new_token_ids)}): {new_token_ids[:20].tolist()}")
|
|
||||||
print(f"Output: {output_text[:200]}...")
|
|
||||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return passed
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# CLI Entry Point
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Needle-in-haystack reference test (torch + transformers)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
type=str,
|
|
||||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
|
||||||
help="Path to model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=8 * 1024,
|
|
||||||
help="Target input sequence length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-position",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-value",
|
|
||||||
type=str,
|
|
||||||
default="7492",
|
|
||||||
help="The secret value to hide"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-new-tokens",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Maximum tokens to generate"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dtype",
|
|
||||||
type=str,
|
|
||||||
default="auto",
|
|
||||||
choices=["auto", "float16", "bfloat16"],
|
|
||||||
help="Model dtype"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
passed = run_needle_test(
|
|
||||||
model_path=args.model,
|
|
||||||
input_len=args.input_len,
|
|
||||||
needle_position=args.needle_position,
|
|
||||||
needle_value=args.needle_value,
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
dtype=args.dtype,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_needle_ref: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_needle_ref: FAILED")
|
|
||||||
exit(1)
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
"""
|
|
||||||
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
|
|
||||||
|
|
||||||
Demonstrates the key limitation: scores are AVERAGED across heads,
|
|
||||||
so blocks strongly needed by one head but not others may be dropped.
|
|
||||||
|
|
||||||
This is the expected Quest behavior - not a bug.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from nanovllm.kvcache.sparse import (
|
|
||||||
create_sparse_policy,
|
|
||||||
SparsePolicyType,
|
|
||||||
PolicyContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test: Per-Head Score Averaging in GQA
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Determine device (GPU if available, else CPU)
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Running test on device: {device}")
|
|
||||||
|
|
||||||
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
|
|
||||||
# topk=2 to make selection competitive
|
|
||||||
|
|
||||||
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
|
|
||||||
quest.initialize(
|
|
||||||
num_layers=1,
|
|
||||||
num_kv_heads=2,
|
|
||||||
head_dim=4,
|
|
||||||
num_cpu_blocks=6,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device, # Metadata stored on GPU
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = quest.metadata
|
|
||||||
|
|
||||||
def set_key(block_id, head_id, values):
|
|
||||||
"""Set both key_min and key_max to same values for deterministic scoring."""
|
|
||||||
# Values need to be on the same device as metadata
|
|
||||||
tensor = torch.tensor(values, device=device)
|
|
||||||
metadata.key_min[block_id, 0, head_id, :] = tensor
|
|
||||||
metadata.key_max[block_id, 0, head_id, :] = tensor
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Design: Different heads want different blocks
|
|
||||||
# ============================================================
|
|
||||||
#
|
|
||||||
# Query = [1,1,1,1] for all heads, so score = sum(key values)
|
|
||||||
#
|
|
||||||
# Block | Head 0 | Head 1 | Average | Result
|
|
||||||
# ------|--------|--------|---------|--------
|
|
||||||
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
|
|
||||||
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
|
|
||||||
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
|
|
||||||
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
|
|
||||||
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
|
|
||||||
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
|
|
||||||
|
|
||||||
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
|
|
||||||
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
|
||||||
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
|
|
||||||
|
|
||||||
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
|
|
||||||
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
|
|
||||||
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
|
||||||
|
|
||||||
# Block 2: Both heads want equally (highest average)
|
|
||||||
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
|
||||||
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
|
||||||
|
|
||||||
# Block 3: Both heads want moderately
|
|
||||||
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
|
|
||||||
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
|
|
||||||
|
|
||||||
# Block 4: Head 0 strongly wants, Head 1 neutral
|
|
||||||
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
|
||||||
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
|
|
||||||
|
|
||||||
# Block 5: Head 1 strongly wants, Head 0 neutral
|
|
||||||
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
|
|
||||||
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Run selection
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Query on same device as metadata
|
|
||||||
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
|
|
||||||
|
|
||||||
ctx = PolicyContext(
|
|
||||||
query_chunk_idx=0,
|
|
||||||
num_query_chunks=1,
|
|
||||||
layer_id=0,
|
|
||||||
query=query,
|
|
||||||
is_prefill=False,
|
|
||||||
block_size=1024,
|
|
||||||
total_kv_len=6144,
|
|
||||||
)
|
|
||||||
|
|
||||||
available = list(range(6))
|
|
||||||
selected = quest.select_blocks(available, ctx)
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Verify: Averaging behavior
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
|
|
||||||
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
|
|
||||||
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
|
|
||||||
|
|
||||||
# Key insight: blocks 0 and 1 have score +4 for ONE head,
|
|
||||||
# but they cancel out due to averaging with the other head's -4
|
|
||||||
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
|
|
||||||
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
|
|
||||||
|
|
||||||
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
|
|
||||||
# But +2 < +3 (block 3), so they don't make the top-2
|
|
||||||
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
|
|
||||||
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
|
|
||||||
|
|
||||||
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
|
|
||||||
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
|
|
||||||
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
|
|
||||||
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
|
|
||||||
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
|
|
||||||
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
|
|
||||||
|
|
||||||
# Verify metadata is on correct device
|
|
||||||
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
|
|
||||||
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
|
|
||||||
print(f"✓ Metadata stored on {device.type.upper()}")
|
|
||||||
|
|
||||||
print("\ntest_quest_policy: PASSED")
|
|
||||||
@@ -41,6 +41,7 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Tuple, Optional
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.utils.density_observer import DensityObserver
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -381,6 +382,16 @@ def run_ruler_benchmark(
|
|||||||
print(f"Fresh LLM mode: {fresh_llm}")
|
print(f"Fresh LLM mode: {fresh_llm}")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# Enable DensityObserver for XAttention BSA
|
||||||
|
if sparse_policy and sparse_policy.upper() == "XATTN_BSA":
|
||||||
|
DensityObserver.enable()
|
||||||
|
DensityObserver.complete_reset()
|
||||||
|
# Set mode for correct density interpretation
|
||||||
|
DensityObserver.set_mode("offload" if enable_cpu_offload else "gpu_only")
|
||||||
|
if not json_output:
|
||||||
|
mode_str = "offload" if enable_cpu_offload else "gpu_only"
|
||||||
|
print(f"[DensityObserver] Enabled for XAttention BSA (mode: {mode_str})")
|
||||||
|
|
||||||
# LLM initialization kwargs
|
# LLM initialization kwargs
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
"max_model_len": max_model_len,
|
"max_model_len": max_model_len,
|
||||||
@@ -471,6 +482,14 @@ def run_ruler_benchmark(
|
|||||||
print(f"{'-'*54}")
|
print(f"{'-'*54}")
|
||||||
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
|
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
|
||||||
print(f"\nTime: {total_time:.1f}s")
|
print(f"\nTime: {total_time:.1f}s")
|
||||||
|
|
||||||
|
# Print DensityObserver summary if enabled
|
||||||
|
if sparse_policy and sparse_policy.upper() == "XATTN_BSA" and DensityObserver.is_enabled():
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print("Density Statistics (XAttention BSA)")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
DensityObserver.print_summary()
|
||||||
|
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
results = {
|
results = {
|
||||||
|
|||||||
@@ -1,199 +0,0 @@
|
|||||||
"""
|
|
||||||
Sequential inference test for LLM.
|
|
||||||
|
|
||||||
Tests: After completing one prompt, the system can correctly handle
|
|
||||||
a second prompt with a clean state (first prompt's KV cache deallocated).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from utils import generate_needle_prompt, check_needle_answer
|
|
||||||
|
|
||||||
|
|
||||||
def run_sequential_test(
|
|
||||||
model_path: str,
|
|
||||||
max_model_len: int,
|
|
||||||
input_len: int,
|
|
||||||
num_gpu_blocks: int = 4,
|
|
||||||
block_size: int = 1024,
|
|
||||||
enable_cpu_offload: bool = False,
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Run sequential inference test with two different prompts.
|
|
||||||
|
|
||||||
Each prompt has a different needle value. Both must be retrieved correctly.
|
|
||||||
"""
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Sequential Inference Test")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Max model len: {max_model_len}")
|
|
||||||
print(f"Input length: {input_len}")
|
|
||||||
print(f"Block size: {block_size}")
|
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# Initialize LLM once
|
|
||||||
llm_kwargs = {
|
|
||||||
"enforce_eager": True,
|
|
||||||
"max_model_len": max_model_len,
|
|
||||||
"max_num_batched_tokens": max_model_len,
|
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
|
||||||
"kvcache_block_size": block_size,
|
|
||||||
}
|
|
||||||
if enable_cpu_offload:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=32,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test 1: First prompt with needle value "1234"
|
|
||||||
# ============================================================
|
|
||||||
needle_value_1 = "1234"
|
|
||||||
if verbose:
|
|
||||||
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
|
|
||||||
|
|
||||||
prompt_1, expected_1 = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=0.5,
|
|
||||||
needle_value=needle_value_1,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
|
|
||||||
output_text_1 = outputs_1[0]["text"]
|
|
||||||
passed_1 = check_needle_answer(output_text_1, expected_1)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f" Expected: {expected_1}")
|
|
||||||
print(f" Output: {output_text_1[:100]}...")
|
|
||||||
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test 2: Second prompt with needle value "5678"
|
|
||||||
# ============================================================
|
|
||||||
needle_value_2 = "5678"
|
|
||||||
if verbose:
|
|
||||||
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
|
|
||||||
|
|
||||||
prompt_2, expected_2 = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=0.5,
|
|
||||||
needle_value=needle_value_2,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
|
|
||||||
output_text_2 = outputs_2[0]["text"]
|
|
||||||
passed_2 = check_needle_answer(output_text_2, expected_2)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f" Expected: {expected_2}")
|
|
||||||
print(f" Output: {output_text_2[:100]}...")
|
|
||||||
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
|
|
||||||
# ============================================================
|
|
||||||
needle_value_3 = "9999"
|
|
||||||
if verbose:
|
|
||||||
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
|
|
||||||
|
|
||||||
prompt_3, expected_3 = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=0.5,
|
|
||||||
needle_value=needle_value_3,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
|
|
||||||
output_text_3 = outputs_3[0]["text"]
|
|
||||||
passed_3 = check_needle_answer(output_text_3, expected_3)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f" Expected: {expected_3}")
|
|
||||||
print(f" Output: {output_text_3[:100]}...")
|
|
||||||
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Summary
|
|
||||||
# ============================================================
|
|
||||||
all_passed = passed_1 and passed_2 and passed_3
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Summary")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
|
|
||||||
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
|
|
||||||
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
|
|
||||||
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return all_passed
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Sequential inference test")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
type=str,
|
|
||||||
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
|
|
||||||
help="Path to model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-model-len",
|
|
||||||
type=int,
|
|
||||||
default=36 * 1024,
|
|
||||||
help="Maximum model context length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=8 * 1024,
|
|
||||||
help="Target input sequence length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-gpu-blocks",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Number of GPU blocks for CPU offload"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--block-size",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="KV cache block size"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-offload",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CPU offload"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
passed = run_sequential_test(
|
|
||||||
model_path=args.model,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
input_len=args.input_len,
|
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
|
||||||
block_size=args.block_size,
|
|
||||||
enable_cpu_offload=args.enable_offload,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_sequential: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_sequential: FAILED")
|
|
||||||
exit(1)
|
|
||||||
@@ -1,334 +0,0 @@
|
|||||||
"""
|
|
||||||
Test XAttention + BSA with RULER benchmark data.
|
|
||||||
|
|
||||||
Tests XAttention sparse attention correctness using RULER NIAH task.
|
|
||||||
|
|
||||||
Attention methods:
|
|
||||||
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
|
|
||||||
- Decode: FlashAttention (always, since q_len=1)
|
|
||||||
|
|
||||||
Usage (in compass conda env with BSA available):
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
|
||||||
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
|
|
||||||
|
|
||||||
# Test with XAttention + BSA for prefill (default)
|
|
||||||
python tests/test_xattn_bsa.py --prefill-method xattn
|
|
||||||
|
|
||||||
# Test with FlashAttention for prefill (baseline)
|
|
||||||
python tests/test_xattn_bsa.py --prefill-method flash
|
|
||||||
|
|
||||||
# Test specific sample(s)
|
|
||||||
python tests/test_xattn_bsa.py --sample-id 0
|
|
||||||
python tests/test_xattn_bsa.py --sample-ids 0,1,2
|
|
||||||
|
|
||||||
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
|
|
||||||
and new `past_key_values` API).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from transformers.cache_utils import DynamicCache
|
|
||||||
|
|
||||||
from nanovllm.ops.xattn import xattn_estimate
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# XAttention + BSA Functions
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def expand_kv_for_gqa(key_states, value_states, num_heads):
|
|
||||||
"""Expand KV for Grouped Query Attention."""
|
|
||||||
num_kv_heads = key_states.shape[1]
|
|
||||||
if num_heads == num_kv_heads:
|
|
||||||
return key_states, value_states
|
|
||||||
num_groups = num_heads // num_kv_heads
|
|
||||||
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
|
|
||||||
"""Standard FlashAttention."""
|
|
||||||
from flash_attn import flash_attn_func
|
|
||||||
q = query_states.transpose(1, 2)
|
|
||||||
k = key_states.transpose(1, 2)
|
|
||||||
v = value_states.transpose(1, 2)
|
|
||||||
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
|
|
||||||
"""XAttention + BSA sparse attention."""
|
|
||||||
from block_sparse_attn import block_sparse_attn_func
|
|
||||||
|
|
||||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
|
||||||
k_len = key_states.shape[2]
|
|
||||||
|
|
||||||
_, mask = xattn_estimate(
|
|
||||||
query_states, key_states,
|
|
||||||
chunk_size=16384, block_size=128, threshold=threshold,
|
|
||||||
use_triton=True, causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
q_block_num = (q_len + 127) // 128
|
|
||||||
k_block_num = (k_len + 127) // 128
|
|
||||||
|
|
||||||
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
|
|
||||||
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
|
||||||
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
|
||||||
|
|
||||||
__import__('pdb').set_trace()
|
|
||||||
|
|
||||||
output = block_sparse_attn_func(
|
|
||||||
q, k, v,
|
|
||||||
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
|
|
||||||
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
|
|
||||||
torch.ones(num_heads, dtype=torch.int32, device=q.device),
|
|
||||||
None,
|
|
||||||
mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
|
||||||
q_len, k_len,
|
|
||||||
p_dropout=0.0, deterministic=True, is_causal=True,
|
|
||||||
)
|
|
||||||
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
DEBUG = False # Set to True to enable debugging
|
|
||||||
|
|
||||||
def create_patched_forward(prefill_method="xattn", threshold=0.9):
|
|
||||||
"""Create patched forward with configurable prefill method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
|
|
||||||
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- Prefill (q_len > 1): Uses specified prefill_method
|
|
||||||
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
|
|
||||||
"""
|
|
||||||
call_count = [0] # Mutable to track calls across layers
|
|
||||||
|
|
||||||
def patched_forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
position_embeddings=None,
|
|
||||||
attention_mask=None,
|
|
||||||
past_key_value=None, # Old API (transformers < 4.57)
|
|
||||||
past_key_values=None, # New API (transformers >= 4.57)
|
|
||||||
cache_position=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
# Handle both old and new transformers API
|
|
||||||
kv_cache = past_key_values if past_key_values is not None else past_key_value
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
num_heads = self.config.num_attention_heads
|
|
||||||
num_kv_heads = self.config.num_key_value_heads
|
|
||||||
head_dim = self.head_dim
|
|
||||||
|
|
||||||
# Compute Q, K, V projections
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
|
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Apply rotary position embedding
|
|
||||||
cos, sin = position_embeddings
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
||||||
|
|
||||||
# Handle KV cache
|
|
||||||
if kv_cache is not None:
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
||||||
key_states, value_states = kv_cache.update(
|
|
||||||
key_states, value_states, self.layer_idx, cache_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expand KV for GQA
|
|
||||||
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
|
|
||||||
|
|
||||||
# Debug output
|
|
||||||
if DEBUG and self.layer_idx == 0:
|
|
||||||
call_count[0] += 1
|
|
||||||
if call_count[0] <= 5:
|
|
||||||
phase = "prefill" if q_len > 1 else "decode"
|
|
||||||
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
|
|
||||||
print(f" kv_cache is None: {kv_cache is None}")
|
|
||||||
|
|
||||||
# Choose attention method:
|
|
||||||
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
|
|
||||||
# - Decode (q_len = 1): Always use FlashAttention
|
|
||||||
is_prefill = q_len > 1
|
|
||||||
|
|
||||||
if is_prefill and prefill_method == "xattn":
|
|
||||||
# Prefill with XAttention + BSA (sparse)
|
|
||||||
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
|
|
||||||
else:
|
|
||||||
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
|
|
||||||
# Note: For decode (q_len=1), causal=False since single query attends to all KV
|
|
||||||
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
|
|
||||||
return attn_output, None
|
|
||||||
|
|
||||||
return patched_forward
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Data & Evaluation
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def load_samples(filepath, indices=None):
|
|
||||||
"""Load samples from JSONL file."""
|
|
||||||
samples = []
|
|
||||||
with open(filepath) as f:
|
|
||||||
for i, line in enumerate(f):
|
|
||||||
if indices is None or i in indices:
|
|
||||||
sample = json.loads(line)
|
|
||||||
sample["_idx"] = i
|
|
||||||
samples.append(sample)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def string_match_all(output_text, expected_list):
|
|
||||||
"""RULER metric: fraction of expected values found in output."""
|
|
||||||
output_lower = output_text.lower().replace('\n', ' ')
|
|
||||||
if not expected_list:
|
|
||||||
return 1.0
|
|
||||||
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
|
|
||||||
"""Test attention methods using RULER data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
|
|
||||||
"""
|
|
||||||
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("RULER NIAH Attention Test")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Data: {data_file}")
|
|
||||||
print(f"Samples: {sample_ids}")
|
|
||||||
print(f"Prefill method: {prefill_desc}")
|
|
||||||
print(f"Decode method: FlashAttention (always)")
|
|
||||||
if prefill_method == "xattn":
|
|
||||||
print(f"XAttention threshold: {threshold}")
|
|
||||||
|
|
||||||
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
|
|
||||||
if not samples:
|
|
||||||
print("No samples found!")
|
|
||||||
return False
|
|
||||||
print(f"Loaded {len(samples)} samples")
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
print(f"\nLoading model: {model_path}")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_path, torch_dtype=torch.float16, device_map="cuda",
|
|
||||||
attn_implementation="eager", # Will be patched
|
|
||||||
)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Patch all layers
|
|
||||||
print(f"Patching attention layers...")
|
|
||||||
print(f" - Prefill: {prefill_desc}")
|
|
||||||
print(f" - Decode: FlashAttention")
|
|
||||||
for idx, layer in enumerate(model.model.layers):
|
|
||||||
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
|
|
||||||
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
|
|
||||||
layer.self_attn, type(layer.self_attn)
|
|
||||||
)
|
|
||||||
|
|
||||||
total_score = 0.0
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for sample in samples:
|
|
||||||
idx = sample["_idx"]
|
|
||||||
prompt = sample["input"]
|
|
||||||
expected = sample["outputs"]
|
|
||||||
|
|
||||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
|
||||||
num_tokens = inputs["input_ids"].shape[1]
|
|
||||||
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
|
|
||||||
print(f"Expected: {expected}")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model.generate(
|
|
||||||
inputs["input_ids"],
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
|
|
||||||
score = string_match_all(output_text, expected)
|
|
||||||
total_score += score
|
|
||||||
|
|
||||||
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
|
|
||||||
print(f"Output: '{output_text[:100]}...'")
|
|
||||||
print(f"Result: {status} (score={score:.2f})")
|
|
||||||
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
|
|
||||||
|
|
||||||
avg_score = total_score / len(samples)
|
|
||||||
passed = sum(1 for r in results if r["passed"])
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
return avg_score >= 0.5
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
|
|
||||||
)
|
|
||||||
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
|
|
||||||
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
|
|
||||||
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
|
|
||||||
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
|
|
||||||
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
|
|
||||||
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
|
|
||||||
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
|
|
||||||
parser.add_argument("--max-new-tokens", type=int, default=50)
|
|
||||||
# Keep old option for backwards compatibility
|
|
||||||
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
model_path = args.model.replace("~", "/home/zijie")
|
|
||||||
|
|
||||||
# Handle deprecated --no-xattn option
|
|
||||||
prefill_method = args.prefill_method
|
|
||||||
if args.no_xattn:
|
|
||||||
prefill_method = "flash"
|
|
||||||
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
|
|
||||||
|
|
||||||
if args.sample_id is not None:
|
|
||||||
sample_ids = [args.sample_id]
|
|
||||||
elif args.sample_ids:
|
|
||||||
sample_ids = [int(x) for x in args.sample_ids.split(",")]
|
|
||||||
else:
|
|
||||||
sample_ids = [0]
|
|
||||||
|
|
||||||
# Check BSA availability if using xattn
|
|
||||||
if prefill_method == "xattn":
|
|
||||||
try:
|
|
||||||
from block_sparse_attn import block_sparse_attn_func
|
|
||||||
print("✓ BSA (Block Sparse Attention) available")
|
|
||||||
except ImportError:
|
|
||||||
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
|
|
||||||
print("\ntest_xattn_bsa: PASSED")
|
|
||||||
else:
|
|
||||||
print("\ntest_xattn_bsa: FAILED")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,259 +0,0 @@
|
|||||||
"""
|
|
||||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
|
||||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
|
|
||||||
|
|
||||||
Uses real QKV data captured from model inference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
BLOCK_SIZE = 64
|
|
||||||
STRIDE = 4
|
|
||||||
THRESHOLD = 0.9
|
|
||||||
CHUNK_SIZE = 4096
|
|
||||||
|
|
||||||
# Default QKV data directory (relative to project root)
|
|
||||||
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Utility Functions
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def load_qkv(path):
|
|
||||||
"""Load saved QKV data."""
|
|
||||||
data = torch.load(path, map_location="cpu", weights_only=False)
|
|
||||||
print(f"Loaded: {path}")
|
|
||||||
print(f" Query shape: {data['query'].shape}")
|
|
||||||
print(f" Key shape: {data['key'].shape}")
|
|
||||||
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
|
||||||
"""Compare two masks and report differences."""
|
|
||||||
if mask1.shape != mask2.shape:
|
|
||||||
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
diff = (mask1 != mask2).sum().item()
|
|
||||||
total = mask1.numel()
|
|
||||||
match_rate = (total - diff) / total * 100
|
|
||||||
|
|
||||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
|
||||||
|
|
||||||
if diff > 0:
|
|
||||||
diff_indices = torch.where(mask1 != mask2)
|
|
||||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
|
||||||
|
|
||||||
return diff == 0
|
|
||||||
|
|
||||||
|
|
||||||
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
|
|
||||||
"""
|
|
||||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
|
||||||
This simulates how chunked prefill should be used in practice.
|
|
||||||
"""
|
|
||||||
batch_size, num_heads, q_len, head_dim = query.shape
|
|
||||||
_, _, k_len, _ = key.shape
|
|
||||||
|
|
||||||
q_block_num = (q_len + block_size - 1) // block_size
|
|
||||||
k_block_num = (k_len + block_size - 1) // block_size
|
|
||||||
|
|
||||||
# If Q fits in one chunk, call directly
|
|
||||||
if q_len <= chunk_size:
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
return xattn_estimate_chunked(
|
|
||||||
query, key,
|
|
||||||
q_start_pos=q_start_pos,
|
|
||||||
block_size=block_size,
|
|
||||||
stride=stride,
|
|
||||||
threshold=threshold,
|
|
||||||
use_triton=True,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# External chunking: split Q and call for each chunk
|
|
||||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
|
||||||
print(f" External chunking: {num_q_chunks} chunks")
|
|
||||||
|
|
||||||
combined_attn_sum = torch.zeros(
|
|
||||||
batch_size, num_heads, q_block_num, k_block_num,
|
|
||||||
dtype=query.dtype, device=query.device
|
|
||||||
)
|
|
||||||
combined_mask = torch.zeros(
|
|
||||||
batch_size, num_heads, q_block_num, k_block_num,
|
|
||||||
dtype=torch.bool, device=query.device
|
|
||||||
)
|
|
||||||
|
|
||||||
q_block_offset = 0
|
|
||||||
for q_chunk_idx in range(num_q_chunks):
|
|
||||||
q_chunk_start = q_chunk_idx * chunk_size
|
|
||||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
|
||||||
|
|
||||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
|
||||||
|
|
||||||
# For causal attention, K accumulates up to current Q position
|
|
||||||
k_end = q_start_pos + q_chunk_end
|
|
||||||
k_chunk = key[:, :, :k_end, :]
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
|
||||||
q_chunk, k_chunk,
|
|
||||||
q_start_pos=q_start_pos + q_chunk_start,
|
|
||||||
block_size=block_size,
|
|
||||||
stride=stride,
|
|
||||||
threshold=threshold,
|
|
||||||
use_triton=True,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Place chunk results into combined output
|
|
||||||
chunk_q_blocks = mask_chunk.shape[2]
|
|
||||||
chunk_k_blocks = mask_chunk.shape[3]
|
|
||||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
|
||||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
|
||||||
q_block_offset += chunk_q_blocks
|
|
||||||
|
|
||||||
return combined_attn_sum, combined_mask
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_qkv(qkv_path):
|
|
||||||
"""Test a single QKV file."""
|
|
||||||
data = load_qkv(qkv_path)
|
|
||||||
query = data["query"].cuda().to(torch.bfloat16)
|
|
||||||
key = data["key"].cuda().to(torch.bfloat16)
|
|
||||||
|
|
||||||
seq_len = query.shape[2]
|
|
||||||
print(f"\nTesting with seq_len={seq_len}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Run standard xattn_estimate
|
|
||||||
print("[1] Running standard xattn_estimate...")
|
|
||||||
try:
|
|
||||||
attn_sum_std, mask_std = xattn_estimate(
|
|
||||||
query, key,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
stride=STRIDE,
|
|
||||||
threshold=THRESHOLD,
|
|
||||||
chunk_size=CHUNK_SIZE,
|
|
||||||
use_triton=True,
|
|
||||||
)
|
|
||||||
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ERROR: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
|
||||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
|
||||||
try:
|
|
||||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
|
||||||
query, key,
|
|
||||||
q_start_pos=0,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
stride=STRIDE,
|
|
||||||
threshold=THRESHOLD,
|
|
||||||
chunk_size=CHUNK_SIZE,
|
|
||||||
)
|
|
||||||
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ERROR: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
print("[3] Comparing results...")
|
|
||||||
chunked_q_blocks = mask_chunked.shape[2]
|
|
||||||
chunked_k_blocks = mask_chunked.shape[3]
|
|
||||||
|
|
||||||
# Extract comparable region from standard mask
|
|
||||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
||||||
|
|
||||||
# Compare masks
|
|
||||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
|
||||||
|
|
||||||
# Compare attn_sums
|
|
||||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
||||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
|
||||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
|
||||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
|
||||||
else:
|
|
||||||
print(f" Attn sum shape mismatch")
|
|
||||||
|
|
||||||
# Clean up GPU memory
|
|
||||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return masks_match
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
|
|
||||||
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
|
|
||||||
help="Directory containing QKV files")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# QKV files to test
|
|
||||||
qkv_files = [
|
|
||||||
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
|
|
||||||
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
|
|
||||||
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
|
|
||||||
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
|
|
||||||
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
|
|
||||||
]
|
|
||||||
|
|
||||||
available_files = [p for p in qkv_files if os.path.exists(p)]
|
|
||||||
|
|
||||||
if not available_files:
|
|
||||||
print(f"No QKV file found in {args.qkv_dir}.")
|
|
||||||
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"Found {len(available_files)} QKV files to test")
|
|
||||||
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
|
|
||||||
print(f"Using Triton kernels")
|
|
||||||
|
|
||||||
all_passed = True
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for qkv_path in available_files:
|
|
||||||
passed = test_single_qkv(qkv_path)
|
|
||||||
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
|
|
||||||
results.append((seq_len, passed))
|
|
||||||
if not passed:
|
|
||||||
all_passed = False
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("SUMMARY")
|
|
||||||
print("=" * 60)
|
|
||||||
for seq_len, passed in results:
|
|
||||||
status = "PASSED" if passed else "FAILED"
|
|
||||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
|
||||||
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
if all_passed:
|
|
||||||
print("test_xattn_chunked: PASSED")
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print("test_xattn_chunked: FAILED")
|
|
||||||
sys.exit(1)
|
|
||||||
365
tests/test_xattn_estimate_alignment.py
Normal file
365
tests/test_xattn_estimate_alignment.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""
|
||||||
|
Test: 验证 xattn_estimate 与 KV chunking kernels 的一致性
|
||||||
|
|
||||||
|
使用真实 KV cache 数据,对比:
|
||||||
|
1. xattn_estimate (高层 API)
|
||||||
|
2. 三阶段 KV chunking (softmax_compute_partial_stats + merge + normalize)
|
||||||
|
|
||||||
|
三阶段 KV chunking 流程:
|
||||||
|
1. softmax_compute_partial_stats: 计算每个 KV chunk 的 (m, l)
|
||||||
|
2. merge_softmax_stats: Host 端合并所有 chunks 的 stats
|
||||||
|
3. softmax_normalize_and_block_sum: 使用全局 stats 归一化
|
||||||
|
|
||||||
|
支持两种数据格式:
|
||||||
|
1. offload 模式保存: {"query", "key", "stride", "threshold", "density", "layer_id"}
|
||||||
|
2. GPU-only 模式保存: {"Q", "K", "chunk_size", "block_size", "stride", "threshold", "mask", "attn_sums", ...}
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# 使用 offload 模式数据
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_alignment.py
|
||||||
|
|
||||||
|
# 使用 GPU-only 模式数据
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_alignment.py --gpuonly
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from nanovllm.ops.xattn import (
|
||||||
|
xattn_estimate,
|
||||||
|
flat_group_gemm_fuse_reshape,
|
||||||
|
softmax_compute_partial_stats,
|
||||||
|
softmax_normalize_and_block_sum,
|
||||||
|
merge_softmax_stats,
|
||||||
|
find_blocks_chunked,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 命令行参数
|
||||||
|
# ============================================================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--gpuonly", action="store_true", help="使用 GPU-only 模式保存的数据")
|
||||||
|
parser.add_argument("--data-file", type=str, default=None, help="数据文件路径")
|
||||||
|
parser.add_argument("--chunk-size", type=int, default=None, help="覆盖 CHUNK_SIZE (用于测试不同分块大小)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 参数配置
|
||||||
|
# ============================================================
|
||||||
|
if args.gpuonly:
|
||||||
|
DATA_FILE = args.data_file or "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt"
|
||||||
|
else:
|
||||||
|
DATA_FILE = args.data_file or "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt"
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 1: 加载真实数据
|
||||||
|
# ============================================================
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 1: 加载真实 KV cache 数据")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
data = torch.load(DATA_FILE, map_location="cpu")
|
||||||
|
|
||||||
|
# 检测数据格式并加载
|
||||||
|
if "Q" in data:
|
||||||
|
# GPU-only 模式保存的格式
|
||||||
|
print(f"[INFO] 检测到 GPU-only 模式数据格式")
|
||||||
|
Q = data["Q"].to(device)
|
||||||
|
K = data["K"].to(device)
|
||||||
|
BSA_BLOCK_SIZE = data.get("block_size", 128)
|
||||||
|
CHUNK_SIZE = data.get("chunk_size", 4096)
|
||||||
|
STRIDE = data.get("stride", 8)
|
||||||
|
THRESHOLD = data.get("threshold", 0.9)
|
||||||
|
if isinstance(THRESHOLD, torch.Tensor):
|
||||||
|
THRESHOLD = THRESHOLD.item()
|
||||||
|
# GPU-only 模式保存了 mask 和 attn_sums,可以用于验证
|
||||||
|
saved_mask = data.get("mask", None)
|
||||||
|
saved_attn_sums = data.get("attn_sums", None)
|
||||||
|
saved_density = None # GPU-only 模式没有保存 density
|
||||||
|
layer_id = 0 # GPU-only 只保存 layer 0
|
||||||
|
else:
|
||||||
|
# offload 模式保存的格式
|
||||||
|
print(f"[INFO] 检测到 offload 模式数据格式")
|
||||||
|
Q = data["query"].to(device)
|
||||||
|
K = data["key"].to(device)
|
||||||
|
BSA_BLOCK_SIZE = 128
|
||||||
|
CHUNK_SIZE = 4096
|
||||||
|
STRIDE = data["stride"]
|
||||||
|
THRESHOLD = data["threshold"]
|
||||||
|
if isinstance(THRESHOLD, torch.Tensor):
|
||||||
|
THRESHOLD = THRESHOLD[0].item()
|
||||||
|
saved_mask = None
|
||||||
|
saved_attn_sums = None
|
||||||
|
saved_density = data.get("density", None)
|
||||||
|
layer_id = data.get("layer_id", 0)
|
||||||
|
|
||||||
|
batch_size, num_heads, seq_len, head_dim = Q.shape
|
||||||
|
|
||||||
|
# 命令行覆盖 CHUNK_SIZE
|
||||||
|
if args.chunk_size is not None:
|
||||||
|
CHUNK_SIZE = args.chunk_size
|
||||||
|
print(f"[INFO] 使用命令行指定的 CHUNK_SIZE={CHUNK_SIZE}")
|
||||||
|
|
||||||
|
print(f"Q shape: {Q.shape}")
|
||||||
|
print(f"K shape: {K.shape}")
|
||||||
|
if saved_density is not None:
|
||||||
|
print(f"Data layer_id: {layer_id}, saved density: {saved_density:.4f}")
|
||||||
|
else:
|
||||||
|
print(f"Data layer_id: {layer_id}")
|
||||||
|
print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}, CHUNK_SIZE={CHUNK_SIZE}, BSA_BLOCK_SIZE={BSA_BLOCK_SIZE}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 2: 使用 xattn_estimate 高层 API
|
||||||
|
# ============================================================
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 2: 调用 xattn_estimate (高层 API)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
attn_sums_api, mask_api = xattn_estimate(
|
||||||
|
Q, K,
|
||||||
|
block_size=BSA_BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 裁剪到有效区域
|
||||||
|
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||||
|
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||||
|
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
|
||||||
|
|
||||||
|
# 计算 density (causal)
|
||||||
|
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
|
||||||
|
total_api = causal_mask.sum().item() * batch_size * num_heads
|
||||||
|
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
|
density_api = selected_api / total_api
|
||||||
|
|
||||||
|
print(f"mask_api shape (padded): {mask_api.shape}")
|
||||||
|
print(f"mask_api_valid shape: {mask_api_valid.shape}")
|
||||||
|
print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, total={total_api})")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 3: 三阶段 KV Chunking
|
||||||
|
# ============================================================
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 3: 三阶段 KV Chunking")
|
||||||
|
print("=" * 60)
|
||||||
|
print(" 1) 每个 KV chunk 计算 partial stats")
|
||||||
|
print(" 2) Host 端合并 stats")
|
||||||
|
print(" 3) 使用全局 stats 归一化并计算 block sums")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 计算 padding 参数
|
||||||
|
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||||
|
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||||
|
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
|
||||||
|
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||||
|
|
||||||
|
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
|
||||||
|
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
|
||||||
|
|
||||||
|
reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||||
|
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
||||||
|
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
|
||||||
|
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
|
||||||
|
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||||
|
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||||
|
|
||||||
|
print(f"seq_len: {seq_len}, q_chunk_num: {q_chunk_num}, kv_chunk_num: {kv_chunk_num}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Padding
|
||||||
|
if k_num_to_pad > 0:
|
||||||
|
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
|
||||||
|
else:
|
||||||
|
K_padded = K
|
||||||
|
|
||||||
|
if q_num_to_pad > 0:
|
||||||
|
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
|
||||||
|
else:
|
||||||
|
Q_padded = Q
|
||||||
|
|
||||||
|
# Softmax scale
|
||||||
|
norm = 1.0
|
||||||
|
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
||||||
|
|
||||||
|
simple_mask_list = []
|
||||||
|
|
||||||
|
for q_chunk_idx in range(q_chunk_num):
|
||||||
|
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
|
||||||
|
q_end = q_start + reshaped_chunk_size * STRIDE
|
||||||
|
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||||
|
|
||||||
|
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
|
||||||
|
chunk_end = chunk_start + reshaped_chunk_size
|
||||||
|
|
||||||
|
# 阶段 1: 每个 KV chunk 计算 partial stats 和 raw scores
|
||||||
|
m_chunks = []
|
||||||
|
l_chunks = []
|
||||||
|
attn_weights_chunks = []
|
||||||
|
|
||||||
|
for kv_chunk_idx in range(kv_chunk_num):
|
||||||
|
kv_start = kv_chunk_idx * CHUNK_SIZE
|
||||||
|
kv_end = kv_start + CHUNK_SIZE
|
||||||
|
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||||
|
|
||||||
|
# KV offset in reshaped space
|
||||||
|
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||||
|
|
||||||
|
# 计算 raw attention scores
|
||||||
|
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||||
|
Q_chunk, K_chunk, STRIDE,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False, # K 不完整,不能在这里用 causal
|
||||||
|
)
|
||||||
|
attn_weights_chunks.append(attn_weights_kv)
|
||||||
|
|
||||||
|
# 计算 partial stats (带 causal mask)
|
||||||
|
m_partial, l_partial = softmax_compute_partial_stats(
|
||||||
|
attn_weights_kv,
|
||||||
|
reshaped_block_size,
|
||||||
|
min(4096, reshaped_block_size),
|
||||||
|
scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset_reshaped,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
m_chunks.append(m_partial)
|
||||||
|
l_chunks.append(l_partial)
|
||||||
|
|
||||||
|
# 阶段 2: Host 端合并 stats
|
||||||
|
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||||
|
|
||||||
|
# 阶段 3: 使用全局 stats 归一化并计算 block sums
|
||||||
|
attn_sum_per_kv = []
|
||||||
|
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
|
||||||
|
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||||
|
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_kv,
|
||||||
|
m_global,
|
||||||
|
l_global,
|
||||||
|
reshaped_block_size,
|
||||||
|
min(4096, reshaped_block_size),
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||||
|
scale=scale,
|
||||||
|
kv_offset=kv_offset_reshaped,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
attn_sum_per_kv.append(attn_sum_kv)
|
||||||
|
|
||||||
|
# 拼接各 KV chunk 的 block sums
|
||||||
|
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||||
|
|
||||||
|
# 选择 blocks
|
||||||
|
simple_mask = find_blocks_chunked(
|
||||||
|
attn_sum_concat,
|
||||||
|
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
num_to_choose=None,
|
||||||
|
decoding=False,
|
||||||
|
mode="prefill",
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
simple_mask_list.append(simple_mask)
|
||||||
|
|
||||||
|
print(f" Q chunk {q_chunk_idx}: merged {kv_chunk_num} KV chunks, attn_sum shape={attn_sum_concat.shape}")
|
||||||
|
|
||||||
|
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
|
||||||
|
|
||||||
|
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
|
||||||
|
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||||
|
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
|
||||||
|
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
|
||||||
|
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
|
density_kv = selected_kv / total_api
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"[KV chunking] density: {density_kv:.6f} (selected={selected_kv}, total={total_api})")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 4: 对比结果
|
||||||
|
# ============================================================
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 4: 对比结果")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
mask_total = mask_api_valid.numel()
|
||||||
|
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
|
||||||
|
|
||||||
|
print("| 方法 | density | 与 API 差异 | Mask 差异 |")
|
||||||
|
print("|------|---------|-------------|-----------|")
|
||||||
|
print(f"| xattn_estimate API | {density_api:.6f} | - | - |")
|
||||||
|
print(f"| KV chunking | {density_kv:.6f} | {abs(density_api - density_kv):.6f} | {100*mask_diff/mask_total:.4f}% |")
|
||||||
|
print()
|
||||||
|
|
||||||
|
passed = abs(density_api - density_kv) < 1e-6 and mask_diff / mask_total < 0.001
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 5: 与 GPU-only 保存的数据对比 (如果有)
|
||||||
|
# ============================================================
|
||||||
|
if saved_mask is not None or saved_attn_sums is not None:
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 5: 与 GPU-only 保存的数据对比")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if saved_mask is not None:
|
||||||
|
saved_mask_gpu = saved_mask.to(device)
|
||||||
|
# 比较 mask
|
||||||
|
mask_saved_diff = (mask_api_valid != saved_mask_gpu).sum().item()
|
||||||
|
mask_saved_total = saved_mask_gpu.numel()
|
||||||
|
print(f"| xattn_estimate vs GPU-only saved mask | 差异 blocks: {mask_saved_diff} / {mask_saved_total} ({100*mask_saved_diff/mask_saved_total:.4f}%) |")
|
||||||
|
|
||||||
|
if mask_saved_diff == 0:
|
||||||
|
print("✅ mask 与 GPU-only 保存完全一致")
|
||||||
|
else:
|
||||||
|
print("❌ mask 与 GPU-only 保存存在差异")
|
||||||
|
passed = False
|
||||||
|
|
||||||
|
if saved_attn_sums is not None:
|
||||||
|
saved_attn_sums_gpu = saved_attn_sums.to(device)
|
||||||
|
# 需要从 xattn_estimate 获取 attn_sums
|
||||||
|
# 重新调用一次获取 attn_sums
|
||||||
|
attn_sums_check, _ = xattn_estimate(
|
||||||
|
Q, K,
|
||||||
|
block_size=BSA_BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn_sums_check_valid = attn_sums_check[:, :, :q_blocks, :k_blocks]
|
||||||
|
|
||||||
|
max_diff = (attn_sums_check_valid - saved_attn_sums_gpu).abs().max().item()
|
||||||
|
mean_diff = (attn_sums_check_valid - saved_attn_sums_gpu).abs().mean().item()
|
||||||
|
print(f"| xattn_estimate vs GPU-only saved attn_sums | max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e} |")
|
||||||
|
|
||||||
|
if max_diff < 1e-5:
|
||||||
|
print("✅ attn_sums 与 GPU-only 保存一致")
|
||||||
|
else:
|
||||||
|
print("❌ attn_sums 与 GPU-only 保存存在差异")
|
||||||
|
passed = False
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
print("test_xattn_estimate_alignment: PASSED")
|
||||||
|
else:
|
||||||
|
print("test_xattn_estimate_alignment: FAILED")
|
||||||
@@ -1,244 +0,0 @@
|
|||||||
"""
|
|
||||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
|
||||||
|
|
||||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
|
||||||
as standard estimation. This ensures the chunked version can be used in
|
|
||||||
chunked prefill scenarios without accuracy loss.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
|
||||||
python tests/test_xattn_estimate_chunked.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import torch
|
|
||||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Configuration for xattn_estimate_chunked consistency test.
|
|
||||||
# Key requirements for 100% match:
|
|
||||||
# 1. Use matching chunk_size for both standard and chunked versions
|
|
||||||
# 2. Use same random seed for reproducibility
|
|
||||||
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
|
||||||
# floating point precision in cumulative sum calculations.
|
|
||||||
BLOCK_SIZE = 64
|
|
||||||
STRIDE = 4
|
|
||||||
THRESHOLD = 0.9
|
|
||||||
CHUNK_SIZE = 4096 # External chunking size
|
|
||||||
|
|
||||||
# Test sequence lengths
|
|
||||||
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Utility Functions
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
|
||||||
"""Compare two masks and report differences."""
|
|
||||||
if mask1.shape != mask2.shape:
|
|
||||||
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
diff = (mask1 != mask2).sum().item()
|
|
||||||
total = mask1.numel()
|
|
||||||
match_rate = (total - diff) / total * 100
|
|
||||||
|
|
||||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
|
||||||
|
|
||||||
if diff > 0:
|
|
||||||
diff_indices = torch.where(mask1 != mask2)
|
|
||||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
|
||||||
|
|
||||||
return diff == 0
|
|
||||||
|
|
||||||
|
|
||||||
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
|
||||||
"""
|
|
||||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
|
||||||
This simulates how chunked prefill should be used in practice.
|
|
||||||
"""
|
|
||||||
batch_size, num_heads, q_len, head_dim = query.shape
|
|
||||||
_, _, k_len, _ = key.shape
|
|
||||||
|
|
||||||
q_block_num = (q_len + block_size - 1) // block_size
|
|
||||||
k_block_num = (k_len + block_size - 1) // block_size
|
|
||||||
|
|
||||||
# If Q fits in one chunk, call directly
|
|
||||||
if q_len <= chunk_size:
|
|
||||||
return xattn_estimate_chunked(
|
|
||||||
query, key,
|
|
||||||
q_start_pos=0,
|
|
||||||
block_size=block_size,
|
|
||||||
stride=stride,
|
|
||||||
threshold=threshold,
|
|
||||||
use_triton=True,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# External chunking: split Q and call for each chunk
|
|
||||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
|
||||||
print(f" External chunking: {num_q_chunks} chunks")
|
|
||||||
|
|
||||||
combined_attn_sum = torch.zeros(
|
|
||||||
batch_size, num_heads, q_block_num, k_block_num,
|
|
||||||
dtype=query.dtype, device=query.device
|
|
||||||
)
|
|
||||||
combined_mask = torch.zeros(
|
|
||||||
batch_size, num_heads, q_block_num, k_block_num,
|
|
||||||
dtype=torch.bool, device=query.device
|
|
||||||
)
|
|
||||||
|
|
||||||
q_block_offset = 0
|
|
||||||
for q_chunk_idx in range(num_q_chunks):
|
|
||||||
q_chunk_start = q_chunk_idx * chunk_size
|
|
||||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
|
||||||
|
|
||||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
|
||||||
|
|
||||||
# For causal attention, K accumulates up to current Q position
|
|
||||||
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
|
||||||
# K is [0, q_chunk_end) for causal attention
|
|
||||||
k_end = q_chunk_end
|
|
||||||
k_chunk = key[:, :, :k_end, :]
|
|
||||||
|
|
||||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
|
||||||
q_chunk, k_chunk,
|
|
||||||
q_start_pos=q_chunk_start,
|
|
||||||
block_size=block_size,
|
|
||||||
stride=stride,
|
|
||||||
threshold=threshold,
|
|
||||||
use_triton=True,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Place chunk results into combined output
|
|
||||||
chunk_q_blocks = mask_chunk.shape[2]
|
|
||||||
chunk_k_blocks = mask_chunk.shape[3]
|
|
||||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
|
||||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
|
||||||
q_block_offset += chunk_q_blocks
|
|
||||||
|
|
||||||
return combined_attn_sum, combined_mask
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
|
||||||
"""Test a single sequence length."""
|
|
||||||
print(f"\nTesting seq_len={seq_len}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Generate random Q/K
|
|
||||||
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
|
||||||
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# Run standard xattn_estimate
|
|
||||||
print("[1] Running standard xattn_estimate...")
|
|
||||||
try:
|
|
||||||
attn_sum_std, mask_std = xattn_estimate(
|
|
||||||
query, key,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
stride=STRIDE,
|
|
||||||
threshold=THRESHOLD,
|
|
||||||
chunk_size=CHUNK_SIZE,
|
|
||||||
use_triton=True,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
density_std = mask_std.float().mean().item()
|
|
||||||
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ERROR: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
|
||||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
|
||||||
try:
|
|
||||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
|
||||||
query, key,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
stride=STRIDE,
|
|
||||||
threshold=THRESHOLD,
|
|
||||||
chunk_size=CHUNK_SIZE,
|
|
||||||
)
|
|
||||||
density_chunked = mask_chunked.float().mean().item()
|
|
||||||
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ERROR: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
print("[3] Comparing results...")
|
|
||||||
chunked_q_blocks = mask_chunked.shape[2]
|
|
||||||
chunked_k_blocks = mask_chunked.shape[3]
|
|
||||||
|
|
||||||
# Extract comparable region from standard mask
|
|
||||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
||||||
|
|
||||||
# Compare masks
|
|
||||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
|
||||||
|
|
||||||
# Compare attn_sums
|
|
||||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
||||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
|
||||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
|
||||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
|
||||||
else:
|
|
||||||
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
|
||||||
|
|
||||||
# Clean up GPU memory
|
|
||||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return masks_match
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("XAttention Chunked vs Standard Test")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
|
||||||
print(f"External chunk_size={CHUNK_SIZE}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Check CUDA availability
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available!")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
||||||
print("✓ xattn_estimate imported")
|
|
||||||
print("✓ xattn_estimate_chunked imported")
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
all_passed = True
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for seq_len in TEST_SEQ_LENS:
|
|
||||||
passed = test_single_seq_len(seq_len)
|
|
||||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
|
||||||
results.append((seq_len, chunks, passed))
|
|
||||||
if not passed:
|
|
||||||
all_passed = False
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("SUMMARY")
|
|
||||||
print("=" * 60)
|
|
||||||
for seq_len, chunks, passed in results:
|
|
||||||
status = "PASSED" if passed else "FAILED"
|
|
||||||
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
if all_passed:
|
|
||||||
print("ALL TESTS PASSED!")
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print("SOME TESTS FAILED!")
|
|
||||||
sys.exit(1)
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
"""
|
|
||||||
Test: XAttention Triton kernels
|
|
||||||
|
|
||||||
演示 XAttention 的两个核心 Triton kernel:
|
|
||||||
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
|
|
||||||
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
|
|
||||||
|
|
||||||
数据流:
|
|
||||||
Q [batch, heads, q_len, head_dim]
|
|
||||||
K [batch, heads, kv_len, head_dim]
|
|
||||||
↓ flat_group_gemm_fuse_reshape
|
|
||||||
attn_scores [batch, heads, q_len/stride, kv_len/stride]
|
|
||||||
↓ softmax_fuse_block_sum
|
|
||||||
block_sums [batch, heads, q_blocks, k_blocks]
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import sys
|
|
||||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
|
||||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# 参数配置
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
|
|
||||||
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
|
|
||||||
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
|
|
||||||
q_len = 512
|
|
||||||
kv_len = 2048
|
|
||||||
head_dim = 128
|
|
||||||
stride = 4
|
|
||||||
block_size = 128 # softmax block size (in reshaped space)
|
|
||||||
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# 构造输入: 偶数位置=1, 奇数位置=2
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
|
|
||||||
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
|
|
||||||
|
|
||||||
for i in range(q_len):
|
|
||||||
if i % 2 == 0:
|
|
||||||
Q[0, 0, i, :] = 1
|
|
||||||
else:
|
|
||||||
Q[0, 0, i, :] = 2
|
|
||||||
|
|
||||||
for i in range(kv_len):
|
|
||||||
if i % 2 == 0:
|
|
||||||
K[0, 0, i, :] = 1
|
|
||||||
else:
|
|
||||||
K[0, 0, i, :] = 2
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
q_reshaped_len = q_len // stride # 128
|
|
||||||
kv_reshaped_len = kv_len // stride # 512
|
|
||||||
|
|
||||||
# 将 K 沿着长度维度分成多个 chunk
|
|
||||||
k_chunk_size = 512 # 每个 chunk 512 tokens
|
|
||||||
num_k_chunks = kv_len // k_chunk_size # 4 chunks
|
|
||||||
|
|
||||||
attn_scores_list = []
|
|
||||||
for k_chunk_idx in range(num_k_chunks):
|
|
||||||
k_start = k_chunk_idx * k_chunk_size
|
|
||||||
k_end = k_start + k_chunk_size
|
|
||||||
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
|
|
||||||
|
|
||||||
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
|
|
||||||
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
|
|
||||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
|
||||||
Q, K_chunk, stride,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=q_reshaped_len,
|
|
||||||
is_causal=False
|
|
||||||
)
|
|
||||||
attn_scores_list.append(attn_chunk)
|
|
||||||
|
|
||||||
# 拼接所有 K chunks 的结果
|
|
||||||
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
|
|
||||||
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
|
|
||||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
|
||||||
|
|
||||||
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
|
|
||||||
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
|
|
||||||
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
|
|
||||||
|
|
||||||
# 验证: 反对角线求和
|
|
||||||
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
|
|
||||||
# 反对角线有 stride/2 对,再乘以 head_dim
|
|
||||||
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
|
|
||||||
actual_gemm = attn_scores[0, 0, 0, 0].item()
|
|
||||||
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Step 2: softmax_fuse_block_sum
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
scale = 1.4426950408889634 # log2(e) for exp2
|
|
||||||
|
|
||||||
block_sums = softmax_fuse_block_sum(
|
|
||||||
attn_scores,
|
|
||||||
block_size,
|
|
||||||
segment_size,
|
|
||||||
chunk_start=0,
|
|
||||||
chunk_end=q_reshaped_len,
|
|
||||||
real_q_len=q_reshaped_len,
|
|
||||||
scale=scale,
|
|
||||||
is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 验证 shape: [batch, heads, q_blocks, k_blocks]
|
|
||||||
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
|
|
||||||
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
|
|
||||||
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
|
|
||||||
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
|
|
||||||
|
|
||||||
# 验证: 每个 block 的 softmax 结果求和
|
|
||||||
# 所有 attn_scores 相同 → softmax 均匀分布
|
|
||||||
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
|
|
||||||
# 每个 Q block 有 block_size 行
|
|
||||||
# block_sum = block_size * (block_size / kv_reshaped_len)
|
|
||||||
expected_sum = block_size * block_size / kv_reshaped_len
|
|
||||||
actual_sum = block_sums[0, 0, 0, 0].item()
|
|
||||||
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
|
|
||||||
|
|
||||||
print("test_xattn_kernels: PASSED")
|
|
||||||
Reference in New Issue
Block a user