15 Commits

Author SHA1 Message Date
Zijie Tian
52b12a89e3 📋 docs: add changelog for 2026-02-05
Document today's changes:
- GQA buffer OOM fix (saves 16GB for 1M seq in offload mode)
- Tests directory cleanup (removed 16 files, -4306 lines)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:16:39 +08:00
Zijie Tian
d35dd76e09 🗑️ chore: clean up tests directory to essential files only
Keep only core test files:
- test_ruler.py - main RULER benchmark
- test_xattn_estimate_alignment.py - XAttn kernel validation
- utils.py - shared utilities

Remove 8 files (recoverable from git history):
- bench_estimate_block_size.py
- modeling_qwen3.py
- test_chunk_attention_graph_reuse.py
- test_cudagraph_memory.py
- test_gpuonly_density_alignment.py
- test_hierarchical_estimate.py
- test_quest_policy.py
- test_sequential.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:13:50 +08:00
Zijie Tian
2b61c5ab57 🗑️ chore: remove test_needle* files
Remove needle tests (validation now covered by test_ruler.py):
- test_needle.py - basic needle-in-haystack test
- test_needle_ref.py - HuggingFace reference implementation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:11:28 +08:00
Zijie Tian
a709551072 🗑️ chore: remove redundant XAttention test files
Remove 6 obsolete test files:
- test_xattn_bsa.py - XAttn+BSA integration (covered by test_ruler)
- test_xattn_chunked.py - duplicate of test_xattn_estimate_chunked
- test_xattn_estimate_chunked.py - chunked prefill validation
- test_xattn_kernels.py - Triton kernel unit tests
- test_xattn_kv_chunking_batch.py - batch KV chunking validation
- test_chunk_attention_graph.py - superseded by graph_reuse version

Retained: test_xattn_estimate_alignment.py (critical kernel validation)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:11:21 +08:00
Zijie Tian
11a867f6fb 🐛 fix: skip GQA buffer allocation in XAttention offload mode
In offload mode, GQA expansion buffers (_k_expanded, _v_expanded) are not
needed since compute_chunked_prefill() handles GQA inline. Previously,
these buffers were always allocated based on max_model_len, causing OOM
on 24GB GPUs (e.g., RTX 3090) when max_model_len=1M (16GB buffer).

Changes:
- Add enable_cpu_offload parameter to alloc_policy_metadata() in base class
- Skip GQA buffer allocation when enable_cpu_offload=True in XAttentionBSAPolicy
- Pass enable_cpu_offload from model_runner to policy

Memory savings: ~16GB for 1M seq, ~1.1GB for 72K seq

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:57:18 +08:00
Zijie Tian
af4da454ba 📊 docs: add XAttention offload profiling analysis for 32K context
- Profile XAttn vs Full attention using nsys NVTX markers
- Key finding: estimate (41%) + find_blocks (37%) dominate, compute only 21%
- Chunk7 comparison: XAttn (38ms) vs Full (35ms) - XAttn slightly slower
- Identify optimization opportunities: reduce find_blocks overhead, merge estimate passes

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-05 02:49:59 +08:00
Zijie Tian
ef37d4f1a8 🐛 docs: document XAttention offload GQA buffer OOM issue
Document OOM issue when using XAttention BSA + CPU offload
with large models (GLM-4-9B) on 24GB GPUs.

Issue: 8GB allocation for k_expanded buffer fails due to
using num_heads instead of num_kv_heads in GQA models.

Root cause analysis and proposed fix included.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:46:50 +08:00
Zijie Tian
c8a5ef04c0 📝 docs: add test_ruler.py usage guide and rule
- Add comprehensive test_ruler.py usage guide with verified commands
- Add .claude/rules/test-ruler.md to enforce documentation-first approach
- Update CLAUDE.md documentation index

Tested commands on RTX 3090 (GPU 4):
- 32K/64K offload + XAttn BSA
- Multi-dataset, JSON output, quiet mode
- GLM-4 model support

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:46:44 +08:00
Zijie Tian
1c36d53570 🙈 chore: add ralph-tui session file to gitignore
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:00:44 +08:00
Zijie Tian
54fd302fa8 📝 docs: add XAttention density alignment verification results
- Add verification doc comparing GPU-only vs Offload mode density
- Test results: 32K (0.37% diff), 64K (0.09% diff) - alignment successful
- Both modes achieve 100% accuracy on RULER niah_single_1

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-05 01:59:11 +08:00
Zijie Tian
1eb7521994 📝 docs: add XAttention density types documentation
Document the difference between compute density (BSA block level)
and communication density (CPU block level).

Key finding: Even with 37% compute density, comm density can be 100%
due to any() aggregation across heads/Q-positions spreading sparse
blocks across all CPU blocks.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 01:44:11 +08:00
Zijie Tian
51bd678335 📊 feat: distinguish compute density and communication density in DensityObserver
- Add record_comm_density() call in select_blocks to track CPU block selection
- Add get_per_layer_comm_density() method for detailed analysis
- Update print_summary() to show both densities and H2D savings ratio
- Set DensityObserver mode (offload/gpu_only) in test_ruler.py
- Update get_summary() to return both density types

Key insight: Comm density can be 100% even when compute density is ~37%
because sparse BSA blocks are distributed across all CPU blocks.
Since CPU block granularity is 32x coarser (4096 vs 128 tokens),
any() aggregation across heads/Q-blocks results in all CPU blocks being needed.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 01:43:17 +08:00
Zijie Tian
1ea5afd886 📝 docs: add XAttention offload stream sync fix documentation
- Document the CUDA stream synchronization bug in XAttention BSA
- Include root cause analysis with stream timing diagrams
- Add test commands and verification results (100% accuracy)
- Update CLAUDE.md documentation index

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 01:32:50 +08:00
Zijie Tian
829b311c02 🐛 fix: stream synchronization for XAttention estimate kernels in offload mode
- Wrap all compute kernels in select_blocks with compute_stream context
  (Pass 1 historical blocks, Pass 1 current chunk, Step 2 merge,
   Pass 2 historical blocks, Pass 2 current chunk, Step 4 block selection)
- Fix K data mismatch between Pass 1 and Pass 2 by ensuring wait_slot_layer
  syncs with compute_stream where kernels actually run
- Remove STRONG SYNC code from offload_engine.py (now handled by events)
- Remove debug print statements and torch.save code
- Consolidate fallback conditions in compute_with_xattn
- Change default chunk_size from 16384 to 4096 for density alignment

The bug caused Pass 1 and Pass 2 to see different K data from the same
CPU block because compute kernels ran on default stream while
wait_slot_layer only synced compute_stream.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 01:30:23 +08:00
Zijie Tian
dd0472aea8 [plugin] Added ralph-tui setup. 2026-02-05 01:27:53 +08:00
32 changed files with 1750 additions and 4487 deletions

View File

@@ -0,0 +1,90 @@
# test_ruler.py 使用规则
## 强制规则
**执行 `test_ruler.py` 前必须查阅文档**,禁止运行 `--help` 或猜测参数。
| 禁止 | 原因 |
|------|------|
| `python tests/test_ruler.py --help` | 浪费交互,文档已有完整说明 |
| 猜测参数格式 | 容易出错,降低效率 |
## 必读文档
**[`docs/test_ruler_usage_guide.md`](../docs/test_ruler_usage_guide.md)** - 包含:
- 完整参数说明
- 已验证的命令示例
- GPU 模式选择指南
- max-model-len 设置指南
## 快速参考
### 标准命令格式
```bash
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/<MODEL> \
--data-dir tests/data/ruler_<CTX> \
--datasets <TASK> \
--num-samples <N> \
--max-model-len <LEN> \
--enable-offload \
[--sparse-policy XATTN_BSA] \
[--sparse-threshold 0.9]
```
### 常用参数速查
| 参数 | 用途 | 示例 |
|------|------|------|
| `--datasets` | 指定任务 | `niah_single_1,qa_1` |
| `--num-samples` | 样本数 | `1`, `10`, `0`(全部) |
| `--sample-indices` | 指定索引 | `0,5,10` |
| `--enable-offload` | CPU offload | RTX 3090 必须 |
| `--sparse-policy` | 稀疏策略 | `XATTN_BSA` |
| `--json-output` | JSON 输出 | 脚本使用 |
| `--quiet` | 安静模式 | 减少输出 |
### max-model-len 速查
| 数据目录 | max-model-len |
|---------|---------------|
| ruler_32k | 40960 |
| ruler_64k | 72000 |
| ruler_128k | 135000 |
### 常用命令模板
**32K Offload + XAttn**:
```bash
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA
```
**64K Offload + XAttn**:
```bash
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA
```
## 执行前检查清单
- [ ] 用户指定了 GPU否则询问
- [ ] RTX 3090/4090必须 `--enable-offload`
- [ ] data-dir 与 max-model-len 匹配?
- [ ] 需要 density 统计?添加 `--sparse-policy XATTN_BSA`

1
.gitignore vendored
View File

@@ -240,3 +240,4 @@ findings_*.md
progress_*.md progress_*.md
notes.md notes.md
Snipaste* Snipaste*
.ralph-tui/session-meta.json

12
.ralph-tui/config.toml Normal file
View File

@@ -0,0 +1,12 @@
# Ralph TUI Configuration
# Generated by setup wizard
# See: ralph-tui config help
configVersion = "2.1"
tracker = "json"
agent = "claude"
maxIterations = 30
autoCommit = false
[trackerOptions]
[agentOptions]

View File

@@ -42,6 +42,12 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/xattn_kv_chunking_density_test.md`](docs/xattn_kv_chunking_density_test.md) | 🧪 TEST: XAttention KV chunking density 验证threshold=1.0 对齐threshold<1.0 差异 10-13% | | [`docs/xattn_kv_chunking_density_test.md`](docs/xattn_kv_chunking_density_test.md) | 🧪 TEST: XAttention KV chunking density 验证threshold=1.0 对齐threshold<1.0 差异 10-13% |
| [`docs/gpuonly_density_alignment_test.md`](docs/gpuonly_density_alignment_test.md) | ✅ TEST: Density 对齐验证 (GPU-only + Offload, 4K-64K)xattn_estimate vs KV chunking 完全一致 | | [`docs/gpuonly_density_alignment_test.md`](docs/gpuonly_density_alignment_test.md) | ✅ TEST: Density 对齐验证 (GPU-only + Offload, 4K-64K)xattn_estimate vs KV chunking 完全一致 |
| [`docs/xattn_memory_benchmark.md`](docs/xattn_memory_benchmark.md) | 📊 BENCH: XAttention 内存基准测试Qwen3-0.6B 32K 在 24GB 显存可行 (gpu-util=0.28) | | [`docs/xattn_memory_benchmark.md`](docs/xattn_memory_benchmark.md) | 📊 BENCH: XAttention 内存基准测试Qwen3-0.6B 32K 在 24GB 显存可行 (gpu-util=0.28) |
| [`docs/xattn_offload_stream_sync_fix.md`](docs/xattn_offload_stream_sync_fix.md) | 🐛 FIX: XAttention Offload stream 同步 bugPass1/Pass2 K 数据不一致compute_stream 包装 |
| [`docs/xattn_density_types.md`](docs/xattn_density_types.md) | 📊 Compute vs Comm density: BSA block (128) vs CPU block (4096) 粒度,聚合效应导致 comm=100% |
| [`docs/xattn_density_alignment_verification.md`](docs/xattn_density_alignment_verification.md) | ✅ VERIFIED: GPU-only vs Offload density 对齐验证 (32K 差异 0.37%, 64K 差异 0.09%) |
| [`docs/test_ruler_usage_guide.md`](docs/test_ruler_usage_guide.md) | 📖 GUIDE: test_ruler.py 使用指南RULER benchmark 测试命令,已验证的命令示例 |
| [`docs/xattn_offload_profiling_32k.md`](docs/xattn_offload_profiling_32k.md) | 📊 PROFILE: XAttn vs Full 32K nsys 分析estimate 占 41%find_blocks 占 37%compute 仅 21% |
| [`docs/changelog_2026-02-05.md`](docs/changelog_2026-02-05.md) | 📋 CHANGELOG: GQA buffer OOM 修复 (节省 16GB)tests 目录清理 (-4306 行) |
## Rules Index ## Rules Index
@@ -52,6 +58,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`.claude/rules/sparse-policy.md`](.claude/rules/sparse-policy.md) | SparsePolicy implementation requirements | | [`.claude/rules/sparse-policy.md`](.claude/rules/sparse-policy.md) | SparsePolicy implementation requirements |
| [`.claude/rules/planning-with-files.md`](.claude/rules/planning-with-files.md) | Planning file management for complex tasks | | [`.claude/rules/planning-with-files.md`](.claude/rules/planning-with-files.md) | Planning file management for complex tasks |
| [`.claude/rules/gpu-monitor.md`](.claude/rules/gpu-monitor.md) | **GPU memory monitoring**: 必须使用 gpu-monitor agent禁止手动 nvidia-smi 循环 | | [`.claude/rules/gpu-monitor.md`](.claude/rules/gpu-monitor.md) | **GPU memory monitoring**: 必须使用 gpu-monitor agent禁止手动 nvidia-smi 循环 |
| [`.claude/rules/test-ruler.md`](.claude/rules/test-ruler.md) | **test_ruler.py 规则**: 禁止 --help必须查阅文档含快速参考和命令模板 |
## GPU Mutex for Multi-Instance Debugging ## GPU Mutex for Multi-Instance Debugging
@@ -106,6 +113,13 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
**Files**: `bench.py` (GPU), `bench_offload.py` (CPU offload), `bench_vllm.py` (comparison) **Files**: `bench.py` (GPU), `bench_offload.py` (CPU offload), `bench_vllm.py` (comparison)
**GPU-only 测试模型选择**:
| GPU | 显存 | GPU-only 测试模型 |
|-----|------|------------------|
| RTX 3090 | 24GB | **Qwen3-0.6B** (必须7B+ 模型会 OOM) |
| A100 | 40GB+ | Qwen3-0.6B / 4B / 7B 均可 |
**Offload Mode Constraint**: When using `enable_cpu_offload=True`, only test with context length ≥ 32K. Shorter contexts don't exercise the chunked offload pipeline properly. **Offload Mode Constraint**: When using `enable_cpu_offload=True`, only test with context length ≥ 32K. Shorter contexts don't exercise the chunked offload pipeline properly.
**Common Issues**: **Common Issues**:

View File

@@ -0,0 +1,94 @@
# Changelog 2026-02-05
## Bug Fixes
### XAttention Offload GQA Buffer OOM Fix
**Issue**: `docs/issue_xattn_offload_gqa_buffer_oom.md`
**Problem**: 在 XAttention BSA + CPU Offload 模式下,`alloc_policy_metadata()` 分配了只有 GPU-only 模式才需要的 GQA expansion buffers (`_k_expanded`, `_v_expanded`),导致 24GB GPU (RTX 3090) 上 OOM。
**Root Cause**:
- GQA buffer 大小: `2 × num_heads × max_seq_len × head_dim × dtype_size`
- 对于 1M max_seq_len: 2 × 32 × 1048576 × 128 × 2 = **16 GB**
- Offload 模式的 `compute_chunked_prefill()` 不需要这些 buffer
**Fix** (commit `11a867f`):
1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数
2. `nanovllm/kvcache/sparse/xattn_bsa.py`: offload 模式跳过 GQA buffer 分配
3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数
**Memory Savings**:
| max_model_len | 修复前 | 修复后 |
|---------------|--------|--------|
| 72K | +1.1 GB | 0 GB |
| 1M | +16 GB | 0 GB |
**Verification**:
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA
```
- 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers`
- 测试结果: 100% 准确率
---
## Code Cleanup
### Tests Directory Cleanup
**Commits**: `a709551`, `2b61c5a`, `d35dd76`
删除了 16 个冗余/过时的测试文件,保留核心测试:
**保留的文件** (4 个):
| 文件 | 用途 |
|------|------|
| `test_ruler.py` | 核心 RULER benchmark (13 tasks, 100 samples) |
| `test_xattn_estimate_alignment.py` | XAttn kernel 一致性验证 |
| `utils.py` | 共享工具函数 |
| `__init__.py` | 包标记 |
**删除的文件** (16 个, -4306 行):
| 类别 | 文件 | 删除原因 |
|------|------|----------|
| XAttn 测试 | `test_xattn_bsa.py` | 功能被 test_ruler 覆盖 |
| | `test_xattn_chunked.py` | 与 estimate_chunked 重复 |
| | `test_xattn_estimate_chunked.py` | chunked prefill 验证 |
| | `test_xattn_kernels.py` | Triton kernel 单元测试 |
| | `test_xattn_kv_chunking_batch.py` | batch 验证 |
| Needle 测试 | `test_needle.py` | 被 test_ruler NIAH 任务覆盖 |
| | `test_needle_ref.py` | HF 参考实现 |
| CUDA Graph | `test_chunk_attention_graph.py` | 被 graph_reuse 取代 |
| | `test_chunk_attention_graph_reuse.py` | 实验性功能 |
| | `test_cudagraph_memory.py` | 内存分析工具 |
| 其他 | `test_gpuonly_density_alignment.py` | GPU-only 密度测试 |
| | `test_hierarchical_estimate.py` | 分层估计测试 |
| | `test_quest_policy.py` | Quest 策略测试 |
| | `test_sequential.py` | 状态隔离测试 |
| | `bench_estimate_block_size.py` | 性能 benchmark |
| | `modeling_qwen3.py` | Qwen3 参考模型 |
**Note**: 所有删除的文件可从 git 历史恢复:
```bash
git checkout <commit-hash>^ -- tests/<filename>
```
---
## Summary
| 类型 | 数量 | 影响 |
|------|------|------|
| Bug Fix | 1 | 节省 16GB 显存 (1M seq) |
| 文件删除 | 16 | -4306 行代码 |
| 新增文档 | 1 | 本文件 |

View File

@@ -0,0 +1,209 @@
# Issue: XAttention Offload Mode GQA Buffer OOM
## 问题描述
在使用 XAttention BSA (Block Sparse Attention) + CPU Offload 模式运行 GLM-4-9B 等大模型时,出现 CUDA OOM 错误。
### 错误信息
```
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB.
GPU 0 has a total capacity of 23.57 GiB of which 4.19 GiB is free.
```
### 复现环境
| 项目 | 值 |
|------|-----|
| 模型 | GLM-4-9B-Chat-1M |
| GPU | RTX 3090 (24GB) |
| Context Length | 32K |
| sparse_policy | XATTN_BSA |
| enable_cpu_offload | true |
| max_model_len | 1048576 (1M) |
### 错误位置
```
File "nanovllm/kvcache/sparse/xattn_bsa.py", line 246, in alloc_policy_metadata
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
```
---
## 问题分析
### 内存分配分析
`alloc_policy_metadata()` 在 KV cache 初始化时分配以下 buffer
| Buffer | 用途 | 大小 (GLM-4, 1M seq) |
|--------|------|----------------------|
| `_prefill_mask_buffer` | BSA mask | ~32 MB |
| `_m_partial_buffer` | KV chunking m stats | ~32 MB |
| `_l_partial_buffer` | KV chunking l stats | ~32 MB |
| `_block_sums_buffer` | Block sums | ~64 MB |
| **`_k_expanded`** | GQA K 扩展 | **~8 GB** |
| **`_v_expanded`** | GQA V 扩展 | **~8 GB** |
### GQA Buffer 计算
```python
shape = (1, num_heads, max_seq_len, head_dim)
= (1, 32, 1048576, 128)
size = 1 × 32 × 1048576 × 128 × 2 bytes (fp16)
= 8,589,934,592 bytes
= 8 GB per buffer
```
### 根本原因
1. **设计意图冲突**`_k_expanded``_v_expanded` 的文档注释明确说是 "for GPU-only mode"
2. **条件检查不完整**:代码只检查了 `num_heads == num_kv_heads` 来跳过分配,没有检查 offload 模式
3. **Offload 模式不需要这些 buffer**`compute_chunked_prefill()` 使用不同的计算路径,不依赖预分配的 GQA buffer
### 相关代码
```python
# xattn_bsa.py:238-247
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return # <-- 只检查了 GQA没检查 offload 模式
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
shape = (1, num_heads, max_seq_len, head_dim)
self._k_expanded = torch.empty(shape, dtype=dtype, device=device) # <-- OOM here
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
```
---
## 解决思路
### 方案 1: 在 Offload 模式下跳过 GQA Buffer 分配 (推荐)
`alloc_policy_metadata()` 中添加 offload 模式检查:
```python
def alloc_policy_metadata(
self,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
enable_cpu_offload: bool = False, # <-- 新增参数
) -> None:
# ... 分配 mask buffer 和 KV chunking buffers (offload 模式需要)
# Skip GQA buffers in offload mode
# Chunked prefill uses compute_chunked_prefill() which doesn't need these
if enable_cpu_offload:
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers")
return
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed")
return
shape = (1, num_heads, max_seq_len, head_dim)
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
```
**需要修改的文件**
1. `nanovllm/kvcache/sparse/xattn_bsa.py` - `alloc_policy_metadata()` 方法
2. `nanovllm/engine/model_runner.py` - 调用 `alloc_policy_metadata()` 时传入 `enable_cpu_offload`
### 方案 2: 延迟分配 (Lazy Allocation)
只在 `compute_prefill()` 首次调用时分配 GQA bufferoffload 模式走 `compute_chunked_prefill()` 不会触发分配。
```python
def compute_prefill(self, ...):
# Lazy allocation on first use
if self._k_expanded is None and num_heads != num_kv_heads:
self._allocate_gqa_buffers(...)
...
```
### 方案 3: 基于 chunk_size 限制 buffer 大小
不预分配 max_seq_len 大小,而是只分配 chunk_size 大小:
```python
# 原来: max_seq_len (1M tokens) -> 8 GB
# 修改后: chunk_size (16K tokens) -> ~130 MB
buffer_len = self.chunk_size if enable_cpu_offload else max_seq_len
shape = (1, num_heads, buffer_len, head_dim)
```
---
## 验证方法
修复后运行以下命令验证:
```bash
cd /home/zijie/Code/COMPASS
GPULIST=0 ./scripts/run_ruler.sh glm4-9b-xattn-nanovllm synthetic xattn --task niah_single_1
```
预期结果:
- 不再出现 8GB allocation 的 OOM 错误
- 模型正常加载并完成推理
---
## 相关文档
- `docs/xattn_bsa_policy_design.md` - XAttention BSA Policy 设计文档
- `docs/gpu_only_xattn_guide.md` - GPU-Only XAttention 指南
## 优先级
**High** - 阻塞 9B+ 模型在 24GB 显存 GPU 上使用 XAttention + Offload 模式
---
## 修复状态
**✅ 已修复** (2026-02-05)
### 修复内容
采用方案 1在 offload 模式下跳过 GQA buffer 分配:
1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数
2. `nanovllm/kvcache/sparse/xattn_bsa.py`: 实现 offload 模式检查,跳过 GQA buffer
3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数
### 验证结果
```bash
# 64K offload 测试
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA
```
- ✅ 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers`
- ✅ 测试通过: 100% 准确率
- ✅ 内存节省: ~16 GB (for 1M max_seq_len)
### 内存对比
| 配置 | 修复前 | 修复后 |
|------|--------|--------|
| max_model_len=72K | +1.1 GB | 0 GB |
| max_model_len=1M | +16 GB | 0 GB |

