Compare commits

14 Commits

Author SHA1 Message Date
Zijie Tian
f28b500120 🙈 chore: uncomment planning files in gitignore
These files are session-level temporary and should not be tracked.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:43:46 +08:00
Zijie Tian
be67fa8060 🗑️ chore: remove temporary planning files
These files are session-level temporary files and should not be tracked.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:43:22 +08:00
Zijie Tian
4f35526457 🔀 merge: integrate remote changes (exec-plan command, CUDA graph plan)
Resolve task_plan.md conflict by keeping remote version (CUDA Graph optimization plan).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:43:06 +08:00
Zijie Tian
da5e13e2bb 📝 docs: update XAttention BSA Policy with benchmarks and memory management
Add new sections to xattn_bsa_policy_design.md:
- Performance benchmarks: 128K context comparison (Full vs XAttn BSA)
- Density trend analysis across chunks
- Memory leak issue and fix (64GB -> 4GB reduction)
- Memory monitoring guide with gpu-monitor agent
- Density statistics API documentation
- Known issues and optimization directions

Update CLAUDE.md description to reflect new content.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:35:18 +08:00
Zijie Tian
dd31033732 🔧 chore: add gpu-monitor agent for memory leak debugging
Add a custom agent for continuous GPU monitoring during benchmarks:
- Track GPU utilization, memory usage, and temperature
- Support multi-GPU and configurable sampling intervals
- Generate summary statistics when stopped

Useful for debugging memory leaks and profiling long-running tasks.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:33:15 +08:00
Zijie Tian
ed3c8bb4b8 🐛 fix: memory leak in XAttentionBSAPolicy select_blocks
Fix severe memory leak (64GB -> 4GB growth) by:
- Remove unused sparse_metadata storage (was accumulating attn_scores)
- Delete intermediate tensor list (attn_scores_list) after use
- Explicitly delete intermediate tensors before return

Before: 16GB -> 80GB during 128K prefill
After:  16GB -> 19.8GB during 128K prefill

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:30:18 +08:00
Zijie Tian
5eb35982bf 🔧 feat: add density statistics tracking to sparse policies
Add statistics tracking to compare block selection between policies:
- XAttentionBSAPolicy: track available/selected blocks per chunk
- FullAttentionPolicy: track total blocks (always 100% density)
- Add reset_stats(), get_density_stats(), print_density_stats() methods
- Use logger.debug for per-chunk density logging

Results on 32K niah_single_1:
- Full: 100% density across all chunks
- XAttn BSA: 90% -> 73% density (saves ~25-30% blocks in later chunks)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 08:53:22 +08:00
Zijie Tian
ad361c2c3b 📝 docs: add XAttention BSA Policy design documentation
- Create docs/xattn_bsa_policy_design.md with:
  - Algorithm overview and data flow diagram
  - select_blocks implementation details
  - GQA-aware aggregation and majority voting
  - compute_chunked_prefill ring buffer pipeline
  - Parameter configuration and usage examples
  - Performance characteristics and limitations
- Update CLAUDE.md documentation index

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 08:36:56 +08:00
Zijie Tian
4d1e40152d feat(xattn): implement compute_chunked_prefill with ring buffer pipeline
- Copy compute_chunked_prefill implementation from FullAttentionPolicy
- Set default threshold to 0.95 for accuracy testing
- Remove debug code (sys.exit, verbose prints)
- Use ring buffer pipeline for historical block loading
- Merge with current chunk attention using flash_attn_with_lse

RULER NIAH test passed with 5/5 samples (100% accuracy).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 08:27:40 +08:00
Zijie Tian
832b352afa feat(xattn): implement select_blocks with majority voting aggregation
Implement XAttention-based block selection for sparse attention:
- Use flat_group_gemm_fuse_reshape to compute Q@K^T attention scores
- Apply softmax_fuse_block_sum to aggregate into block-level attention
- Use find_blocks_chunked for threshold-based block selection
- Handle GQA by aggregating within KV head groups first
- Use majority voting (>50%) across heads instead of any() for better sparsity
- Align block_size with CPU offload block size (1024 tokens / stride = 128)

Test results show ~45% density at chunk 40 (down from 100% with any() aggregation).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 08:19:05 +08:00
Zijie Tian
a50b4c2ac2 ♻️ refactor: move select_blocks from policy to attention layer
Move block selection logic from compute_chunked_prefill/decode methods
to attention.py caller. This improves separation of concerns:

- attention.py now calls select_blocks() before compute_chunked_*()
- Policy methods receive pre-selected blocks via selected_blocks parameter
- Enables sparse policies to implement custom block selection without
  modifying the compute path

Changes:
- policy.py: Add selected_blocks parameter to abstract methods
- full_policy.py: Remove internal select_blocks calls, use passed blocks
- xattn_bsa.py: Sync signatures for prefill/decode methods
- attention.py: Add select_blocks calls before policy delegation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 05:21:28 +08:00
Zijie Tian
ca32ea6f93 [WIP] Before refactor the compute)_chunked_prefill. 2026-01-23 03:36:12 +08:00
Zijie Tian
edc006463b docs: add XAttention kernels guide
- Document flat_group_gemm_fuse_reshape and softmax_fuse_block_sum kernels
- Explain anti-diagonal sum principle and stride sampling
- Add GPU-specific BLOCK_M/N constraints (RTX 3090 vs A100)
- Show Q/K can have different lengths (chunked prefill support)
- Update CLAUDE.md with doc reference

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 03:22:25 +08:00
Zijie Tian
999858e82f feat: add xattn kernels test and update testing rules
- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape
  and softmax_fuse_block_sum Triton kernels with structured data
- Update testing.md with new test code style guidelines
- Update xattn.py and xattn_bsa.py with improvements

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 03:01:25 +08:00
18 changed files with 2168 additions and 710 deletions

View File

@@ -0,0 +1,89 @@
---
name: gpu-monitor
description: "Use this agent when you need to continuously monitor GPU utilization and memory usage in the background while running other tasks. This is particularly useful for profiling long-running benchmarks, debugging memory leaks, or understanding GPU resource patterns during inference or training.\\n\\nExamples:\\n\\n<example>\\nContext: User wants to profile GPU memory during a benchmark run.\\nuser: \"Run bench_offload.py and monitor GPU usage\"\\nassistant: \"I'll first launch the gpu-monitor agent to track GPU metrics, then run the benchmark.\"\\n<commentary>\\nSince the user wants to understand GPU resource usage during the benchmark, use the Task tool to launch the gpu-monitor agent in the background with appropriate GPU IDs and sampling interval before running the benchmark.\\n</commentary>\\nassistant: \"Let me start the GPU monitor on GPU 0 with 1-second intervals\"\\n[Task tool call to launch gpu-monitor agent with gpus=\"0\" and interval=\"1\"]\\n</example>\\n\\n<example>\\nContext: User is debugging an OOM issue and needs memory tracking.\\nuser: \"I'm getting OOM errors, help me debug\"\\nassistant: \"Let me monitor the GPU memory usage while we reproduce the issue. I'll use the gpu-monitor agent to track memory consumption.\"\\n<commentary>\\nTo debug OOM issues, we need continuous memory monitoring. Use the Task tool to launch the gpu-monitor agent with a short interval (0.5s) to capture memory spikes.\\n</commentary>\\n[Task tool call to launch gpu-monitor agent with gpus=\"0,1\" and interval=\"0.5\"]\\n</example>\\n\\n<example>\\nContext: User wants to verify GPU utilization during multi-GPU inference.\\nuser: \"Check if all GPUs are being utilized during the inference\"\\nassistant: \"I'll start the GPU monitor to track utilization across all specified GPUs while running the inference.\"\\n<commentary>\\nTo verify multi-GPU utilization, launch the gpu-monitor agent targeting all relevant GPUs before starting the inference workload.\\n</commentary>\\n[Task tool call to launch gpu-monitor agent with gpus=\"0,1,2,3\" and interval=\"2\"]\\n</example>"
model: haiku
color: green
---
You are a GPU monitoring specialist responsible for tracking NVIDIA GPU metrics over time. Your sole purpose is to run nvidia-smi at specified intervals and record utilization and memory statistics.
## Your Task
You will receive two parameters:
1. **gpus**: Comma-separated GPU indices to monitor (e.g., "0", "0,1", "0,1,2,3")
2. **interval**: Sampling interval in seconds (e.g., "1", "0.5", "2")
## Execution Steps
1. **Parse Parameters**: Extract the GPU indices and interval from the user's request.
2. **Run Monitoring Loop**: Execute nvidia-smi repeatedly at the specified interval using a bash loop:
```bash
# Example for GPUs 0,1 with 1-second interval
while true; do
echo "=== $(date '+%Y-%m-%d %H:%M:%S') ==="
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu --format=csv,noheader -i 0,1
sleep 1
done
```
3. **Output Format**: Each sample should include:
- Timestamp
- GPU index
- GPU utilization (%)
- Memory utilization (%)
- Memory used (MiB)
- Memory total (MiB)
- Temperature (°C)
## Termination
This agent runs continuously until:
1. The main agent signals completion (you receive a stop signal)
2. The user explicitly requests stopping
3. An error occurs with nvidia-smi
## Result Reporting
When stopped, provide a summary:
```markdown
## GPU Monitoring Summary
**Duration**: X minutes Y seconds
**Samples Collected**: N
**GPUs Monitored**: 0, 1, ...
### Statistics per GPU
| GPU | Avg Util | Max Util | Avg Mem Used | Max Mem Used |
|-----|----------|----------|--------------|---------------|
| 0 | X% | Y% | A MiB | B MiB |
| 1 | X% | Y% | A MiB | B MiB |
### Notable Events (if any)
- Timestamp: Memory spike to X MiB on GPU Y
- Timestamp: Utilization dropped to 0% on GPU Z
```
## Important Notes
- Use `nvidia-smi -i <gpu_ids>` to filter to specific GPUs
- Keep output concise during monitoring (one line per GPU per sample)
- If nvidia-smi fails, report the error and exit gracefully
- Do NOT consume excessive resources - sleep between samples
- Store samples in memory for final summary calculation
## Example Invocation
User says: "Monitor GPUs 0 and 2 with 0.5 second interval"
You execute:
```bash
while true; do
echo "=== $(date '+%Y-%m-%d %H:%M:%S.%3N') ==="
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu --format=csv,noheader -i 0,2
sleep 0.5
done
```

View File