View File

@@ -0,0 +1,338 @@
# test_ruler.py 使用指南
RULER benchmark 综合测试工具,用于评估 LLM 长上下文能力。
**测试日期**: 2026-02-05
**测试 GPU**: RTX 3090 (GPU 4)
---
## 支持的任务
| 类别 | 任务 |
|------|------|
| NIAH (Needle-In-A-Haystack) | `niah_single_1/2/3`, `niah_multikey_1/2/3`, `niah_multiquery`, `niah_multivalue` |
| QA (Question Answering) | `qa_1`, `qa_2` |
| Recall | `cwe`, `fwe`, `vt` |
---
## 基本命令格式
```bash
CUDA_VISIBLE_DEVICES=<GPU_ID> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py [OPTIONS]
```
---
## 参数说明
### 必要参数
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--model` | `~/models/Llama-3.1-8B-Instruct` | 模型路径 |
| `--data-dir` | `tests/data/ruler_64k` | 数据目录 |
| `--max-model-len` | 65664 | 最大上下文长度 |
### 数据选择
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--datasets` | 全部 | 逗号分隔的数据集名 |
| `--num-samples` | 0 (全部) | 每个数据集测试样本数 |
| `--sample-indices` | - | 指定样本索引 (如 `0,5,10`) |
### Offload 配置
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--enable-offload` | False | 启用 CPU offload 模式 |
| `--num-gpu-blocks` | 4 | GPU 上的 KV cache blocks 数量 |
| `--block-size` | 4096 | KV cache block 大小 (tokens) |
| `--num-kv-buffers` | 4 | Ring buffer 数量 |
| `--gpu-utilization` | 0.9 | GPU 显存利用率 |
### Sparse Attention 配置
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--sparse-policy` | - | 稀疏策略: `FULL`, `QUEST`, `XATTN_BSA` |
| `--sparse-threshold` | 0.9 | XAttn cumulative attention 阈值 |
| `--sparse-samples` | 128 | XAttn 每 chunk 采样数 |
| `--sparse-stride` | 8 | XAttn Q/K 下采样步长 |
### 输出控制
| 参数 | 说明 |
|------|------|
| `--quiet` / `-q` | 安静模式 |
| `--json-output` | JSON 格式输出 |
| `--fresh-llm` | 每个样本重新初始化 LLM |
### 其他
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--dtype` | auto | 模型数据类型 (`bfloat16`, `float16`) |
| `--use-cuda-graph` | False | 启用 CUDA Graph |
| `--max-new-tokens` | 16 | 最大生成 token 数 |
---
## 已验证的命令示例
以下命令均在 RTX 3090 (24GB) 上测试通过。
### 1. 基础 Offload 测试 (32K)
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload
```
**结果**: 100% 准确率, 耗时 ~16s
### 2. Offload + XAttention BSA (32K)
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
**结果**: 100% 准确率, compute density ~50%, 耗时 ~19s
### 3. Offload + XAttention BSA (64K)
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
**结果**: 100% 准确率, compute density ~37%, 耗时 ~52s
### 4. 多数据集多样本测试
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1,qa_1 \
--num-samples 2 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA
```
**结果**: 4/4 (100%), 耗时 ~71s
### 5. 指定样本索引测试
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--sample-indices 0,5,10 \
--max-model-len 40960 \
--enable-offload
```
### 6. JSON 输出模式 (用于脚本)
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--json-output
```
**输出格式**:
```json
{
"total_correct": 1,
"total_samples": 1,
"overall_accuracy": 1.0,
"avg_score": 1.0,
"time": 30.44,
"tasks": {"niah_single_1": {"correct": 1, "total": 1, "accuracy": 1.0}},
"failed_samples": {}
}
```
### 7. 安静模式
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--quiet
```
### 8. 调整 GPU blocks 数量
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--num-gpu-blocks 8 \
--sparse-policy XATTN_BSA
```
### 9. GLM-4 模型测试
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/GLM-4-9B-Chat-1M \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--dtype bfloat16
```
**结果**: 100% 准确率, 耗时 ~17s
---
## 数据目录结构
```
tests/data/
├── ruler_4k/ # 4K context
├── ruler_8k/ # 8K context
├── ruler_16k/ # 16K context
├── ruler_32k/ # 32K context (推荐测试)
├── ruler_64k/ # 64K context
├── ruler_128k/ # 128K context
├── ruler_256k/ # 256K context
├── ruler_512k/ # 512K context
├── ruler_768k/ # 768K context
└── ruler_1m/ # 1M context
```
每个目录包含 13 个任务子目录,每个任务有 `validation.jsonl` 文件。
---
## GPU 与模式选择
| GPU 显存 | 推荐模式 | 说明 |
|---------|---------|------|
| 24GB (3090/4090) | `--enable-offload` | 必须使用 offload |
| 40GB+ (A100) | 两种模式均可 | 可测试 GPU-only |
**RTX 3090 限制**: 由于显存限制,必须使用 `--enable-offload` 参数。
---
## max-model-len 设置指南
| 数据目录 | 推荐 max-model-len | 说明 |
|---------|-------------------|------|
| ruler_4k | 5000 | 留出 output 空间 |
| ruler_8k | 9000 | |
| ruler_16k | 17000 | |
| ruler_32k | 40960 | |
| ruler_64k | 72000 | |
| ruler_128k | 135000 | |
**公式**: `max_model_len >= max_input_len + max_new_tokens`
---
## DensityObserver 输出
使用 `--sparse-policy XATTN_BSA` 时自动启用,输出示例:
```
============================================================
Density Statistics (XAttention BSA)
============================================================
[DensityObserver] Mode: offload
Compute density: 0.3691 (min: 0.3691 @ layer 0)
Comm density: 1.0000 (CPU block granularity)
Savings ratio: 0.0% H2D transfer reduction
Num layers: 1
Layer 0 density: 0.369052
```
| 指标 | 说明 |
|------|------|
| Compute density | BSA block (128 tokens) 粒度的计算密度 |
| Comm density | CPU block (4096 tokens) 粒度的通信密度 |
| Savings ratio | H2D 传输减少比例 |
---
## 常见问题
### 1. OOM 错误
**原因**: 显存不足
**解决**:
- 使用 `--enable-offload`
- 降低 `--gpu-utilization`
- 减少 `--num-gpu-blocks`
### 2. 模型加载失败
**原因**: 模型配置不兼容
**解决**:
- 检查 `--dtype` 参数 (GLM 模型需要 `--dtype bfloat16`)
- 确认模型路径正确
### 3. 准确率异常
**原因**: 状态泄漏
**解决**: 使用 `--fresh-llm` 参数为每个样本重新初始化 LLM
---
## 相关文档
- [`docs/xattn_density_types.md`](xattn_density_types.md) - Compute vs Comm density 解释
- [`docs/xattn_density_alignment_verification.md`](xattn_density_alignment_verification.md) - GPU-only vs Offload 对齐验证
- [`docs/ruler_benchmark_results_32k.md`](ruler_benchmark_results_32k.md) - RULER 32K 基准测试结果

View File

@@ -0,0 +1,142 @@
# XAttention Density Alignment Verification
验证 GPU-only 和 Offload 模式的 density 对齐情况。
**测试日期**: 2026-02-05
**测试模型**: Llama-3.1-8B-Instruct
**测试任务**: RULER niah_single_1
---
## 测试配置
| 参数 | 值 |
|------|-----|
| sparse_policy | XATTN_BSA |
| threshold | 0.9 |
| chunk_size | 4096 (已对齐) |
| stride | 8 |
| BSA block_size | 128 |
---
## 测试结果
### 32K Context
| 模式 | Layer 0 Density | Overall Density | 准确率 |
|------|-----------------|-----------------|--------|
| GPU-only | 0.502079 | 0.4012 | 100% |
| Offload | 0.498421 | 0.4984 | 100% |
| **差异** | **0.37%** | - | - |
### 64K Context
| 模式 | Layer 0 Density | Overall Density | 准确率 |
|------|-----------------|-----------------|--------|
| GPU-only | 0.369972 | 0.2963 | 100% |
| Offload | 0.369052 | 0.3691 | 100% |
| **差异** | **0.09%** | - | - |
---
## 关键修复
### Commit 829b311 - chunk_size 对齐 + Stream 同步修复
**问题**: 之前 GPU-only 和 Offload 模式的 density 差异达 10-13%
**根因**:
1. GPU-only 使用 `chunk_size=16384`Offload 使用 `chunk_size=4096`
2. Stream 同步 bug 导致 Pass 1/2 K 数据不一致
**修复**:
1.`XAttentionBSAPolicy.chunk_size` 默认值从 16384 改为 4096
2. 所有 compute kernels 包装在 `compute_stream` context 中
---
## 测试命令
### GPU-only 模式
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
### Offload 模式
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
---
## 详细日志
### 32K Offload 模式 Per-Chunk Density
```
Layer0 chunk: q_len=4096, k_len=4096, density=0.6234
Layer0 chunk: q_len=4096, k_len=8192, density=0.6239
Layer0 chunk: q_len=4096, k_len=12288, density=0.6026
Layer0 chunk: q_len=4096, k_len=16384, density=0.5695
Layer0 chunk: q_len=4096, k_len=20480, density=0.5285
Layer0 chunk: q_len=4096, k_len=24576, density=0.4891
Layer0 chunk: q_len=4096, k_len=28672, density=0.4514
Layer0 chunk: q_len=3813, k_len=32485, density=0.4208
```
### 64K Offload 模式 Per-Chunk Density
```
Layer0 chunk: q_len=4096, k_len=4096, density=0.6234
Layer0 chunk: q_len=4096, k_len=8192, density=0.6239
Layer0 chunk: q_len=4096, k_len=12288, density=0.6026
Layer0 chunk: q_len=4096, k_len=16384, density=0.5681
Layer0 chunk: q_len=4096, k_len=20480, density=0.5255
Layer0 chunk: q_len=4096, k_len=24576, density=0.4859
Layer0 chunk: q_len=4096, k_len=28672, density=0.4485
Layer0 chunk: q_len=4096, k_len=32768, density=0.4161
Layer0 chunk: q_len=4096, k_len=36864, density=0.3892
Layer0 chunk: q_len=4096, k_len=40960, density=0.3658
Layer0 chunk: q_len=4096, k_len=45056, density=0.3464
Layer0 chunk: q_len=4096, k_len=49152, density=0.3303
Layer0 chunk: q_len=4096, k_len=53248, density=0.3170
Layer0 chunk: q_len=4096, k_len=57344, density=0.3068
Layer0 chunk: q_len=4096, k_len=61440, density=0.2988
Layer0 chunk: q_len=3451, k_len=64891, density=0.2947
```
---
## 结论
1. **Density 对齐成功**: 差异从 10-13% 降到 <0.5%
2. **准确率一致**: 两种模式都达到 100% 准确率
3. **Density 随 context 增长下降**: 符合预期,更长的 context 稀疏性更高
---
## 相关文档
- [`docs/xattn_offload_stream_sync_fix.md`](xattn_offload_stream_sync_fix.md) - Stream 同步修复详情
- [`docs/xattn_density_types.md`](xattn_density_types.md) - Compute vs Comm density
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md) - 早期对齐测试

152
docs/xattn_density_types.md Normal file
View File