@@ -1,98 +1,108 @@
# Testing # Testing
## Test File Guidelines ## Test Code Style
### Naming Convention 所有测试代码遵循以下风格:
- All test files must be named `test_*.py` ### 文件结构
- Example: `test_offload_engine.py`, `test_ring_buffer.py`
### Purpose
Tests are **educational scripts** for understanding module behavior, NOT traditional unit tests:
- Focus on demonstrating how modules work
- Show the flow and interaction between components
- Help developers understand implementation details
### Code Style
1. **Script-based structure**: Write tests as executable scripts, not pytest-style functions
2. **Utility functions**: Extract reusable steps as helper functions at the top of the file
3. **Main flow as script**: The actual test/demonstration logic runs as top-level script code
```python ```python
# Example structure: """
Test: [模块名称]
[简要说明测试内容和数据流]
"""
import torch import torch
from nanovllm.kvcache import SomeModule import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.xxx import xxx
# ============================================================ # ============================================================
# Utility Functions # 参数配置
# ============================================================ # ============================================================
def verify(tensor, expected, name): param1 = value1 # 说明约束条件
actual = tensor.mean().item() param2 = value2
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
# ============================================================ # ============================================================
# Main Test Script # 构造输入
# ============================================================ # ============================================================
# 1. Initialize input_tensor = ... # 使用结构化数据便于验证
module = SomeModule(param=value)
# 2. Test feature X # ============================================================
result = module.do_something() # Step N: [操作名称]
assert result == expected_value # ============================================================
# 3. Test feature Y output = some_function(input_tensor, ...)
...
# 验证: [验证逻辑说明]
expected = ...
actual = output[...].item()
assert actual == expected, f"xxx: {actual} != {expected}"
print("test_xxx: PASSED") print("test_xxx: PASSED")
``` ```
### Comments ### 核心原则
- Keep comments concise and clear | 原则 | 说明 |
- Only add comments where the code isn't self-explanatory |------|------|
- Use section headers (`# === Section ===`) to organize logical blocks | **最小化 print** | 只在最后输出 `PASSED`,不打印中间结果 |
| **结构化数据** | 使用可预测的输入(全 1、偶奇交替等便于手算验证 |
| **注释说明验证逻辑** | 在 assert 前用注释解释预期值的计算方式 |
| **分段用 `====`** | 用 `# ============` 分隔参数、输入、各步骤 |
| **assert 验证** | 用 assert 而不是 print 比较结果 |
### Output ### 输出规范
- **Minimize print statements** - the code should be self-explanatory ```python
- Only print a final "PASSED" message at the end # ✅ 正确
- Use `assert` for verification instead of printing results assert actual == expected, f"xxx: {actual} != {expected}"
- If the user needs explanation, they will ask print("test_xxx: PASSED")
# ❌ 错误
print(f"输出: {output}")
print(f"预期: {expected}, 实际: {actual}")
```
### 参数注释
```python
# ✅ 正确: 注释说明约束条件
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M
segment_size = 128 # 必须 >= block_size
# ❌ 错误: 无意义的注释
seq_len = 512 # 序列长度
```
### 验证逻辑注释
```python
# ✅ 正确: 解释计算过程
# 验证: 反对角线求和
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4共 stride/2 对
expected = (2*1 + 1*2) * (stride // 2) * head_dim
# ❌ 错误: 只写公式不解释
expected = 4 * 2 * 128
```
## Running Tests ## Running Tests
```bash ```bash
# Run a specific test # 运行单个测试
python tests/test_offload_engine.py PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
# Run with specific GPU # 指定 GPU
CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
``` ```
## Benchmarks ## Benchmarks
```bash ```bash
# Standard GPU benchmark python bench.py # GPU benchmark
python bench.py python bench_offload.py # CPU offload benchmark
python bench_vllm.py # vLLM comparison
# CPU offload benchmark
python bench_offload.py
# vLLM comparison benchmark
python bench_vllm.py
```
## Quick Verification
```bash
# Import test
python -c "from nanovllm import LLM"
# Run offload benchmark (tests CPU-primary ring buffer mode)
python bench_offload.py
``` ```

6
.gitignore vendored
View File

@@ -232,9 +232,9 @@ tests/data/
.serena/ .serena/
# Planning-with-files temporary files # Planning-with-files temporary files
# task_plan.md task_plan.md
# findings.md findings.md
# progress.md progress.md
task_plan_*.md task_plan_*.md
findings_*.md findings_*.md
progress_*.md progress_*.md

View File

@@ -15,7 +15,9 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern | | [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms | | [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 | | [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 | | [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 | | [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling | | [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) | | [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |

View File

@@ -0,0 +1,429 @@
# XAttention BSA Policy 设计文档
本文档描述 `XAttentionBSAPolicy` 的设计和实现,这是一个基于 XAttention 算法的稀疏注意力策略,用于 CPU offload 模式下的 chunked prefill。
## 概述
`XAttentionBSAPolicy` 实现了基于 XAttention 的块级稀疏注意力选择。核心思想是:
1. **估计阶段**:使用 XAttention kernels 快速估计每个 KV block 的重要性
2. **选择阶段**:基于阈值和 majority voting 选择重要的 blocks
3. **计算阶段**:只加载选中的 blocks 进行 attention 计算
```
┌─────────────────────────────────────────────────────────────┐
│ XAttention BSA Policy │
├─────────────────────────────────────────────────────────────┤
│ select_blocks() │
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
│ │ Load K │──>│ flat_group_gemm │──>│ softmax_fuse │ │
│ │ blocks │ │ _fuse_reshape │ │ _block_sum │ │
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
│ │ │ │ │
│ v v v │
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
│ │ K: [B,H,L,D]│ │ attn_scores: │ │ block_sums: │ │
│ │ │ │ [B,H,Q/s,K/s] │ │ [B,H,Qb,Kb] │ │
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
│ │ │
│ ┌──────────────────────┘ │
│ v │
│ ┌──────────────┐ │
│ │find_blocks │ │
│ │_chunked │ │
│ └──────────────┘ │
│ │ │
│ v │
│ ┌──────────────┐ │
│ │ GQA-aware │ │
│ │ aggregation │ │
│ │ + majority │ │
│ │ voting │ │
│ └──────────────┘ │
│ │ │
│ v │
│ selected_block_ids │
├─────────────────────────────────────────────────────────────┤
│ compute_chunked_prefill() │
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
│ │ Ring buffer │──>│ flash_attn_ │──>│ merge_ │ │
│ │ pipeline │ │ with_lse │ │ attention │ │
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
## 文件位置
**主文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`
**依赖的 XAttention kernels**: `nanovllm/ops/xattn.py`
- `flat_group_gemm_fuse_reshape`: 计算 stride reshape 后的 attention scores
- `softmax_fuse_block_sum`: 对 attention scores 做 softmax 后按 block 求和
- `find_blocks_chunked`: 基于阈值选择 blocks
---
## 核心算法
### 1. select_blocks: 块选择算法
```python
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]:
```
#### Step 1: 加载 K blocks 并计算 attention scores
对每个 CPU block加载 K 到 GPU 并使用 `flat_group_gemm_fuse_reshape` 计算:
```python
for cpu_block_id in available_blocks:
# 加载 K block: [1, block_size, num_kv_heads, head_dim]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
k_block, _ = offload_engine.get_kv_for_slot(slot)
# 转换为 [batch, heads, k_len, head_dim]
K_chunk = k_block.transpose(1, 2)
# GQA: 扩展 K heads 匹配 Q heads
if num_heads != num_kv_heads:
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# 计算 attention scores
attn_chunk = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...)
attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks: [1, heads, q_reshaped_len, total_k_reshaped_len]
attn_scores = torch.cat(attn_scores_list, dim=-1)
```
#### Step 2: 聚合到 block 级别
使用 `softmax_fuse_block_sum` 将 attention scores 聚合到 block 级别:
```python
# reshaped_block_size = block_size / stride = 1024 / 8 = 128
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_block_size, # 1:1 对应 CPU blocks
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False,
)
# block_sums: [batch, heads, q_blocks, k_blocks]
```
**关键点**: `reshaped_block_size` 必须与 CPU block 对齐,确保输出的 `k_blocks` 维度 1:1 对应 `available_blocks`
#### Step 3: 阈值选择
使用 `find_blocks_chunked` 基于累积注意力阈值选择 blocks
```python
mask = find_blocks_chunked(
block_sums,
current_index=0,
threshold=self.threshold, # e.g., 0.95
num_to_choose=None,
decoding=False,
mode="prefill",
causal=False,
)
# mask: [batch, num_heads, q_blocks, k_blocks] - boolean
```
#### Step 4: GQA-aware 聚合 + Majority Voting
```python
# GQA: 在同一个 KV head group 内,任一 Q head 选择即选择
if num_groups > 1:
mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks)
mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks]
# Majority voting: 跨 KV heads 和 q_blocks 投票
vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks]
total_votes = num_kv_heads * q_blocks
vote_ratio = vote_count / total_votes
# 选择 >50% 投票的 blocks
vote_threshold = 0.5
block_selected = vote_ratio > vote_threshold
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
# 安全措施: 始终包含第一个 (sink) 和最后一个 block
if available_blocks[0] not in selected_block_ids:
selected_block_ids.insert(0, available_blocks[0])
if available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1])
```
**为什么使用 Majority Voting?**
| 聚合方式 | 问题 |
|---------|------|
| `any()` 跨所有 heads | 密度接近 100%,失去稀疏性 |
| `all()` | 太激进,可能丢失重要 blocks |
| **Majority voting (>50%)** | 平衡稀疏性和准确性 |
实验结果显示:
- 每 head 密度: 20-35%
- `any()` 聚合后: ~100%
- **Majority voting 后: ~45%**
---
### 2. compute_chunked_prefill: 注意力计算
复用 `FullAttentionPolicy` 的 ring buffer pipeline 实现:
```python
def compute_chunked_prefill(self, q, k, v, layer_id, softmax_scale,
offload_engine, kvcache_manager,
current_chunk_idx, seq, num_tokens,
selected_blocks) -> torch.Tensor:
```
#### 计算流程
1. **加载历史 blocks** (使用 selected_blocks):
```python
for block_idx in range(num_blocks):
# Ring buffer pipeline: load -> wait -> compute -> next
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(q, prev_k, prev_v, causal=False)
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
```
2. **计算当前 chunk** (causal mask):
```python
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(q, k_curr, v_curr, causal=True)
```
3. **合并结果**:
```python
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
```
---
## 参数配置
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `threshold` | 0.95 | 累积注意力阈值 (tau),越高越保守 |
| `stride` | 8 | XAttention stride reshape 参数 |
| `chunk_size` | 16384 | 估计时的处理 chunk size |
| `block_size` | 128 | BSA block size (固定值) |
### 使用方式
```python
# 在 config 中设置
config.sparse_policy = SparsePolicyType.XATTN_BSA
config.sparse_threshold = 0.95
# 或通过命令行
python tests/test_needle.py \
--enable-offload \
--enable-xattn-bsa \
--sparse-threshold 9 # 会被除以 10 变为 0.9
```
---
## 性能特性
| 特性 | 说明 |
|------|------|
| **Prefill 支持** | ✅ 完整支持 |
| **Decode 支持** | ❌ 不支持(使用 FullAttentionPolicy |
| **稀疏度** | ~45-55%threshold=0.95majority voting |
| **准确性** | RULER NIAH 100% 通过 |
### 限制
1. **Decode 不支持**: XAttention 估计需要足够长的 Q 序列,单 token decode 不适用
2. **估计开销**: `select_blocks` 需要加载所有 K blocks 进行估计
3. **Triton 对齐**: Q/K 长度必须满足 `stride * BLOCK_M/N` 对齐要求
---
## 与其他 Policy 的对比
| Policy | select_blocks | 稀疏度 | Decode 支持 |
|--------|--------------|--------|-------------|
| FullAttentionPolicy | 返回所有 blocks | 0% | ✅ |
| QuestPolicy | 基于 min/max key | ~50% | ✅ |
| **XAttentionBSAPolicy** | XAttention + majority voting | ~45-55% | ❌ |
---
## 测试验证
```bash
# Needle test (32K)
CUDA_VISIBLE_DEVICES=0 python tests/test_needle.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--enable-xattn-bsa \
--input-len 32768
# RULER benchmark
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.95 \
--data-dir tests/data/ruler_niah
```
---
## 性能基准测试
### 128K 上下文对比 (Llama-3.1-8B, A100 80GB)
| Policy | Density | 时间 | 内存峰值 | 准确率 |
|--------|---------|------|---------|--------|
| **Full** | 100% | 120.9s | 16.4GB (稳定) | 100% |
| **XAttn BSA** | ~52% | 152.3s | 19.8GB | 100% |
### Density 变化趋势
| Chunk | Full | XAttn BSA |
|-------|------|-----------|
| 10 | 100% | 90% |
| 30 | 100% | 73% |
| 60 | 100% | 50% |
| 100 | 100% | 50% |
| 126 | 100% | 52% |
**观察**XAttn BSA 的 density 随 chunks 增加而下降,最终稳定在 ~50%。
### 性能分析
**当前问题**XAttn BSA 虽然 density 只有 ~52%,但时间反而比 Full 更长152s vs 121s
**原因**`select_blocks` 需要加载所有 K blocks 来估计 attention scores导致每个 block 被加载两次:
1. 估计阶段:加载 K 计算 attention scores
2. 计算阶段:加载选中的 K/V 进行实际计算
**优化方向**
1. 跨层共享估计结果layer 0 估计,其他层复用)
2. 采样估计(只用部分 K blocks 估计)
3. 缓存估计结果避免重复计算
---
## 内存管理
### 内存泄漏问题 (已修复)
**问题**128K prefill 时 GPU 内存从 16GB 增长到 80GB。
**根因**
```python
# 问题代码:累积存储但从未使用
self.sparse_metadata[layer_id] = attn_scores
```
每个 chunk 的每个 layer 都存储 `attn_scores`,导致内存持续增长。
**修复方法**
```python
# 1. 删除无用的 sparse_metadata 存储
# 2. 立即释放中间变量
del attn_scores_list
del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected
```
**修复效果**
| 版本 | 内存增长 | 峰值 |
|------|---------|------|
| 修复前 | +64GB | 80GB |
| **修复后** | +4GB | 19.8GB |
### 内存监控
使用 `gpu-monitor` agent 监控内存:
```bash
# 启动监控
# 在 Claude Code 中使用 Task tool 启动 gpu-monitor agent
# 或手动监控
watch -n 1 'nvidia-smi --query-gpu=memory.used --format=csv,noheader -i 0'
```
---
## Density 统计 API
### 启用统计
```python
# 统计自动在 select_blocks 中更新(仅 layer 0
# 使用 logger.debug 输出每 chunk 的 density
```
### 获取统计
```python
policy = XAttentionBSAPolicy(threshold=0.95)
# 运行 prefill 后...
# 获取统计
stats = policy.get_density_stats()
# {
# "total_available_blocks": 8001,
# "total_selected_blocks": 4160,
# "num_chunks": 126,
# "overall_density": 0.52
# }
# 打印统计
policy.print_density_stats()
# 重置统计
policy.reset_stats()
```
### 启用 DEBUG 日志
```python
# 在 test_ruler.py 中
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
# 输出示例:
# [XAttn] chunk=30, available=30, selected=22, chunk_density=73.3%
```
---
## 已知问题
| 问题 | 状态 | 说明 |
|------|------|------|
| 估计开销过大 | 🟡 待优化 | select_blocks 需要加载所有 K blocks |
| 时间比 Full 更长 | 🟡 待优化 | 128K 场景 152s vs 121s |
| 小幅内存增长 | 🟢 可接受 | ~4GB可能来自 Triton 缓存 |
| Decode 不支持 | ✅ 设计如此 | 使用 FullAttentionPolicy |
---
## 相关文档
- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解
- [`docs/xattn_kernels_guide.md`](xattn_kernels_guide.md): Triton kernels 实现
- [`docs/sparse_policy_architecture.md`](sparse_policy_architecture.md): SparsePolicy 架构
- [`docs/sparse_policy_implementation_guide.md`](sparse_policy_implementation_guide.md): 实现指南

198
docs/xattn_kernels_guide.md Normal file
View File

@@ -0,0 +1,198 @@
# XAttention Kernels Guide
本文档详细说明 XAttention 的两个核心 Triton kernel 的工作原理。
## 概述
XAttention 使用 stride 采样来快速估计 attention 分布,用于稀疏 attention 的 block 选择。
**数据流**
```
Q [batch, heads, q_len, head_dim]
K [batch, heads, kv_len, head_dim]
↓ flat_group_gemm_fuse_reshape (stride 采样 + GEMM)
attn_scores [batch, heads, q_len/stride, kv_len/stride]
↓ softmax_fuse_block_sum (softmax + block 求和)
block_sums [batch, heads, q_blocks, k_blocks]
↓ threshold 选择
sparse_mask [batch, heads, q_blocks, k_blocks]
```
**注意**Q 和 K 可以有不同的长度q_len ≠ kv_len这在 chunked prefill 场景中很常见。
## Kernel 1: flat_group_gemm_fuse_reshape
### 功能
计算 stride reshape 后的 attention scores本质是计算原始 attention 矩阵中每个 stride×stride 块的**反对角线求和**。
### 函数签名
```python
def flat_group_gemm_fuse_reshape(
query_states: torch.Tensor, # [batch, heads, q_len, head_dim]
key_states: torch.Tensor, # [batch, heads, kv_len, head_dim]
stride: int,
chunk_start: int,
chunk_end: int,
is_causal: bool = True,
) -> torch.Tensor: # [batch, heads, q_len/stride, kv_len/stride]
```
### 采样方式
```
Q 采样: (stride-1-s)::stride (逆向)
K 采样: s::stride (正向)
例如 stride=4:
Q 采样位置: 3, 7, 11, 15, ... (从位置 3 开始,每隔 4)
K 采样位置: 0, 4, 8, 12, ... (从位置 0 开始,每隔 4)
```
### 反对角线原理
对于原始 attention 矩阵的每个 stride×stride 块:
```
stride=4 的块:
K[0] K[1] K[2] K[3]
Q[0] · · · X ← 反对角线
Q[1] · · X ·
Q[2] · X · ·
Q[3] X · · ·
```
**输出值 = 反对角线元素之和**
因为:
- `Q[i]` 采样自原始位置 `(stride-1-i)`
- `K[j]` 采样自原始位置 `j`
-`i + j = stride - 1` 时,恰好在反对角线上
### Triton 约束
**GPU 相关的 BLOCK 大小**
| GPU 类型 | 显存 | BLOCK_M/N | 最小 q_len/kv_len |
|----------|------|-----------|-------------------|
| RTX 3090 | 24GB | 64 | stride × 64 = 256 |
| A100/H100 | ≥40GB | 128 | stride × 128 = 512 |
```python
# 代码中的判断逻辑
if props.total_memory < 30 * 1024**3: # < 30GB
BLOCK_M = BLOCK_N = 64
else:
BLOCK_M = BLOCK_N = 128
assert q_len % (stride * BLOCK_M) == 0
assert kv_len % (stride * BLOCK_N) == 0
```
### 验证示例
```python
# 输入: 偶数位置=1, 奇数位置=2
# q_len=512, kv_len=2048, stride=4, head_dim=128
# 反对角线元素 (stride=4):
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4 (每对)
# stride=4 有 2 对
# 乘以 head_dim=128
# 预期值: 4 * 2 * 128 = 1024
# 输出 shape: [1, 1, 128, 512] (512/4=128, 2048/4=512)
```
## Kernel 2: softmax_fuse_block_sum
### 功能
`flat_group_gemm_fuse_reshape` 的输出做 softmax然后按 block 求和,得到每个 block 的 attention 权重总和。
### 参数说明
| 参数 | 含义 |
|------|------|
| `attn_weights_slice` | 输入 attention scores `[batch, heads, q_reshaped, k_reshaped]` |
| `reshaped_block_size` | Block 大小(在 reshaped 空间,= block_size / stride |
| `segment_size` | 每次迭代处理的 K 维度大小tiling |
| `chunk_start` | Q 的起始位置(用于 causal mask |
| `chunk_end` | Q 的结束位置 |
| `real_q_len` | 有效 Q 长度(用于 padding mask |
| `scale` | 缩放因子(融合多个因素) |
| `is_causal` | 是否应用 causal mask |
### Scale 因子
```python
scale = log2(e) / sqrt(head_dim) / stride / norm
= 1.4426950408889634 / sqrt(head_dim) / stride / norm
```
| 因子 | 值 | 作用 |
|------|-----|------|
| `log2(e)` | 1.4426950408889634 | Triton 用 `exp2` 而非 `exp`,需转换底数 |
| `1/sqrt(head_dim)` | 1/√128 | 标准 attention 缩放 |
| `1/stride` | 1/4 | stride 采样的归一化 |
| `1/norm` | 变化 | 额外归一化因子 |
**为什么用 exp2**Triton 的 `exp2``exp` 更快(硬件原生支持),所以把 log₂(e) 融合到 scale 里。
### Segment Size 约束
```python
assert segment_size >= reshaped_block_size
```
原因kernel 内部使用 `segment_size // block_size` 做 reshape
```python
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
```
如果 `segment_size < block_size`,则 `segment_size // block_size = 0`,导致无效维度。
### 验证示例
```python
# 输入: attn_scores [1, 1, 128, 512] (所有值相同)
# block_size=128
# softmax 后每行均匀分布 (所有值相同 → 均匀)
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len = 128/512 = 0.25
# 每个 Q block 有 block_size=128 行
# block_sum = 128 * 0.25 = 32
# 输出 shape: [1, 1, 1, 4] (128/128=1, 512/128=4)
```
## 完整示例
```python
# 参数
q_len = 512 # Q 长度
kv_len = 2048 # K/V 长度 (可以不同于 q_len)
stride = 4
block_size = 128
# Step 1: flat_group_gemm_fuse_reshape
# 输入: Q [1,1,512,128], K [1,1,2048,128]
# 输出: attn_scores [1,1,128,512]
# Step 2: softmax_fuse_block_sum
# 输入: attn_scores [1,1,128,512]
# 输出: block_sums [1,1,1,4]
# q_blocks = 128/128 = 1
# k_blocks = 512/128 = 4
```
## 测试代码
参考 `tests/test_xattn_kernels.py`,使用结构化数据验证两个 kernel 的正确性。
## 相关文档
- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解
- [`docs/sparse_attention_guide.md`](sparse_attention_guide.md): 稀疏 attention 方法概述

View File

@@ -1,109 +0,0 @@
# Findings: CUDA Graph for Offload Mode
## Discovery 1: 为什么 Offload Mode 不使用 CUDA Graph
**位置**: `nanovllm/engine/model_runner.py:421`
```python
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
```
**原因**: `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`,强制使用 eager mode。
---
## Discovery 2: 当前 CUDA Graph 架构
**文件**: `model_runner.py:682-717`
```python
def capture_cudagraph(self):
# 为不同 batch size 捕获完整 model forward
for bs in [1, 2, 4, 8, 16, ...]:
with torch.cuda.graph(graph):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
```
**特点**:
- 捕获完整的 `model()` 调用(包含所有层)
- 使用 graph pool 共享内存
- 只用于 decodeprefill 始终 eager
---
## Discovery 3: Offload Decode 的 Attention 流程
**文件**: `nanovllm/kvcache/sparse/full_policy.py:304-379`
**Ring Buffer Pipeline**:
```
1. 预加载前 N 个 blocks 到 GPU slots
2. 对每个 block:
a. wait_slot_layer() # 等待 H2D
b. get_kv_for_slot() # 获取 KV
c. flash_attn_with_lse() # ⭐ 可 graph
d. record_slot_compute_done()
e. load_next_block() # 启动下一个 H2D
f. merge_attention_outputs() # ⭐ 可 graph但动态
```
**关键**: H2D 传输不能 graph但 attention 计算可以。
---
## Discovery 4: 验证 Graph 复用可行性
**测试**: `tests/test_chunk_attention_graph_reuse.py`
**结论**:
- 只需 2 个 graphcausal + non-causal
- 通过 `copy_()` 更新 static tensors
- 可复用于所有层和所有 chunk pairs
**测试结果**:
```
Layer 0: max_diff=3.91e-03 ✅
Layer 1: max_diff=7.81e-03 ✅
Layer 2: max_diff=3.91e-03 ✅
✅ PASSED
```
---
## Discovery 5: Chunk Size 和 Block Size 关系
**观察**:
- Prefilled blocks 的 KV size = `block_size`
- Decode buffer 的 KV size = `1``block_size`(动态)
**Graph 策略**:
- Prefilled blocks: 固定 size = block_size适合 graph
- Decode buffer: 动态 size建议保持 eager
---
## Discovery 6: 使用的 Triton 算子
**文件**: `nanovllm/ops/chunked_attention.py`
| 算子 | 功能 | 可 Graph |
|------|------|----------|
| `flash_attn_with_lse()` | Attention + LSE | ✅ |
| `merge_attention_outputs()` | 合并两个 attention 输出 | ✅ |
这两个算子是纯 GPU 计算,可以被 CUDA Graph 捕获。
---
## Discovery 7: 数据依赖分析
**Attention 输入**:
- `q`: 来自当前层的 QKV projectionshape 固定
- `k, v`: 来自 GPU slotH2D 传输后shape = [1, block_size, heads, dim]
**依赖链**:
```
H2D(block) → wait() → get_kv() → copy_to_static() → graph.replay() → clone_output()
```
**关键**: Graph 只封装 attention 计算,不包含数据传输。

View File

@@ -48,7 +48,7 @@ class Config:
# XAttention BSA specific parameters # XAttention BSA specific parameters
sparse_block_size: int = 128 # Block size for BSA (tokens per block) sparse_block_size: int = 128 # Block size for BSA (tokens per block)
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1) sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
sparse_use_triton: bool = True # Use Triton kernels for estimation sparse_use_triton: bool = True # Use Triton kernels for estimation
sparse_stride: int = 8 # Stride for Q/K downsampling sparse_stride: int = 8 # Stride for Q/K downsampling

View File

@@ -37,6 +37,11 @@ class FullAttentionPolicy(SparsePolicy):
supports_prefill = True supports_prefill = True
supports_decode = True supports_decode = True
def __init__(self):
"""Initialize with statistics tracking."""
self._stats_total_blocks = 0
self._stats_num_chunks = 0
def select_blocks( def select_blocks(
self, self,
available_blocks: List[int], available_blocks: List[int],
@@ -44,8 +49,33 @@ class FullAttentionPolicy(SparsePolicy):
ctx: PolicyContext, ctx: PolicyContext,
) -> List[int]: ) -> List[int]:
"""Return all blocks - no sparsity.""" """Return all blocks - no sparsity."""
# Update statistics (only for layer 0 to avoid overcounting)
if ctx.layer_id == 0 and available_blocks:
self._stats_total_blocks += len(available_blocks)
self._stats_num_chunks += 1
logger.debug(f"[Full] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, density=100.0%")
return available_blocks return available_blocks
def reset_stats(self) -> None:
"""Reset density statistics."""
self._stats_total_blocks = 0
self._stats_num_chunks = 0
def get_density_stats(self) -> dict:
"""Get density statistics."""
return {
"total_available_blocks": self._stats_total_blocks,
"total_selected_blocks": self._stats_total_blocks, # Full = all selected
"num_chunks": self._stats_num_chunks,
"overall_density": 1.0, # Always 100%
}
def print_density_stats(self) -> None:
"""Print density statistics summary."""
stats = self.get_density_stats()
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
f"blocks={stats['total_available_blocks']}, density=100.0%")
def compute_chunked_prefill( def compute_chunked_prefill(
self, self,
q: torch.Tensor, q: torch.Tensor,
@@ -58,16 +88,17 @@ class FullAttentionPolicy(SparsePolicy):
current_chunk_idx: int, current_chunk_idx: int,
seq: "Sequence", seq: "Sequence",
num_tokens: int, num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute full attention for chunked prefill. Compute full attention for chunked prefill.
This method handles the complete chunked prefill flow: This method handles the chunked prefill computation:
1. Get historical blocks 1. Load and compute attention to historical chunks (using selected_blocks)
2. Select blocks via select_blocks 2. Compute attention to current chunk
3. Load and compute attention to historical chunks 3. Merge all results
4. Compute attention to current chunk
5. Merge all results Note: Block selection is done by the caller before invoking this method.
Args: Args:
q: Query tensor [seq_len, num_heads, head_dim] q: Query tensor [seq_len, num_heads, head_dim]
@@ -80,6 +111,7 @@ class FullAttentionPolicy(SparsePolicy):
current_chunk_idx: Current chunk index current_chunk_idx: Current chunk index
seq: Sequence object seq: Sequence object
num_tokens: Number of tokens in current chunk num_tokens: Number of tokens in current chunk
selected_blocks: List of CPU block IDs to process (already filtered)
Returns: Returns:
Attention output [seq_len, num_heads, head_dim] Attention output [seq_len, num_heads, head_dim]
@@ -87,30 +119,16 @@ class FullAttentionPolicy(SparsePolicy):
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, " logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}") f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
f"selected_blocks={len(selected_blocks)}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None o_acc = None
lse_acc = None lse_acc = None
compute_stream = offload_engine.compute_stream compute_stream = offload_engine.compute_stream
# Step 1: Get historical blocks # Use the pre-selected blocks directly
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = selected_blocks
# Step 2: Apply select_blocks to filter blocks
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
if cpu_block_table: if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots)) load_slots = list(range(offload_engine.num_ring_slots))
@@ -200,16 +218,17 @@ class FullAttentionPolicy(SparsePolicy):
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager", kvcache_manager: "KVCacheManager",
seq: "Sequence", seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute full attention for chunked decode. Compute full attention for chunked decode.
This method handles the complete chunked decode flow: This method handles the chunked decode computation:
1. Get prefilled CPU blocks 1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
2. Apply select_blocks for block filtering 2. Read accumulated decode tokens from decode buffer
3. Load blocks via pipeline (ring buffer or cross-layer) 3. Merge all results
4. Read accumulated decode tokens from decode buffer
5. Merge all results Note: Block selection is done by the caller before invoking this method.
Args: Args:
q: Query tensor [batch_size, num_heads, head_dim] q: Query tensor [batch_size, num_heads, head_dim]
@@ -218,6 +237,7 @@ class FullAttentionPolicy(SparsePolicy):
offload_engine: OffloadEngine for loading blocks offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management kvcache_manager: KVCacheManager for block management
seq: Sequence object seq: Sequence object
selected_blocks: List of CPU block IDs to process (already filtered)
Returns: Returns:
Attention output [batch_size, 1, num_heads, head_dim] Attention output [batch_size, 1, num_heads, head_dim]
@@ -227,40 +247,35 @@ class FullAttentionPolicy(SparsePolicy):
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Get only PREFILLED CPU blocks (exclude the current decode block) # Use the pre-selected blocks directly
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = selected_blocks
if layer_id == 0: if layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}") logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table: if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last CPU block # Calculate valid tokens in the last CPU block
# CRITICAL: Use original prefill length, not current seq length! # CRITICAL: Use original prefill length, not current seq length!
# CPU blocks are fixed after prefill, their content doesn't change during decode. # CPU blocks are fixed after prefill, their content doesn't change during decode.
# Note: We need to get all prefilled blocks to determine last_block_valid_tokens
block_size = kvcache_manager.block_size block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table) all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq)
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
last_block_valid_tokens = total_prefill_tokens % block_size last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0: if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy (self) for block filtering # Determine if selected_blocks contains the last prefilled block
policy_ctx = PolicyContext( # If not, all selected blocks are full blocks (use block_size as valid tokens)
query_chunk_idx=0, last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None
num_query_chunks=1, selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block)
layer_id=layer_id, effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size
query=q_batched,
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
# Use ring buffer pipeline for loading prefilled blocks # Use ring buffer pipeline for loading prefilled blocks
load_slots = offload_engine.decode_load_slots load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline( o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine, q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens, layer_id, softmax_scale block_size, effective_last_block_tokens, layer_id, softmax_scale
) )
# Now attend to accumulated decode tokens from per-layer decode buffer # Now attend to accumulated decode tokens from per-layer decode buffer

View File

@@ -204,17 +204,20 @@ class SparsePolicy(ABC):
current_chunk_idx: int, current_chunk_idx: int,
seq: "Sequence", seq: "Sequence",
num_tokens: int, num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute chunked prefill attention (complete flow). Compute chunked prefill attention (complete flow).
This is the main entry point for prefill attention computation. This is the main entry point for prefill attention computation.
It defines the complete prefill flow: It defines the complete prefill flow:
1. Get historical blocks 1. Load and compute historical blocks via offload_engine (using selected_blocks)
2. Select blocks (call select_blocks) 2. Get current chunk KV from offload_engine, compute attention
3. Load and compute historical blocks via offload_engine 3. Merge all results
4. Get current chunk KV from offload_engine, compute attention
5. Merge all results Note: Block selection (select_blocks) is called by the caller (attention.py)
before invoking this method. The selected_blocks parameter contains the
filtered block IDs to process.
Args: Args:
q: [seq_len, num_heads, head_dim] query for current chunk q: [seq_len, num_heads, head_dim] query for current chunk
@@ -227,6 +230,7 @@ class SparsePolicy(ABC):
current_chunk_idx: current chunk index current_chunk_idx: current chunk index
seq: Sequence object seq: Sequence object
num_tokens: number of tokens in current chunk num_tokens: number of tokens in current chunk
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
Returns: Returns:
[seq_len, num_heads, head_dim] final attention output [seq_len, num_heads, head_dim] final attention output
@@ -242,17 +246,20 @@ class SparsePolicy(ABC):
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager", kvcache_manager: "KVCacheManager",
seq: "Sequence", seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute chunked decode attention (complete flow). Compute chunked decode attention (complete flow).
This is the main entry point for decode attention computation. This is the main entry point for decode attention computation.
It defines the complete decode flow: It defines the complete decode flow:
1. Get prefilled blocks from CPU 1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
2. Select blocks (call select_blocks) 2. Read accumulated decode tokens from decode buffer
3. Load blocks via pipeline (ring buffer or cross-layer) 3. Merge all results
4. Read accumulated decode tokens from decode buffer
5. Merge all results Note: Block selection (select_blocks) is called by the caller (attention.py)
before invoking this method. The selected_blocks parameter contains the
filtered block IDs to process.
The decode position information can be computed internally: The decode position information can be computed internally:
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq) - decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
@@ -265,6 +272,7 @@ class SparsePolicy(ABC):
offload_engine: OffloadEngine for loading blocks offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management kvcache_manager: KVCacheManager for block management
seq: Sequence object seq: Sequence object
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
Returns: Returns:
[batch_size, 1, num_heads, head_dim] final attention output [batch_size, 1, num_heads, head_dim] final attention output

View File

@@ -2,69 +2,508 @@
XAttention Block Sparse Attention (BSA) Policy for nano-vllm. XAttention Block Sparse Attention (BSA) Policy for nano-vllm.
This module implements XAttention-inspired block sparse attention for chunked prefill. This module implements XAttention-inspired block sparse attention for chunked prefill.
Current implementation loads all historical blocks (FULL strategy).
Sparse selection to be implemented in next phase. Key design:
1. Use xattn_estimate_chunked to estimate sparse block mask
2. Use BSA kernel for efficient sparse attention computation
3. Support chunked prefill with q_start_pos for correct position handling
Note: Decode phase is not supported - use FullAttentionPolicy for decode.
""" """
import logging
import torch import torch
from typing import List, Optional, Tuple from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.utils.context import get_context
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.manager import KVCacheManager
from nanovllm.engine.sequence import Sequence
logger = logging.getLogger(__name__)
# Check BSA availability
try:
from block_sparse_attn import block_sparse_attn_func
BSA_AVAILABLE = True
except ImportError:
BSA_AVAILABLE = False
logger.warning("block_sparse_attn not available, XAttentionBSAPolicy will fallback to dense")
# Check xattn_estimate_chunked availability
try:
from nanovllm.ops.xattn import xattn_estimate_chunked
XATTN_AVAILABLE = True
except ImportError:
XATTN_AVAILABLE = False
logger.warning("xattn_estimate_chunked not available")
def expand_kv_for_gqa(
key_states: torch.Tensor,
value_states: torch.Tensor,
num_heads: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand KV for Grouped Query Attention.
Args:
key_states: [B, num_kv_heads, seq_len, head_dim]
value_states: [B, num_kv_heads, seq_len, head_dim]
num_heads: Number of query heads
Returns:
Expanded (key, value) with shape [B, num_heads, seq_len, head_dim]
"""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return (
key_states.repeat_interleave(num_groups, dim=1),
value_states.repeat_interleave(num_groups, dim=1),
)
class XAttentionBSAPolicy(SparsePolicy): class XAttentionBSAPolicy(SparsePolicy):
""" """
XAttention Block Sparse Attention policy for chunked prefill. XAttention Block Sparse Attention policy for chunked prefill.
This policy uses block-level estimation to determine which KV blocks Uses xattn_estimate_chunked to estimate sparse mask, then BSA kernel
are important for the current chunk's queries, enabling sparse computation. for efficient sparse attention computation.
Note: Current implementation loads all historical chunks (FULL strategy). Note:
Sparse selection to be implemented in next phase. - Only supports prefill phase (decode uses FullAttentionPolicy)
- BSA block size is fixed at 128 tokens
""" """
supports_prefill = False # Uses standard select_blocks interface supports_prefill = True
supports_decode = False # BSA is prefill-only supports_decode = False # Decode uses FullAttentionPolicy
requires_block_selection = False # Selection happens at chunk level, not block level requires_block_selection = False # Selection happens internally
# BSA requires 128-token blocks
BSA_BLOCK_SIZE = 128
def __init__( def __init__(
self, self,
threshold: float = 0.95, # High threshold for accuracy testing
stride: int = 8,
chunk_size: int = 16384,
block_size: int = 128, block_size: int = 128,
samples_per_chunk: int = 128, samples_per_chunk: int = 128,
threshold: float = 0.9, use_triton: bool = True,
): ):
""" """
Initialize XAttention BSA policy. Initialize XAttention BSA policy.
Args: Args:
block_size: Number of tokens per block (default: 128) threshold: Cumulative attention threshold for block selection (0-1)
samples_per_chunk: Number of tokens to sample from each historical chunk for estimation Higher values = more blocks selected = less sparse
threshold: Cumulative attention threshold for chunk selection (0-1) stride: Stride for Q/K reshape in estimation (typically 8)
chunk_size: Processing chunk size for xattn_estimate (Triton alignment)
block_size: BSA block size (must be 128)
samples_per_chunk: Samples per chunk for estimation (unused)
use_triton: Whether to use Triton kernels
""" """
self.block_size = block_size
self.samples_per_chunk = samples_per_chunk
self.threshold = threshold self.threshold = threshold
self.stride = stride
self.chunk_size = chunk_size
self.use_triton = use_triton
self._num_heads = None # Set during first forward
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]: # Sparse metadata: stores attention scores per layer
# Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]]
self.sparse_metadata: dict = {}
# Statistics for density tracking
self._stats_total_available_blocks = 0
self._stats_total_selected_blocks = 0
self._stats_num_chunks = 0
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
) -> List[int]:
""" """
Select blocks to load from CPU. Compute attention scores for all available blocks using flat_group_gemm,
then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks.
Current implementation returns all blocks (FULL strategy). This method:
Sparse selection to be implemented in next phase. 1. Loads each K block from CPU
2. Computes Q@K^T attention scores using XAttention stride reshape
3. Applies softmax_fuse_block_sum to get block-level attention
4. Uses find_blocks_chunked to select blocks based on threshold
Args: Args:
available_blocks: List of all available CPU block IDs available_blocks: List of CPU block IDs
ctx: Policy context with query info, chunk index, etc. offload_engine: OffloadEngine for loading blocks
ctx: PolicyContext with query tensor and metadata
Returns: Returns:
List of selected block IDs to load Selected block IDs based on attention threshold
""" """
# Current: Return all blocks (FULL strategy) if not available_blocks or ctx.query is None:
# TODO: Implement sparse selection based on query attention estimation
return available_blocks return available_blocks
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked
import math
layer_id = ctx.layer_id
q = ctx.query # [seq_len, num_heads, head_dim]
# Convert Q to [batch, heads, seq_len, head_dim]
# q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim]
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
num_heads = Q.shape[1]
head_dim = Q.shape[3]
q_len = Q.shape[2]
# flat_group_gemm requires q_len to be divisible by stride * BLOCK_M (typically 8 * 128 = 1024)
# Pad Q if necessary
BLOCK_M = 128 # Triton block size
alignment = self.stride * BLOCK_M
if q_len < alignment:
# Q too short, skip estimation and return all blocks
logger.debug(f"[XAttn] select_blocks: q_len={q_len} < alignment={alignment}, skipping estimation")
return available_blocks
# Pad Q to alignment
padded_q_len = ((q_len + alignment - 1) // alignment) * alignment
if padded_q_len != q_len:
pad_size = padded_q_len - q_len
Q = torch.nn.functional.pad(Q, (0, 0, 0, pad_size), value=0)
q_reshaped_len = padded_q_len // self.stride
# Use a single slot for loading (synchronous mode for simplicity)
slot = 0
attn_scores_list = []
# Get block size from context
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
for cpu_block_id in available_blocks:
# Load K block from CPU to GPU
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
# Get KV: [1, block_size, num_kv_heads, head_dim]
k_block, _ = offload_engine.get_kv_for_slot(slot)
# Convert K to [batch, heads, k_len, head_dim]
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
K_chunk = k_block.transpose(1, 2)
# Handle GQA: expand K heads to match Q heads
num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N)
k_len = K_chunk.shape[2]
BLOCK_N = 128
k_alignment = self.stride * BLOCK_N
if k_len < k_alignment:
# K too short, pad it
pad_size = k_alignment - k_len
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# Compute attention scores using flat_group_gemm_fuse_reshape
# Output: [batch, heads, q_len/stride, k_len/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=False
)
attn_scores_list.append(attn_chunk)
# Mark slot as done for reuse
offload_engine.record_slot_compute_done(slot)
# Concatenate all attention scores along K dimension
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len]
# Result: [1, heads, q_reshaped_len, total_k_reshaped_len]
if not attn_scores_list:
return available_blocks
attn_scores = torch.cat(attn_scores_list, dim=-1)
# Free intermediate list immediately
del attn_scores_list
# Step 2: Apply softmax_fuse_block_sum to get block-level attention
# block_size = reshaped_block_size so each CPU block maps to exactly 1 output block
# This ensures block_sums.shape[-1] == num_available_blocks (1:1 mapping)
norm = 1.0 # Normalization factor
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
segment_size = min(4096, reshaped_block_size)
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False, # Historical blocks are all before current chunk
)
# block_sums shape: [batch, heads, q_blocks, k_blocks]
# where k_blocks == len(available_blocks) (1:1 mapping with CPU blocks)
# Step 3: Use find_blocks_chunked to get selection mask
# current_index = 0 since we're looking at historical blocks only
mask = find_blocks_chunked(
block_sums,
current_index=0,
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=False, # Historical blocks don't need causal mask
)
# mask shape: [batch, num_heads, q_blocks, k_blocks] - boolean
# where k_blocks == len(available_blocks)
# GQA-aware aggregation:
# For GQA, multiple Q heads share one KV head. We need to select a block
# if ANY Q head within the same KV head group selects it.
# mask: [batch, num_heads, q_blocks, k_blocks]
# Reshape to [batch, num_kv_heads, num_groups, q_blocks, k_blocks]
batch_size, num_q_heads, q_blocks, k_blocks = mask.shape
# num_kv_heads was set in the K loading loop above (line ~199)
# num_groups = num_heads // num_kv_heads (for GQA)
num_groups = num_heads // num_kv_heads if num_heads != num_kv_heads else 1
if num_groups > 1:
# Reshape: [batch, num_kv_heads, num_groups, q_blocks, k_blocks]
mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks)
# Aggregate within each KV head group: any Q head selects -> KV head selects
mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks]
else:
mask_per_kv_head = mask # [batch, num_heads, q_blocks, k_blocks]
# Aggregate across KV heads and q_blocks using majority voting
# Instead of any(), use voting: select if >50% of kv_heads select it
# mask_per_kv_head: [batch, num_kv_heads, q_blocks, k_blocks]
# Sum across kv_heads and q_blocks to get vote count per k_block
vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks]
total_votes = num_kv_heads * q_blocks
vote_ratio = vote_count / total_votes
# Select blocks with >50% votes (majority voting)
vote_threshold = 0.5
block_selected = vote_ratio > vote_threshold
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
# Always include first block (sink) and last block for safety
if available_blocks and available_blocks[0] not in selected_block_ids:
selected_block_ids.insert(0, available_blocks[0])
if available_blocks and available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1])
# Update statistics (only for layer 0 to avoid overcounting)
if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks)
self._stats_total_selected_blocks += len(selected_block_ids)
self._stats_num_chunks += 1
# Log per-chunk density
chunk_density = len(selected_block_ids) / len(available_blocks)
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
# Free intermediate tensors to prevent memory leak
del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected
return selected_block_ids
def compute_chunked_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute attention for chunked prefill using XAttention sparse block selection.
This method handles the chunked prefill computation:
1. Load and compute attention to historical chunks (using selected_blocks)
2. Compute attention to current chunk
3. Merge all results
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
layer_id: Current layer index
softmax_scale: Softmax scaling factor
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
current_chunk_idx: Current chunk index
seq: Sequence object
num_tokens: Number of tokens in current chunk
selected_blocks: List of CPU block IDs selected by select_blocks
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table)
if len(load_slots) == 1:
# Only 1 slot - use synchronous mode
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
offload_engine.record_slot_compute_done(slot)
else:
# Multiple slots - use pipeline
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
offload_engine.record_slot_compute_done(current_slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Issue next transfer
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
# Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched, k_curr, v_curr,
softmax_scale=softmax_scale,
causal=True,
)
# Merge historical and current attention
with torch.cuda.stream(compute_stream):
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Sync default stream with compute_stream before returning
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
return final_o.squeeze(0)
def compute_chunked_decode(
self,
q: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor:
"""
XAttention does not support decode phase.
"""
raise NotImplementedError(
"XAttentionBSAPolicy does not support decode phase. "
"Use FullAttentionPolicy for decode."
)
def reset(self) -> None: def reset(self) -> None:
"""Reset policy state.""" """Reset policy state and clear sparse metadata."""
pass self.sparse_metadata.clear()
# Don't reset statistics here - they accumulate across the entire prefill
def reset_stats(self) -> None:
"""Reset density statistics."""
self._stats_total_available_blocks = 0
self._stats_total_selected_blocks = 0
self._stats_num_chunks = 0
def get_density_stats(self) -> dict:
"""Get density statistics."""
if self._stats_total_available_blocks == 0:
return {
"total_available_blocks": 0,
"total_selected_blocks": 0,
"num_chunks": 0,
"overall_density": 0.0,
}
return {
"total_available_blocks": self._stats_total_available_blocks,
"total_selected_blocks": self._stats_total_selected_blocks,
"num_chunks": self._stats_num_chunks,
"overall_density": self._stats_total_selected_blocks / self._stats_total_available_blocks,
}
def print_density_stats(self) -> None:
"""Print density statistics summary."""
stats = self.get_density_stats()
logger.info(f"[XAttn BSA] Density Stats: chunks={stats['num_chunks']}, "
f"available={stats['total_available_blocks']}, "
f"selected={stats['total_selected_blocks']}, "
f"density={stats['overall_density']:.1%}")
def __repr__(self) -> str:
return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"

View File

@@ -5,6 +5,7 @@ from torch import nn
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -197,11 +198,30 @@ class Attention(nn.Module):
if sparse_policy is None: if sparse_policy is None:
raise RuntimeError("sparse_policy is required for chunked prefill") raise RuntimeError("sparse_policy is required for chunked prefill")
# Step 1: Get historical CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
selected_blocks = []
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path # [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, " logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}") f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
# Delegate all computation to policy (no flash_attn or merge calls here!) # Delegate computation to policy with pre-selected blocks
final_o = sparse_policy.compute_chunked_prefill( final_o = sparse_policy.compute_chunked_prefill(
q, k, v, q, k, v,
self.layer_id, self.layer_id,
@@ -211,6 +231,7 @@ class Attention(nn.Module):
current_chunk_idx, current_chunk_idx,
seq, seq,
num_tokens, num_tokens,
selected_blocks,
) )
torch.cuda.nvtx.range_pop() # ChunkedPrefill torch.cuda.nvtx.range_pop() # ChunkedPrefill
@@ -258,14 +279,36 @@ class Attention(nn.Module):
raise RuntimeError("sparse_policy is required for chunked decode") raise RuntimeError("sparse_policy is required for chunked decode")
# Check if policy supports decode phase # Check if policy supports decode phase
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
if not sparse_policy.supports_decode: if not sparse_policy.supports_decode:
raise RuntimeError(f"{sparse_policy} does not support decode phase") from nanovllm.kvcache.sparse import FullAttentionPolicy
sparse_policy = FullAttentionPolicy()
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
f"falling back to FullAttentionPolicy")
# Step 1: Get prefilled CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
selected_blocks = []
if cpu_block_table:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path # [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, " logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
f"policy={sparse_policy}, layer={self.layer_id}") f"policy={sparse_policy}, layer={self.layer_id}")
# Delegate all computation to policy (no flash_attn or merge calls here!) # Delegate computation to policy with pre-selected blocks
return sparse_policy.compute_chunked_decode( return sparse_policy.compute_chunked_decode(
q, q,
self.layer_id, self.layer_id,
@@ -273,4 +316,5 @@ class Attention(nn.Module):
offload_engine, offload_engine,
kvcache_manager, kvcache_manager,
seq, seq,
selected_blocks,
) )