@@ -0,0 +1,152 @@
# XAttention Density Types: Compute vs Communication
XAttention BSA 统计两种不同粒度的 density它们反映不同的优化效果。
## 两种 Density 的定义
### 1. Compute Density计算密度
**粒度**: BSA block (128 tokens)
**公式**:
```
compute_density = selected_bsa_blocks / total_causal_bsa_blocks
```
**含义**: 实际需要计算 attention 的 blocks 占 causal 区域的比例。
**影响**: 决定 attention 计算量的减少。
### 2. Communication Density通信密度
**粒度**: CPU block (4096 tokens = 32 BSA blocks)
**公式**:
```
comm_density = selected_cpu_blocks / total_cpu_blocks
```
**含义**: 需要从 CPU 传输到 GPU 的 blocks 占总 blocks 的比例。
**影响**: 决定 H2D 传输量的减少。
## 为什么 Comm Density 通常高于 Compute Density
### 聚合效应
由于 CPU block 粒度是 BSA block 的 32 倍CPU block 选择使用 `any()` 聚合:
```python
# BSA mask: [B, H, Q_bsa, K_bsa]
# Reshape to CPU block level
mask_per_cpu = mask.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
# Any BSA block selected -> whole CPU block needed
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1)
```
只要 CPU block 中**任意一个**:
- Head 选择了该 block
- Q position 选择了该 block
- BSA sub-block 被选中
则整个 CPU block 都需要传输。
### 示例
| 场景 | Compute Density | Comm Density | 说明 |
|------|-----------------|--------------|------|
| 64K context, threshold=0.9 | 37% | 100% | 稀疏 blocks 均匀分布在所有 CPU blocks |
| 32K context, threshold=0.9 | 50% | 100% | 同上 |
## 测试结果
### 测试命令
```bash
# Offload 模式测试
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
### 输出示例
```
[DensityObserver] Mode: offload
Compute density: 0.3691 (min: 0.3691 @ layer 0)
Comm density: 1.0000 (CPU block granularity)
Savings ratio: 0.0% H2D transfer reduction
Num layers: 1
Layer 0 density: 0.369052
```
## 关键发现
### 当前 XAttention 的通信优化局限
1. **Compute density 有效降低**: ~37% @ 64K context计算量减少 63%
2. **Comm density 没有降低**: 100%(通信量没有减少)
### 原因分析
Attention pattern 的特点:
- 不同 heads 关注不同位置
- 不同 Q positions 关注不同 K positions
- 稀疏选择分布在整个 sequence 上
这导致虽然每个 (head, Q, K) 组合只选择少量 blocks但聚合后覆盖了所有 CPU blocks。
### 潜在优化方向
1. **Per-head block selection**: 每个 head 独立选择 CPU blocks
2. **Block clustering**: 将相关 blocks 聚合到同一 CPU block
3. **Dynamic block size**: 根据 attention pattern 动态调整 CPU block 大小
## DensityObserver API
### 启用和重置
```python
from nanovllm.utils.density_observer import DensityObserver
DensityObserver.enable()
DensityObserver.complete_reset()
DensityObserver.set_mode("offload") # or "gpu_only"
```
### 记录
```python
# Compute density (GPU-only 模式自动记录)
DensityObserver.record(layer_id, mask, causal=True)
# Comm density (Offload 模式在 select_blocks 中记录)
DensityObserver.record_comm_density(layer_id, selected_cpu_blocks, total_cpu_blocks)
```
### 获取结果
```python
# 总体 density
overall_compute = DensityObserver.get_overall_density()
overall_comm = DensityObserver.get_overall_comm_density()
# Per-layer density
per_layer_compute = DensityObserver.get_per_layer_density()
per_layer_comm = DensityObserver.get_per_layer_comm_density()
# 打印摘要
DensityObserver.print_summary()
```
## 相关文件
- `nanovllm/utils/density_observer.py`: DensityObserver 实现
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policyselect_blocks 中记录 comm density
- `tests/test_ruler.py`: RULER benchmark 测试脚本

View File

@@ -0,0 +1,184 @@
# XAttention Offload Profiling - 32K Context
Nsys profiling 分析 XAttention vs Full Attention 在 Offload 模式下的性能。
**测试日期**: 2026-02-05
**测试模型**: Llama-3.1-8B-Instruct
**Context**: 32K tokens
**GPU**: A100-80GB (GPU 0)
---
## 测试配置
| 参数 | Full | XAttention |
|------|------|------------|
| Policy | FULL | XATTN_BSA |
| Block size | 4096 | 4096 |
| GPU blocks | 4 | 4 |
| Threshold | - | 0.95 |
| Density | 100% | ~50% |
---
## XAttention 各阶段时间统计
### NVTX Markers Summary
| 阶段 | 总时间(ms) | 调用次数 | 平均时间(ms) | 说明 |
|------|------------|----------|--------------|------|
| xattn_find_blocks | 1155.1 | 256 | 4.51 | 块选择 (threshold-based) |
| xattn_estimate_pass1 | 588.3 | 256 | 2.30 | 第一轮: partial stats |
| xattn_compute_historical | 512.0 | 224 | 2.29 | 历史 KV attention |
| xattn_estimate_pass2 | 501.6 | 256 | 1.96 | 第二轮: block sums |
| xattn_estimate_merge | 197.9 | 256 | 0.77 | 合并 softmax stats |
| xattn_compute_merge | 93.8 | 256 | 0.37 | 计算结果合并 |
| xattn_compute_current | 59.2 | 256 | 0.23 | 当前 chunk attention |
### 时间分配
```
Total XAttention overhead: 3108 ms
Estimate 阶段: 1288 ms (41.4%)
- pass1: 588 ms
- pass2: 502 ms
- merge: 198 ms
Find blocks: 1155 ms (37.2%)
Compute 阶段: 665 ms (21.4%)
- historical: 512 ms
- merge: 94 ms
- current: 59 ms
```
---
## Chunk7 (最后一个 chunk) 对比
### Per-Layer 时间
| Policy | Layer 0 | Layer 1 | ... | Layer 31 | Avg |
|--------|---------|---------|-----|----------|-----|
| Full | 36.5 ms | 33.6 ms | ... | 32.7 ms | ~35 ms |
| XAttn | 39.7 ms | 39.3 ms | ... | 38.5 ms | ~38 ms |
### 分析
Chunk7 是序列的最后 ~4K tokens (3813 tokens),此时:
- K 长度: 32485 tokens
- Density: 42.08%
**结论**: XAttention 在 Chunk7 比 Full 慢约 8%,原因:
1. Estimate 开销无法被稀疏计算收益抵消
2. 42% density 仍然较高,稀疏收益有限
---
## Full Attention Chunk7 详细数据
```
Layer Time(ms)
L0 36.5
L1 44.3
L2 43.7
L3 38.7
L4 34.2
L5 45.2
...
L31 32.7
Avg ~35
```
---
## XAttention Chunk7 详细数据
```
Layer Time(ms)
L0 39.7
L1 39.3
L2 37.1
L3 39.1
L4 38.7
L5 39.4
...
L31 38.5
Avg ~38
```
---
## 性能瓶颈分析
### 1. xattn_find_blocks 开销过高
- 平均 4.51 ms per call
- 占总时间 37.2%
- 原因: threshold-based 块选择涉及排序和累积求和
### 2. 两轮 estimate 开销
- Pass1 + Pass2 共 1090 ms
- 需要遍历所有 KV chunks 两次
- 可优化方向: 单轮 estimate
### 3. Compute 阶段相对高效
- 只占 21.4%
- 说明 BSA 稀疏计算本身效率不错
---
## 优化建议
### 短期
1. **减少 find_blocks 开销**
- 使用 top-k 而不是 threshold
- 预分配 mask buffer 避免动态分配
2. **合并 estimate 两轮**
- 在单轮中同时计算 stats 和 block sums
### 中期
1. **estimate 阶段使用更小的 block_size**
- 当前 block_size=4096 对 estimate 不友好
- 参考 `docs/estimate_block_size_performance.md`
2. **Pipeline estimate 和 H2D**
- 将 estimate 与下一个 chunk 的 H2D 重叠
### 长期
1. **预测式块选择**
- 基于历史 pattern 预测下一个 chunk 的重要 blocks
- 减少 estimate 开销
---
## 相关文件
- `results/nsys/full_offload_32k_blk4096_20260205_023257.nsys-rep`
- `results/nsys/xattn_offload_32k_blk4096_20260205_023435.nsys-rep`
---
## 命令
### Profile Full
```bash
bash scripts/profile_offload.sh --policy full --ctx-len 32k --gpu 0 --model ~/models/Llama-3.1-8B-Instruct
```
### Profile XAttention
```bash
bash scripts/profile_offload.sh --policy xattn --ctx-len 32k --gpu 0 --model ~/models/Llama-3.1-8B-Instruct
```
### 分析 NVTX
```bash
nsys stats --report nvtx_pushpop_sum <file>.nsys-rep
```

View File

@@ -0,0 +1,307 @@
# XAttention Offload Stream Synchronization Fix
修复 XAttention BSA Policy 在 Offload 模式下的 CUDA stream 同步 bug。
**修复日期**: 2026-02-05
**Commit**: `829b311`
**影响文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`, `nanovllm/kvcache/offload_engine.py`
---
## 问题描述
### 症状
在 Offload 模式下运行 RULER benchmark 时XAttention BSA 的 `select_blocks` 方法中 Pass 1 和 Pass 2 从**同一个 CPU block** 加载的 K 数据不一致:
```
Pass 1: K_chunk sum = 745472.00 (正确)
Pass 2: K_chunk sum = 0.00 (错误,数据未加载完成)
```
这导致 attention 计算结果错误RULER 准确率下降。
### 复现条件
- 模式: Offload (`--enable-offload`)
- Context: ≥ 32K tokens
- 稀疏策略: `--sparse-policy XATTN_BSA`
---
## 根因分析
### Stream 配置回顾
nano-vllm 的 CPU offload 使用多个 CUDA streams 实现 pipeline
| Stream | 用途 |
|--------|------|
| `slot_transfer_streams[i]` | H2D 传输 (CPU → GPU slot) |
| `compute_stream` | Attention 计算 |
| `prefill_offload_streams[i]` | D2H 传输 (GPU → CPU cache) |
### 同步机制
`wait_slot_layer(slot)` 使用 event 机制同步:
```python
def wait_slot_layer(self, slot_idx: int):
"""Make compute_stream wait for H2D transfer completion."""
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
```
### Bug 根因
`select_blocks` 方法中:
1. H2D 传输在 `slot_transfer_streams` 上执行
2. `wait_slot_layer``compute_stream` 等待传输完成
3. **但是** 后续的 compute kernels 在**默认 stream** 上执行,而不是 `compute_stream`
```python
# Bug 代码
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot) # compute_stream 等待
# 这些 kernel 在默认 stream 上运行,没有等待 H2D 完成!
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
# ... 后续计算 ...
```
### 时序图
```
slot_transfer_stream: [====H2D====]
compute_stream: |wait|
default_stream: [kernel1][kernel2] ← 没有等待!
数据未就绪
```
---
## 修复方案
### 核心修改
将所有 estimate 阶段的 compute kernels 包装在 `with torch.cuda.stream(compute_stream):` 中:
```python
# 修复后代码
compute_stream = offload_engine.compute_stream
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot) # compute_stream 等待
# 所有计算在 compute_stream 上执行
with torch.cuda.stream(compute_stream):
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
# ... 后续计算 ...
```
### 修复位置
`select_blocks` 方法中共 6 处需要修复:
| 位置 | 阶段 | 修复内容 |
|------|------|----------|
| Pass 1 历史 blocks | `xattn_estimate_pass1` | 历史 KV chunk 处理 |
| Pass 1 当前 chunk | `xattn_estimate_pass1` | 当前 GPU 上的 K 处理 |
| Step 2 合并 | `merge_softmax_stats` | softmax stats 合并 |
| Pass 2 历史 blocks | `xattn_estimate_pass2` | 带全局 stats 的 block_sum |
| Pass 2 当前 chunk | `xattn_estimate_pass2` | 当前 chunk 的 block_sum |
| Step 4 block 选择 | `find_blocks_chunked` | 最终 block 选择 |
### 时序图(修复后)
```
slot_transfer_stream: [====H2D====]
compute_stream: |wait|[kernel1][kernel2]
数据已就绪
```
---
## 代码变更详情
### 1. Pass 1 历史 blocks 处理
```python
# Before (bug)
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
k_block = offload_engine.get_k_for_slot(slot) # 默认 stream
K_chunk = k_block.transpose(1, 2)
# ... compute ...
# After (fixed)
compute_stream = offload_engine.compute_stream
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream): # 显式指定 stream
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
# ... compute ...
```
### 2. 移除 STRONG SYNC
`offload_engine.py` 中移除了不必要的强同步:
```python
# Removed from load_to_slot_layer() and load_k_only_to_slot_layer()
# STRONG SYNC: Synchronize all prefill offload streams before H2D
# for offload_stream in self.prefill_offload_streams:
# offload_stream.synchronize()
```
这些同步现在由 event 机制正确处理,不再需要阻塞式同步。
### 3. 其他清理
- 移除 DEBUG print 语句
- 移除 `torch.save()` debug 代码
- 合并多个 fallback 条件
-`chunk_size` 默认值从 16384 改为 4096匹配 offload Q chunk size
---
## 测试验证
### 测试命令
**GPU 0 - Offload 模式测试**:
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 10 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
**GPU 1 - GPU-only 模式测试**:
```bash
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Qwen3-0.6B \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 10 \
--max-model-len 40960 \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
### 测试结果
| 模式 | 模型 | Context | Samples | Pass Rate | Density |
|------|------|---------|---------|-----------|---------|
| Offload | Llama-3.1-8B | 32K | 10/10 | **100%** | 9.53% |
| GPU-only | Qwen3-0.6B | 32K | 10/10 | **100%** | 9.84% |
### Density 对齐验证
| 模式 | Layer 0 Density | 差异 |
|------|-----------------|------|
| GPU-only | 9.84% | - |
| Offload | 9.53% | ~3% |
~3% 的差异是预期的,因为两种模式的 KV 累积模式不同:
- GPU-only: 一次性处理所有 KV
- Offload: 分 chunk 处理,每个 chunk 独立计算 softmax stats 后合并
---
## 技术细节
### 三阶段 KV Chunking 流程
```
┌─────────────────────────────────────────────────────────────┐
│ Stage 1: softmax_compute_partial_stats │
│ └── 每个 KV chunk 独立计算 partial stats (m_i, l_i) │
│ │
│ Stage 2: merge_softmax_stats │
│ └── Host 端合并所有 chunks: (m_global, l_global) │
│ │
│ Stage 3: softmax_normalize_and_block_sum │
│ └── 使用全局 stats 归一化并计算 block sums │
└─────────────────────────────────────────────────────────────┘
```
### Stream 配置要求
| 操作类型 | Stream | 原因 |
|----------|--------|------|
| H2D 传输 | `slot_transfer_streams` | 异步传输,不阻塞计算 |
| D2H 传输 | `prefill_offload_streams` | 异步 offload不阻塞计算 |
| Estimate kernels | `compute_stream` | 与 attention 计算共享,确保同步 |
| Attention kernels | `compute_stream` | 主计算流 |
### Event 同步机制
```python
# H2D 传输完成后记录 event
self.ring_slot_ready[slot_idx].record(slot_transfer_stream)
# 计算前等待 H2D 完成
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
# 计算完成后记录 event用于下一轮 H2D
self.ring_slot_compute_done[slot_idx].record(compute_stream)
```
---
## 相关文档
- [`docs/architecture_guide.md`](architecture_guide.md): Stream 配置和 ring buffer 架构
- [`docs/xattn_kv_chunking_kernels.md`](xattn_kv_chunking_kernels.md): 三阶段 softmax kernels
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md): Density 对齐测试
- [`docs/xattn_bsa_policy_design.md`](xattn_bsa_policy_design.md): XAttention BSA Policy 设计
---
## 经验总结
### 1. Stream 同步的隐蔽性
CUDA stream 同步 bug 很难发现:
- 数据可能"大部分时间"正确(取决于时序)
- 错误表现为随机/间歇性的结果偏差
- 需要精确的 debug logging 才能定位
### 2. Event vs Synchronize
| 方法 | 优点 | 缺点 |
|------|------|------|
| `stream.wait_event()` | 非阻塞,保持 pipeline | 只同步指定 stream |
| `stream.synchronize()` | 保证完成 | 阻塞整个 stream破坏 pipeline |
**最佳实践**: 使用 event 进行精确同步,避免 synchronize 阻塞。
### 3. 调试方法
```python
# 打印 tensor sum 验证数据一致性
print(f"K_chunk sum = {K_chunk.sum().item()}")
# 保存中间结果进行离线比较
torch.save({'K': K_chunk, 'layer': layer_id}, f'/tmp/debug_{pass}_{chunk}.pt')
```

View File

@@ -227,9 +227,9 @@ class ModelRunner:
device=torch.device("cuda"), device=torch.device("cuda"),
) )
# GPU-only mode: pre-allocate policy metadata buffers # Pre-allocate policy metadata buffers
# This avoids dynamic GPU memory allocation during forward pass # - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
# if not config.enable_cpu_offload: # - GPU-only mode: additionally allocate GQA expansion buffers
num_heads = hf_config.num_attention_heads // self.world_size num_heads = hf_config.num_attention_heads // self.world_size
self.kvcache_manager.sparse_policy.alloc_policy_metadata( self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads, num_heads=num_heads,
@@ -238,6 +238,7 @@ class ModelRunner:
max_seq_len=config.max_model_len, max_seq_len=config.max_model_len,
dtype=hf_config.torch_dtype, dtype=hf_config.torch_dtype,
device=torch.device("cuda"), device=torch.device("cuda"),
enable_cpu_offload=config.enable_cpu_offload,
) )
# Log policy info (handle both enum and None cases) # Log policy info (handle both enum and None cases)

View File

@@ -116,13 +116,15 @@ class SparsePolicy(ABC):
max_seq_len: int, max_seq_len: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
enable_cpu_offload: bool = False,
) -> None: ) -> None:
""" """
Pre-allocate GPU buffers for policy computation. Pre-allocate GPU buffers for policy computation.
Called by the framework after KV cache allocation, but ONLY for GPU-only Called by the framework after KV cache allocation. Implementations should
mode (not CPU offload mode). Override this to pre-allocate buffers that use enable_cpu_offload to decide which buffers to allocate:
would otherwise be dynamically allocated during forward pass. - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
- GPU-only mode: additionally allocate GQA expansion buffers
This is separate from initialize() which is used for CPU offload metadata. This is separate from initialize() which is used for CPU offload metadata.
@@ -133,6 +135,7 @@ class SparsePolicy(ABC):
max_seq_len: Maximum sequence length (for buffer sizing) max_seq_len: Maximum sequence length (for buffer sizing)
dtype: Data type (typically float16/bfloat16) dtype: Data type (typically float16/bfloat16)
device: Target device (cuda) device: Target device (cuda)
enable_cpu_offload: Whether CPU offload is enabled
""" """
pass pass

View File

@@ -96,7 +96,7 @@ class XAttentionBSAPolicy(SparsePolicy):
self, self,
threshold: float = 0.95, # High threshold for accuracy testing threshold: float = 0.95, # High threshold for accuracy testing
stride: int = 8, stride: int = 8,
chunk_size: int = 16384, chunk_size: int = 4096, # Match offload Q chunk size for density alignment
block_size: int = 128, block_size: int = 128,
samples_per_chunk: int = 128, samples_per_chunk: int = 128,
use_triton: bool = True, use_triton: bool = True,
@@ -175,6 +175,7 @@ class XAttentionBSAPolicy(SparsePolicy):
max_seq_len: int, max_seq_len: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
enable_cpu_offload: bool = False,
) -> None: ) -> None:
""" """
Pre-allocate GQA expansion buffers for GPU-only mode. Pre-allocate GQA expansion buffers for GPU-only mode.
@@ -235,7 +236,14 @@ class XAttentionBSAPolicy(SparsePolicy):
f"m/l shape={m_partial_shape} ({m_l_memory_mb:.1f} MB), " f"m/l shape={m_partial_shape} ({m_l_memory_mb:.1f} MB), "
f"block_sums shape={block_sums_shape} ({block_sums_memory_mb:.1f} MB)") f"block_sums shape={block_sums_shape} ({block_sums_memory_mb:.1f} MB)")
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads) # Skip GQA buffers in offload mode
# Chunked prefill uses compute_chunked_prefill() which handles GQA inline
if enable_cpu_offload:
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers (saves ~16GB for 1M seq)")
return
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
# Only allocate if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads: if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})") logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return return
@@ -289,9 +297,11 @@ class XAttentionBSAPolicy(SparsePolicy):
Returns: Returns:
Attention output [total_q, num_heads, head_dim] Attention output [total_q, num_heads, head_dim]
""" """
# When block_tables is provided (paged KV cache / prefix cache), # Fallback to flash attention when:
# fallback to flash_attn as XAttention expects contiguous K, V # 1. block_tables provided (paged KV cache / prefix cache) - XAttention expects contiguous K, V
if block_tables is not None: # 2. BSA kernel not available
# 3. xattn_estimate not available
if block_tables is not None or not BSA_AVAILABLE or not XATTN_AVAILABLE:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func( return flash_attn_varlen_func(
q, k, v, q, k, v,
@@ -304,32 +314,6 @@ class XAttentionBSAPolicy(SparsePolicy):
block_table=block_tables, block_table=block_tables,
) )
if not BSA_AVAILABLE:
# Fallback to flash attention if BSA not available
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
if not XATTN_AVAILABLE:
# Fallback to flash attention if xattn not available
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
from nanovllm.ops.xattn import xattn_estimate from nanovllm.ops.xattn import xattn_estimate
# Set DensityObserver mode on first layer # Set DensityObserver mode on first layer
@@ -477,8 +461,7 @@ class XAttentionBSAPolicy(SparsePolicy):
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1] causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool)) causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, "
f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}")
DensityObserver.record(layer_id, mask_trimmed, causal=True) DensityObserver.record(layer_id, mask_trimmed, causal=True)
return output return output
@@ -633,98 +616,108 @@ class XAttentionBSAPolicy(SparsePolicy):
l_chunks = [] l_chunks = []
num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk
# Get compute_stream for all compute kernels (like attention computation)
compute_stream = offload_engine.compute_stream
with nvtx.range("xattn_estimate_pass1"): with nvtx.range("xattn_estimate_pass1"):
slot = 0 slot = 0
# Process historical blocks (from CPU) # Process historical blocks (from CPU)
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks): for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
# Load K from CPU # Load K from CPU (on slot_transfer_stream)
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
# wait_slot_layer makes compute_stream wait for H2D transfer
offload_engine.wait_slot_layer(slot) offload_engine.wait_slot_layer(slot)
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim] # All compute kernels run on compute_stream (like attention computation)
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim] with torch.cuda.stream(compute_stream):
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim]
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
# GQA expansion # GQA expansion
num_kv_heads = K_chunk.shape[1] num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# KV offset in reshaped space
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
# Compute raw attention scores
attn_weights_kv = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False, # K 不完整,不能在这里用 causal
)
# Compute partial stats (带 causal mask)
m_partial, l_partial = softmax_compute_partial_stats(
attn_weights_kv,
reshaped_block_size,
segment_size,
scale,
chunk_start=chunk_start,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
m_chunks.append(m_partial)
l_chunks.append(l_partial)
offload_engine.record_slot_compute_done(slot)
del attn_weights_kv
# Process current chunk K (already on GPU) on compute_stream
with torch.cuda.stream(compute_stream):
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
K_current = k.unsqueeze(0).transpose(1, 2)
# GQA expansion for current chunk
num_kv_heads = K_current.shape[1]
if num_heads != num_kv_heads: if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) K_current = K_current.repeat_interleave(num_groups, dim=1)
# KV offset in reshaped space # Pad current K to alignment
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped curr_k_len = K_current.shape[2]
padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment
if padded_curr_k_len != curr_k_len:
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0)
# Compute raw attention scores # KV offset for current chunk
attn_weights_kv = flat_group_gemm_fuse_reshape( kv_offset_current = num_historical_blocks * kv_chunk_reshaped
Q, K_chunk, self.stride,
# Compute attention scores for current chunk
attn_weights_curr = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=chunk_start, chunk_start=chunk_start,
chunk_end=chunk_end, chunk_end=chunk_end,
is_causal=False, # K 不完整,不能在这里用 causal is_causal=False,
) )
# Compute partial stats (带 causal mask) # Compute partial stats for current chunk
m_partial, l_partial = softmax_compute_partial_stats( m_partial_curr, l_partial_curr = softmax_compute_partial_stats(
attn_weights_kv, attn_weights_curr,
reshaped_block_size, reshaped_block_size,
segment_size, segment_size,
scale, scale,
chunk_start=chunk_start, chunk_start=chunk_start,
kv_offset=kv_offset_reshaped, kv_offset=kv_offset_current,
is_causal=True, is_causal=True,
) )
m_chunks.append(m_partial) m_chunks.append(m_partial_curr)
l_chunks.append(l_partial) l_chunks.append(l_partial_curr)
offload_engine.record_slot_compute_done(slot) del attn_weights_curr
del attn_weights_kv
# Process current chunk K (already on GPU)
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
K_current = k.unsqueeze(0).transpose(1, 2)
# GQA expansion for current chunk
num_kv_heads = K_current.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_current = K_current.repeat_interleave(num_groups, dim=1)
# Pad current K to alignment
curr_k_len = K_current.shape[2]
padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment
if padded_curr_k_len != curr_k_len:
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0)
# KV offset for current chunk
kv_offset_current = num_historical_blocks * kv_chunk_reshaped
# Compute attention scores for current chunk
attn_weights_curr = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False,
)
# Compute partial stats for current chunk
m_partial_curr, l_partial_curr = softmax_compute_partial_stats(
attn_weights_curr,
reshaped_block_size,
segment_size,
scale,
chunk_start=chunk_start,
kv_offset=kv_offset_current,
is_causal=True,
)
m_chunks.append(m_partial_curr)
l_chunks.append(l_partial_curr)
del attn_weights_curr
# ================================================================ # ================================================================
# Step 2: Merge all partial stats # Step 2: Merge all partial stats (on compute_stream)
# ================================================================ # ================================================================
with nvtx.range("xattn_estimate_merge"): with torch.cuda.stream(compute_stream):
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks) with nvtx.range("xattn_estimate_merge"):
del m_chunks, l_chunks m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
del m_chunks, l_chunks
# ================================================================ # ================================================================
# Step 3: Second pass - normalize and compute block sums # Step 3: Second pass - normalize and compute block sums
@@ -736,30 +729,61 @@ class XAttentionBSAPolicy(SparsePolicy):
# Process historical blocks again # Process historical blocks again
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks): for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
# Load K from CPU (on slot_transfer_stream)
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
# wait_slot_layer makes compute_stream wait for H2D transfer
offload_engine.wait_slot_layer(slot) offload_engine.wait_slot_layer(slot)
k_block = offload_engine.get_k_for_slot(slot) # All compute kernels run on compute_stream
K_chunk = k_block.transpose(1, 2) with torch.cuda.stream(compute_stream):
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
num_kv_heads = K_chunk.shape[1] num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads: if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
# Recompute attention scores (trade-off: compute vs memory) # Recompute attention scores (trade-off: compute vs memory)
attn_weights_kv = flat_group_gemm_fuse_reshape( attn_weights_kv = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride, Q, K_chunk, self.stride,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False,
)
# Normalize with global stats and compute block sums
block_sum_kv = softmax_normalize_and_block_sum(
attn_weights_kv,
m_global,
l_global,
reshaped_block_size,
segment_size,
chunk_start=chunk_start,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
attn_sum_per_kv.append(block_sum_kv)
offload_engine.record_slot_compute_done(slot)
del attn_weights_kv
# Process current chunk on compute_stream
with torch.cuda.stream(compute_stream):
# Recompute attention scores for current chunk
attn_weights_curr = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=chunk_start, chunk_start=chunk_start,
chunk_end=chunk_end, chunk_end=chunk_end,
is_causal=False, is_causal=False,
) )
# Normalize with global stats and compute block sums block_sum_curr = softmax_normalize_and_block_sum(
block_sum_kv = softmax_normalize_and_block_sum( attn_weights_curr,
attn_weights_kv,
m_global, m_global,
l_global, l_global,
reshaped_block_size, reshaped_block_size,
@@ -767,67 +791,42 @@ class XAttentionBSAPolicy(SparsePolicy):
chunk_start=chunk_start, chunk_start=chunk_start,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad, real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale, scale=scale,
kv_offset=kv_offset_reshaped, kv_offset=kv_offset_current,
is_causal=True, is_causal=True,
) )
attn_sum_per_kv.append(block_sum_kv) attn_sum_per_kv.append(block_sum_curr)
del attn_weights_curr, K_current
offload_engine.record_slot_compute_done(slot)
del attn_weights_kv
# Process current chunk
# Recompute attention scores for current chunk
attn_weights_curr = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False,
)
block_sum_curr = softmax_normalize_and_block_sum(
attn_weights_curr,
m_global,
l_global,
reshaped_block_size,
segment_size,
chunk_start=chunk_start,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale,
kv_offset=kv_offset_current,
is_causal=True,
)
attn_sum_per_kv.append(block_sum_curr)
del attn_weights_curr, K_current
# ================================================================ # ================================================================
# Step 4: Concatenate block sums and select blocks # Step 4: Concatenate block sums and select blocks (on compute_stream)
# ================================================================ # ================================================================
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1) with torch.cuda.stream(compute_stream):
del attn_sum_per_kv, m_global, l_global attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
del attn_sum_per_kv, m_global, l_global
# Calculate q_block offset for find_blocks_chunked # Calculate q_block offset for find_blocks_chunked
# This is the number of BSA blocks before Q in the full sequence # This is the number of BSA blocks before Q in the full sequence
num_blocks_per_chunk = q_reshaped_len // reshaped_block_size num_blocks_per_chunk = q_reshaped_len // reshaped_block_size
current_index = k_block_num - q_block_num # Q starts at this BSA block index current_index = k_block_num - q_block_num # Q starts at this BSA block index
with nvtx.range("xattn_find_blocks"): with nvtx.range("xattn_find_blocks"):
mask = find_blocks_chunked( mask = find_blocks_chunked(
attn_sum_concat, attn_sum_concat,
current_index=current_index, current_index=current_index,
threshold=self.threshold, threshold=self.threshold,
num_to_choose=None, num_to_choose=None,
decoding=False, decoding=False,
mode="prefill", mode="prefill",
causal=True, causal=True,
)
# Apply causal mask post-processing (same as xattn.py lines 1300-1306)
mask[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0),
mask[:, :, -q_block_num:, -q_block_num:],
False,
) )
# Apply causal mask post-processing (same as xattn.py lines 1300-1306)
mask[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0),
mask[:, :, -q_block_num:, -q_block_num:],
False,
)
# ================================================================ # ================================================================
# Step 5: Record density (only on layer 0) # Step 5: Record density (only on layer 0)
# ================================================================ # ================================================================
@@ -908,20 +907,21 @@ class XAttentionBSAPolicy(SparsePolicy):
if available_blocks and available_blocks[-1] not in selected_block_ids: if available_blocks and available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1]) selected_block_ids.append(available_blocks[-1])
# Record communication density
if available_blocks:
DensityObserver.record_comm_density(
layer_id,
selected_cpu_blocks=len(selected_block_ids),
total_cpu_blocks=len(available_blocks),
)
# Update statistics (only for layer 0 to avoid overcounting) # Update statistics (only for layer 0 to avoid overcounting)
if layer_id == 0 and available_blocks: if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks) self._stats_total_available_blocks += len(available_blocks)
self._stats_total_selected_blocks += len(selected_block_ids) self._stats_total_selected_blocks += len(selected_block_ids)
self._stats_num_chunks += 1 self._stats_num_chunks += 1
# Record communication density to DensityObserver
# Comm density = selected_cpu_blocks / available_cpu_blocks
# This is different from compute density (BSA block granularity)
DensityObserver.record_comm_density(
layer_id=layer_id,
selected_cpu_blocks=len(selected_block_ids),
total_cpu_blocks=len(available_blocks),
)
# Log per-chunk density # Log per-chunk density
chunk_density = len(selected_block_ids) / len(available_blocks) chunk_density = len(selected_block_ids) / len(available_blocks)
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, " logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "

View File

@@ -266,14 +266,31 @@ class DensityObserver(Observer):
return 0.0 return 0.0
return sum(all_densities) / len(all_densities) return sum(all_densities) / len(all_densities)
@classmethod
def get_per_layer_comm_density(cls) -> Dict[int, float]:
"""
获取每层的 communication density (CPU block 粒度)。
Returns:
Dict[layer_id, avg_comm_density]
"""
result = {}
for layer_id, densities in cls._layer_comm_densities.items():
if densities:
result[layer_id] = sum(densities) / len(densities)
return result
@classmethod @classmethod
def get_summary(cls) -> dict: def get_summary(cls) -> dict:
"""返回统计摘要""" """返回统计摘要"""
per_layer = cls.get_per_layer_density() per_layer = cls.get_per_layer_density()
per_layer_comm = cls.get_per_layer_comm_density()
return { return {
"mode": cls._mode, "mode": cls._mode,
"overall_density": cls.get_overall_density(), "overall_compute_density": cls.get_overall_density(),
"per_layer_density": per_layer, "overall_comm_density": cls.get_overall_comm_density(),
"per_layer_compute_density": per_layer,
"per_layer_comm_density": per_layer_comm,
"num_layers": len(per_layer), "num_layers": len(per_layer),
"last_mask_shape": { "last_mask_shape": {
"q_blocks": cls._last_q_blocks, "q_blocks": cls._last_q_blocks,
@@ -301,7 +318,9 @@ class DensityObserver(Observer):
print(f"[DensityObserver] Mode: {cls._mode}") print(f"[DensityObserver] Mode: {cls._mode}")
print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})") print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
if overall_comm > 0: if overall_comm > 0:
print(f" Comm density: {overall_comm:.4f}") # Offload mode: show both densities with explanation
print(f" Comm density: {overall_comm:.4f} (CPU block granularity)")
print(f" Savings ratio: {1 - overall_comm:.1%} H2D transfer reduction")
print(f" Num layers: {len(per_layer)}") print(f" Num layers: {len(per_layer)}")
# 输出 layer 0 的 density 用于对比 # 输出 layer 0 的 density 用于对比
if 0 in per_layer: if 0 in per_layer:

View File

@@ -1,314 +0,0 @@
"""
Benchmark: block_size impact on XAttention estimate phase performance.
This script tests how different block_size values affect the performance of:
1. flat_group_gemm_fuse_reshape (estimate GEMM)
2. softmax_fuse_block_sum (estimate softmax + block aggregation)
Key insight: The current select_blocks uses global kvcache_block_size for estimation,
which may not be optimal for the Triton kernels.
"""
import sys
import os
import torch
import time
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# Configuration
# ============================================================
# Test configurations
BLOCK_SIZES = [64, 128, 256, 512] # BSA optimal is 128
STRIDE = 8
NUM_WARMUP = 3
NUM_RUNS = 10
# Model dimensions (Llama-3.1-8B-Instruct)
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
# Context lengths to test
CONTEXT_LENGTHS = [16384, 32768, 65536] # 16K, 32K, 64K
# ============================================================
# Benchmark Functions
# ============================================================
def benchmark_flat_group_gemm(Q, K, stride, block_size, num_warmup=3, num_runs=10):
"""
Benchmark flat_group_gemm_fuse_reshape kernel.
Args:
Q: [batch, heads, q_len, head_dim]
K: [batch, heads, k_len, head_dim]
stride: Stride for reshape
block_size: Block size (affects alignment requirements)
Returns:
(avg_time_ms, output_tensor)
"""
q_len = Q.shape[2]
k_len = K.shape[2]
# Compute reshaped dimensions
reshaped_q_len = q_len // stride
reshaped_k_len = k_len // stride
reshaped_block_size = block_size // stride
# Warmup
for _ in range(num_warmup):
_ = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=reshaped_q_len,
is_causal=False,
)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_runs):
output = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=reshaped_q_len,
is_causal=False,
)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_runs * 1000
return avg_time_ms, output
def benchmark_softmax_fuse_block_sum(attn_weights, reshaped_block_size, num_warmup=3, num_runs=10):
"""
Benchmark softmax_fuse_block_sum kernel.
Args:
attn_weights: [batch, heads, q_len, k_len] attention weights
reshaped_block_size: Block size in reshaped space
Returns:
avg_time_ms
"""
batch_size, num_heads, q_len, k_len = attn_weights.shape
head_dim = HEAD_DIM
stride = STRIDE
norm = 1.0
# segment_size must divide k_len and be >= reshaped_block_size
segment_size = min(4096, reshaped_block_size)
# Ensure k_len is divisible by segment_size
if k_len % segment_size != 0:
# Pad k_len
pad_size = segment_size - (k_len % segment_size)
attn_weights = torch.nn.functional.pad(attn_weights, (0, pad_size), value=0)
k_len = attn_weights.shape[3]
# Scale factor
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Warmup
for _ in range(num_warmup):
_ = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
segment_size,
chunk_start=0,
chunk_end=q_len,
real_q_len=q_len,
scale=scale,
is_causal=False,
)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_runs):
output = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
segment_size,
chunk_start=0,
chunk_end=q_len,
real_q_len=q_len,
scale=scale,
is_causal=False,
)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_runs * 1000
return avg_time_ms
def run_estimate_benchmark(q_len, k_len, block_size, stride=STRIDE):
"""
Run full estimate benchmark for given configuration.
Args:
q_len: Query length
k_len: Key length (usually same as q_len for current chunk scenario)
block_size: Block size to test
stride: Stride for reshape
Returns:
dict with timing results
"""
# Create random Q and K tensors
# Shape: [batch, heads, seq_len, head_dim]
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
K = torch.randn(1, NUM_HEADS, k_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
reshaped_block_size = block_size // stride
reshaped_q_len = q_len // stride
reshaped_k_len = k_len // stride
# Benchmark GEMM
gemm_time, attn_weights = benchmark_flat_group_gemm(
Q, K, stride, block_size,
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
)
# Benchmark softmax + block sum
softmax_time = benchmark_softmax_fuse_block_sum(
attn_weights, reshaped_block_size,
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
)
# Clean up
del Q, K, attn_weights
torch.cuda.empty_cache()
return {
"q_len": q_len,
"k_len": k_len,
"block_size": block_size,
"reshaped_block_size": reshaped_block_size,
"gemm_time_ms": gemm_time,
"softmax_time_ms": softmax_time,
"total_time_ms": gemm_time + softmax_time,
}
# ============================================================
# Main Benchmark
# ============================================================
def main():
import argparse
parser = argparse.ArgumentParser(description="Benchmark block_size impact on estimate phase")
parser.add_argument("--gpu", type=int, default=0, help="GPU to use")
parser.add_argument("--ctx-len", type=int, default=None,
help="Single context length to test (default: test multiple)")
args = parser.parse_args()
# Set GPU
torch.cuda.set_device(args.gpu)
device_name = torch.cuda.get_device_name(args.gpu)
print(f"Using GPU {args.gpu}: {device_name}")
print()
# Determine context lengths to test
if args.ctx_len:
context_lengths = [args.ctx_len]
else:
context_lengths = CONTEXT_LENGTHS
print("=" * 80)
print("Benchmark: block_size impact on XAttention estimate phase")
print("=" * 80)
print(f"Configuration:")
print(f" NUM_HEADS: {NUM_HEADS}")
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
print(f" HEAD_DIM: {HEAD_DIM}")
print(f" STRIDE: {STRIDE}")
print(f" BLOCK_SIZES: {BLOCK_SIZES}")
print(f" NUM_WARMUP: {NUM_WARMUP}")
print(f" NUM_RUNS: {NUM_RUNS}")
print()
all_results = []
for ctx_len in context_lengths:
print(f"\n{'='*80}")
print(f"Context Length: {ctx_len // 1024}K ({ctx_len} tokens)")
print(f"{'='*80}")
# Pad to alignment
alignment = STRIDE * 128 # Triton BLOCK_M requirement
padded_len = ((ctx_len + alignment - 1) // alignment) * alignment
print(f"Padded to: {padded_len} tokens (alignment={alignment})")
print()
results = []
for block_size in BLOCK_SIZES:
print(f"Testing block_size={block_size} (reshaped={block_size // STRIDE})...", end=" ")
try:
result = run_estimate_benchmark(padded_len, padded_len, block_size)
results.append(result)
print(f"GEMM={result['gemm_time_ms']:.2f}ms, "
f"Softmax={result['softmax_time_ms']:.2f}ms, "
f"Total={result['total_time_ms']:.2f}ms")
except Exception as e:
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
if results:
all_results.extend(results)
# Print summary table for this context length
print(f"\n--- Summary for {ctx_len // 1024}K context ---")
print(f"{'block_size':>12} {'reshaped':>10} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12} {'Speedup':>10}")
print("-" * 74)
baseline_total = results[0]["total_time_ms"]
for r in results:
speedup = baseline_total / r["total_time_ms"]
print(f"{r['block_size']:>12} {r['reshaped_block_size']:>10} "
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
f"{r['total_time_ms']:>12.2f} {speedup:>9.2f}x")
# Final summary across all context lengths
if len(context_lengths) > 1:
print(f"\n{'='*80}")
print("OVERALL SUMMARY")
print(f"{'='*80}")
print(f"{'ctx_len':>10} {'block_size':>12} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12}")
print("-" * 64)
for r in all_results:
print(f"{r['q_len']//1024:>9}K {r['block_size']:>12} "
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
f"{r['total_time_ms']:>12.2f}")
# Find optimal block_size for softmax
print(f"\n{'='*80}")
print("ANALYSIS: Optimal block_size for softmax_fuse_block_sum")
print(f"{'='*80}")
for ctx_len in context_lengths:
ctx_results = [r for r in all_results if r["q_len"] == ((ctx_len + STRIDE * 128 - 1) // (STRIDE * 128)) * (STRIDE * 128)]
if ctx_results:
best = min(ctx_results, key=lambda x: x["softmax_time_ms"])
worst = max(ctx_results, key=lambda x: x["softmax_time_ms"])
improvement = worst["softmax_time_ms"] / best["softmax_time_ms"]
print(f"Context {ctx_len // 1024}K:")
print(f" Best: block_size={best['block_size']} ({best['softmax_time_ms']:.2f}ms)")
print(f" Worst: block_size={worst['block_size']} ({worst['softmax_time_ms']:.2f}ms)")
print(f" Potential improvement: {improvement:.2f}x")
print("\nbench_estimate_block_size: DONE")
if __name__ == "__main__":
main()

View File

@@ -1,757 +0,0 @@
"""
Custom Qwen3 implementation using only torch and transformers.
This file provides a clean reference implementation for understanding the model computation graph.
Computation Graph:
==================
Input: token_ids [batch, seq_len]
┌─────────────┐
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
└─────────────┘
hidden_states [batch, seq_len, hidden_size]
┌─────────────────────────────────────────────────────────┐
│ Decoder Layer (x N) │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Self Attention Block │ │
│ │ │ │
│ │ input_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3Attention │ │ │
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
│ │ │ V = v_proj(x) → reshape │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ attn_output = attention(Q, K, V) │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ output = o_proj(attn_output) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + attn_output │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ MLP Block │ │
│ │ │ │
│ │ post_attention_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3MLP │ │ │
│ │ │ gate = gate_proj(x) │ │ │
│ │ │ up = up_proj(x) │ │ │
│ │ │ output = down_proj(silu(gate) * up) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + mlp_output │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
┌─────────────┐
│ norm │ final RMSNorm
└─────────────┘
┌─────────────┐
│ lm_head │ [hidden_size, vocab_size]
└─────────────┘
logits [batch, seq_len, vocab_size]
"""
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Qwen3RMSNorm(nn.Module):
"""RMSNorm implementation."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
class Qwen3RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
position_ids: Position indices [batch, seq_len]
Returns:
cos, sin: [batch, seq_len, head_dim]
"""
# inv_freq: [dim/2]
# position_ids: [batch, seq_len]
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
# freqs: [batch, dim/2, seq_len]
freqs = inv_freq_expanded @ position_ids_expanded
# freqs: [batch, seq_len, dim/2]
freqs = freqs.transpose(1, 2)
# Duplicate for full head_dim: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(x.dtype)
sin = emb.sin().to(x.dtype)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to Q and K.
Args:
q: [batch, num_heads, seq_len, head_dim]
k: [batch, num_kv_heads, seq_len, head_dim]
cos: [batch, seq_len, head_dim]
sin: [batch, seq_len, head_dim]
Returns:
q_embed, k_embed with same shapes as inputs
"""
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3Attention(nn.Module):
"""
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
apply_rotary_pos_emb(Q, K)
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
attention_bias: bool = False,
rms_norm_eps: float = 1e-6,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = head_dim
self.num_kv_groups = num_attention_heads // num_key_value_heads
self.layer_idx = layer_idx
# Scaling factor
self.scaling = head_dim ** -0.5
# QKV projections
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
# QK normalization (Qwen3 specific)
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
# Rotary embeddings
self.rotary_emb = Qwen3RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
past_key_value: (k_cache, v_cache) from previous steps
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V tensors for debugging
Returns:
output: [batch, seq_len, hidden_size]
past_key_value: Updated cache (if use_cache=True)
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
"""
batch_size, seq_len, _ = hidden_states.shape
# === QKV Projections ===
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
# Reshape to [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# === QK Normalization (Qwen3 specific) ===
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to [batch, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# === Rotary Position Embeddings ===
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# === KV Cache Update ===
if past_key_value is not None:
k_cache, v_cache = past_key_value
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_past_key_value = (k, v) if use_cache else None
# === Grouped Query Attention (expand KV heads if needed) ===
if self.num_kv_groups > 1:
# Repeat KV for each query group
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
# === Attention Computation (using SDPA for memory efficiency) ===
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
# is_causal only works when q_len == kv_len (prefill), not during decode
q_len, kv_len = q.shape[2], k.shape[2]
is_causal = (q_len == kv_len) and (q_len > 1)
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=is_causal,
scale=self.scaling,
) # [batch, num_heads, seq_len, head_dim]
# === Output Projection ===
# Transpose back and reshape
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
output = self.o_proj(attn_output)
# Optional QKV output for debugging
qkv_dict = None
if output_qkv:
qkv_dict = {
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
}
return output, new_past_key_value, qkv_dict
class Qwen3MLP(nn.Module):
"""
Qwen3 MLP with SwiGLU activation.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
└──► up_proj ──► up [batch, seq_len, intermediate_size]
silu(gate) * up
down_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(F.silu(gate) * up)
class Qwen3DecoderLayer(nn.Module):
"""Single Qwen3 Decoder Layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
layer_idx: int = 0,
):
super().__init__()
self.layer_idx = layer_idx
# Pre-attention LayerNorm
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Self-attention
self.self_attn = Qwen3Attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
attention_bias=attention_bias,
rms_norm_eps=rms_norm_eps,
layer_idx=layer_idx,
)
# Post-attention LayerNorm
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: Causal attention mask
past_key_value: KV cache for this layer
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V for debugging
Returns:
hidden_states: [batch, seq_len, hidden_size]
past_key_value: Updated cache
qkv_dict: QKV tensors (if output_qkv=True)
"""
# === Self Attention Block ===
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, new_past_key_value, qkv_dict = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_qkv=output_qkv,
)
hidden_states = residual + attn_output
# === MLP Block ===
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, new_past_key_value, qkv_dict
class Qwen3Model(nn.Module):
"""Qwen3 Transformer Model (without LM head)."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers
# Token embeddings
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# Decoder layers
self.layers = nn.ModuleList([
Qwen3DecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
layer_idx=i,
)
for i in range(num_hidden_layers)
])
# Final LayerNorm
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
position_ids: [batch, seq_len]
attention_mask: [batch, seq_len] or pre-computed 4D mask
past_key_values: List of (k, v) tuples for each layer
use_cache: Whether to return new cache
output_qkv_layers: List of layer indices to output QKV for
Returns:
hidden_states: [batch, seq_len, hidden_size]
new_past_key_values: Updated cache
qkv_outputs: {layer_idx: qkv_dict}
"""
batch_size, seq_len = input_ids.shape
# Embedding
hidden_states = self.embed_tokens(input_ids)
# Position IDs
if position_ids is None:
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Attention mask (create causal mask if not provided)
if attention_mask is None or attention_mask.dim() == 2:
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
causal_mask = torch.triu(
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
diagonal=kv_seq_len - seq_len + 1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
# Initialize cache list
new_past_key_values = [] if use_cache else None
qkv_outputs = {} if output_qkv_layers else None
# Decoder layers
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values else None
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
hidden_states, new_kv, qkv_dict = layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_kv,
use_cache=use_cache,
output_qkv=output_qkv,
)
if use_cache:
new_past_key_values.append(new_kv)
if qkv_dict is not None:
qkv_outputs[i] = qkv_dict
# Final norm
hidden_states = self.norm(hidden_states)
return hidden_states, new_past_key_values, qkv_outputs
class Qwen3ForCausalLM(nn.Module):
"""Qwen3 Model with Language Modeling head."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
tie_word_embeddings: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
# Transformer model
self.model = Qwen3Model(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
)
# LM head
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
... (same as Qwen3Model)
Returns:
logits: [batch, seq_len, vocab_size]
past_key_values: Updated KV cache
qkv_outputs: QKV tensors for specified layers
"""
hidden_states, new_past_key_values, qkv_outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_qkv_layers=output_qkv_layers,
)
logits = self.lm_head(hidden_states)
return logits, new_past_key_values, qkv_outputs
@classmethod
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
"""
Load weights from a pretrained Qwen3 model.
Args:
model_path: Path to model directory containing config.json and model weights
dtype: Data type for model weights
Returns:
Initialized Qwen3ForCausalLM model
"""
import json
import os
from safetensors.torch import load_file
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
# Create model
model = cls(
vocab_size=config["vocab_size"],
hidden_size=config["hidden_size"],
intermediate_size=config["intermediate_size"],
num_hidden_layers=config["num_hidden_layers"],
num_attention_heads=config["num_attention_heads"],
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
max_position_embeddings=config.get("max_position_embeddings", 32768),
rope_theta=config.get("rope_theta", 10000.0),
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
attention_bias=config.get("attention_bias", False),
mlp_bias=config.get("mlp_bias", False),
tie_word_embeddings=config.get("tie_word_embeddings", True),
)
# Load weights
weight_files = sorted([
f for f in os.listdir(model_path)
if f.endswith(".safetensors")
])
state_dict = {}
for wf in weight_files:
state_dict.update(load_file(os.path.join(model_path, wf)))
# Load into model
model.load_state_dict(state_dict, strict=False)
# Tie lm_head weights to embed_tokens if configured
if model.tie_word_embeddings:
model.lm_head.weight = model.model.embed_tokens.weight
model = model.to(dtype)
return model
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 32,
temperature: float = 1.0,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""Simple autoregressive generation."""
device = input_ids.device
batch_size, seq_len = input_ids.shape
past_key_values = None
generated = input_ids.clone()
for _ in range(max_new_tokens):
if past_key_values is None:
current_input = generated
else:
current_input = generated[:, -1:]
logits, past_key_values, _ = self(
input_ids=current_input,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = logits[:, -1, :]
if temperature > 0 and do_sample:
next_token_logits = next_token_logits / temperature
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def print_computation_graph():
"""Print the computation graph for reference."""
print(__doc__)
if __name__ == "__main__":
print_computation_graph()

View File

@@ -1,151 +0,0 @@
#!/usr/bin/env python3
"""
Test: Pre-allocated chunk pair graphs for block sparse attention.
Each (Q_chunk, K_chunk) pair has its own captured CUDA graph.
Zero copy_() during replay - all data pre-filled.
Usage:
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph.py
"""
from dataclasses import dataclass
from typing import List, Optional
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@dataclass
class ChunkAttentionGraph:
"""Container for a captured chunk attention graph."""
graph: torch.cuda.CUDAGraph
static_q: torch.Tensor
static_k: torch.Tensor
static_v: torch.Tensor
static_output: torch.Tensor
static_lse: torch.Tensor
causal: bool
def capture_chunk_attention_graph(
chunk_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
scale: float,
device: torch.device,
dtype: torch.dtype,
causal: bool = False,
) -> ChunkAttentionGraph:
"""Capture a CUDA graph for single chunk attention."""
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_q.normal_()
static_k.normal_()
static_v.normal_()
# Warmup
with torch.inference_mode():
for _ in range(3):
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
return ChunkAttentionGraph(
graph=graph,
static_q=static_q,
static_k=static_k,
static_v=static_v,
static_output=static_output,
static_lse=static_lse,
causal=causal,
)
def main():
device = torch.device("cuda")
dtype = torch.bfloat16
chunk_size = 64
num_chunks = 4
num_heads = 8
num_kv_heads = 8
head_dim = 64
scale = 1.0 / (head_dim ** 0.5)
seq_len = chunk_size * num_chunks
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}")
print(f"Total graphs: {num_chunks * (num_chunks + 1) // 2}")
# Test data
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
# Reference
with torch.inference_mode():
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
# Capture all graphs
graphs: List[List[Optional[ChunkAttentionGraph]]] = [[None] * num_chunks for _ in range(num_chunks)]
for q_idx in range(num_chunks):
for k_idx in range(q_idx + 1):
graphs[q_idx][k_idx] = capture_chunk_attention_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype,
causal=(k_idx == q_idx)
)
print("All graphs captured")
# Pre-fill static tensors
for q_idx in range(num_chunks):
for k_idx in range(q_idx + 1):
g = graphs[q_idx][k_idx]
g.static_q.copy_(full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size])
g.static_k.copy_(full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
g.static_v.copy_(full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
print("Static tensors pre-filled")
# Replay and merge
chunked_output = torch.zeros_like(full_output)
for q_idx in range(num_chunks):
acc_out, acc_lse = None, None
for k_idx in range(q_idx + 1):
g = graphs[q_idx][k_idx]
g.graph.replay()
out, lse = g.static_output.clone(), g.static_lse.clone()
if acc_out is None:
acc_out, acc_lse = out, lse
else:
with torch.inference_mode():
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
torch.cuda.synchronize()
# Compare
all_pass = True
for q_idx in range(num_chunks):
s, e = q_idx * chunk_size, (q_idx + 1) * chunk_size
diff = (full_output[:, s:e] - chunked_output[:, s:e]).abs().max().item()
status = "" if diff < 1e-2 else ""
print(f"Q[{q_idx}]: max_diff={diff:.2e} {status}")
if diff >= 1e-2:
all_pass = False
print("✅ PASSED" if all_pass else "❌ FAILED")
if __name__ == "__main__":
main()

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env python3
"""
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
Key insight: LLM layers have identical computation structure.
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
Usage:
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
"""
from dataclasses import dataclass
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@dataclass
class ReusableChunkGraph:
"""A single graph that can be reused with copy_() updates."""
graph: torch.cuda.CUDAGraph
static_q: torch.Tensor
static_k: torch.Tensor
static_v: torch.Tensor
static_output: torch.Tensor
static_lse: torch.Tensor
def capture_reusable_graph(
chunk_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
scale: float,
device: torch.device,
dtype: torch.dtype,
causal: bool,
) -> ReusableChunkGraph:
"""Capture ONE graph to be reused for all chunk pairs."""
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_q.normal_()
static_k.normal_()
static_v.normal_()
# Warmup
with torch.inference_mode():
for _ in range(3):
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
return ReusableChunkGraph(
graph=graph,
static_q=static_q,
static_k=static_k,
static_v=static_v,
static_output=static_output,
static_lse=static_lse,
)
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Replay graph after updating static tensors with copy_()."""
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
return graph.static_output.clone(), graph.static_lse.clone()
def main():
device = torch.device("cuda")
dtype = torch.bfloat16
chunk_size = 64
num_chunks = 4
num_layers = 3 # Simulate multiple layers
num_heads = 8
num_kv_heads = 8
head_dim = 64
scale = 1.0 / (head_dim ** 0.5)
seq_len = chunk_size * num_chunks
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
# Capture only 2 graphs
graph_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
)
graph_non_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
)
print("2 graphs captured (causal + non-causal)")
all_pass = True
for layer_id in range(num_layers):
# Different Q/K/V for each layer (simulating different layer outputs)
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
# Reference: full causal attention
with torch.inference_mode():
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
# Chunked with graph reuse
chunked_output = torch.zeros_like(full_output)
for q_idx in range(num_chunks):
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
acc_out, acc_lse = None, None
for k_idx in range(q_idx + 1):
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
# Reuse graph with copy_()
graph = graph_causal if k_idx == q_idx else graph_non_causal
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
if acc_out is None:
acc_out, acc_lse = out, lse
else:
with torch.inference_mode():
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
torch.cuda.synchronize()
# Compare
max_diff = (full_output - chunked_output).abs().max().item()
status = "" if max_diff < 1e-2 else ""
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
if max_diff >= 1e-2:
all_pass = False
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
if __name__ == "__main__":
main()

View File

@@ -1,357 +0,0 @@
#!/usr/bin/env python3
"""
CUDA Graph Memory Analysis Test
This script analyzes the memory overhead of CUDA Graph at each stage:
1. Model loading
2. StaticCache allocation
3. Warmup runs
4. Graph capture
5. Graph replay
Usage:
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048
"""
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache
def get_memory_mb():
"""Get current allocated memory in MB."""
return torch.cuda.memory_allocated() / 1024**2
def get_memory_gb():
"""Get current allocated memory in GB."""
return torch.cuda.memory_allocated() / 1024**3
def get_peak_memory_gb():
"""Get peak allocated memory in GB."""
return torch.cuda.max_memory_allocated() / 1024**3
def print_separator(title=None):
"""Print a separator line."""
if title:
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}")
else:
print("-" * 70)
def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1):
"""
Test memory usage at each stage of CUDA Graph setup.
Args:
model_path: Path to the model
max_cache_len: Maximum cache length for StaticCache
batch_size: Batch size for inference
"""
print_separator("CUDA Graph Memory Analysis")
print(f"Model: {model_path}")
print(f"Max cache length: {max_cache_len}")
print(f"Batch size: {batch_size}")
results = {}
# Stage 0: Initial
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
results["initial"] = get_memory_mb()
# Stage 1: Load model
print_separator("Stage 1: Model Loading")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
results["after_model"] = get_memory_mb()
model_size = results["after_model"] - results["initial"]
print(f" Memory: {results['after_model']:.0f} MB")
print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)")
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Stage 2: Allocate StaticCache
print_separator("Stage 2: StaticCache Allocation")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
static_cache = StaticCache(
config=config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
)
results["after_cache"] = get_memory_mb()
cache_size = results["after_cache"] - before
print(f" Memory: {results['after_cache']:.0f} MB")
print(f" StaticCache size: {cache_size:.0f} MB")
# Calculate theoretical cache size
num_layers = config.num_hidden_layers
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
head_dim = config.hidden_size // config.num_attention_heads
dtype_size = 2 # bfloat16
theoretical_cache = (
num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size
) / (1024**2)
print(f" Theoretical: {theoretical_cache:.0f} MB")
print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)")
# Stage 3: Prepare static tensors
print_separator("Stage 3: Static Tensor Allocation")
before = get_memory_mb()
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
results["after_tensors"] = get_memory_mb()
tensor_size = results["after_tensors"] - before
print(f" Memory: {results['after_tensors']:.0f} MB")
print(f" Static tensors: {tensor_size:.2f} MB (negligible)")
# Stage 4: Warmup runs
print_separator("Stage 4: Warmup Runs (3 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for i in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
results["after_warmup"] = get_memory_mb()
results["warmup_peak"] = get_peak_memory_gb() * 1024
warmup_size = results["after_warmup"] - before
print(f" Memory: {results['after_warmup']:.0f} MB")
print(f" Peak: {results['warmup_peak']:.0f} MB")
print(f" Warmup overhead: {warmup_size:.0f} MB")
# Stage 5: CUDA Graph capture
print_separator("Stage 5: CUDA Graph Capture")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
static_logits = outputs.logits
torch.cuda.synchronize()
results["after_capture"] = get_memory_mb()
results["capture_peak"] = get_peak_memory_gb() * 1024
capture_size = results["after_capture"] - before
print(f" Memory: {results['after_capture']:.0f} MB")
print(f" Peak: {results['capture_peak']:.0f} MB")
print(f" Graph capture overhead: {capture_size:.0f} MB")
# Stage 6: Graph replay
print_separator("Stage 6: Graph Replay (10 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for _ in range(10):
static_input_ids.fill_(1)
static_cache_position.fill_(0)
graph.replay()
torch.cuda.synchronize()
results["after_replay"] = get_memory_mb()
results["replay_peak"] = get_peak_memory_gb() * 1024
replay_change = results["after_replay"] - before
print(f" Memory: {results['after_replay']:.0f} MB")
print(f" Peak: {results['replay_peak']:.0f} MB")
print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)")
# Summary
print_separator("SUMMARY")
total_overhead = results["after_capture"] - results["after_model"]
print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}")
print("-" * 50)
print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}")
print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}")
print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}")
print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}")
print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}")
print("-" * 50)
print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}")
print_separator("KEY FINDINGS")
print(f" 1. Model size: {model_size/1024:.2f} GB")
print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)")
print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)")
print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)")
print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB")
return results
def test_cache_length_scaling(model_path: str, cache_lengths: list):
"""
Test how memory scales with different cache lengths.
Args:
model_path: Path to the model
cache_lengths: List of cache lengths to test
"""
print_separator("Cache Length Scaling Test")
print(f"Model: {model_path}")
print(f"Cache lengths: {cache_lengths}")
# Load model once
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
model_mem = get_memory_mb()
results = []
for cache_len in cache_lengths:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Create cache and capture graph
static_cache = StaticCache(
config=config,
max_batch_size=1,
max_cache_len=cache_len,
device=device,
dtype=dtype,
)
static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
with torch.inference_mode():
# Warmup
for _ in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
total_mem = get_memory_mb()
overhead = total_mem - model_mem
results.append((cache_len, total_mem, overhead))
del static_cache, graph
torch.cuda.empty_cache()
# Print results
print()
print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}")
print("-" * 60)
for cache_len, total, overhead in results:
per_1k = overhead / (cache_len / 1000)
print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}")
return results
def main():
parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis")
parser.add_argument(
"--model",
type=str,
default="~/models/Qwen3-4B-Instruct-2507",
help="Model path",
)
parser.add_argument(
"--max-cache-len",
type=int,
default=1024,
help="Maximum cache length",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size",
)
parser.add_argument(
"--test-scaling",
action="store_true",
help="Test cache length scaling",
)
args = parser.parse_args()
model_path = os.path.expanduser(args.model)
if not torch.cuda.is_available():
print("CUDA is not available!")
return
print(f"Device: cuda:{torch.cuda.current_device()}")
print(f"GPU: {torch.cuda.get_device_name()}")
if args.test_scaling:
cache_lengths = [256, 512, 1024, 2048, 4096]
test_cache_length_scaling(model_path, cache_lengths)
else:
test_memory_stages(model_path, args.max_cache_len, args.batch_size)
print("\ntest_cudagraph_memory: PASSED")
if __name__ == "__main__":
main()

View File

@@ -1,149 +0,0 @@
"""
Test: GPU-only density alignment verification
验证 xattn_bsa.py 中 GPU-only 路径的 density 计算是否与独立调用 xattn_estimate 一致。
流程:
1. 运行 GPU-only 推理,保存 Q, K, mask, attn_sums
2. 加载保存的数据,独立调用 xattn_estimate
3. 比较两者的 mask 和 density
"""
import torch
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.ops.xattn import xattn_estimate
# ============================================================
# 参数配置
# ============================================================
DATA_PATH = "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt"
# ============================================================
# 加载保存的数据
# ============================================================
print(f"Loading data from {DATA_PATH}")
data = torch.load(DATA_PATH, weights_only=False)
Q = data["Q"].cuda() # [1, num_heads, q_len, head_dim]
K = data["K"].cuda() # [1, num_heads, k_len, head_dim]
chunk_size = data["chunk_size"]
block_size = data["block_size"]
stride = data["stride"]
threshold = data["threshold"]
mask_saved = data["mask"] # [1, num_heads, valid_q_blocks, valid_k_blocks]
attn_sums_saved = data["attn_sums"]
q_len = data["q_len"]
k_len = data["k_len"]
valid_q_blocks = data["valid_q_blocks"]
valid_k_blocks = data["valid_k_blocks"]
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"q_len: {q_len}, k_len: {k_len}")
print(f"chunk_size: {chunk_size}, block_size: {block_size}, stride: {stride}, threshold: {threshold}")
print(f"valid_q_blocks: {valid_q_blocks}, valid_k_blocks: {valid_k_blocks}")
print(f"mask_saved shape: {mask_saved.shape}")
# ============================================================
# 独立调用 xattn_estimate
# ============================================================
print("\nCalling xattn_estimate independently...")
attn_sums_ext, mask_ext = xattn_estimate(
Q, K,
chunk_size=chunk_size,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
causal=True,
)
# Trim to valid blocks
mask_ext_valid = mask_ext[:, :, :valid_q_blocks, :valid_k_blocks]
attn_sums_ext_valid = attn_sums_ext[:, :, :valid_q_blocks, :valid_k_blocks]
print(f"mask_ext shape: {mask_ext.shape}")
print(f"mask_ext_valid shape: {mask_ext_valid.shape}")
# ============================================================
# 比较 attn_sums
# ============================================================
print("\n" + "=" * 60)
print("Comparing attn_sums")
print("=" * 60)
attn_sums_saved_gpu = attn_sums_saved.cuda()
attn_diff = (attn_sums_ext_valid - attn_sums_saved_gpu).abs()
print(f"attn_sums max diff: {attn_diff.max().item():.6e}")
print(f"attn_sums mean diff: {attn_diff.mean().item():.6e}")
# Check if attn_sums match
attn_match = attn_diff.max().item() < 1e-4
print(f"attn_sums match: {attn_match}")
# ============================================================
# 比较 mask
# ============================================================
print("\n" + "=" * 60)
print("Comparing mask")
print("=" * 60)
mask_saved_gpu = mask_saved.cuda()
mask_match = (mask_ext_valid == mask_saved_gpu).all().item()
print(f"mask exact match: {mask_match}")
if not mask_match:
diff_count = (mask_ext_valid != mask_saved_gpu).sum().item()
total_count = mask_ext_valid.numel()
print(f"mask diff count: {diff_count} / {total_count} ({diff_count/total_count*100:.2f}%)")
# ============================================================
# 计算 density
# ============================================================
print("\n" + "=" * 60)
print("Comparing density")
print("=" * 60)
# 计算 causal mask
q_offset_blocks = valid_k_blocks - valid_q_blocks
indices = torch.arange(valid_k_blocks, device=mask_ext_valid.device).unsqueeze(0)
q_indices = torch.arange(valid_q_blocks, device=mask_ext_valid.device).unsqueeze(1)
causal_mask = indices <= (q_indices + q_offset_blocks)
# Density from saved mask
total_saved = causal_mask.sum().item() * mask_saved_gpu.shape[0] * mask_saved_gpu.shape[1]
selected_saved = (mask_saved_gpu & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_saved = selected_saved / total_saved
# Density from external xattn_estimate
total_ext = causal_mask.sum().item() * mask_ext_valid.shape[0] * mask_ext_valid.shape[1]
selected_ext = (mask_ext_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_ext = selected_ext / total_ext
print(f"Saved density: {density_saved:.6f} (selected={selected_saved}, total={total_saved})")
print(f"External density: {density_ext:.6f} (selected={selected_ext}, total={total_ext})")
print(f"Density diff: {abs(density_saved - density_ext):.6f}")
# ============================================================
# 结论
# ============================================================
print("\n" + "=" * 60)
print("RESULT")
print("=" * 60)
if attn_match and mask_match:
print("✅ PASSED: GPU-only density matches external xattn_estimate")
else:
print("❌ FAILED: Mismatch detected")
if not attn_match:
print(" - attn_sums mismatch")
if not mask_match:
print(" - mask mismatch")

View File

@@ -1,442 +0,0 @@
"""
Test: Hierarchical Block Sum Estimation for XAttention
Verify that hierarchical estimation (small estimate_block_size + aggregation)
produces equivalent results to direct estimation (large block_size), while
being significantly faster.
Key changes validated:
1. Hierarchical block sum: estimate_block_size=1024 → aggregate to cpu_block_size=4096
2. Selection strategy: score + threshold (NOT mask + majority voting)
This test uses pure torch + xattn kernels, independent of nanovllm framework.
"""
import sys
import os
import torch
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# Configuration
# ============================================================
# Model dimensions (Llama-3.1-8B-Instruct style)
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
STRIDE = 8
# Block sizes
CPU_BLOCK_SIZE = 4096 # External CPU block size (fixed, for overlap)
ESTIMATE_BLOCK_SIZE = 1024 # Internal estimate block size (optimized)
# Selection parameters
THRESHOLD = 0.95 # Cumulative attention threshold
# ============================================================
# Hierarchical Estimation Implementation
# ============================================================
def compute_attention_scores(Q, K_blocks, stride):
"""
Compute attention scores for Q against multiple K blocks.
Args:
Q: [1, num_heads, q_len, head_dim]
K_blocks: List of K tensors, each [1, num_heads, block_size, head_dim]
stride: Stride for reshape
Returns:
attn_scores: [1, num_heads, q_reshaped, total_k_reshaped]
"""
q_len = Q.shape[2]
q_reshaped = q_len // stride
attn_chunks = []
for K_block in K_blocks:
# flat_group_gemm_fuse_reshape
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_block, stride,
chunk_start=0,
chunk_end=q_reshaped,
is_causal=False,
)
attn_chunks.append(attn_chunk)
# Concatenate along K dimension
attn_scores = torch.cat(attn_chunks, dim=-1)
return attn_scores
def hierarchical_block_sum(
attn_scores,
estimate_block_size,
cpu_block_size,
stride,
head_dim,
):
"""
Compute hierarchical block sums: fine-grained → aggregated to CPU block level.
Args:
attn_scores: [batch, heads, q_reshaped, k_reshaped]
estimate_block_size: Small block size for efficient softmax (e.g., 1024)
cpu_block_size: External CPU block size (e.g., 4096)
stride: Stride used in reshape
head_dim: Head dimension for scale computation
Returns:
cpu_block_scores: [batch, heads, num_cpu_blocks] - attention score per CPU block
"""
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
# Compute reshaped block sizes
reshaped_est_bs = estimate_block_size // stride # 1024/8 = 128
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
# Scale factor
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Segment size for softmax kernel
segment_size = min(4096, reshaped_est_bs)
# Step 1: Fine-grained softmax + block sum
block_sums_fine = softmax_fuse_block_sum(
attn_scores,
reshaped_est_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped,
real_q_len=q_reshaped,
scale=scale,
is_causal=False,
)
# block_sums_fine: [batch, heads, q_est_blocks, k_est_blocks]
q_est_blocks = block_sums_fine.shape[2]
k_est_blocks = block_sums_fine.shape[3]
# Step 2: Aggregate to CPU block level
# ratio = cpu_block_size / estimate_block_size = 4
ratio = cpu_block_size // estimate_block_size
num_cpu_blocks = k_est_blocks // ratio
# Reshape and sum along K dimension
# [batch, heads, q_est, k_est] → [batch, heads, q_est, num_cpu, ratio]
block_sums_coarse = block_sums_fine.view(
batch_size, num_heads, q_est_blocks, num_cpu_blocks, ratio
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
# Step 3: Sum over Q dimension (total attention from Q chunk to each K block)
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
return cpu_block_scores, block_sums_fine
def direct_block_sum(
attn_scores,
cpu_block_size,
stride,
head_dim,
):
"""
Compute block sums directly with CPU block size (baseline for comparison).
Args:
attn_scores: [batch, heads, q_reshaped, k_reshaped]
cpu_block_size: Block size (e.g., 4096)
stride: Stride used in reshape
head_dim: Head dimension for scale computation
Returns:
cpu_block_scores: [batch, heads, num_cpu_blocks]
"""
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
segment_size = min(4096, reshaped_cpu_bs)
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_cpu_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped,
real_q_len=q_reshaped,
scale=scale,
is_causal=False,
)
# block_sums: [batch, heads, q_cpu_blocks, k_cpu_blocks]
# Sum over Q dimension
cpu_block_scores = block_sums.sum(dim=2) # [batch, heads, num_cpu_blocks]
return cpu_block_scores
def select_blocks_by_score(
cpu_block_scores,
threshold=0.95,
always_include_first=True,
always_include_last=True,
):
"""
Select CPU blocks based on score + threshold.
⚠️ IMPORTANT: This replaces the original mask + majority voting strategy.
This change should be documented in the final implementation.
Args:
cpu_block_scores: [batch, heads, num_cpu_blocks]
threshold: Cumulative attention threshold (e.g., 0.95)
always_include_first: Always include first block (sink)
always_include_last: Always include last block (safety)
Returns:
selected_block_ids: List of selected block indices
density: Fraction of blocks selected
"""
# Average scores across heads
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
num_blocks = scores_per_block.shape[0]
# Normalize to get attention distribution
total_score = scores_per_block.sum()
score_ratio = scores_per_block / total_score
# Sort by score (descending)
sorted_indices = torch.argsort(score_ratio, descending=True)
# Select blocks until cumulative threshold is reached
cumsum = 0.0
selected = set()
for idx in sorted_indices.tolist():
selected.add(idx)
cumsum += score_ratio[idx].item()
if cumsum >= threshold:
break
# Always include first and last blocks
if always_include_first:
selected.add(0)
if always_include_last:
selected.add(num_blocks - 1)
selected_block_ids = sorted(list(selected))
density = len(selected_block_ids) / num_blocks
return selected_block_ids, density
# ============================================================
# Test Cases
# ============================================================
def test_equivalence():
"""
Test that hierarchical estimation produces equivalent scores to direct estimation.
"""
print("=" * 60)
print("Test 1: Hierarchical vs Direct - Equivalence")
print("=" * 60)
# Create random Q and multiple K blocks
q_len = CPU_BLOCK_SIZE # 4096
num_k_blocks = 4
# Q: [1, num_heads, q_len, head_dim]
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# K blocks: each [1, num_heads, cpu_block_size, head_dim]
K_blocks = [
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
for _ in range(num_k_blocks)
]
# Compute attention scores
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
print(f"attn_scores shape: {attn_scores.shape}")
# Method 1: Hierarchical (fast)
scores_hier, _ = hierarchical_block_sum(
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
print(f"scores_hier shape: {scores_hier.shape}")
# Method 2: Direct (slow)
scores_direct = direct_block_sum(
attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
print(f"scores_direct shape: {scores_direct.shape}")
# Compare
diff = (scores_hier - scores_direct).abs().max().item()
print(f"\nMax difference: {diff:.6f}")
# Per-block comparison
print("\nPer-block scores comparison:")
for i in range(num_k_blocks):
h_val = scores_hier[0, 0, i].item()
d_val = scores_direct[0, 0, i].item()
print(f" Block {i}: hierarchical={h_val:.4f}, direct={d_val:.4f}, diff={abs(h_val-d_val):.6f}")
passed = diff < 0.01
print(f"\nTest 1: {'PASSED' if passed else 'FAILED'}")
return passed
def test_selection():
"""
Test the score + threshold selection strategy.
"""
print("\n" + "=" * 60)
print("Test 2: Score + Threshold Selection")
print("=" * 60)
# Create Q and K blocks with varying importance
q_len = CPU_BLOCK_SIZE
num_k_blocks = 8
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# Create K blocks - make some more important than others
K_blocks = []
for i in range(num_k_blocks):
# First and middle blocks are more important (higher values)
importance = 2.0 if i in [0, 3, 4] else 1.0
K = torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
K = K * importance
K_blocks.append(K)
# Compute scores
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
scores, _ = hierarchical_block_sum(
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
# Print scores per block
print("\nCPU block scores (head 0):")
for i in range(num_k_blocks):
print(f" Block {i}: {scores[0, 0, i].item():.4f}")
# Select blocks with different thresholds
for thresh in [0.9, 0.95, 0.99]:
selected, density = select_blocks_by_score(scores, threshold=thresh)
print(f"\nThreshold {thresh}: selected {len(selected)}/{num_k_blocks} blocks ({density:.1%})")
print(f" Selected: {selected}")
print("\nTest 2: PASSED (visual inspection)")
return True
def test_performance():
"""
Benchmark hierarchical vs direct estimation performance.
"""
print("\n" + "=" * 60)
print("Test 3: Performance Benchmark")
print("=" * 60)
import time
NUM_WARMUP = 3
NUM_RUNS = 10
# Larger test case
q_len = CPU_BLOCK_SIZE
num_k_blocks = 16 # 64K context
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
K_blocks = [
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
for _ in range(num_k_blocks)
]
# Compute attention scores (shared)
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
print(f"attn_scores shape: {attn_scores.shape}")
print(f"Context: {num_k_blocks * CPU_BLOCK_SIZE // 1024}K tokens")
# Warmup and benchmark hierarchical
for _ in range(NUM_WARMUP):
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(NUM_RUNS):
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
hier_time = (time.perf_counter() - start) / NUM_RUNS * 1000
# Warmup and benchmark direct
for _ in range(NUM_WARMUP):
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(NUM_RUNS):
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
direct_time = (time.perf_counter() - start) / NUM_RUNS * 1000
speedup = direct_time / hier_time
print(f"\nResults:")
print(f" Hierarchical (bs=1024): {hier_time:.2f} ms")
print(f" Direct (bs=4096): {direct_time:.2f} ms")
print(f" Speedup: {speedup:.2f}x")
passed = speedup > 5.0 # Expect at least 5x speedup
print(f"\nTest 3: {'PASSED' if passed else 'FAILED'} (speedup > 5x expected)")
return passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Hierarchical Block Sum Estimation Test")
print("=" * 60)
print(f"\nConfiguration:")
print(f" NUM_HEADS: {NUM_HEADS}")
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
print(f" HEAD_DIM: {HEAD_DIM}")
print(f" STRIDE: {STRIDE}")
print(f" CPU_BLOCK_SIZE: {CPU_BLOCK_SIZE}")
print(f" ESTIMATE_BLOCK_SIZE: {ESTIMATE_BLOCK_SIZE}")
print(f" THRESHOLD: {THRESHOLD}")
print()
results = []
results.append(("Equivalence", test_equivalence()))
results.append(("Selection", test_selection()))
results.append(("Performance", test_performance()))
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for name, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" {name}: {status}")
all_passed = all(p for _, p in results)
print("=" * 60)
if all_passed:
print("test_hierarchical_estimate: ALL PASSED")
sys.exit(0)
else:
print("test_hierarchical_estimate: SOME FAILED")
sys.exit(1)

View File

@@ -1,254 +0,0 @@
"""
Needle-in-a-haystack test for LLM.
Tests: Long context retrieval capability with configurable sequence length.
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
enable_quest: bool = False,
enable_xattn_bsa: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
sparse_samples: int = 128,
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test.
Args:
model_path: Path to model
max_model_len: Maximum model context length
input_len: Target input sequence length
num_gpu_blocks: Number of GPU blocks for offload
block_size: KV cache block size
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_xattn_bsa: Enable XAttention BSA sparse attention (prefill-only)
sparse_topk: Top-K blocks for Quest
sparse_threshold: Threshold for sparse selection (Quest/XAttention BSA)
sparse_samples: Samples per chunk for XAttention BSA estimation
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
# Determine sparse policy
if enable_xattn_bsa:
sparse_policy = SparsePolicyType.XATTN_BSA
elif enable_quest:
sparse_policy = SparsePolicyType.QUEST
else:
sparse_policy = SparsePolicyType.FULL
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload:
print(f"Sparse policy: {sparse_policy.name}")
if sparse_policy == SparsePolicyType.QUEST:
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
elif sparse_policy == SparsePolicyType.XATTN_BSA:
print(f" XAttention BSA: threshold={sparse_threshold}, samples={sparse_samples}")
print(f"{'='*60}\n")
# 1. Initialize LLM
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy
if sparse_policy == SparsePolicyType.QUEST:
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
elif sparse_policy == SparsePolicyType.XATTN_BSA:
llm_kwargs["sparse_threshold"] = float(sparse_threshold) / 10.0 # Convert to 0.0-1.0 range
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt
prompt, expected = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Generate output
sampling_params = SamplingParams(
temperature=0.6, # Moderate temperature
max_tokens=max_new_tokens,
)
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
# 4. Check result
output_text = outputs[0]["text"]
output_token_ids = outputs[0]["token_ids"]
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=128 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload (has known bug for long sequences)"
)
parser.add_argument(
"--enable-quest",
action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)"
)
parser.add_argument(
"--enable-xattn-bsa",
action="store_true",
help="Enable XAttention BSA sparse attention (prefill-only)"
)
parser.add_argument(
"--sparse-topk",
type=int,
default=8,
help="Top-K blocks for Quest sparse attention"
)
parser.add_argument(
"--sparse-threshold",
type=int,
default=4,
help="Apply sparse only when blocks > threshold (Quest) or attention threshold 0-9 (XAttention BSA)"
)
parser.add_argument(
"--sparse-samples",
type=int,
default=128,
help="Samples per chunk for XAttention BSA estimation"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
enable_xattn_bsa=args.enable_xattn_bsa,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
sparse_samples=args.sparse_samples,
verbose=True,
)
if passed:
print("test_needle: PASSED")
else:
print("test_needle: FAILED")
exit(1)

View File

@@ -1,176 +0,0 @@
"""
Needle-in-a-haystack reference test using pure torch + transformers.
This is a reference implementation for comparison with nanovllm.
Uses standard HuggingFace inference (no custom KV cache, no offload).
"""
import os
import argparse
import torch
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
input_len: int,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
dtype: str = "auto",
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test using standard transformers inference.
Args:
model_path: Path to model
input_len: Target input sequence length
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
dtype: Model dtype ("auto", "float16", "bfloat16")
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Reference Test (torch + transformers)")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Input length: {input_len}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"Dtype: {dtype}")
print(f"{'='*60}\n")
# 1. Load tokenizer
print("[1/4] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# 2. Generate needle prompt
print("[2/4] Generating needle prompt...")
prompt, expected = generate_needle_prompt(
tokenizer=tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Load model
print("[3/4] Loading model...")
torch_dtype = {
"auto": torch.float16, # default to float16 for custom model
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}.get(dtype, torch.float16)
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# 4. Generate output
print("[4/4] Running inference...")
device = next(model.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
print(f" Input shape: {input_ids.shape}")
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=0.6,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
# Decode only the new tokens
new_token_ids = output_ids[0, input_ids.shape[1]:]
output_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
# 5. Check result
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(new_token_ids)}): {new_token_ids[:20].tolist()}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Needle-in-haystack reference test (torch + transformers)"
)
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "float16", "bfloat16"],
help="Model dtype"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
input_len=args.input_len,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
dtype=args.dtype,
verbose=True,
)
if passed:
print("test_needle_ref: PASSED")
else:
print("test_needle_ref: FAILED")
exit(1)

View File

@@ -1,136 +0,0 @@
"""
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
Demonstrates the key limitation: scores are AVERAGED across heads,
so blocks strongly needed by one head but not others may be dropped.
This is the expected Quest behavior - not a bug.
"""
import torch
from nanovllm.kvcache.sparse import (
create_sparse_policy,
SparsePolicyType,
PolicyContext,
)
# ============================================================
# Test: Per-Head Score Averaging in GQA
# ============================================================
# Determine device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running test on device: {device}")
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
# topk=2 to make selection competitive
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
quest.initialize(
num_layers=1,
num_kv_heads=2,
head_dim=4,
num_cpu_blocks=6,
dtype=torch.float32,
device=device, # Metadata stored on GPU
)
metadata = quest.metadata
def set_key(block_id, head_id, values):
"""Set both key_min and key_max to same values for deterministic scoring."""
# Values need to be on the same device as metadata
tensor = torch.tensor(values, device=device)
metadata.key_min[block_id, 0, head_id, :] = tensor
metadata.key_max[block_id, 0, head_id, :] = tensor
# ============================================================
# Design: Different heads want different blocks
# ============================================================
#
# Query = [1,1,1,1] for all heads, so score = sum(key values)
#
# Block | Head 0 | Head 1 | Average | Result
# ------|--------|--------|---------|--------
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 2: Both heads want equally (highest average)
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 3: Both heads want moderately
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
# Block 4: Head 0 strongly wants, Head 1 neutral
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
# Block 5: Head 1 strongly wants, Head 0 neutral
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# ============================================================
# Run selection
# ============================================================
# Query on same device as metadata
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=0,
query=query,
is_prefill=False,
block_size=1024,
total_kv_len=6144,
)
available = list(range(6))
selected = quest.select_blocks(available, ctx)
# ============================================================
# Verify: Averaging behavior
# ============================================================
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
# Key insight: blocks 0 and 1 have score +4 for ONE head,
# but they cancel out due to averaging with the other head's -4
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
# But +2 < +3 (block 3), so they don't make the top-2
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
# Verify metadata is on correct device
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
print(f"✓ Metadata stored on {device.type.upper()}")
print("\ntest_quest_policy: PASSED")

View File

@@ -386,8 +386,11 @@ def run_ruler_benchmark(
if sparse_policy and sparse_policy.upper() == "XATTN_BSA": if sparse_policy and sparse_policy.upper() == "XATTN_BSA":
DensityObserver.enable() DensityObserver.enable()
DensityObserver.complete_reset() DensityObserver.complete_reset()
# Set mode for correct density interpretation
DensityObserver.set_mode("offload" if enable_cpu_offload else "gpu_only")
if not json_output: if not json_output:
print("[DensityObserver] Enabled for XAttention BSA") mode_str = "offload" if enable_cpu_offload else "gpu_only"
print(f"[DensityObserver] Enabled for XAttention BSA (mode: {mode_str})")
# LLM initialization kwargs # LLM initialization kwargs
llm_kwargs = { llm_kwargs = {

View File

@@ -1,199 +0,0 @@
"""
Sequential inference test for LLM.
Tests: After completing one prompt, the system can correctly handle
a second prompt with a clean state (first prompt's KV cache deallocated).
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from utils import generate_needle_prompt, check_needle_answer
def run_sequential_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
enable_cpu_offload: bool = False,
verbose: bool = True,
) -> bool:
"""
Run sequential inference test with two different prompts.
Each prompt has a different needle value. Both must be retrieved correctly.
"""
if verbose:
print(f"\n{'='*60}")
print(f"Sequential Inference Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"CPU offload: {enable_cpu_offload}")
print(f"{'='*60}\n")
# Initialize LLM once
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm = LLM(model_path, **llm_kwargs)
sampling_params = SamplingParams(
temperature=0.6,
max_tokens=32,
)
# ============================================================
# Test 1: First prompt with needle value "1234"
# ============================================================
needle_value_1 = "1234"
if verbose:
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
prompt_1, expected_1 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_1,
)
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
output_text_1 = outputs_1[0]["text"]
passed_1 = check_needle_answer(output_text_1, expected_1)
if verbose:
print(f" Expected: {expected_1}")
print(f" Output: {output_text_1[:100]}...")
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
# ============================================================
# Test 2: Second prompt with needle value "5678"
# ============================================================
needle_value_2 = "5678"
if verbose:
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
prompt_2, expected_2 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_2,
)
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
output_text_2 = outputs_2[0]["text"]
passed_2 = check_needle_answer(output_text_2, expected_2)
if verbose:
print(f" Expected: {expected_2}")
print(f" Output: {output_text_2[:100]}...")
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
# ============================================================
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
# ============================================================
needle_value_3 = "9999"
if verbose:
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
prompt_3, expected_3 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_3,
)
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
output_text_3 = outputs_3[0]["text"]
passed_3 = check_needle_answer(output_text_3, expected_3)
if verbose:
print(f" Expected: {expected_3}")
print(f" Output: {output_text_3[:100]}...")
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
# ============================================================
# Summary
# ============================================================
all_passed = passed_1 and passed_2 and passed_3
if verbose:
print(f"\n{'='*60}")
print(f"Summary")
print(f"{'='*60}")
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
print(f"{'='*60}\n")
return all_passed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential inference test")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=36 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload"
)
args = parser.parse_args()
passed = run_sequential_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
enable_cpu_offload=args.enable_offload,
verbose=True,
)
if passed:
print("test_sequential: PASSED")
else:
print("test_sequential: FAILED")
exit(1)

View File

@@ -1,334 +0,0 @@
"""
Test XAttention + BSA with RULER benchmark data.
Tests XAttention sparse attention correctness using RULER NIAH task.
Attention methods:
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
- Decode: FlashAttention (always, since q_len=1)
Usage (in compass conda env with BSA available):
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
# Test with XAttention + BSA for prefill (default)
python tests/test_xattn_bsa.py --prefill-method xattn
# Test with FlashAttention for prefill (baseline)
python tests/test_xattn_bsa.py --prefill-method flash
# Test specific sample(s)
python tests/test_xattn_bsa.py --sample-id 0
python tests/test_xattn_bsa.py --sample-ids 0,1,2
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
and new `past_key_values` API).
"""
import argparse
import json
import sys
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from nanovllm.ops.xattn import xattn_estimate
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# ============================================================
# XAttention + BSA Functions
# ============================================================
def expand_kv_for_gqa(key_states, value_states, num_heads):
"""Expand KV for Grouped Query Attention."""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
"""Standard FlashAttention."""
from flash_attn import flash_attn_func
q = query_states.transpose(1, 2)
k = key_states.transpose(1, 2)
v = value_states.transpose(1, 2)
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
"""XAttention + BSA sparse attention."""
from block_sparse_attn import block_sparse_attn_func
batch_size, num_heads, q_len, head_dim = query_states.shape
k_len = key_states.shape[2]
_, mask = xattn_estimate(
query_states, key_states,
chunk_size=16384, block_size=128, threshold=threshold,
use_triton=True, causal=True,
)
q_block_num = (q_len + 127) // 128
k_block_num = (k_len + 127) // 128
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
__import__('pdb').set_trace()
output = block_sparse_attn_func(
q, k, v,
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
torch.ones(num_heads, dtype=torch.int32, device=q.device),
None,
mask[:, :, :q_block_num, :k_block_num].contiguous(),
q_len, k_len,
p_dropout=0.0, deterministic=True, is_causal=True,
)
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
DEBUG = False # Set to True to enable debugging
def create_patched_forward(prefill_method="xattn", threshold=0.9):
"""Create patched forward with configurable prefill method.
Args:
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
Note:
- Prefill (q_len > 1): Uses specified prefill_method
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
"""
call_count = [0] # Mutable to track calls across layers
def patched_forward(
self,
hidden_states,
position_embeddings=None,
attention_mask=None,
past_key_value=None, # Old API (transformers < 4.57)
past_key_values=None, # New API (transformers >= 4.57)
cache_position=None,
**kwargs
):
# Handle both old and new transformers API
kv_cache = past_key_values if past_key_values is not None else past_key_value
bsz, q_len, _ = hidden_states.size()
num_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
# Compute Q, K, V projections
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
# Apply rotary position embedding
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Handle KV cache
if kv_cache is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = kv_cache.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# Expand KV for GQA
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
# Debug output
if DEBUG and self.layer_idx == 0:
call_count[0] += 1
if call_count[0] <= 5:
phase = "prefill" if q_len > 1 else "decode"
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
print(f" kv_cache is None: {kv_cache is None}")
# Choose attention method:
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
# - Decode (q_len = 1): Always use FlashAttention
is_prefill = q_len > 1
if is_prefill and prefill_method == "xattn":
# Prefill with XAttention + BSA (sparse)
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
else:
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
# Note: For decode (q_len=1), causal=False since single query attends to all KV
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
return attn_output, None
return patched_forward
# ============================================================
# Data & Evaluation
# ============================================================
def load_samples(filepath, indices=None):
"""Load samples from JSONL file."""
samples = []
with open(filepath) as f:
for i, line in enumerate(f):
if indices is None or i in indices:
sample = json.loads(line)
sample["_idx"] = i
samples.append(sample)
return samples
def string_match_all(output_text, expected_list):
"""RULER metric: fraction of expected values found in output."""
output_lower = output_text.lower().replace('\n', ' ')
if not expected_list:
return 1.0
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
# ============================================================
# Test
# ============================================================
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
"""Test attention methods using RULER data.
Args:
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
"""
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
print("=" * 60)
print("RULER NIAH Attention Test")
print("=" * 60)
print(f"Data: {data_file}")
print(f"Samples: {sample_ids}")
print(f"Prefill method: {prefill_desc}")
print(f"Decode method: FlashAttention (always)")
if prefill_method == "xattn":
print(f"XAttention threshold: {threshold}")
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
if not samples:
print("No samples found!")
return False
print(f"Loaded {len(samples)} samples")
# Load model
print(f"\nLoading model: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="cuda",
attn_implementation="eager", # Will be patched
)
model.eval()
# Patch all layers
print(f"Patching attention layers...")
print(f" - Prefill: {prefill_desc}")
print(f" - Decode: FlashAttention")
for idx, layer in enumerate(model.model.layers):
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
layer.self_attn, type(layer.self_attn)
)
total_score = 0.0
results = []
for sample in samples:
idx = sample["_idx"]
prompt = sample["input"]
expected = sample["outputs"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
num_tokens = inputs["input_ids"].shape[1]
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
print(f"Expected: {expected}")
with torch.no_grad():
output = model.generate(
inputs["input_ids"],
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
score = string_match_all(output_text, expected)
total_score += score
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
print(f"Output: '{output_text[:100]}...'")
print(f"Result: {status} (score={score:.2f})")
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
avg_score = total_score / len(samples)
passed = sum(1 for r in results if r["passed"])
print(f"\n{'='*60}")
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
print(f"{'='*60}")
return avg_score >= 0.5
def main():
parser = argparse.ArgumentParser(
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
)
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
parser.add_argument("--max-new-tokens", type=int, default=50)
# Keep old option for backwards compatibility
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
args = parser.parse_args()
model_path = args.model.replace("~", "/home/zijie")
# Handle deprecated --no-xattn option
prefill_method = args.prefill_method
if args.no_xattn:
prefill_method = "flash"
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
if args.sample_id is not None:
sample_ids = [args.sample_id]
elif args.sample_ids:
sample_ids = [int(x) for x in args.sample_ids.split(",")]
else:
sample_ids = [0]
# Check BSA availability if using xattn
if prefill_method == "xattn":
try:
from block_sparse_attn import block_sparse_attn_func
print("✓ BSA (Block Sparse Attention) available")
except ImportError:
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
sys.exit(1)
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
print("\ntest_xattn_bsa: PASSED")
else:
print("\ntest_xattn_bsa: FAILED")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,259 +0,0 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
Uses real QKV data captured from model inference.
"""
import sys
import os
import torch
import warnings
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096
# Default QKV data directory (relative to project root)
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
# ============================================================
# Utility Functions
# ============================================================
def load_qkv(path):
"""Load saved QKV data."""
data = torch.load(path, map_location="cpu", weights_only=False)
print(f"Loaded: {path}")
print(f" Query shape: {data['query'].shape}")
print(f" Key shape: {data['key'].shape}")
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
return data
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return xattn_estimate_chunked(
query, key,
q_start_pos=q_start_pos,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
k_end = q_start_pos + q_chunk_end
k_chunk = key[:, :, :k_end, :]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_start_pos + q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_qkv(qkv_path):
"""Test a single QKV file."""
data = load_qkv(qkv_path)
query = data["query"].cuda().to(torch.bfloat16)
key = data["key"].cuda().to(torch.bfloat16)
seq_len = query.shape[2]
print(f"\nTesting with seq_len={seq_len}")
print("=" * 60)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
)
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
q_start_pos=0,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
help="Directory containing QKV files")
args = parser.parse_args()
# QKV files to test
qkv_files = [
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
]
available_files = [p for p in qkv_files if os.path.exists(p)]
if not available_files:
print(f"No QKV file found in {args.qkv_dir}.")
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
sys.exit(1)
print(f"Found {len(available_files)} QKV files to test")
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
print(f"Using Triton kernels")
all_passed = True
results = []
for qkv_path in available_files:
passed = test_single_qkv(qkv_path)
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
results.append((seq_len, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, passed in results:
status = "PASSED" if passed else "FAILED"
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
print("=" * 60)
if all_passed:
print("test_xattn_chunked: PASSED")
sys.exit(0)
else:
print("test_xattn_chunked: FAILED")
sys.exit(1)

View File

@@ -1,244 +0,0 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask
as standard estimation. This ensures the chunked version can be used in
chunked prefill scenarios without accuracy loss.
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
"""
import sys
import traceback
import torch
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
# Configuration for xattn_estimate_chunked consistency test.
# Key requirements for 100% match:
# 1. Use matching chunk_size for both standard and chunked versions
# 2. Use same random seed for reproducibility
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
# floating point precision in cumulative sum calculations.
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096 # External chunking size
# Test sequence lengths
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
# ============================================================
# Utility Functions
# ============================================================
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
return xattn_estimate_chunked(
query, key,
q_start_pos=0,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
# q_start_pos=0 means Q starts at position 0 in the full sequence
# K is [0, q_chunk_end) for causal attention
k_end = q_chunk_end
k_chunk = key[:, :, :k_end, :]
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
"""Test a single sequence length."""
print(f"\nTesting seq_len={seq_len}")
print("=" * 60)
# Generate random Q/K
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
causal=True,
)
density_std = mask_std.float().mean().item()
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
density_chunked = mask_chunked.float().mean().item()
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
print("XAttention Chunked vs Standard Test")
print("=" * 60)
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
print(f"External chunk_size={CHUNK_SIZE}")
print()
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available!")
sys.exit(1)
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print("✓ xattn_estimate imported")
print("✓ xattn_estimate_chunked imported")
# Run tests
all_passed = True
results = []
for seq_len in TEST_SEQ_LENS:
passed = test_single_seq_len(seq_len)
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
results.append((seq_len, chunks, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, chunks, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
print("=" * 60)
if all_passed:
print("ALL TESTS PASSED!")
sys.exit(0)
else:
print("SOME TESTS FAILED!")
sys.exit(1)

View File

@@ -1,132 +0,0 @@
"""
Test: XAttention Triton kernels
演示 XAttention 的两个核心 Triton kernel:
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
数据流:
Q [batch, heads, q_len, head_dim]
K [batch, heads, kv_len, head_dim]
↓ flat_group_gemm_fuse_reshape
attn_scores [batch, heads, q_len/stride, kv_len/stride]
↓ softmax_fuse_block_sum
block_sums [batch, heads, q_blocks, k_blocks]
"""
import torch
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# 参数配置
# ============================================================
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
q_len = 512
kv_len = 2048
head_dim = 128
stride = 4
block_size = 128 # softmax block size (in reshaped space)
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
# ============================================================
# 构造输入: 偶数位置=1, 奇数位置=2
# ============================================================
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(q_len):
if i % 2 == 0:
Q[0, 0, i, :] = 1 * (i // stride + 1)
else:
Q[0, 0, i, :] = 2 * (i // stride + 1)
for i in range(kv_len):
if i % 2 == 0:
K[0, 0, i, :] = 1
else:
K[0, 0, i, :] = 2
# ============================================================
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
# ============================================================
q_reshaped_len = q_len // stride # 128
kv_reshaped_len = kv_len // stride # 512
# 将 K 沿着长度维度分成多个 chunk
k_chunk_size = 512 # 每个 chunk 512 tokens
num_k_chunks = kv_len // k_chunk_size # 4 chunks
attn_scores_list = []
for k_chunk_idx in range(num_k_chunks):
k_start = k_chunk_idx * k_chunk_size
k_end = k_start + k_chunk_size
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=True
)
__import__('pdb').set_trace()
attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks 的结果
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
attn_scores = torch.cat(attn_scores_list, dim=-1)
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
# 验证: 反对角线求和
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
# 反对角线有 stride/2 对,再乘以 head_dim
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
actual_gemm = attn_scores[0, 0, 0, 0].item()
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
# ============================================================
# Step 2: softmax_fuse_block_sum
# ============================================================
scale = 1.4426950408889634 # log2(e) for exp2
block_sums = softmax_fuse_block_sum(
attn_scores,
block_size,
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False
)
# 验证 shape: [batch, heads, q_blocks, k_blocks]
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
# 验证: 每个 block 的 softmax 结果求和
# 所有 attn_scores 相同 → softmax 均匀分布
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
# 每个 Q block 有 block_size 行
# block_sum = block_size * (block_size / kv_reshaped_len)
expected_sum = block_size * block_size / kv_reshaped_len
actual_sum = block_sums[0, 0, 0, 0].item()
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
print("test_xattn_kernels: PASSED")

View File

@@ -1,246 +0,0 @@
"""
Test: 批量验证 xattn_estimate 与 KV chunking kernels 的一致性
测试 results/kvcache 下所有保存的 QKV 数据
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_kv_chunking_batch.py
"""
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
import os
import glob
import torch
import math
from nanovllm.ops.xattn import (
xattn_estimate,
flat_group_gemm_fuse_reshape,
softmax_compute_partial_stats,
softmax_normalize_and_block_sum,
merge_softmax_stats,
find_blocks_chunked,
)
# ============================================================
# 参数配置
# ============================================================
DATA_DIR = "/home/zijie/Code/nano-vllm/results/kvcache"
BSA_BLOCK_SIZE = 128
CHUNK_SIZE = 16384
device = "cuda"
def test_single_file(data_file: str) -> dict:
"""测试单个 kvcache 文件"""
data = torch.load(data_file, map_location="cpu")
Q = data["query"].to(device)
K = data["key"].to(device)
batch_size, num_heads, seq_len, head_dim = Q.shape
STRIDE = data["stride"]
THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"]
# ========== xattn_estimate API ==========
attn_sums_api, mask_api = xattn_estimate(
Q, K,
block_size=BSA_BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
causal=True,
)
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
total_api = causal_mask.sum().item() * batch_size * num_heads
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_api = selected_api / total_api
# ========== 三阶段 KV Chunking ==========
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
reshaped_chunk_size = CHUNK_SIZE // STRIDE
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
if k_num_to_pad > 0:
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
else:
K_padded = K
if q_num_to_pad > 0:
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
else:
Q_padded = Q
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
simple_mask_list = []
for q_chunk_idx in range(q_chunk_num):
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
q_end = q_start + reshaped_chunk_size * STRIDE
Q_chunk = Q_padded[:, :, q_start:q_end, :]
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
chunk_end = chunk_start + reshaped_chunk_size
m_chunks = []
l_chunks = []
attn_weights_chunks = []
for kv_chunk_idx in range(kv_chunk_num):
kv_start = kv_chunk_idx * CHUNK_SIZE
kv_end = kv_start + CHUNK_SIZE
K_chunk = K_padded[:, :, kv_start:kv_end, :]
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
attn_weights_kv = flat_group_gemm_fuse_reshape(
Q_chunk, K_chunk, STRIDE,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False,
)
attn_weights_chunks.append(attn_weights_kv)
m_partial, l_partial = softmax_compute_partial_stats(
attn_weights_kv,
reshaped_block_size,
min(4096, reshaped_block_size),
scale,
chunk_start=chunk_start,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
m_chunks.append(m_partial)
l_chunks.append(l_partial)
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
attn_sum_per_kv = []
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
attn_sum_kv = softmax_normalize_and_block_sum(
attn_weights_kv,
m_global,
l_global,
reshaped_block_size,
min(4096, reshaped_block_size),
chunk_start=chunk_start,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
attn_sum_per_kv.append(attn_sum_kv)
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
simple_mask = find_blocks_chunked(
attn_sum_concat,
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
threshold=THRESHOLD,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True,
)
simple_mask_list.append(simple_mask)
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
False,
)
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_kv = selected_kv / total_api
mask_total = mask_api_valid.numel()
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
mask_diff_pct = 100 * mask_diff / mask_total
return {
"seq_len": seq_len,
"stride": STRIDE,
"threshold": THRESHOLD,
"kv_chunks": kv_chunk_num,
"density_api": density_api,
"density_kv": density_kv,
"density_diff": abs(density_api - density_kv),
"mask_diff_pct": mask_diff_pct,
"passed": abs(density_api - density_kv) < 1e-6 and mask_diff_pct < 0.01,
}
def main():
files = sorted(glob.glob(os.path.join(DATA_DIR, "qkv_*.pt")))
print("=" * 80)
print("XAttention KV Chunking Alignment Test")
print("=" * 80)
print()
results = []
for f in files:
fname = os.path.basename(f)
print(f"Testing {fname}...", end=" ", flush=True)
try:
r = test_single_file(f)
results.append(r)
status = "✓ PASS" if r["passed"] else "✗ FAIL"
print(f"{status} (seq_len={r['seq_len']}, kv_chunks={r['kv_chunks']})")
except Exception as e:
print(f"✗ ERROR: {e}")
results.append({"file": fname, "error": str(e)})
print()
print("=" * 80)
print("Results Summary")
print("=" * 80)
print()
print("| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |")
print("|---------|--------|-----------|-----------|-------------|------------|------|-----------|--------|")
all_passed = True
for r in results:
if "error" in r:
print(f"| ERROR | - | - | - | - | - | - | - | {r['error'][:20]} |")
all_passed = False
else:
status = "PASS" if r["passed"] else "FAIL"
if not r["passed"]:
all_passed = False
print(f"| {r['seq_len']:>7} | {r['stride']:>6} | {r['threshold']:.2f} | {r['kv_chunks']:>9} | "
f"{r['density_api']:.6f} | {r['density_kv']:.6f} | {r['density_diff']:.6f} | "
f"{r['mask_diff_pct']:.4f}% | {status} |")
print()
if all_passed:
print("test_xattn_kv_chunking_batch: ALL PASSED")
else:
print("test_xattn_kv_chunking_batch: SOME FAILED")
if __name__ == "__main__":
main()