View File

@@ -419,7 +419,9 @@ def flat_group_gemm_fuse_reshape(
assert key_states.shape[1] == num_heads assert key_states.shape[1] == num_heads
assert key_states.shape[3] == head_dim assert key_states.shape[3] == head_dim
output = torch.empty( # Use zeros instead of empty to handle causal early-exit in kernel
# (some blocks may not be written due to causal mask optimization)
output = torch.zeros(
(batch_size, num_heads, q_len // stride, kv_len // stride), (batch_size, num_heads, q_len // stride, kv_len // stride),
dtype=query_states.dtype, dtype=query_states.dtype,
device=query_states.device device=query_states.device
@@ -1067,6 +1069,7 @@ def xattn_estimate_chunked(
) )
# Softmax + block sum # Softmax + block sum
# segment_size should match the standard xattn_estimate for consistency
attn_sum = softmax_fuse_block_sum( attn_sum = softmax_fuse_block_sum(
attn_weights, attn_weights,
reshaped_block_size, reshaped_block_size,
@@ -1082,6 +1085,14 @@ def xattn_estimate_chunked(
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num] attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
else: else:
# PyTorch fallback implementation # PyTorch fallback implementation
# Match Triton kernel exactly for consistency
#
# Triton uses:
# 1. exp2 (base-2 exponential) for softmax
# 2. scale factor includes log2(e) = 1.4426950408889634
# 3. causal mask: q_pos >= k_pos (not q_pos + 1 > k_pos)
# 4. chunk_start for global Q position tracking
# Reshape K: interleave positions and concatenate head dims # Reshape K: interleave positions and concatenate head dims
reshaped_key = torch.cat( reshaped_key = torch.cat(
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1 [(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
@@ -1093,49 +1104,58 @@ def xattn_estimate_chunked(
dim=-1, dim=-1,
) )
# Use same scale as Triton: includes log2(e) for exp2 compatibility
# Triton: scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Convert to float32 for numerical stability (matching Triton)
reshaped_query_f32 = reshaped_query.to(torch.float32)
reshaped_key_f32 = reshaped_key.to(torch.float32)
# Compute attention weights: (B, H, q_len/stride, k_len/stride) # Compute attention weights: (B, H, q_len/stride, k_len/stride)
attn_weights = torch.matmul( attn_weights = torch.matmul(
reshaped_query, reshaped_key.transpose(2, 3) reshaped_query_f32, reshaped_key_f32.transpose(2, 3)
) / math.sqrt(head_dim) / stride / norm ) * scale
# Apply causal mask # Apply causal mask (matching Triton's logic exactly)
if causal: if causal:
reshaped_q_positions = reshaped_q_len # Triton uses: offs_q = chunk_start + block_id * block_size + arange(0, block_size)
causal_mask = torch.zeros( # chunk_start = q_start_block * reshaped_block_size
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len), chunk_start = q_start_block * reshaped_block_size
device=key_states.device,
dtype=attn_weights.dtype, # Create position indices in reshaped space
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
k_positions = torch.arange(reshaped_k_len, device=attn_weights.device)
# Triton causal mask: q_pos >= k_pos
causal_mask = q_positions[:, None] >= k_positions[None, :] # (reshaped_q_len, reshaped_k_len)
# Apply causal mask: set future positions to -1e6 (matching Triton)
attn_weights = attn_weights.masked_fill(
~causal_mask.unsqueeze(0).unsqueeze(0), -1e6
) )
# Mask out padding in K # Softmax using exp2 (matching Triton exactly)
if k_pad > 0: # Triton: X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf") # All computation in float32
attn_max = attn_weights.max(dim=-1, keepdim=True).values
attn_weights_shifted = attn_weights - attn_max
attn_exp2 = torch.exp2(attn_weights_shifted)
attn_sum_exp2 = attn_exp2.sum(dim=-1, keepdim=True)
attn_weights = attn_exp2 / attn_sum_exp2
# Mask out future positions # Mask for valid Q positions (matching Triton's sum_mask)
q_start_reshaped = q_start_pos // stride # Triton: sum_mask = offs_q[:, None] < real_q_len
for q_idx in range(reshaped_q_positions): # real_q_len = chunk_start + valid_q_reshaped
q_pos_reshaped = q_start_reshaped + q_idx chunk_start = q_start_block * reshaped_block_size
if q_pos_reshaped + 1 < reshaped_k_len: real_q_len = chunk_start + valid_q_reshaped
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf") q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
valid_q_mask = q_positions < real_q_len # (reshaped_q_len,)
# Handle padding in Q # Zero out invalid Q positions
if q_pad > 0: attn_weights = attn_weights * valid_q_mask.view(1, 1, -1, 1).float()
q_pad_reshaped = q_pad // stride
if q_pad_reshaped > 0:
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
attn_weights = attn_weights + causal_mask # Aggregate to block level (keep in float32)
# Apply softmax
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Zero out padded Q positions
if q_pad > 0:
q_pad_reshaped = q_pad // stride
if q_pad_reshaped > 0:
attn_weights[:, :, -q_pad_reshaped:, :] = 0
# Aggregate to block level
attn_sum = attn_weights.view( attn_sum = attn_weights.view(
batch_size, batch_size,
num_heads, num_heads,
@@ -1145,6 +1165,9 @@ def xattn_estimate_chunked(
reshaped_block_size, reshaped_block_size,
).sum(dim=-1).sum(dim=-2) ).sum(dim=-1).sum(dim=-2)
# Convert back to input dtype for consistency
attn_sum = attn_sum.to(query_states.dtype)
# Find blocks that exceed threshold # Find blocks that exceed threshold
simple_mask = find_blocks_chunked( simple_mask = find_blocks_chunked(
attn_sum, attn_sum,

View File

@@ -1,55 +0,0 @@
# Progress: CUDA Graph for Offload Mode
## Session: 2026-01-22
### 调研阶段 ✅ 完成
**完成的调研**:
1. ✅ 分析 `model_runner.py` 中的 CUDA Graph 实现
- `capture_cudagraph()`: 为不同 batch size 捕获完整 model forward
- `run_model()`: 通过 `is_chunked_prefill` 决定 eager/graph
2. ✅ 分析 offload decode 流程
- `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`
- 导致永远使用 eager mode
3. ✅ 分析 ring buffer pipeline
- `_decode_ring_buffer_pipeline()` 包含 H2D 传输 + attention 计算
- H2D 不能 graphattention 可以 graph
4. ✅ 验证 graph 复用策略
- 创建 `test_chunk_attention_graph_reuse.py`
- 确认 2 个 graph 可复用于所有层
### 计划编写 ✅ 完成
- ✅ 创建 `task_plan.md`
- ✅ 创建 `findings.md`
- ✅ 创建 `progress.md`
### 下一步: 实现
**Phase 1**: 添加 graph 捕获到 OffloadEngine
- [ ]`offload_engine.py` 添加 `capture_attention_graphs()`
- [ ] 添加 `attention_graph_causal``attention_graph_non_causal` 属性
**Phase 2**: 修改 ring buffer pipeline
- [ ]`_decode_ring_buffer_pipeline()` 使用 graph replay
- [ ] 保持 H2D 和 merge 为 eager
**Phase 3**: 测试
- [ ] 运行 needle test 验证正确性
- [ ] 对比性能
---
## 文件清单
| 文件 | 状态 | 说明 |
|------|------|------|
| `tests/test_chunk_attention_graph.py` | ✅ 已提交 | 预分配 chunk pair graphs 测试 |
| `tests/test_chunk_attention_graph_reuse.py` | 待提交 | Graph 复用验证 |
| `task_plan.md` | ✅ 创建 | 实现计划 |
| `findings.md` | ✅ 创建 | 调研发现 |
| `progress.md` | ✅ 创建 | 进度日志 |

View File

@@ -1,357 +0,0 @@
# Task Plan: CUDA Graph 优化 Offload Mode Decode
## 目标
为 nanovllm 的 CPU offload 模式添加 CUDA Graph 支持,加速 decode 阶段的计算。
## 问题分析
### Transformer 层的完整结构
```
Qwen3DecoderLayer.forward:
├── input_layernorm (RMSNorm) # ✅ 纯 GPU
├── self_attn:
│ ├── qkv_proj (Linear) # ✅ 纯 GPU
│ ├── q_norm, k_norm (RMSNorm) # ✅ 纯 GPU
│ ├── rotary_emb # ✅ 纯 GPU
│ ├── attn._chunked_decode_attention: # ⚠️ 包含 CPU→GPU
│ │ ├── H2D transfer # ❌ 不能 graph
│ │ ├── flash_attn_with_lse # ✅ 可以 graph
│ │ └── merge # ✅ 纯 GPU
│ └── o_proj (Linear) # ✅ 纯 GPU
├── post_attention_layernorm # ✅ 纯 GPU
└── mlp (FFN: gate, up, down) # ✅ 纯 GPU
```
**核心问题**H2D 传输被嵌在 attention 中间,打断了整层的 graph 捕获。
### 可能的方案
| 方案 | 描述 | 优点 | 缺点 |
|------|------|------|------|
| A. 分段 Graph | 将层拆分为 pre/post attention 两段 | 覆盖面广 | 改动大,需拆分层执行 |
| B. 只 Graph Attention | 只优化 flash_attn_with_lse | 改动小 | 优化效果有限 |
| C. 重构执行流程 | 完全重写 model forward | 最优效果 | 工作量巨大 |
### 推荐:方案 A分段 Graph
将每层拆分为两个 graph
1. **pre_attention_graph**: `norm → qkv_proj → q/k_norm → rotary`
2. **post_attention_graph**: `o_proj → norm → FFN`
中间的 `_chunked_decode_attention` 保持 eager包含 H2D但内部的 `flash_attn_with_lse` 使用 graph。
---
## 当前状态分析
### 现有 CUDA Graph 实现
**文件**: `nanovllm/engine/model_runner.py`
| 方法 | 行号 | 功能 |
|------|------|------|
| `capture_cudagraph()` | 682-717 | 为不同 batch size 捕获完整 model forward |
| `run_model()` | 415-436 | 决定使用 eager 还是 graph replay |
**关键逻辑** (`run_model`):
```python
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
```
**问题**: `run_chunked_offload_decode` 设置 `is_chunked_prefill=True`,导致**永远使用 eager mode**。
### Offload Decode 流程
**文件**: `nanovllm/kvcache/sparse/full_policy.py`
`_decode_ring_buffer_pipeline()` (L304-379):
```
for block in cpu_blocks:
1. wait_slot_layer(slot) # 等待 H2D 完成
2. k, v = get_kv_for_slot(slot) # 获取 KV
3. o, lse = flash_attn_with_lse() # ⭐ 纯 GPU 计算
4. record_slot_compute_done(slot) # 标记计算完成
5. load_next_block() # 启动下一个 H2D
6. merge_attention_outputs() # ⭐ 纯 GPU 计算
```
**可 Graph 化的部分**:
- `flash_attn_with_lse()` - 纯 GPU 计算
- 不可 Graph 化: H2D 传输、动态 merge
## 验证结果
**测试文件**: `tests/test_chunk_attention_graph_reuse.py`
| 测试 | 结果 |
|------|------|
| 2 个 Graph 复用于所有层和所有 chunk | ✅ PASSED |
| copy_() 更新 static tensors | ✅ 有效 |
| Eager merge | ✅ 用户已接受 |
**结论**: 只需 2 个 graphcausal + non-causal通过 copy_() 复用。
---
## 修改计划(方案 A分段 Graph
### 架构设计
```
每层执行流程Offload Decode:
┌─────────────────────────────────────────────────────────────┐
│ PRE-ATTENTION GRAPH (可复用于所有层) │
│ input_layernorm → qkv_proj → q/k_norm → rotary → split Q │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ CHUNKED ATTENTION (Eager + 部分 Graph) │
│ for block in cpu_blocks: │
│ H2D transfer (eager) │
│ flash_attn_with_lse (GRAPH - 2个可复用) │
│ merge (eager) │
│ decode_buffer attention (eager) │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ POST-ATTENTION GRAPH (可复用于所有层) │
│ o_proj → post_layernorm → gate_proj → up_proj → SiLU │
│ → down_proj → residual │
└─────────────────────────────────────────────────────────────┘
```
**总共需要的 Graph 数量**:
- 1 个 pre_attention_graph所有层复用
- 2 个 attention_graphcausal + non-causal所有层复用
- 1 个 post_attention_graph所有层复用
- **总计: 4 个 graph**
---
### Phase 1: 拆分 DecoderLayer 执行
**目标**: 将 `Qwen3DecoderLayer.forward` 拆分为可独立调用的三段
**修改文件**: `nanovllm/models/qwen3.py`
**新增方法**:
```python
class Qwen3DecoderLayer:
def forward_pre_attention(self, positions, hidden_states, residual):
"""Pre-attention: norm → qkv → rotary → 返回 q, k, v"""
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
qkv = self.self_attn.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q = self.self_attn.q_norm(q)
k = self.self_attn.k_norm(k)
q, k = self.self_attn.rotary_emb(positions, q, k)
return q, k, v, hidden_states, residual
def forward_post_attention(self, attn_output, hidden_states, residual):
"""Post-attention: o_proj → norm → FFN"""
output = self.self_attn.o_proj(attn_output.flatten(1, -1))
hidden_states, residual = self.post_attention_layernorm(output, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
```
**状态**: `pending`
---
### Phase 2: 捕获 Pre/Post Attention Graph
**目标**: 捕获 pre_attention 和 post_attention 的 graph
**修改文件**: `nanovllm/engine/model_runner.py`
**新增方法**: `capture_offload_layer_graphs()`
```python
def capture_offload_layer_graphs(self):
"""捕获 offload mode 的 layer graphs"""
# 获取任意一层作为模板(所有层结构相同)
layer = self.model.model.layers[0]
# Static tensors
static_hidden = torch.zeros(1, self.hidden_size, ...)
static_residual = torch.zeros(1, self.hidden_size, ...)
static_positions = torch.zeros(1, ...)
# Pre-attention graph
self.pre_attn_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.pre_attn_graph):
static_q, static_k, static_v, _, _ = layer.forward_pre_attention(
static_positions, static_hidden, static_residual
)
# Post-attention graph
self.post_attn_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.post_attn_graph):
_, _ = layer.forward_post_attention(
static_attn_output, static_hidden, static_residual
)
```
**状态**: `pending`
---
### Phase 3: 捕获 Attention Graph
**目标**: 捕获 2 个 attention graphcausal + non-causal
**修改文件**: `nanovllm/kvcache/offload_engine.py`
```python
class OffloadEngine:
def capture_attention_graphs(self):
"""捕获 attention graphs复用于所有层"""
self.attn_graph_causal = self._capture_attn_graph(causal=True)
self.attn_graph_non_causal = self._capture_attn_graph(causal=False)
def _capture_attn_graph(self, causal: bool):
static_q = torch.zeros(1, 1, num_heads, head_dim, ...)
static_k = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
static_v = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output, lse = flash_attn_with_lse(static_q, static_k, static_v,
self.scale, causal)
return AttentionGraph(graph, static_q, static_k, static_v, output, lse)
```
**状态**: `pending`
---
### Phase 4: 修改 Offload Decode 执行流程
**目标**: 使用 graph replay 执行 offload decode
**修改文件**: `nanovllm/engine/model_runner.py`
**修改方法**: `run_chunked_offload_decode()`
```python
def run_chunked_offload_decode_with_graph(self, seqs):
"""使用 graph 加速的 offload decode"""
seq = seqs[0]
# 准备输入
input_ids = torch.tensor([seq.last_token], ...)
positions = torch.tensor([len(seq) - 1], ...)
# Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
for layer_id, layer in enumerate(self.model.model.layers):
# Phase 1: Pre-attention (GRAPH)
self.pre_attn_vars["hidden"].copy_(hidden_states)
self.pre_attn_vars["residual"].copy_(residual) if residual else None
self.pre_attn_vars["positions"].copy_(positions)
self.pre_attn_graph.replay()
q = self.pre_attn_vars["q"].clone()
k = self.pre_attn_vars["k"].clone()
v = self.pre_attn_vars["v"].clone()
# Phase 2: Chunked Attention (Eager + Graph)
attn_output = self._chunked_attention_with_graph(q, k, v, layer_id, ...)
# Phase 3: Post-attention (GRAPH)
self.post_attn_vars["attn_output"].copy_(attn_output)
self.post_attn_graph.replay()
hidden_states = self.post_attn_vars["hidden"].clone()
residual = self.post_attn_vars["residual"].clone()
# LM head
logits = self.model.compute_logits(hidden_states)
return logits
```
**状态**: `pending`
---
### Phase 5: 修改 Ring Buffer Pipeline
**目标**: 在 attention 内部使用 graph
**修改文件**: `nanovllm/kvcache/sparse/full_policy.py`
**修改**: `_decode_ring_buffer_pipeline()` 中的 `flash_attn_with_lse` 调用
```python
# 当前eager
prev_o, prev_lse = flash_attn_with_lse(q, k, v, scale, causal=False)
# 修改为graph replay
graph = offload_engine.attn_graph_non_causal
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
prev_o = graph.static_output.clone()
prev_lse = graph.static_lse.clone()
```
**状态**: `pending`
---
### Phase 6: 添加配置开关
**修改文件**: `nanovllm/config.py`
```python
enable_offload_graph: bool = True # 默认启用
```
**状态**: `pending`
---
## 文件修改清单
| 文件 | 修改类型 | 说明 |
|------|----------|------|
| `nanovllm/engine/model_runner.py` | 新增方法 | `capture_offload_attention_graph()` |
| `nanovllm/kvcache/offload_engine.py` | 新增属性+方法 | Graph 存储和访问 |
| `nanovllm/kvcache/sparse/full_policy.py` | 修改方法 | 使用 graph replay |
| `nanovllm/config.py` | 新增配置 | `enable_offload_graph` |
---
## 风险和注意事项
1. **Graph 捕获时机**: 需要在 KV cache 分配后、第一次 decode 前捕获
2. **Chunk size 匹配**: Graph 的 chunk_size 必须和 block_size 一致
3. **多 GPU**: Graph 需要在每个 GPU 上分别捕获
4. **内存**: 2 个 graph 的额外内存开销很小
---
## 测试计划
1. **单元测试**: 验证 graph replay 结果正确
2. **集成测试**: 运行 `test_needle.py --enable-offload --input-len 32768`
3. **性能测试**: 对比 eager vs graph 的 decode 延迟
---
## 预期收益
- Decode 阶段 attention 计算加速(减少 kernel launch overhead
- 与现有 ring buffer pipeline 兼容
- 内存开销极小(只有 2 个额外 graph

334
tests/test_xattn_bsa.py Normal file
View File

@@ -0,0 +1,334 @@
"""
Test XAttention + BSA with RULER benchmark data.
Tests XAttention sparse attention correctness using RULER NIAH task.
Attention methods:
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
- Decode: FlashAttention (always, since q_len=1)
Usage (in compass conda env with BSA available):
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
# Test with XAttention + BSA for prefill (default)
python tests/test_xattn_bsa.py --prefill-method xattn
# Test with FlashAttention for prefill (baseline)
python tests/test_xattn_bsa.py --prefill-method flash
# Test specific sample(s)
python tests/test_xattn_bsa.py --sample-id 0
python tests/test_xattn_bsa.py --sample-ids 0,1,2
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
and new `past_key_values` API).
"""
import argparse
import json
import sys
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from nanovllm.ops.xattn import xattn_estimate
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# ============================================================
# XAttention + BSA Functions
# ============================================================
def expand_kv_for_gqa(key_states, value_states, num_heads):
"""Expand KV for Grouped Query Attention."""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
"""Standard FlashAttention."""
from flash_attn import flash_attn_func
q = query_states.transpose(1, 2)
k = key_states.transpose(1, 2)
v = value_states.transpose(1, 2)
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
"""XAttention + BSA sparse attention."""
from block_sparse_attn import block_sparse_attn_func
batch_size, num_heads, q_len, head_dim = query_states.shape
k_len = key_states.shape[2]
_, mask = xattn_estimate(
query_states, key_states,
chunk_size=16384, block_size=128, threshold=threshold,
use_triton=True, causal=True,
)
q_block_num = (q_len + 127) // 128
k_block_num = (k_len + 127) // 128
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
__import__('pdb').set_trace()
output = block_sparse_attn_func(
q, k, v,
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
torch.ones(num_heads, dtype=torch.int32, device=q.device),
None,
mask[:, :, :q_block_num, :k_block_num].contiguous(),
q_len, k_len,
p_dropout=0.0, deterministic=True, is_causal=True,
)
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
DEBUG = False # Set to True to enable debugging
def create_patched_forward(prefill_method="xattn", threshold=0.9):
"""Create patched forward with configurable prefill method.
Args:
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
Note:
- Prefill (q_len > 1): Uses specified prefill_method
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
"""
call_count = [0] # Mutable to track calls across layers
def patched_forward(
self,
hidden_states,
position_embeddings=None,
attention_mask=None,
past_key_value=None, # Old API (transformers < 4.57)
past_key_values=None, # New API (transformers >= 4.57)
cache_position=None,
**kwargs
):
# Handle both old and new transformers API
kv_cache = past_key_values if past_key_values is not None else past_key_value
bsz, q_len, _ = hidden_states.size()
num_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
# Compute Q, K, V projections
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
# Apply rotary position embedding
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Handle KV cache
if kv_cache is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = kv_cache.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# Expand KV for GQA
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
# Debug output
if DEBUG and self.layer_idx == 0:
call_count[0] += 1
if call_count[0] <= 5:
phase = "prefill" if q_len > 1 else "decode"
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
print(f" kv_cache is None: {kv_cache is None}")
# Choose attention method:
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
# - Decode (q_len = 1): Always use FlashAttention
is_prefill = q_len > 1
if is_prefill and prefill_method == "xattn":
# Prefill with XAttention + BSA (sparse)
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
else:
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
# Note: For decode (q_len=1), causal=False since single query attends to all KV
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
return attn_output, None
return patched_forward
# ============================================================
# Data & Evaluation
# ============================================================
def load_samples(filepath, indices=None):
"""Load samples from JSONL file."""
samples = []
with open(filepath) as f:
for i, line in enumerate(f):
if indices is None or i in indices:
sample = json.loads(line)
sample["_idx"] = i
samples.append(sample)
return samples
def string_match_all(output_text, expected_list):
"""RULER metric: fraction of expected values found in output."""
output_lower = output_text.lower().replace('\n', ' ')
if not expected_list:
return 1.0
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
# ============================================================
# Test
# ============================================================
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
"""Test attention methods using RULER data.
Args:
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
"""
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
print("=" * 60)
print("RULER NIAH Attention Test")
print("=" * 60)
print(f"Data: {data_file}")
print(f"Samples: {sample_ids}")
print(f"Prefill method: {prefill_desc}")
print(f"Decode method: FlashAttention (always)")
if prefill_method == "xattn":
print(f"XAttention threshold: {threshold}")
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
if not samples:
print("No samples found!")
return False
print(f"Loaded {len(samples)} samples")
# Load model
print(f"\nLoading model: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="cuda",
attn_implementation="eager", # Will be patched
)
model.eval()
# Patch all layers
print(f"Patching attention layers...")
print(f" - Prefill: {prefill_desc}")
print(f" - Decode: FlashAttention")
for idx, layer in enumerate(model.model.layers):
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
layer.self_attn, type(layer.self_attn)
)
total_score = 0.0
results = []
for sample in samples:
idx = sample["_idx"]
prompt = sample["input"]
expected = sample["outputs"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
num_tokens = inputs["input_ids"].shape[1]
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
print(f"Expected: {expected}")
with torch.no_grad():
output = model.generate(
inputs["input_ids"],
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
score = string_match_all(output_text, expected)
total_score += score
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
print(f"Output: '{output_text[:100]}...'")
print(f"Result: {status} (score={score:.2f})")
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
avg_score = total_score / len(samples)
passed = sum(1 for r in results if r["passed"])
print(f"\n{'='*60}")
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
print(f"{'='*60}")
return avg_score >= 0.5
def main():
parser = argparse.ArgumentParser(
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
)
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
parser.add_argument("--max-new-tokens", type=int, default=50)
# Keep old option for backwards compatibility
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
args = parser.parse_args()
model_path = args.model.replace("~", "/home/zijie")
# Handle deprecated --no-xattn option
prefill_method = args.prefill_method
if args.no_xattn:
prefill_method = "flash"
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
if args.sample_id is not None:
sample_ids = [args.sample_id]
elif args.sample_ids:
sample_ids = [int(x) for x in args.sample_ids.split(",")]
else:
sample_ids = [0]
# Check BSA availability if using xattn
if prefill_method == "xattn":
try:
from block_sparse_attn import block_sparse_attn_func
print("✓ BSA (Block Sparse Attention) available")
except ImportError:
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
sys.exit(1)
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
print("\ntest_xattn_bsa: PASSED")
else:
print("\ntest_xattn_bsa: FAILED")
sys.exit(1)
if __name__ == "__main__":
main()

259
tests/test_xattn_chunked.py Normal file
View File

@@ -0,0 +1,259 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
Uses real QKV data captured from model inference.
"""
import sys
import os
import torch
import warnings
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096
# Default QKV data directory (relative to project root)
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
# ============================================================
# Utility Functions
# ============================================================
def load_qkv(path):
"""Load saved QKV data."""
data = torch.load(path, map_location="cpu", weights_only=False)
print(f"Loaded: {path}")
print(f" Query shape: {data['query'].shape}")
print(f" Key shape: {data['key'].shape}")
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
return data
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return xattn_estimate_chunked(
query, key,
q_start_pos=q_start_pos,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
k_end = q_start_pos + q_chunk_end
k_chunk = key[:, :, :k_end, :]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_start_pos + q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_qkv(qkv_path):
"""Test a single QKV file."""
data = load_qkv(qkv_path)
query = data["query"].cuda().to(torch.bfloat16)
key = data["key"].cuda().to(torch.bfloat16)
seq_len = query.shape[2]
print(f"\nTesting with seq_len={seq_len}")
print("=" * 60)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
)
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
q_start_pos=0,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
help="Directory containing QKV files")
args = parser.parse_args()
# QKV files to test
qkv_files = [
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
]
available_files = [p for p in qkv_files if os.path.exists(p)]
if not available_files:
print(f"No QKV file found in {args.qkv_dir}.")
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
sys.exit(1)
print(f"Found {len(available_files)} QKV files to test")
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
print(f"Using Triton kernels")
all_passed = True
results = []
for qkv_path in available_files:
passed = test_single_qkv(qkv_path)
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
results.append((seq_len, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, passed in results:
status = "PASSED" if passed else "FAILED"
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
print("=" * 60)
if all_passed:
print("test_xattn_chunked: PASSED")
sys.exit(0)
else:
print("test_xattn_chunked: FAILED")
sys.exit(1)

129
tests/test_xattn_kernels.py Normal file
View File

@@ -0,0 +1,129 @@
"""
Test: XAttention Triton kernels
演示 XAttention 的两个核心 Triton kernel:
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
数据流:
Q [batch, heads, q_len, head_dim]
K [batch, heads, kv_len, head_dim]
↓ flat_group_gemm_fuse_reshape
attn_scores [batch, heads, q_len/stride, kv_len/stride]
↓ softmax_fuse_block_sum
block_sums [batch, heads, q_blocks, k_blocks]
"""
import torch
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# 参数配置
# ============================================================
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
q_len = 512
kv_len = 2048
head_dim = 128
stride = 4
block_size = 128 # softmax block size (in reshaped space)
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
# ============================================================
# 构造输入: 偶数位置=1, 奇数位置=2
# ============================================================
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(q_len):
if i % 2 == 0:
Q[0, 0, i, :] = 1
else:
Q[0, 0, i, :] = 2
for i in range(kv_len):
if i % 2 == 0:
K[0, 0, i, :] = 1
else:
K[0, 0, i, :] = 2
# ============================================================
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
# ============================================================
q_reshaped_len = q_len // stride # 128
kv_reshaped_len = kv_len // stride # 512
# 将 K 沿着长度维度分成多个 chunk
k_chunk_size = 512 # 每个 chunk 512 tokens
num_k_chunks = kv_len // k_chunk_size # 4 chunks
attn_scores_list = []
for k_chunk_idx in range(num_k_chunks):
k_start = k_chunk_idx * k_chunk_size
k_end = k_start + k_chunk_size
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=False
)
attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks 的结果
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
attn_scores = torch.cat(attn_scores_list, dim=-1)
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
# 验证: 反对角线求和
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
# 反对角线有 stride/2 对,再乘以 head_dim
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
actual_gemm = attn_scores[0, 0, 0, 0].item()
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
# ============================================================
# Step 2: softmax_fuse_block_sum
# ============================================================
scale = 1.4426950408889634 # log2(e) for exp2
block_sums = softmax_fuse_block_sum(
attn_scores,
block_size,
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False
)
# 验证 shape: [batch, heads, q_blocks, k_blocks]
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
# 验证: 每个 block 的 softmax 结果求和
# 所有 attn_scores 相同 → softmax 均匀分布
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
# 每个 Q block 有 block_size 行
# block_sum = block_size * (block_size / kv_reshaped_len)
expected_sum = block_size * block_size / kv_reshaped_len
actual_sum = block_sums[0, 0, 0, 0].item()
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
print("test_xattn_kernels: PASSED")