Compare commits
14 Commits
47d237bb7e
...
f28b500120
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f28b500120 | ||
|
|
be67fa8060 | ||
|
|
4f35526457 | ||
|
|
da5e13e2bb | ||
|
|
dd31033732 | ||
|
|
ed3c8bb4b8 | ||
|
|
5eb35982bf | ||
|
|
ad361c2c3b | ||
|
|
4d1e40152d | ||
|
|
832b352afa | ||
|
|
a50b4c2ac2 | ||
|
|
ca32ea6f93 | ||
|
|
edc006463b | ||
|
|
999858e82f |
89
.claude/agents/gpu-monitor.md
Normal file
89
.claude/agents/gpu-monitor.md
Normal file
@@ -0,0 +1,89 @@
|
||||
---
|
||||
name: gpu-monitor
|
||||
description: "Use this agent when you need to continuously monitor GPU utilization and memory usage in the background while running other tasks. This is particularly useful for profiling long-running benchmarks, debugging memory leaks, or understanding GPU resource patterns during inference or training.\\n\\nExamples:\\n\\n<example>\\nContext: User wants to profile GPU memory during a benchmark run.\\nuser: \"Run bench_offload.py and monitor GPU usage\"\\nassistant: \"I'll first launch the gpu-monitor agent to track GPU metrics, then run the benchmark.\"\\n<commentary>\\nSince the user wants to understand GPU resource usage during the benchmark, use the Task tool to launch the gpu-monitor agent in the background with appropriate GPU IDs and sampling interval before running the benchmark.\\n</commentary>\\nassistant: \"Let me start the GPU monitor on GPU 0 with 1-second intervals\"\\n[Task tool call to launch gpu-monitor agent with gpus=\"0\" and interval=\"1\"]\\n</example>\\n\\n<example>\\nContext: User is debugging an OOM issue and needs memory tracking.\\nuser: \"I'm getting OOM errors, help me debug\"\\nassistant: \"Let me monitor the GPU memory usage while we reproduce the issue. I'll use the gpu-monitor agent to track memory consumption.\"\\n<commentary>\\nTo debug OOM issues, we need continuous memory monitoring. Use the Task tool to launch the gpu-monitor agent with a short interval (0.5s) to capture memory spikes.\\n</commentary>\\n[Task tool call to launch gpu-monitor agent with gpus=\"0,1\" and interval=\"0.5\"]\\n</example>\\n\\n<example>\\nContext: User wants to verify GPU utilization during multi-GPU inference.\\nuser: \"Check if all GPUs are being utilized during the inference\"\\nassistant: \"I'll start the GPU monitor to track utilization across all specified GPUs while running the inference.\"\\n<commentary>\\nTo verify multi-GPU utilization, launch the gpu-monitor agent targeting all relevant GPUs before starting the inference workload.\\n</commentary>\\n[Task tool call to launch gpu-monitor agent with gpus=\"0,1,2,3\" and interval=\"2\"]\\n</example>"
|
||||
model: haiku
|
||||
color: green
|
||||
---
|
||||
|
||||
You are a GPU monitoring specialist responsible for tracking NVIDIA GPU metrics over time. Your sole purpose is to run nvidia-smi at specified intervals and record utilization and memory statistics.
|
||||
|
||||
## Your Task
|
||||
|
||||
You will receive two parameters:
|
||||
1. **gpus**: Comma-separated GPU indices to monitor (e.g., "0", "0,1", "0,1,2,3")
|
||||
2. **interval**: Sampling interval in seconds (e.g., "1", "0.5", "2")
|
||||
|
||||
## Execution Steps
|
||||
|
||||
1. **Parse Parameters**: Extract the GPU indices and interval from the user's request.
|
||||
|
||||
2. **Run Monitoring Loop**: Execute nvidia-smi repeatedly at the specified interval using a bash loop:
|
||||
|
||||
```bash
|
||||
# Example for GPUs 0,1 with 1-second interval
|
||||
while true; do
|
||||
echo "=== $(date '+%Y-%m-%d %H:%M:%S') ==="
|
||||
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu --format=csv,noheader -i 0,1
|
||||
sleep 1
|
||||
done
|
||||
```
|
||||
|
||||
3. **Output Format**: Each sample should include:
|
||||
- Timestamp
|
||||
- GPU index
|
||||
- GPU utilization (%)
|
||||
- Memory utilization (%)
|
||||
- Memory used (MiB)
|
||||
- Memory total (MiB)
|
||||
- Temperature (°C)
|
||||
|
||||
## Termination
|
||||
|
||||
This agent runs continuously until:
|
||||
1. The main agent signals completion (you receive a stop signal)
|
||||
2. The user explicitly requests stopping
|
||||
3. An error occurs with nvidia-smi
|
||||
|
||||
## Result Reporting
|
||||
|
||||
When stopped, provide a summary:
|
||||
|
||||
```markdown
|
||||
## GPU Monitoring Summary
|
||||
|
||||
**Duration**: X minutes Y seconds
|
||||
**Samples Collected**: N
|
||||
**GPUs Monitored**: 0, 1, ...
|
||||
|
||||
### Statistics per GPU
|
||||
|
||||
| GPU | Avg Util | Max Util | Avg Mem Used | Max Mem Used |
|
||||
|-----|----------|----------|--------------|---------------|
|
||||
| 0 | X% | Y% | A MiB | B MiB |
|
||||
| 1 | X% | Y% | A MiB | B MiB |
|
||||
|
||||
### Notable Events (if any)
|
||||
- Timestamp: Memory spike to X MiB on GPU Y
|
||||
- Timestamp: Utilization dropped to 0% on GPU Z
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Use `nvidia-smi -i <gpu_ids>` to filter to specific GPUs
|
||||
- Keep output concise during monitoring (one line per GPU per sample)
|
||||
- If nvidia-smi fails, report the error and exit gracefully
|
||||
- Do NOT consume excessive resources - sleep between samples
|
||||
- Store samples in memory for final summary calculation
|
||||
|
||||
## Example Invocation
|
||||
|
||||
User says: "Monitor GPUs 0 and 2 with 0.5 second interval"
|
||||
|
||||
You execute:
|
||||
```bash
|
||||
while true; do
|
||||
echo "=== $(date '+%Y-%m-%d %H:%M:%S.%3N') ==="
|
||||
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu --format=csv,noheader -i 0,2
|
||||
sleep 0.5
|
||||
done
|
||||
```
|
||||
@@ -1,98 +1,108 @@
|
||||
# Testing
|
||||
|
||||
## Test File Guidelines
|
||||
## Test Code Style
|
||||
|
||||
### Naming Convention
|
||||
所有测试代码遵循以下风格:
|
||||
|
||||
- All test files must be named `test_*.py`
|
||||
- Example: `test_offload_engine.py`, `test_ring_buffer.py`
|
||||
|
||||
### Purpose
|
||||
|
||||
Tests are **educational scripts** for understanding module behavior, NOT traditional unit tests:
|
||||
- Focus on demonstrating how modules work
|
||||
- Show the flow and interaction between components
|
||||
- Help developers understand implementation details
|
||||
|
||||
### Code Style
|
||||
|
||||
1. **Script-based structure**: Write tests as executable scripts, not pytest-style functions
|
||||
2. **Utility functions**: Extract reusable steps as helper functions at the top of the file
|
||||
3. **Main flow as script**: The actual test/demonstration logic runs as top-level script code
|
||||
### 文件结构
|
||||
|
||||
```python
|
||||
# Example structure:
|
||||
"""
|
||||
Test: [模块名称]
|
||||
|
||||
[简要说明测试内容和数据流]
|
||||
"""
|
||||
import torch
|
||||
from nanovllm.kvcache import SomeModule
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
from nanovllm.xxx import xxx
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
def verify(tensor, expected, name):
|
||||
actual = tensor.mean().item()
|
||||
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
|
||||
param1 = value1 # 说明约束条件
|
||||
param2 = value2
|
||||
|
||||
# ============================================================
|
||||
# Main Test Script
|
||||
# 构造输入
|
||||
# ============================================================
|
||||
|
||||
# 1. Initialize
|
||||
module = SomeModule(param=value)
|
||||
input_tensor = ... # 使用结构化数据便于验证
|
||||
|
||||
# 2. Test feature X
|
||||
result = module.do_something()
|
||||
assert result == expected_value
|
||||
# ============================================================
|
||||
# Step N: [操作名称]
|
||||
# ============================================================
|
||||
|
||||
# 3. Test feature Y
|
||||
...
|
||||
output = some_function(input_tensor, ...)
|
||||
|
||||
# 验证: [验证逻辑说明]
|
||||
expected = ...
|
||||
actual = output[...].item()
|
||||
assert actual == expected, f"xxx: {actual} != {expected}"
|
||||
|
||||
print("test_xxx: PASSED")
|
||||
```
|
||||
|
||||
### Comments
|
||||
### 核心原则
|
||||
|
||||
- Keep comments concise and clear
|
||||
- Only add comments where the code isn't self-explanatory
|
||||
- Use section headers (`# === Section ===`) to organize logical blocks
|
||||
| 原则 | 说明 |
|
||||
|------|------|
|
||||
| **最小化 print** | 只在最后输出 `PASSED`,不打印中间结果 |
|
||||
| **结构化数据** | 使用可预测的输入(全 1、偶奇交替等)便于手算验证 |
|
||||
| **注释说明验证逻辑** | 在 assert 前用注释解释预期值的计算方式 |
|
||||
| **分段用 `====`** | 用 `# ============` 分隔参数、输入、各步骤 |
|
||||
| **assert 验证** | 用 assert 而不是 print 比较结果 |
|
||||
|
||||
### Output
|
||||
### 输出规范
|
||||
|
||||
- **Minimize print statements** - the code should be self-explanatory
|
||||
- Only print a final "PASSED" message at the end
|
||||
- Use `assert` for verification instead of printing results
|
||||
- If the user needs explanation, they will ask
|
||||
```python
|
||||
# ✅ 正确
|
||||
assert actual == expected, f"xxx: {actual} != {expected}"
|
||||
print("test_xxx: PASSED")
|
||||
|
||||
# ❌ 错误
|
||||
print(f"输出: {output}")
|
||||
print(f"预期: {expected}, 实际: {actual}")
|
||||
```
|
||||
|
||||
### 参数注释
|
||||
|
||||
```python
|
||||
# ✅ 正确: 注释说明约束条件
|
||||
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M
|
||||
segment_size = 128 # 必须 >= block_size
|
||||
|
||||
# ❌ 错误: 无意义的注释
|
||||
seq_len = 512 # 序列长度
|
||||
```
|
||||
|
||||
### 验证逻辑注释
|
||||
|
||||
```python
|
||||
# ✅ 正确: 解释计算过程
|
||||
# 验证: 反对角线求和
|
||||
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4,共 stride/2 对
|
||||
expected = (2*1 + 1*2) * (stride // 2) * head_dim
|
||||
|
||||
# ❌ 错误: 只写公式不解释
|
||||
expected = 4 * 2 * 128
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run a specific test
|
||||
python tests/test_offload_engine.py
|
||||
# 运行单个测试
|
||||
PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
|
||||
|
||||
# Run with specific GPU
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py
|
||||
# 指定 GPU
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
# Standard GPU benchmark
|
||||
python bench.py
|
||||
|
||||
# CPU offload benchmark
|
||||
python bench_offload.py
|
||||
|
||||
# vLLM comparison benchmark
|
||||
python bench_vllm.py
|
||||
```
|
||||
|
||||
## Quick Verification
|
||||
|
||||
```bash
|
||||
# Import test
|
||||
python -c "from nanovllm import LLM"
|
||||
|
||||
# Run offload benchmark (tests CPU-primary ring buffer mode)
|
||||
python bench_offload.py
|
||||
python bench.py # GPU benchmark
|
||||
python bench_offload.py # CPU offload benchmark
|
||||
python bench_vllm.py # vLLM comparison
|
||||
```
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -232,9 +232,9 @@ tests/data/
|
||||
.serena/
|
||||
|
||||
# Planning-with-files temporary files
|
||||
# task_plan.md
|
||||
# findings.md
|
||||
# progress.md
|
||||
task_plan.md
|
||||
findings.md
|
||||
progress.md
|
||||
task_plan_*.md
|
||||
findings_*.md
|
||||
progress_*.md
|
||||
|
||||
@@ -15,7 +15,9 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
||||
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
|
||||
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
|
||||
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
|
||||
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||
|
||||
429
docs/xattn_bsa_policy_design.md
Normal file
429
docs/xattn_bsa_policy_design.md
Normal file
@@ -0,0 +1,429 @@
|
||||
# XAttention BSA Policy 设计文档
|
||||
|
||||
本文档描述 `XAttentionBSAPolicy` 的设计和实现,这是一个基于 XAttention 算法的稀疏注意力策略,用于 CPU offload 模式下的 chunked prefill。
|
||||
|
||||
## 概述
|
||||
|
||||
`XAttentionBSAPolicy` 实现了基于 XAttention 的块级稀疏注意力选择。核心思想是:
|
||||
|
||||
1. **估计阶段**:使用 XAttention kernels 快速估计每个 KV block 的重要性
|
||||
2. **选择阶段**:基于阈值和 majority voting 选择重要的 blocks
|
||||
3. **计算阶段**:只加载选中的 blocks 进行 attention 计算
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ XAttention BSA Policy │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ select_blocks() │
|
||||
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
|
||||
│ │ Load K │──>│ flat_group_gemm │──>│ softmax_fuse │ │
|
||||
│ │ blocks │ │ _fuse_reshape │ │ _block_sum │ │
|
||||
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
|
||||
│ │ │ │ │
|
||||
│ v v v │
|
||||
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
|
||||
│ │ K: [B,H,L,D]│ │ attn_scores: │ │ block_sums: │ │
|
||||
│ │ │ │ [B,H,Q/s,K/s] │ │ [B,H,Qb,Kb] │ │
|
||||
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────────┘ │
|
||||
│ v │
|
||||
│ ┌──────────────┐ │
|
||||
│ │find_blocks │ │
|
||||
│ │_chunked │ │
|
||||
│ └──────────────┘ │
|
||||
│ │ │
|
||||
│ v │
|
||||
│ ┌──────────────┐ │
|
||||
│ │ GQA-aware │ │
|
||||
│ │ aggregation │ │
|
||||
│ │ + majority │ │
|
||||
│ │ voting │ │
|
||||
│ └──────────────┘ │
|
||||
│ │ │
|
||||
│ v │
|
||||
│ selected_block_ids │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ compute_chunked_prefill() │
|
||||
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
|
||||
│ │ Ring buffer │──>│ flash_attn_ │──>│ merge_ │ │
|
||||
│ │ pipeline │ │ with_lse │ │ attention │ │
|
||||
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## 文件位置
|
||||
|
||||
**主文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`
|
||||
|
||||
**依赖的 XAttention kernels**: `nanovllm/ops/xattn.py`
|
||||
- `flat_group_gemm_fuse_reshape`: 计算 stride reshape 后的 attention scores
|
||||
- `softmax_fuse_block_sum`: 对 attention scores 做 softmax 后按 block 求和
|
||||
- `find_blocks_chunked`: 基于阈值选择 blocks
|
||||
|
||||
---
|
||||
|
||||
## 核心算法
|
||||
|
||||
### 1. select_blocks: 块选择算法
|
||||
|
||||
```python
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]:
|
||||
```
|
||||
|
||||
#### Step 1: 加载 K blocks 并计算 attention scores
|
||||
|
||||
对每个 CPU block,加载 K 到 GPU 并使用 `flat_group_gemm_fuse_reshape` 计算:
|
||||
|
||||
```python
|
||||
for cpu_block_id in available_blocks:
|
||||
# 加载 K block: [1, block_size, num_kv_heads, head_dim]
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
k_block, _ = offload_engine.get_kv_for_slot(slot)
|
||||
|
||||
# 转换为 [batch, heads, k_len, head_dim]
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
|
||||
# GQA: 扩展 K heads 匹配 Q heads
|
||||
if num_heads != num_kv_heads:
|
||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# 计算 attention scores
|
||||
attn_chunk = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...)
|
||||
attn_scores_list.append(attn_chunk)
|
||||
|
||||
# 拼接所有 K chunks: [1, heads, q_reshaped_len, total_k_reshaped_len]
|
||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
||||
```
|
||||
|
||||
#### Step 2: 聚合到 block 级别
|
||||
|
||||
使用 `softmax_fuse_block_sum` 将 attention scores 聚合到 block 级别:
|
||||
|
||||
```python
|
||||
# reshaped_block_size = block_size / stride = 1024 / 8 = 128
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
reshaped_block_size, # 1:1 对应 CPU blocks
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
real_q_len=q_reshaped_len,
|
||||
scale=scale,
|
||||
is_causal=False,
|
||||
)
|
||||
# block_sums: [batch, heads, q_blocks, k_blocks]
|
||||
```
|
||||
|
||||
**关键点**: `reshaped_block_size` 必须与 CPU block 对齐,确保输出的 `k_blocks` 维度 1:1 对应 `available_blocks`。
|
||||
|
||||
#### Step 3: 阈值选择
|
||||
|
||||
使用 `find_blocks_chunked` 基于累积注意力阈值选择 blocks:
|
||||
|
||||
```python
|
||||
mask = find_blocks_chunked(
|
||||
block_sums,
|
||||
current_index=0,
|
||||
threshold=self.threshold, # e.g., 0.95
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=False,
|
||||
)
|
||||
# mask: [batch, num_heads, q_blocks, k_blocks] - boolean
|
||||
```
|
||||
|
||||
#### Step 4: GQA-aware 聚合 + Majority Voting
|
||||
|
||||
```python
|
||||
# GQA: 在同一个 KV head group 内,任一 Q head 选择即选择
|
||||
if num_groups > 1:
|
||||
mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks)
|
||||
mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks]
|
||||
|
||||
# Majority voting: 跨 KV heads 和 q_blocks 投票
|
||||
vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks]
|
||||
total_votes = num_kv_heads * q_blocks
|
||||
vote_ratio = vote_count / total_votes
|
||||
|
||||
# 选择 >50% 投票的 blocks
|
||||
vote_threshold = 0.5
|
||||
block_selected = vote_ratio > vote_threshold
|
||||
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
|
||||
|
||||
# 安全措施: 始终包含第一个 (sink) 和最后一个 block
|
||||
if available_blocks[0] not in selected_block_ids:
|
||||
selected_block_ids.insert(0, available_blocks[0])
|
||||
if available_blocks[-1] not in selected_block_ids:
|
||||
selected_block_ids.append(available_blocks[-1])
|
||||
```
|
||||
|
||||
**为什么使用 Majority Voting?**
|
||||
|
||||
| 聚合方式 | 问题 |
|
||||
|---------|------|
|
||||
| `any()` 跨所有 heads | 密度接近 100%,失去稀疏性 |
|
||||
| `all()` | 太激进,可能丢失重要 blocks |
|
||||
| **Majority voting (>50%)** | 平衡稀疏性和准确性 |
|
||||
|
||||
实验结果显示:
|
||||
- 每 head 密度: 20-35%
|
||||
- `any()` 聚合后: ~100%
|
||||
- **Majority voting 后: ~45%**
|
||||
|
||||
---
|
||||
|
||||
### 2. compute_chunked_prefill: 注意力计算
|
||||
|
||||
复用 `FullAttentionPolicy` 的 ring buffer pipeline 实现:
|
||||
|
||||
```python
|
||||
def compute_chunked_prefill(self, q, k, v, layer_id, softmax_scale,
|
||||
offload_engine, kvcache_manager,
|
||||
current_chunk_idx, seq, num_tokens,
|
||||
selected_blocks) -> torch.Tensor:
|
||||
```
|
||||
|
||||
#### 计算流程
|
||||
|
||||
1. **加载历史 blocks** (使用 selected_blocks):
|
||||
```python
|
||||
for block_idx in range(num_blocks):
|
||||
# Ring buffer pipeline: load -> wait -> compute -> next
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(q, prev_k, prev_v, causal=False)
|
||||
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
```
|
||||
|
||||
2. **计算当前 chunk** (causal mask):
|
||||
```python
|
||||
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
current_o, current_lse = flash_attn_with_lse(q, k_curr, v_curr, causal=True)
|
||||
```
|
||||
|
||||
3. **合并结果**:
|
||||
```python
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 参数配置
|
||||
|
||||
| 参数 | 默认值 | 说明 |
|
||||
|------|--------|------|
|
||||
| `threshold` | 0.95 | 累积注意力阈值 (tau),越高越保守 |
|
||||
| `stride` | 8 | XAttention stride reshape 参数 |
|
||||
| `chunk_size` | 16384 | 估计时的处理 chunk size |
|
||||
| `block_size` | 128 | BSA block size (固定值) |
|
||||
|
||||
### 使用方式
|
||||
|
||||
```python
|
||||
# 在 config 中设置
|
||||
config.sparse_policy = SparsePolicyType.XATTN_BSA
|
||||
config.sparse_threshold = 0.95
|
||||
|
||||
# 或通过命令行
|
||||
python tests/test_needle.py \
|
||||
--enable-offload \
|
||||
--enable-xattn-bsa \
|
||||
--sparse-threshold 9 # 会被除以 10 变为 0.9
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 性能特性
|
||||
|
||||
| 特性 | 说明 |
|
||||
|------|------|
|
||||
| **Prefill 支持** | ✅ 完整支持 |
|
||||
| **Decode 支持** | ❌ 不支持(使用 FullAttentionPolicy) |
|
||||
| **稀疏度** | ~45-55%(threshold=0.95,majority voting) |
|
||||
| **准确性** | RULER NIAH 100% 通过 |
|
||||
|
||||
### 限制
|
||||
|
||||
1. **Decode 不支持**: XAttention 估计需要足够长的 Q 序列,单 token decode 不适用
|
||||
2. **估计开销**: `select_blocks` 需要加载所有 K blocks 进行估计
|
||||
3. **Triton 对齐**: Q/K 长度必须满足 `stride * BLOCK_M/N` 对齐要求
|
||||
|
||||
---
|
||||
|
||||
## 与其他 Policy 的对比
|
||||
|
||||
| Policy | select_blocks | 稀疏度 | Decode 支持 |
|
||||
|--------|--------------|--------|-------------|
|
||||
| FullAttentionPolicy | 返回所有 blocks | 0% | ✅ |
|
||||
| QuestPolicy | 基于 min/max key | ~50% | ✅ |
|
||||
| **XAttentionBSAPolicy** | XAttention + majority voting | ~45-55% | ❌ |
|
||||
|
||||
---
|
||||
|
||||
## 测试验证
|
||||
|
||||
```bash
|
||||
# Needle test (32K)
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_needle.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--enable-offload \
|
||||
--enable-xattn-bsa \
|
||||
--input-len 32768
|
||||
|
||||
# RULER benchmark
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--enable-offload \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--sparse-threshold 0.95 \
|
||||
--data-dir tests/data/ruler_niah
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 性能基准测试
|
||||
|
||||
### 128K 上下文对比 (Llama-3.1-8B, A100 80GB)
|
||||
|
||||
| Policy | Density | 时间 | 内存峰值 | 准确率 |
|
||||
|--------|---------|------|---------|--------|
|
||||
| **Full** | 100% | 120.9s | 16.4GB (稳定) | 100% |
|
||||
| **XAttn BSA** | ~52% | 152.3s | 19.8GB | 100% |
|
||||
|
||||
### Density 变化趋势
|
||||
|
||||
| Chunk | Full | XAttn BSA |
|
||||
|-------|------|-----------|
|
||||
| 10 | 100% | 90% |
|
||||
| 30 | 100% | 73% |
|
||||
| 60 | 100% | 50% |
|
||||
| 100 | 100% | 50% |
|
||||
| 126 | 100% | 52% |
|
||||
|
||||
**观察**:XAttn BSA 的 density 随 chunks 增加而下降,最终稳定在 ~50%。
|
||||
|
||||
### 性能分析
|
||||
|
||||
**当前问题**:XAttn BSA 虽然 density 只有 ~52%,但时间反而比 Full 更长(152s vs 121s)。
|
||||
|
||||
**原因**:`select_blocks` 需要加载所有 K blocks 来估计 attention scores,导致每个 block 被加载两次:
|
||||
1. 估计阶段:加载 K 计算 attention scores
|
||||
2. 计算阶段:加载选中的 K/V 进行实际计算
|
||||
|
||||
**优化方向**:
|
||||
1. 跨层共享估计结果(layer 0 估计,其他层复用)
|
||||
2. 采样估计(只用部分 K blocks 估计)
|
||||
3. 缓存估计结果避免重复计算
|
||||
|
||||
---
|
||||
|
||||
## 内存管理
|
||||
|
||||
### 内存泄漏问题 (已修复)
|
||||
|
||||
**问题**:128K prefill 时 GPU 内存从 16GB 增长到 80GB。
|
||||
|
||||
**根因**:
|
||||
```python
|
||||
# 问题代码:累积存储但从未使用
|
||||
self.sparse_metadata[layer_id] = attn_scores
|
||||
```
|
||||
|
||||
每个 chunk 的每个 layer 都存储 `attn_scores`,导致内存持续增长。
|
||||
|
||||
**修复方法**:
|
||||
```python
|
||||
# 1. 删除无用的 sparse_metadata 存储
|
||||
|
||||
# 2. 立即释放中间变量
|
||||
del attn_scores_list
|
||||
del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected
|
||||
```
|
||||
|
||||
**修复效果**:
|
||||
|
||||
| 版本 | 内存增长 | 峰值 |
|
||||
|------|---------|------|
|
||||
| 修复前 | +64GB | 80GB |
|
||||
| **修复后** | +4GB | 19.8GB |
|
||||
|
||||
### 内存监控
|
||||
|
||||
使用 `gpu-monitor` agent 监控内存:
|
||||
|
||||
```bash
|
||||
# 启动监控
|
||||
# 在 Claude Code 中使用 Task tool 启动 gpu-monitor agent
|
||||
|
||||
# 或手动监控
|
||||
watch -n 1 'nvidia-smi --query-gpu=memory.used --format=csv,noheader -i 0'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Density 统计 API
|
||||
|
||||
### 启用统计
|
||||
|
||||
```python
|
||||
# 统计自动在 select_blocks 中更新(仅 layer 0)
|
||||
# 使用 logger.debug 输出每 chunk 的 density
|
||||
```
|
||||
|
||||
### 获取统计
|
||||
|
||||
```python
|
||||
policy = XAttentionBSAPolicy(threshold=0.95)
|
||||
|
||||
# 运行 prefill 后...
|
||||
|
||||
# 获取统计
|
||||
stats = policy.get_density_stats()
|
||||
# {
|
||||
# "total_available_blocks": 8001,
|
||||
# "total_selected_blocks": 4160,
|
||||
# "num_chunks": 126,
|
||||
# "overall_density": 0.52
|
||||
# }
|
||||
|
||||
# 打印统计
|
||||
policy.print_density_stats()
|
||||
|
||||
# 重置统计
|
||||
policy.reset_stats()
|
||||
```
|
||||
|
||||
### 启用 DEBUG 日志
|
||||
|
||||
```python
|
||||
# 在 test_ruler.py 中
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||
|
||||
# 输出示例:
|
||||
# [XAttn] chunk=30, available=30, selected=22, chunk_density=73.3%
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 已知问题
|
||||
|
||||
| 问题 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 估计开销过大 | 🟡 待优化 | select_blocks 需要加载所有 K blocks |
|
||||
| 时间比 Full 更长 | 🟡 待优化 | 128K 场景 152s vs 121s |
|
||||
| 小幅内存增长 | 🟢 可接受 | ~4GB,可能来自 Triton 缓存 |
|
||||
| Decode 不支持 | ✅ 设计如此 | 使用 FullAttentionPolicy |
|
||||
|
||||
---
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解
|
||||
- [`docs/xattn_kernels_guide.md`](xattn_kernels_guide.md): Triton kernels 实现
|
||||
- [`docs/sparse_policy_architecture.md`](sparse_policy_architecture.md): SparsePolicy 架构
|
||||
- [`docs/sparse_policy_implementation_guide.md`](sparse_policy_implementation_guide.md): 实现指南
|
||||
198
docs/xattn_kernels_guide.md
Normal file
198
docs/xattn_kernels_guide.md
Normal file
@@ -0,0 +1,198 @@
|
||||
# XAttention Kernels Guide
|
||||
|
||||
本文档详细说明 XAttention 的两个核心 Triton kernel 的工作原理。
|
||||
|
||||
## 概述
|
||||
|
||||
XAttention 使用 stride 采样来快速估计 attention 分布,用于稀疏 attention 的 block 选择。
|
||||
|
||||
**数据流**:
|
||||
```
|
||||
Q [batch, heads, q_len, head_dim]
|
||||
K [batch, heads, kv_len, head_dim]
|
||||
↓ flat_group_gemm_fuse_reshape (stride 采样 + GEMM)
|
||||
attn_scores [batch, heads, q_len/stride, kv_len/stride]
|
||||
↓ softmax_fuse_block_sum (softmax + block 求和)
|
||||
block_sums [batch, heads, q_blocks, k_blocks]
|
||||
↓ threshold 选择
|
||||
sparse_mask [batch, heads, q_blocks, k_blocks]
|
||||
```
|
||||
|
||||
**注意**:Q 和 K 可以有不同的长度(q_len ≠ kv_len),这在 chunked prefill 场景中很常见。
|
||||
|
||||
## Kernel 1: flat_group_gemm_fuse_reshape
|
||||
|
||||
### 功能
|
||||
|
||||
计算 stride reshape 后的 attention scores,本质是计算原始 attention 矩阵中每个 stride×stride 块的**反对角线求和**。
|
||||
|
||||
### 函数签名
|
||||
|
||||
```python
|
||||
def flat_group_gemm_fuse_reshape(
|
||||
query_states: torch.Tensor, # [batch, heads, q_len, head_dim]
|
||||
key_states: torch.Tensor, # [batch, heads, kv_len, head_dim]
|
||||
stride: int,
|
||||
chunk_start: int,
|
||||
chunk_end: int,
|
||||
is_causal: bool = True,
|
||||
) -> torch.Tensor: # [batch, heads, q_len/stride, kv_len/stride]
|
||||
```
|
||||
|
||||
### 采样方式
|
||||
|
||||
```
|
||||
Q 采样: (stride-1-s)::stride (逆向)
|
||||
K 采样: s::stride (正向)
|
||||
|
||||
例如 stride=4:
|
||||
Q 采样位置: 3, 7, 11, 15, ... (从位置 3 开始,每隔 4)
|
||||
K 采样位置: 0, 4, 8, 12, ... (从位置 0 开始,每隔 4)
|
||||
```
|
||||
|
||||
### 反对角线原理
|
||||
|
||||
对于原始 attention 矩阵的每个 stride×stride 块:
|
||||
|
||||
```
|
||||
stride=4 的块:
|
||||
K[0] K[1] K[2] K[3]
|
||||
Q[0] · · · X ← 反对角线
|
||||
Q[1] · · X ·
|
||||
Q[2] · X · ·
|
||||
Q[3] X · · ·
|
||||
```
|
||||
|
||||
**输出值 = 反对角线元素之和**
|
||||
|
||||
因为:
|
||||
- `Q[i]` 采样自原始位置 `(stride-1-i)`
|
||||
- `K[j]` 采样自原始位置 `j`
|
||||
- 当 `i + j = stride - 1` 时,恰好在反对角线上
|
||||
|
||||
### Triton 约束
|
||||
|
||||
**GPU 相关的 BLOCK 大小**:
|
||||
|
||||
| GPU 类型 | 显存 | BLOCK_M/N | 最小 q_len/kv_len |
|
||||
|----------|------|-----------|-------------------|
|
||||
| RTX 3090 | 24GB | 64 | stride × 64 = 256 |
|
||||
| A100/H100 | ≥40GB | 128 | stride × 128 = 512 |
|
||||
|
||||
```python
|
||||
# 代码中的判断逻辑
|
||||
if props.total_memory < 30 * 1024**3: # < 30GB
|
||||
BLOCK_M = BLOCK_N = 64
|
||||
else:
|
||||
BLOCK_M = BLOCK_N = 128
|
||||
|
||||
assert q_len % (stride * BLOCK_M) == 0
|
||||
assert kv_len % (stride * BLOCK_N) == 0
|
||||
```
|
||||
|
||||
### 验证示例
|
||||
|
||||
```python
|
||||
# 输入: 偶数位置=1, 奇数位置=2
|
||||
# q_len=512, kv_len=2048, stride=4, head_dim=128
|
||||
|
||||
# 反对角线元素 (stride=4):
|
||||
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4 (每对)
|
||||
# stride=4 有 2 对
|
||||
# 乘以 head_dim=128
|
||||
# 预期值: 4 * 2 * 128 = 1024
|
||||
|
||||
# 输出 shape: [1, 1, 128, 512] (512/4=128, 2048/4=512)
|
||||
```
|
||||
|
||||
## Kernel 2: softmax_fuse_block_sum
|
||||
|
||||
### 功能
|
||||
|
||||
对 `flat_group_gemm_fuse_reshape` 的输出做 softmax,然后按 block 求和,得到每个 block 的 attention 权重总和。
|
||||
|
||||
### 参数说明
|
||||
|
||||
| 参数 | 含义 |
|
||||
|------|------|
|
||||
| `attn_weights_slice` | 输入 attention scores `[batch, heads, q_reshaped, k_reshaped]` |
|
||||
| `reshaped_block_size` | Block 大小(在 reshaped 空间,= block_size / stride) |
|
||||
| `segment_size` | 每次迭代处理的 K 维度大小(tiling) |
|
||||
| `chunk_start` | Q 的起始位置(用于 causal mask) |
|
||||
| `chunk_end` | Q 的结束位置 |
|
||||
| `real_q_len` | 有效 Q 长度(用于 padding mask) |
|
||||
| `scale` | 缩放因子(融合多个因素) |
|
||||
| `is_causal` | 是否应用 causal mask |
|
||||
|
||||
### Scale 因子
|
||||
|
||||
```python
|
||||
scale = log2(e) / sqrt(head_dim) / stride / norm
|
||||
= 1.4426950408889634 / sqrt(head_dim) / stride / norm
|
||||
```
|
||||
|
||||
| 因子 | 值 | 作用 |
|
||||
|------|-----|------|
|
||||
| `log2(e)` | 1.4426950408889634 | Triton 用 `exp2` 而非 `exp`,需转换底数 |
|
||||
| `1/sqrt(head_dim)` | 1/√128 | 标准 attention 缩放 |
|
||||
| `1/stride` | 1/4 | stride 采样的归一化 |
|
||||
| `1/norm` | 变化 | 额外归一化因子 |
|
||||
|
||||
**为什么用 exp2**:Triton 的 `exp2` 比 `exp` 更快(硬件原生支持),所以把 log₂(e) 融合到 scale 里。
|
||||
|
||||
### Segment Size 约束
|
||||
|
||||
```python
|
||||
assert segment_size >= reshaped_block_size
|
||||
```
|
||||
|
||||
原因:kernel 内部使用 `segment_size // block_size` 做 reshape:
|
||||
|
||||
```python
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
```
|
||||
|
||||
如果 `segment_size < block_size`,则 `segment_size // block_size = 0`,导致无效维度。
|
||||
|
||||
### 验证示例
|
||||
|
||||
```python
|
||||
# 输入: attn_scores [1, 1, 128, 512] (所有值相同)
|
||||
# block_size=128
|
||||
|
||||
# softmax 后每行均匀分布 (所有值相同 → 均匀)
|
||||
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len = 128/512 = 0.25
|
||||
# 每个 Q block 有 block_size=128 行
|
||||
# block_sum = 128 * 0.25 = 32
|
||||
|
||||
# 输出 shape: [1, 1, 1, 4] (128/128=1, 512/128=4)
|
||||
```
|
||||
|
||||
## 完整示例
|
||||
|
||||
```python
|
||||
# 参数
|
||||
q_len = 512 # Q 长度
|
||||
kv_len = 2048 # K/V 长度 (可以不同于 q_len)
|
||||
stride = 4
|
||||
block_size = 128
|
||||
|
||||
# Step 1: flat_group_gemm_fuse_reshape
|
||||
# 输入: Q [1,1,512,128], K [1,1,2048,128]
|
||||
# 输出: attn_scores [1,1,128,512]
|
||||
|
||||
# Step 2: softmax_fuse_block_sum
|
||||
# 输入: attn_scores [1,1,128,512]
|
||||
# 输出: block_sums [1,1,1,4]
|
||||
# q_blocks = 128/128 = 1
|
||||
# k_blocks = 512/128 = 4
|
||||
```
|
||||
|
||||
## 测试代码
|
||||
|
||||
参考 `tests/test_xattn_kernels.py`,使用结构化数据验证两个 kernel 的正确性。
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解
|
||||
- [`docs/sparse_attention_guide.md`](sparse_attention_guide.md): 稀疏 attention 方法概述
|
||||
109
findings.md
109
findings.md
@@ -1,109 +0,0 @@
|
||||
# Findings: CUDA Graph for Offload Mode
|
||||
|
||||
## Discovery 1: 为什么 Offload Mode 不使用 CUDA Graph
|
||||
|
||||
**位置**: `nanovllm/engine/model_runner.py:421`
|
||||
|
||||
```python
|
||||
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
|
||||
```
|
||||
|
||||
**原因**: `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`,强制使用 eager mode。
|
||||
|
||||
---
|
||||
|
||||
## Discovery 2: 当前 CUDA Graph 架构
|
||||
|
||||
**文件**: `model_runner.py:682-717`
|
||||
|
||||
```python
|
||||
def capture_cudagraph(self):
|
||||
# 为不同 batch size 捕获完整 model forward
|
||||
for bs in [1, 2, 4, 8, 16, ...]:
|
||||
with torch.cuda.graph(graph):
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
|
||||
```
|
||||
|
||||
**特点**:
|
||||
- 捕获完整的 `model()` 调用(包含所有层)
|
||||
- 使用 graph pool 共享内存
|
||||
- 只用于 decode(prefill 始终 eager)
|
||||
|
||||
---
|
||||
|
||||
## Discovery 3: Offload Decode 的 Attention 流程
|
||||
|
||||
**文件**: `nanovllm/kvcache/sparse/full_policy.py:304-379`
|
||||
|
||||
**Ring Buffer Pipeline**:
|
||||
```
|
||||
1. 预加载前 N 个 blocks 到 GPU slots
|
||||
2. 对每个 block:
|
||||
a. wait_slot_layer() # 等待 H2D
|
||||
b. get_kv_for_slot() # 获取 KV
|
||||
c. flash_attn_with_lse() # ⭐ 可 graph
|
||||
d. record_slot_compute_done()
|
||||
e. load_next_block() # 启动下一个 H2D
|
||||
f. merge_attention_outputs() # ⭐ 可 graph(但动态)
|
||||
```
|
||||
|
||||
**关键**: H2D 传输不能 graph,但 attention 计算可以。
|
||||
|
||||
---
|
||||
|
||||
## Discovery 4: 验证 Graph 复用可行性
|
||||
|
||||
**测试**: `tests/test_chunk_attention_graph_reuse.py`
|
||||
|
||||
**结论**:
|
||||
- 只需 2 个 graph(causal + non-causal)
|
||||
- 通过 `copy_()` 更新 static tensors
|
||||
- 可复用于所有层和所有 chunk pairs
|
||||
|
||||
**测试结果**:
|
||||
```
|
||||
Layer 0: max_diff=3.91e-03 ✅
|
||||
Layer 1: max_diff=7.81e-03 ✅
|
||||
Layer 2: max_diff=3.91e-03 ✅
|
||||
✅ PASSED
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Discovery 5: Chunk Size 和 Block Size 关系
|
||||
|
||||
**观察**:
|
||||
- Prefilled blocks 的 KV size = `block_size`
|
||||
- Decode buffer 的 KV size = `1` 到 `block_size`(动态)
|
||||
|
||||
**Graph 策略**:
|
||||
- Prefilled blocks: 固定 size = block_size,适合 graph
|
||||
- Decode buffer: 动态 size,建议保持 eager
|
||||
|
||||
---
|
||||
|
||||
## Discovery 6: 使用的 Triton 算子
|
||||
|
||||
**文件**: `nanovllm/ops/chunked_attention.py`
|
||||
|
||||
| 算子 | 功能 | 可 Graph |
|
||||
|------|------|----------|
|
||||
| `flash_attn_with_lse()` | Attention + LSE | ✅ |
|
||||
| `merge_attention_outputs()` | 合并两个 attention 输出 | ✅ |
|
||||
|
||||
这两个算子是纯 GPU 计算,可以被 CUDA Graph 捕获。
|
||||
|
||||
---
|
||||
|
||||
## Discovery 7: 数据依赖分析
|
||||
|
||||
**Attention 输入**:
|
||||
- `q`: 来自当前层的 QKV projection,shape 固定
|
||||
- `k, v`: 来自 GPU slot(H2D 传输后),shape = [1, block_size, heads, dim]
|
||||
|
||||
**依赖链**:
|
||||
```
|
||||
H2D(block) → wait() → get_kv() → copy_to_static() → graph.replay() → clone_output()
|
||||
```
|
||||
|
||||
**关键**: Graph 只封装 attention 计算,不包含数据传输。
|
||||
@@ -48,7 +48,7 @@ class Config:
|
||||
# XAttention BSA specific parameters
|
||||
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
|
||||
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
|
||||
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1)
|
||||
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
|
||||
sparse_use_triton: bool = True # Use Triton kernels for estimation
|
||||
sparse_stride: int = 8 # Stride for Q/K downsampling
|
||||
|
||||
|
||||
@@ -37,6 +37,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize with statistics tracking."""
|
||||
self._stats_total_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
@@ -44,8 +49,33 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Return all blocks - no sparsity."""
|
||||
# Update statistics (only for layer 0 to avoid overcounting)
|
||||
if ctx.layer_id == 0 and available_blocks:
|
||||
self._stats_total_blocks += len(available_blocks)
|
||||
self._stats_num_chunks += 1
|
||||
logger.debug(f"[Full] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, density=100.0%")
|
||||
return available_blocks
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset density statistics."""
|
||||
self._stats_total_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def get_density_stats(self) -> dict:
|
||||
"""Get density statistics."""
|
||||
return {
|
||||
"total_available_blocks": self._stats_total_blocks,
|
||||
"total_selected_blocks": self._stats_total_blocks, # Full = all selected
|
||||
"num_chunks": self._stats_num_chunks,
|
||||
"overall_density": 1.0, # Always 100%
|
||||
}
|
||||
|
||||
def print_density_stats(self) -> None:
|
||||
"""Print density statistics summary."""
|
||||
stats = self.get_density_stats()
|
||||
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
||||
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -58,16 +88,17 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute full attention for chunked prefill.
|
||||
|
||||
This method handles the complete chunked prefill flow:
|
||||
1. Get historical blocks
|
||||
2. Select blocks via select_blocks
|
||||
3. Load and compute attention to historical chunks
|
||||
4. Compute attention to current chunk
|
||||
5. Merge all results
|
||||
This method handles the chunked prefill computation:
|
||||
1. Load and compute attention to historical chunks (using selected_blocks)
|
||||
2. Compute attention to current chunk
|
||||
3. Merge all results
|
||||
|
||||
Note: Block selection is done by the caller before invoking this method.
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
@@ -80,6 +111,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
current_chunk_idx: Current chunk index
|
||||
seq: Sequence object
|
||||
num_tokens: Number of tokens in current chunk
|
||||
selected_blocks: List of CPU block IDs to process (already filtered)
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
@@ -87,30 +119,16 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
|
||||
f"selected_blocks={len(selected_blocks)}")
|
||||
|
||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Step 1: Get historical blocks
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Step 2: Apply select_blocks to filter blocks
|
||||
if cpu_block_table:
|
||||
num_chunks = current_chunk_idx + 1
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
num_query_chunks=num_chunks,
|
||||
layer_id=layer_id,
|
||||
query=None, # Prefill typically doesn't use query for selection
|
||||
is_prefill=True,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
|
||||
# Use the pre-selected blocks directly
|
||||
cpu_block_table = selected_blocks
|
||||
|
||||
if cpu_block_table:
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
@@ -200,16 +218,17 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute full attention for chunked decode.
|
||||
|
||||
This method handles the complete chunked decode flow:
|
||||
1. Get prefilled CPU blocks
|
||||
2. Apply select_blocks for block filtering
|
||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||
4. Read accumulated decode tokens from decode buffer
|
||||
5. Merge all results
|
||||
This method handles the chunked decode computation:
|
||||
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
|
||||
2. Read accumulated decode tokens from decode buffer
|
||||
3. Merge all results
|
||||
|
||||
Note: Block selection is done by the caller before invoking this method.
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch_size, num_heads, head_dim]
|
||||
@@ -218,6 +237,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
seq: Sequence object
|
||||
selected_blocks: List of CPU block IDs to process (already filtered)
|
||||
|
||||
Returns:
|
||||
Attention output [batch_size, 1, num_heads, head_dim]
|
||||
@@ -227,40 +247,35 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
# Use the pre-selected blocks directly
|
||||
cpu_block_table = selected_blocks
|
||||
if layer_id == 0:
|
||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||
logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}")
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last CPU block
|
||||
# CRITICAL: Use original prefill length, not current seq length!
|
||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||
# Note: We need to get all prefilled blocks to determine last_block_valid_tokens
|
||||
block_size = kvcache_manager.block_size
|
||||
num_prefill_blocks = len(cpu_block_table)
|
||||
all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
# Apply sparse policy (self) for block filtering
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=layer_id,
|
||||
query=q_batched,
|
||||
is_prefill=False,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
# Determine if selected_blocks contains the last prefilled block
|
||||
# If not, all selected blocks are full blocks (use block_size as valid tokens)
|
||||
last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None
|
||||
selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block)
|
||||
effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size
|
||||
|
||||
# Use ring buffer pipeline for loading prefilled blocks
|
||||
load_slots = offload_engine.decode_load_slots
|
||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
block_size, last_block_valid_tokens, layer_id, softmax_scale
|
||||
block_size, effective_last_block_tokens, layer_id, softmax_scale
|
||||
)
|
||||
|
||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||
|
||||
@@ -204,17 +204,20 @@ class SparsePolicy(ABC):
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute chunked prefill attention (complete flow).
|
||||
|
||||
This is the main entry point for prefill attention computation.
|
||||
It defines the complete prefill flow:
|
||||
1. Get historical blocks
|
||||
2. Select blocks (call select_blocks)
|
||||
3. Load and compute historical blocks via offload_engine
|
||||
4. Get current chunk KV from offload_engine, compute attention
|
||||
5. Merge all results
|
||||
1. Load and compute historical blocks via offload_engine (using selected_blocks)
|
||||
2. Get current chunk KV from offload_engine, compute attention
|
||||
3. Merge all results
|
||||
|
||||
Note: Block selection (select_blocks) is called by the caller (attention.py)
|
||||
before invoking this method. The selected_blocks parameter contains the
|
||||
filtered block IDs to process.
|
||||
|
||||
Args:
|
||||
q: [seq_len, num_heads, head_dim] query for current chunk
|
||||
@@ -227,6 +230,7 @@ class SparsePolicy(ABC):
|
||||
current_chunk_idx: current chunk index
|
||||
seq: Sequence object
|
||||
num_tokens: number of tokens in current chunk
|
||||
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
|
||||
|
||||
Returns:
|
||||
[seq_len, num_heads, head_dim] final attention output
|
||||
@@ -242,17 +246,20 @@ class SparsePolicy(ABC):
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute chunked decode attention (complete flow).
|
||||
|
||||
This is the main entry point for decode attention computation.
|
||||
It defines the complete decode flow:
|
||||
1. Get prefilled blocks from CPU
|
||||
2. Select blocks (call select_blocks)
|
||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||
4. Read accumulated decode tokens from decode buffer
|
||||
5. Merge all results
|
||||
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
|
||||
2. Read accumulated decode tokens from decode buffer
|
||||
3. Merge all results
|
||||
|
||||
Note: Block selection (select_blocks) is called by the caller (attention.py)
|
||||
before invoking this method. The selected_blocks parameter contains the
|
||||
filtered block IDs to process.
|
||||
|
||||
The decode position information can be computed internally:
|
||||
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||
@@ -265,6 +272,7 @@ class SparsePolicy(ABC):
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
seq: Sequence object
|
||||
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, num_heads, head_dim] final attention output
|
||||
|
||||
@@ -2,69 +2,508 @@
|
||||
XAttention Block Sparse Attention (BSA) Policy for nano-vllm.
|
||||
|
||||
This module implements XAttention-inspired block sparse attention for chunked prefill.
|
||||
Current implementation loads all historical blocks (FULL strategy).
|
||||
|
||||
Sparse selection to be implemented in next phase.
|
||||
Key design:
|
||||
1. Use xattn_estimate_chunked to estimate sparse block mask
|
||||
2. Use BSA kernel for efficient sparse attention computation
|
||||
3. Support chunked prefill with q_start_pos for correct position handling
|
||||
|
||||
Note: Decode phase is not supported - use FullAttentionPolicy for decode.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
from nanovllm.kvcache.manager import KVCacheManager
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check BSA availability
|
||||
try:
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
BSA_AVAILABLE = True
|
||||
except ImportError:
|
||||
BSA_AVAILABLE = False
|
||||
logger.warning("block_sparse_attn not available, XAttentionBSAPolicy will fallback to dense")
|
||||
|
||||
# Check xattn_estimate_chunked availability
|
||||
try:
|
||||
from nanovllm.ops.xattn import xattn_estimate_chunked
|
||||
XATTN_AVAILABLE = True
|
||||
except ImportError:
|
||||
XATTN_AVAILABLE = False
|
||||
logger.warning("xattn_estimate_chunked not available")
|
||||
|
||||
|
||||
def expand_kv_for_gqa(
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
num_heads: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Expand KV for Grouped Query Attention.
|
||||
|
||||
Args:
|
||||
key_states: [B, num_kv_heads, seq_len, head_dim]
|
||||
value_states: [B, num_kv_heads, seq_len, head_dim]
|
||||
num_heads: Number of query heads
|
||||
|
||||
Returns:
|
||||
Expanded (key, value) with shape [B, num_heads, seq_len, head_dim]
|
||||
"""
|
||||
num_kv_heads = key_states.shape[1]
|
||||
if num_heads == num_kv_heads:
|
||||
return key_states, value_states
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return (
|
||||
key_states.repeat_interleave(num_groups, dim=1),
|
||||
value_states.repeat_interleave(num_groups, dim=1),
|
||||
)
|
||||
|
||||
|
||||
class XAttentionBSAPolicy(SparsePolicy):
|
||||
"""
|
||||
XAttention Block Sparse Attention policy for chunked prefill.
|
||||
|
||||
This policy uses block-level estimation to determine which KV blocks
|
||||
are important for the current chunk's queries, enabling sparse computation.
|
||||
Uses xattn_estimate_chunked to estimate sparse mask, then BSA kernel
|
||||
for efficient sparse attention computation.
|
||||
|
||||
Note: Current implementation loads all historical chunks (FULL strategy).
|
||||
Sparse selection to be implemented in next phase.
|
||||
Note:
|
||||
- Only supports prefill phase (decode uses FullAttentionPolicy)
|
||||
- BSA block size is fixed at 128 tokens
|
||||
"""
|
||||
|
||||
supports_prefill = False # Uses standard select_blocks interface
|
||||
supports_decode = False # BSA is prefill-only
|
||||
requires_block_selection = False # Selection happens at chunk level, not block level
|
||||
supports_prefill = True
|
||||
supports_decode = False # Decode uses FullAttentionPolicy
|
||||
requires_block_selection = False # Selection happens internally
|
||||
|
||||
# BSA requires 128-token blocks
|
||||
BSA_BLOCK_SIZE = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.95, # High threshold for accuracy testing
|
||||
stride: int = 8,
|
||||
chunk_size: int = 16384,
|
||||
block_size: int = 128,
|
||||
samples_per_chunk: int = 128,
|
||||
threshold: float = 0.9,
|
||||
use_triton: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize XAttention BSA policy.
|
||||
|
||||
Args:
|
||||
block_size: Number of tokens per block (default: 128)
|
||||
samples_per_chunk: Number of tokens to sample from each historical chunk for estimation
|
||||
threshold: Cumulative attention threshold for chunk selection (0-1)
|
||||
threshold: Cumulative attention threshold for block selection (0-1)
|
||||
Higher values = more blocks selected = less sparse
|
||||
stride: Stride for Q/K reshape in estimation (typically 8)
|
||||
chunk_size: Processing chunk size for xattn_estimate (Triton alignment)
|
||||
block_size: BSA block size (must be 128)
|
||||
samples_per_chunk: Samples per chunk for estimation (unused)
|
||||
use_triton: Whether to use Triton kernels
|
||||
"""
|
||||
self.block_size = block_size
|
||||
self.samples_per_chunk = samples_per_chunk
|
||||
self.threshold = threshold
|
||||
self.stride = stride
|
||||
self.chunk_size = chunk_size
|
||||
self.use_triton = use_triton
|
||||
self._num_heads = None # Set during first forward
|
||||
|
||||
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
|
||||
# Sparse metadata: stores attention scores per layer
|
||||
# Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]]
|
||||
self.sparse_metadata: dict = {}
|
||||
|
||||
# Statistics for density tracking
|
||||
self._stats_total_available_blocks = 0
|
||||
self._stats_total_selected_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select blocks to load from CPU.
|
||||
Compute attention scores for all available blocks using flat_group_gemm,
|
||||
then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks.
|
||||
|
||||
Current implementation returns all blocks (FULL strategy).
|
||||
Sparse selection to be implemented in next phase.
|
||||
This method:
|
||||
1. Loads each K block from CPU
|
||||
2. Computes Q@K^T attention scores using XAttention stride reshape
|
||||
3. Applies softmax_fuse_block_sum to get block-level attention
|
||||
4. Uses find_blocks_chunked to select blocks based on threshold
|
||||
|
||||
Args:
|
||||
available_blocks: List of all available CPU block IDs
|
||||
ctx: Policy context with query info, chunk index, etc.
|
||||
available_blocks: List of CPU block IDs
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
ctx: PolicyContext with query tensor and metadata
|
||||
|
||||
Returns:
|
||||
List of selected block IDs to load
|
||||
Selected block IDs based on attention threshold
|
||||
"""
|
||||
# Current: Return all blocks (FULL strategy)
|
||||
# TODO: Implement sparse selection based on query attention estimation
|
||||
return available_blocks
|
||||
if not available_blocks or ctx.query is None:
|
||||
return available_blocks
|
||||
|
||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked
|
||||
import math
|
||||
|
||||
layer_id = ctx.layer_id
|
||||
q = ctx.query # [seq_len, num_heads, head_dim]
|
||||
|
||||
# Convert Q to [batch, heads, seq_len, head_dim]
|
||||
# q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim]
|
||||
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
|
||||
|
||||
num_heads = Q.shape[1]
|
||||
head_dim = Q.shape[3]
|
||||
q_len = Q.shape[2]
|
||||
|
||||
# flat_group_gemm requires q_len to be divisible by stride * BLOCK_M (typically 8 * 128 = 1024)
|
||||
# Pad Q if necessary
|
||||
BLOCK_M = 128 # Triton block size
|
||||
alignment = self.stride * BLOCK_M
|
||||
if q_len < alignment:
|
||||
# Q too short, skip estimation and return all blocks
|
||||
logger.debug(f"[XAttn] select_blocks: q_len={q_len} < alignment={alignment}, skipping estimation")
|
||||
return available_blocks
|
||||
|
||||
# Pad Q to alignment
|
||||
padded_q_len = ((q_len + alignment - 1) // alignment) * alignment
|
||||
if padded_q_len != q_len:
|
||||
pad_size = padded_q_len - q_len
|
||||
Q = torch.nn.functional.pad(Q, (0, 0, 0, pad_size), value=0)
|
||||
|
||||
q_reshaped_len = padded_q_len // self.stride
|
||||
|
||||
# Use a single slot for loading (synchronous mode for simplicity)
|
||||
slot = 0
|
||||
attn_scores_list = []
|
||||
|
||||
# Get block size from context
|
||||
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
|
||||
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
|
||||
|
||||
for cpu_block_id in available_blocks:
|
||||
# Load K block from CPU to GPU
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
# Get KV: [1, block_size, num_kv_heads, head_dim]
|
||||
k_block, _ = offload_engine.get_kv_for_slot(slot)
|
||||
|
||||
# Convert K to [batch, heads, k_len, head_dim]
|
||||
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
|
||||
# Handle GQA: expand K heads to match Q heads
|
||||
num_kv_heads = K_chunk.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N)
|
||||
k_len = K_chunk.shape[2]
|
||||
BLOCK_N = 128
|
||||
k_alignment = self.stride * BLOCK_N
|
||||
if k_len < k_alignment:
|
||||
# K too short, pad it
|
||||
pad_size = k_alignment - k_len
|
||||
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
|
||||
|
||||
# Compute attention scores using flat_group_gemm_fuse_reshape
|
||||
# Output: [batch, heads, q_len/stride, k_len/stride]
|
||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, self.stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
is_causal=False
|
||||
)
|
||||
attn_scores_list.append(attn_chunk)
|
||||
|
||||
# Mark slot as done for reuse
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
|
||||
# Concatenate all attention scores along K dimension
|
||||
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len]
|
||||
# Result: [1, heads, q_reshaped_len, total_k_reshaped_len]
|
||||
if not attn_scores_list:
|
||||
return available_blocks
|
||||
|
||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
||||
# Free intermediate list immediately
|
||||
del attn_scores_list
|
||||
|
||||
# Step 2: Apply softmax_fuse_block_sum to get block-level attention
|
||||
# block_size = reshaped_block_size so each CPU block maps to exactly 1 output block
|
||||
# This ensures block_sums.shape[-1] == num_available_blocks (1:1 mapping)
|
||||
norm = 1.0 # Normalization factor
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
|
||||
segment_size = min(4096, reshaped_block_size)
|
||||
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
real_q_len=q_reshaped_len,
|
||||
scale=scale,
|
||||
is_causal=False, # Historical blocks are all before current chunk
|
||||
)
|
||||
# block_sums shape: [batch, heads, q_blocks, k_blocks]
|
||||
# where k_blocks == len(available_blocks) (1:1 mapping with CPU blocks)
|
||||
|
||||
# Step 3: Use find_blocks_chunked to get selection mask
|
||||
# current_index = 0 since we're looking at historical blocks only
|
||||
mask = find_blocks_chunked(
|
||||
block_sums,
|
||||
current_index=0,
|
||||
threshold=self.threshold,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=False, # Historical blocks don't need causal mask
|
||||
)
|
||||
# mask shape: [batch, num_heads, q_blocks, k_blocks] - boolean
|
||||
# where k_blocks == len(available_blocks)
|
||||
|
||||
# GQA-aware aggregation:
|
||||
# For GQA, multiple Q heads share one KV head. We need to select a block
|
||||
# if ANY Q head within the same KV head group selects it.
|
||||
# mask: [batch, num_heads, q_blocks, k_blocks]
|
||||
# Reshape to [batch, num_kv_heads, num_groups, q_blocks, k_blocks]
|
||||
batch_size, num_q_heads, q_blocks, k_blocks = mask.shape
|
||||
# num_kv_heads was set in the K loading loop above (line ~199)
|
||||
# num_groups = num_heads // num_kv_heads (for GQA)
|
||||
num_groups = num_heads // num_kv_heads if num_heads != num_kv_heads else 1
|
||||
|
||||
if num_groups > 1:
|
||||
# Reshape: [batch, num_kv_heads, num_groups, q_blocks, k_blocks]
|
||||
mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks)
|
||||
# Aggregate within each KV head group: any Q head selects -> KV head selects
|
||||
mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks]
|
||||
else:
|
||||
mask_per_kv_head = mask # [batch, num_heads, q_blocks, k_blocks]
|
||||
|
||||
# Aggregate across KV heads and q_blocks using majority voting
|
||||
# Instead of any(), use voting: select if >50% of kv_heads select it
|
||||
# mask_per_kv_head: [batch, num_kv_heads, q_blocks, k_blocks]
|
||||
# Sum across kv_heads and q_blocks to get vote count per k_block
|
||||
vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks]
|
||||
total_votes = num_kv_heads * q_blocks
|
||||
vote_ratio = vote_count / total_votes
|
||||
|
||||
# Select blocks with >50% votes (majority voting)
|
||||
vote_threshold = 0.5
|
||||
block_selected = vote_ratio > vote_threshold
|
||||
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
|
||||
|
||||
# Always include first block (sink) and last block for safety
|
||||
if available_blocks and available_blocks[0] not in selected_block_ids:
|
||||
selected_block_ids.insert(0, available_blocks[0])
|
||||
if available_blocks and available_blocks[-1] not in selected_block_ids:
|
||||
selected_block_ids.append(available_blocks[-1])
|
||||
|
||||
# Update statistics (only for layer 0 to avoid overcounting)
|
||||
if layer_id == 0 and available_blocks:
|
||||
self._stats_total_available_blocks += len(available_blocks)
|
||||
self._stats_total_selected_blocks += len(selected_block_ids)
|
||||
self._stats_num_chunks += 1
|
||||
|
||||
# Log per-chunk density
|
||||
chunk_density = len(selected_block_ids) / len(available_blocks)
|
||||
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
|
||||
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
|
||||
|
||||
# Free intermediate tensors to prevent memory leak
|
||||
del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected
|
||||
|
||||
return selected_block_ids
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention for chunked prefill using XAttention sparse block selection.
|
||||
|
||||
This method handles the chunked prefill computation:
|
||||
1. Load and compute attention to historical chunks (using selected_blocks)
|
||||
2. Compute attention to current chunk
|
||||
3. Merge all results
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
|
||||
layer_id: Current layer index
|
||||
softmax_scale: Softmax scaling factor
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
current_chunk_idx: Current chunk index
|
||||
seq: Sequence object
|
||||
num_tokens: Number of tokens in current chunk
|
||||
selected_blocks: List of CPU block IDs selected by select_blocks
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Use the pre-selected blocks directly
|
||||
cpu_block_table = selected_blocks
|
||||
|
||||
if cpu_block_table:
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
num_blocks = len(cpu_block_table)
|
||||
|
||||
if len(load_slots) == 1:
|
||||
# Only 1 slot - use synchronous mode
|
||||
slot = load_slots[0]
|
||||
for block_idx in range(num_blocks):
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
else:
|
||||
# Multiple slots - use pipeline
|
||||
num_slots = len(load_slots)
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
# Issue next transfer
|
||||
next_block_idx = block_idx + num_slots
|
||||
if next_block_idx < num_blocks:
|
||||
next_slot = load_slots[next_block_idx % num_slots]
|
||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
||||
|
||||
# Compute attention to current chunk (causal mask)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched, k_curr, v_curr,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Merge historical and current attention
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
|
||||
# Sync default stream with compute_stream before returning
|
||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||
|
||||
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
selected_blocks: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
XAttention does not support decode phase.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"XAttentionBSAPolicy does not support decode phase. "
|
||||
"Use FullAttentionPolicy for decode."
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset policy state."""
|
||||
pass
|
||||
"""Reset policy state and clear sparse metadata."""
|
||||
self.sparse_metadata.clear()
|
||||
# Don't reset statistics here - they accumulate across the entire prefill
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset density statistics."""
|
||||
self._stats_total_available_blocks = 0
|
||||
self._stats_total_selected_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
def get_density_stats(self) -> dict:
|
||||
"""Get density statistics."""
|
||||
if self._stats_total_available_blocks == 0:
|
||||
return {
|
||||
"total_available_blocks": 0,
|
||||
"total_selected_blocks": 0,
|
||||
"num_chunks": 0,
|
||||
"overall_density": 0.0,
|
||||
}
|
||||
return {
|
||||
"total_available_blocks": self._stats_total_available_blocks,
|
||||
"total_selected_blocks": self._stats_total_selected_blocks,
|
||||
"num_chunks": self._stats_num_chunks,
|
||||
"overall_density": self._stats_total_selected_blocks / self._stats_total_available_blocks,
|
||||
}
|
||||
|
||||
def print_density_stats(self) -> None:
|
||||
"""Print density statistics summary."""
|
||||
stats = self.get_density_stats()
|
||||
logger.info(f"[XAttn BSA] Density Stats: chunks={stats['num_chunks']}, "
|
||||
f"available={stats['total_available_blocks']}, "
|
||||
f"selected={stats['total_selected_blocks']}, "
|
||||
f"density={stats['overall_density']:.1%}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"
|
||||
|
||||
@@ -5,6 +5,7 @@ from torch import nn
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -197,11 +198,30 @@ class Attention(nn.Module):
|
||||
if sparse_policy is None:
|
||||
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||||
|
||||
# Step 1: Get historical CPU blocks
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
|
||||
selected_blocks = []
|
||||
if cpu_block_table:
|
||||
num_chunks = current_chunk_idx + 1
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
num_query_chunks=num_chunks,
|
||||
layer_id=self.layer_id,
|
||||
query=q, # Pass query for sparse policies that need it
|
||||
is_prefill=True,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||
|
||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||
# Delegate computation to policy with pre-selected blocks
|
||||
final_o = sparse_policy.compute_chunked_prefill(
|
||||
q, k, v,
|
||||
self.layer_id,
|
||||
@@ -211,6 +231,7 @@ class Attention(nn.Module):
|
||||
current_chunk_idx,
|
||||
seq,
|
||||
num_tokens,
|
||||
selected_blocks,
|
||||
)
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
@@ -258,14 +279,36 @@ class Attention(nn.Module):
|
||||
raise RuntimeError("sparse_policy is required for chunked decode")
|
||||
|
||||
# Check if policy supports decode phase
|
||||
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
|
||||
if not sparse_policy.supports_decode:
|
||||
raise RuntimeError(f"{sparse_policy} does not support decode phase")
|
||||
from nanovllm.kvcache.sparse import FullAttentionPolicy
|
||||
sparse_policy = FullAttentionPolicy()
|
||||
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
|
||||
f"falling back to FullAttentionPolicy")
|
||||
|
||||
# Step 1: Get prefilled CPU blocks
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
|
||||
selected_blocks = []
|
||||
if cpu_block_table:
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=self.layer_id,
|
||||
query=q, # Pass query for sparse policies that need it
|
||||
is_prefill=False,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||||
f"policy={sparse_policy}, layer={self.layer_id}")
|
||||
|
||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||
# Delegate computation to policy with pre-selected blocks
|
||||
return sparse_policy.compute_chunked_decode(
|
||||
q,
|
||||
self.layer_id,
|
||||
@@ -273,4 +316,5 @@ class Attention(nn.Module):
|
||||
offload_engine,
|
||||
kvcache_manager,
|
||||
seq,
|
||||
selected_blocks,
|
||||
)
|
||||
|
||||
@@ -419,7 +419,9 @@ def flat_group_gemm_fuse_reshape(
|
||||
assert key_states.shape[1] == num_heads
|
||||
assert key_states.shape[3] == head_dim
|
||||
|
||||
output = torch.empty(
|
||||
# Use zeros instead of empty to handle causal early-exit in kernel
|
||||
# (some blocks may not be written due to causal mask optimization)
|
||||
output = torch.zeros(
|
||||
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
||||
dtype=query_states.dtype,
|
||||
device=query_states.device
|
||||
@@ -1067,6 +1069,7 @@ def xattn_estimate_chunked(
|
||||
)
|
||||
|
||||
# Softmax + block sum
|
||||
# segment_size should match the standard xattn_estimate for consistency
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights,
|
||||
reshaped_block_size,
|
||||
@@ -1082,6 +1085,14 @@ def xattn_estimate_chunked(
|
||||
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
|
||||
else:
|
||||
# PyTorch fallback implementation
|
||||
# Match Triton kernel exactly for consistency
|
||||
#
|
||||
# Triton uses:
|
||||
# 1. exp2 (base-2 exponential) for softmax
|
||||
# 2. scale factor includes log2(e) = 1.4426950408889634
|
||||
# 3. causal mask: q_pos >= k_pos (not q_pos + 1 > k_pos)
|
||||
# 4. chunk_start for global Q position tracking
|
||||
|
||||
# Reshape K: interleave positions and concatenate head dims
|
||||
reshaped_key = torch.cat(
|
||||
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
|
||||
@@ -1093,49 +1104,58 @@ def xattn_estimate_chunked(
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Use same scale as Triton: includes log2(e) for exp2 compatibility
|
||||
# Triton: scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
# Convert to float32 for numerical stability (matching Triton)
|
||||
reshaped_query_f32 = reshaped_query.to(torch.float32)
|
||||
reshaped_key_f32 = reshaped_key.to(torch.float32)
|
||||
|
||||
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
|
||||
attn_weights = torch.matmul(
|
||||
reshaped_query, reshaped_key.transpose(2, 3)
|
||||
) / math.sqrt(head_dim) / stride / norm
|
||||
reshaped_query_f32, reshaped_key_f32.transpose(2, 3)
|
||||
) * scale
|
||||
|
||||
# Apply causal mask
|
||||
# Apply causal mask (matching Triton's logic exactly)
|
||||
if causal:
|
||||
reshaped_q_positions = reshaped_q_len
|
||||
causal_mask = torch.zeros(
|
||||
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len),
|
||||
device=key_states.device,
|
||||
dtype=attn_weights.dtype,
|
||||
# Triton uses: offs_q = chunk_start + block_id * block_size + arange(0, block_size)
|
||||
# chunk_start = q_start_block * reshaped_block_size
|
||||
chunk_start = q_start_block * reshaped_block_size
|
||||
|
||||
# Create position indices in reshaped space
|
||||
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
|
||||
k_positions = torch.arange(reshaped_k_len, device=attn_weights.device)
|
||||
|
||||
# Triton causal mask: q_pos >= k_pos
|
||||
causal_mask = q_positions[:, None] >= k_positions[None, :] # (reshaped_q_len, reshaped_k_len)
|
||||
|
||||
# Apply causal mask: set future positions to -1e6 (matching Triton)
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
~causal_mask.unsqueeze(0).unsqueeze(0), -1e6
|
||||
)
|
||||
|
||||
# Mask out padding in K
|
||||
if k_pad > 0:
|
||||
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf")
|
||||
# Softmax using exp2 (matching Triton exactly)
|
||||
# Triton: X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
# All computation in float32
|
||||
attn_max = attn_weights.max(dim=-1, keepdim=True).values
|
||||
attn_weights_shifted = attn_weights - attn_max
|
||||
attn_exp2 = torch.exp2(attn_weights_shifted)
|
||||
attn_sum_exp2 = attn_exp2.sum(dim=-1, keepdim=True)
|
||||
attn_weights = attn_exp2 / attn_sum_exp2
|
||||
|
||||
# Mask out future positions
|
||||
q_start_reshaped = q_start_pos // stride
|
||||
for q_idx in range(reshaped_q_positions):
|
||||
q_pos_reshaped = q_start_reshaped + q_idx
|
||||
if q_pos_reshaped + 1 < reshaped_k_len:
|
||||
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf")
|
||||
# Mask for valid Q positions (matching Triton's sum_mask)
|
||||
# Triton: sum_mask = offs_q[:, None] < real_q_len
|
||||
# real_q_len = chunk_start + valid_q_reshaped
|
||||
chunk_start = q_start_block * reshaped_block_size
|
||||
real_q_len = chunk_start + valid_q_reshaped
|
||||
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
|
||||
valid_q_mask = q_positions < real_q_len # (reshaped_q_len,)
|
||||
|
||||
# Handle padding in Q
|
||||
if q_pad > 0:
|
||||
q_pad_reshaped = q_pad // stride
|
||||
if q_pad_reshaped > 0:
|
||||
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
|
||||
# Zero out invalid Q positions
|
||||
attn_weights = attn_weights * valid_q_mask.view(1, 1, -1, 1).float()
|
||||
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Apply softmax
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
# Zero out padded Q positions
|
||||
if q_pad > 0:
|
||||
q_pad_reshaped = q_pad // stride
|
||||
if q_pad_reshaped > 0:
|
||||
attn_weights[:, :, -q_pad_reshaped:, :] = 0
|
||||
|
||||
# Aggregate to block level
|
||||
# Aggregate to block level (keep in float32)
|
||||
attn_sum = attn_weights.view(
|
||||
batch_size,
|
||||
num_heads,
|
||||
@@ -1145,6 +1165,9 @@ def xattn_estimate_chunked(
|
||||
reshaped_block_size,
|
||||
).sum(dim=-1).sum(dim=-2)
|
||||
|
||||
# Convert back to input dtype for consistency
|
||||
attn_sum = attn_sum.to(query_states.dtype)
|
||||
|
||||
# Find blocks that exceed threshold
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
|
||||
55
progress.md
55
progress.md
@@ -1,55 +0,0 @@
|
||||
# Progress: CUDA Graph for Offload Mode
|
||||
|
||||
## Session: 2026-01-22
|
||||
|
||||
### 调研阶段 ✅ 完成
|
||||
|
||||
**完成的调研**:
|
||||
|
||||
1. ✅ 分析 `model_runner.py` 中的 CUDA Graph 实现
|
||||
- `capture_cudagraph()`: 为不同 batch size 捕获完整 model forward
|
||||
- `run_model()`: 通过 `is_chunked_prefill` 决定 eager/graph
|
||||
|
||||
2. ✅ 分析 offload decode 流程
|
||||
- `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`
|
||||
- 导致永远使用 eager mode
|
||||
|
||||
3. ✅ 分析 ring buffer pipeline
|
||||
- `_decode_ring_buffer_pipeline()` 包含 H2D 传输 + attention 计算
|
||||
- H2D 不能 graph,attention 可以 graph
|
||||
|
||||
4. ✅ 验证 graph 复用策略
|
||||
- 创建 `test_chunk_attention_graph_reuse.py`
|
||||
- 确认 2 个 graph 可复用于所有层
|
||||
|
||||
### 计划编写 ✅ 完成
|
||||
|
||||
- ✅ 创建 `task_plan.md`
|
||||
- ✅ 创建 `findings.md`
|
||||
- ✅ 创建 `progress.md`
|
||||
|
||||
### 下一步: 实现
|
||||
|
||||
**Phase 1**: 添加 graph 捕获到 OffloadEngine
|
||||
- [ ] 在 `offload_engine.py` 添加 `capture_attention_graphs()`
|
||||
- [ ] 添加 `attention_graph_causal` 和 `attention_graph_non_causal` 属性
|
||||
|
||||
**Phase 2**: 修改 ring buffer pipeline
|
||||
- [ ] 在 `_decode_ring_buffer_pipeline()` 使用 graph replay
|
||||
- [ ] 保持 H2D 和 merge 为 eager
|
||||
|
||||
**Phase 3**: 测试
|
||||
- [ ] 运行 needle test 验证正确性
|
||||
- [ ] 对比性能
|
||||
|
||||
---
|
||||
|
||||
## 文件清单
|
||||
|
||||
| 文件 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| `tests/test_chunk_attention_graph.py` | ✅ 已提交 | 预分配 chunk pair graphs 测试 |
|
||||
| `tests/test_chunk_attention_graph_reuse.py` | 待提交 | Graph 复用验证 |
|
||||
| `task_plan.md` | ✅ 创建 | 实现计划 |
|
||||
| `findings.md` | ✅ 创建 | 调研发现 |
|
||||
| `progress.md` | ✅ 创建 | 进度日志 |
|
||||
357
task_plan.md
357
task_plan.md
@@ -1,357 +0,0 @@
|
||||
# Task Plan: CUDA Graph 优化 Offload Mode Decode
|
||||
|
||||
## 目标
|
||||
|
||||
为 nanovllm 的 CPU offload 模式添加 CUDA Graph 支持,加速 decode 阶段的计算。
|
||||
|
||||
## 问题分析
|
||||
|
||||
### Transformer 层的完整结构
|
||||
|
||||
```
|
||||
Qwen3DecoderLayer.forward:
|
||||
├── input_layernorm (RMSNorm) # ✅ 纯 GPU
|
||||
├── self_attn:
|
||||
│ ├── qkv_proj (Linear) # ✅ 纯 GPU
|
||||
│ ├── q_norm, k_norm (RMSNorm) # ✅ 纯 GPU
|
||||
│ ├── rotary_emb # ✅ 纯 GPU
|
||||
│ ├── attn._chunked_decode_attention: # ⚠️ 包含 CPU→GPU
|
||||
│ │ ├── H2D transfer # ❌ 不能 graph
|
||||
│ │ ├── flash_attn_with_lse # ✅ 可以 graph
|
||||
│ │ └── merge # ✅ 纯 GPU
|
||||
│ └── o_proj (Linear) # ✅ 纯 GPU
|
||||
├── post_attention_layernorm # ✅ 纯 GPU
|
||||
└── mlp (FFN: gate, up, down) # ✅ 纯 GPU
|
||||
```
|
||||
|
||||
**核心问题**:H2D 传输被嵌在 attention 中间,打断了整层的 graph 捕获。
|
||||
|
||||
### 可能的方案
|
||||
|
||||
| 方案 | 描述 | 优点 | 缺点 |
|
||||
|------|------|------|------|
|
||||
| A. 分段 Graph | 将层拆分为 pre/post attention 两段 | 覆盖面广 | 改动大,需拆分层执行 |
|
||||
| B. 只 Graph Attention | 只优化 flash_attn_with_lse | 改动小 | 优化效果有限 |
|
||||
| C. 重构执行流程 | 完全重写 model forward | 最优效果 | 工作量巨大 |
|
||||
|
||||
### 推荐:方案 A(分段 Graph)
|
||||
|
||||
将每层拆分为两个 graph:
|
||||
1. **pre_attention_graph**: `norm → qkv_proj → q/k_norm → rotary`
|
||||
2. **post_attention_graph**: `o_proj → norm → FFN`
|
||||
|
||||
中间的 `_chunked_decode_attention` 保持 eager(包含 H2D),但内部的 `flash_attn_with_lse` 使用 graph。
|
||||
|
||||
---
|
||||
|
||||
## 当前状态分析
|
||||
|
||||
### 现有 CUDA Graph 实现
|
||||
|
||||
**文件**: `nanovllm/engine/model_runner.py`
|
||||
|
||||
| 方法 | 行号 | 功能 |
|
||||
|------|------|------|
|
||||
| `capture_cudagraph()` | 682-717 | 为不同 batch size 捕获完整 model forward |
|
||||
| `run_model()` | 415-436 | 决定使用 eager 还是 graph replay |
|
||||
|
||||
**关键逻辑** (`run_model`):
|
||||
```python
|
||||
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
|
||||
```
|
||||
|
||||
**问题**: `run_chunked_offload_decode` 设置 `is_chunked_prefill=True`,导致**永远使用 eager mode**。
|
||||
|
||||
### Offload Decode 流程
|
||||
|
||||
**文件**: `nanovllm/kvcache/sparse/full_policy.py`
|
||||
|
||||
`_decode_ring_buffer_pipeline()` (L304-379):
|
||||
```
|
||||
for block in cpu_blocks:
|
||||
1. wait_slot_layer(slot) # 等待 H2D 完成
|
||||
2. k, v = get_kv_for_slot(slot) # 获取 KV
|
||||
3. o, lse = flash_attn_with_lse() # ⭐ 纯 GPU 计算
|
||||
4. record_slot_compute_done(slot) # 标记计算完成
|
||||
5. load_next_block() # 启动下一个 H2D
|
||||
6. merge_attention_outputs() # ⭐ 纯 GPU 计算
|
||||
```
|
||||
|
||||
**可 Graph 化的部分**:
|
||||
- `flash_attn_with_lse()` - 纯 GPU 计算
|
||||
- 不可 Graph 化: H2D 传输、动态 merge
|
||||
|
||||
## 验证结果
|
||||
|
||||
**测试文件**: `tests/test_chunk_attention_graph_reuse.py`
|
||||
|
||||
| 测试 | 结果 |
|
||||
|------|------|
|
||||
| 2 个 Graph 复用于所有层和所有 chunk | ✅ PASSED |
|
||||
| copy_() 更新 static tensors | ✅ 有效 |
|
||||
| Eager merge | ✅ 用户已接受 |
|
||||
|
||||
**结论**: 只需 2 个 graph(causal + non-causal),通过 copy_() 复用。
|
||||
|
||||
---
|
||||
|
||||
## 修改计划(方案 A:分段 Graph)
|
||||
|
||||
### 架构设计
|
||||
|
||||
```
|
||||
每层执行流程(Offload Decode):
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ PRE-ATTENTION GRAPH (可复用于所有层) │
|
||||
│ input_layernorm → qkv_proj → q/k_norm → rotary → split Q │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ CHUNKED ATTENTION (Eager + 部分 Graph) │
|
||||
│ for block in cpu_blocks: │
|
||||
│ H2D transfer (eager) │
|
||||
│ flash_attn_with_lse (GRAPH - 2个可复用) │
|
||||
│ merge (eager) │
|
||||
│ decode_buffer attention (eager) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ POST-ATTENTION GRAPH (可复用于所有层) │
|
||||
│ o_proj → post_layernorm → gate_proj → up_proj → SiLU │
|
||||
│ → down_proj → residual │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**总共需要的 Graph 数量**:
|
||||
- 1 个 pre_attention_graph(所有层复用)
|
||||
- 2 个 attention_graph(causal + non-causal,所有层复用)
|
||||
- 1 个 post_attention_graph(所有层复用)
|
||||
- **总计: 4 个 graph**
|
||||
|
||||
---
|
||||
|
||||
### Phase 1: 拆分 DecoderLayer 执行
|
||||
|
||||
**目标**: 将 `Qwen3DecoderLayer.forward` 拆分为可独立调用的三段
|
||||
|
||||
**修改文件**: `nanovllm/models/qwen3.py`
|
||||
|
||||
**新增方法**:
|
||||
```python
|
||||
class Qwen3DecoderLayer:
|
||||
def forward_pre_attention(self, positions, hidden_states, residual):
|
||||
"""Pre-attention: norm → qkv → rotary → 返回 q, k, v"""
|
||||
if residual is None:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
qkv = self.self_attn.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
q = self.self_attn.q_norm(q)
|
||||
k = self.self_attn.k_norm(k)
|
||||
q, k = self.self_attn.rotary_emb(positions, q, k)
|
||||
return q, k, v, hidden_states, residual
|
||||
|
||||
def forward_post_attention(self, attn_output, hidden_states, residual):
|
||||
"""Post-attention: o_proj → norm → FFN"""
|
||||
output = self.self_attn.o_proj(attn_output.flatten(1, -1))
|
||||
hidden_states, residual = self.post_attention_layernorm(output, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: 捕获 Pre/Post Attention Graph
|
||||
|
||||
**目标**: 捕获 pre_attention 和 post_attention 的 graph
|
||||
|
||||
**修改文件**: `nanovllm/engine/model_runner.py`
|
||||
|
||||
**新增方法**: `capture_offload_layer_graphs()`
|
||||
|
||||
```python
|
||||
def capture_offload_layer_graphs(self):
|
||||
"""捕获 offload mode 的 layer graphs"""
|
||||
# 获取任意一层作为模板(所有层结构相同)
|
||||
layer = self.model.model.layers[0]
|
||||
|
||||
# Static tensors
|
||||
static_hidden = torch.zeros(1, self.hidden_size, ...)
|
||||
static_residual = torch.zeros(1, self.hidden_size, ...)
|
||||
static_positions = torch.zeros(1, ...)
|
||||
|
||||
# Pre-attention graph
|
||||
self.pre_attn_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.pre_attn_graph):
|
||||
static_q, static_k, static_v, _, _ = layer.forward_pre_attention(
|
||||
static_positions, static_hidden, static_residual
|
||||
)
|
||||
|
||||
# Post-attention graph
|
||||
self.post_attn_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.post_attn_graph):
|
||||
_, _ = layer.forward_post_attention(
|
||||
static_attn_output, static_hidden, static_residual
|
||||
)
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: 捕获 Attention Graph
|
||||
|
||||
**目标**: 捕获 2 个 attention graph(causal + non-causal)
|
||||
|
||||
**修改文件**: `nanovllm/kvcache/offload_engine.py`
|
||||
|
||||
```python
|
||||
class OffloadEngine:
|
||||
def capture_attention_graphs(self):
|
||||
"""捕获 attention graphs(复用于所有层)"""
|
||||
self.attn_graph_causal = self._capture_attn_graph(causal=True)
|
||||
self.attn_graph_non_causal = self._capture_attn_graph(causal=False)
|
||||
|
||||
def _capture_attn_graph(self, causal: bool):
|
||||
static_q = torch.zeros(1, 1, num_heads, head_dim, ...)
|
||||
static_k = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
|
||||
static_v = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
output, lse = flash_attn_with_lse(static_q, static_k, static_v,
|
||||
self.scale, causal)
|
||||
return AttentionGraph(graph, static_q, static_k, static_v, output, lse)
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 4: 修改 Offload Decode 执行流程
|
||||
|
||||
**目标**: 使用 graph replay 执行 offload decode
|
||||
|
||||
**修改文件**: `nanovllm/engine/model_runner.py`
|
||||
|
||||
**修改方法**: `run_chunked_offload_decode()`
|
||||
|
||||
```python
|
||||
def run_chunked_offload_decode_with_graph(self, seqs):
|
||||
"""使用 graph 加速的 offload decode"""
|
||||
seq = seqs[0]
|
||||
|
||||
# 准备输入
|
||||
input_ids = torch.tensor([seq.last_token], ...)
|
||||
positions = torch.tensor([len(seq) - 1], ...)
|
||||
|
||||
# Embedding
|
||||
hidden_states = self.model.model.embed_tokens(input_ids)
|
||||
residual = None
|
||||
|
||||
for layer_id, layer in enumerate(self.model.model.layers):
|
||||
# Phase 1: Pre-attention (GRAPH)
|
||||
self.pre_attn_vars["hidden"].copy_(hidden_states)
|
||||
self.pre_attn_vars["residual"].copy_(residual) if residual else None
|
||||
self.pre_attn_vars["positions"].copy_(positions)
|
||||
self.pre_attn_graph.replay()
|
||||
q = self.pre_attn_vars["q"].clone()
|
||||
k = self.pre_attn_vars["k"].clone()
|
||||
v = self.pre_attn_vars["v"].clone()
|
||||
|
||||
# Phase 2: Chunked Attention (Eager + Graph)
|
||||
attn_output = self._chunked_attention_with_graph(q, k, v, layer_id, ...)
|
||||
|
||||
# Phase 3: Post-attention (GRAPH)
|
||||
self.post_attn_vars["attn_output"].copy_(attn_output)
|
||||
self.post_attn_graph.replay()
|
||||
hidden_states = self.post_attn_vars["hidden"].clone()
|
||||
residual = self.post_attn_vars["residual"].clone()
|
||||
|
||||
# LM head
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
return logits
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 5: 修改 Ring Buffer Pipeline
|
||||
|
||||
**目标**: 在 attention 内部使用 graph
|
||||
|
||||
**修改文件**: `nanovllm/kvcache/sparse/full_policy.py`
|
||||
|
||||
**修改**: `_decode_ring_buffer_pipeline()` 中的 `flash_attn_with_lse` 调用
|
||||
|
||||
```python
|
||||
# 当前:eager
|
||||
prev_o, prev_lse = flash_attn_with_lse(q, k, v, scale, causal=False)
|
||||
|
||||
# 修改为:graph replay
|
||||
graph = offload_engine.attn_graph_non_causal
|
||||
graph.static_q.copy_(q)
|
||||
graph.static_k.copy_(k)
|
||||
graph.static_v.copy_(v)
|
||||
graph.graph.replay()
|
||||
prev_o = graph.static_output.clone()
|
||||
prev_lse = graph.static_lse.clone()
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 6: 添加配置开关
|
||||
|
||||
**修改文件**: `nanovllm/config.py`
|
||||
|
||||
```python
|
||||
enable_offload_graph: bool = True # 默认启用
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
## 文件修改清单
|
||||
|
||||
| 文件 | 修改类型 | 说明 |
|
||||
|------|----------|------|
|
||||
| `nanovllm/engine/model_runner.py` | 新增方法 | `capture_offload_attention_graph()` |
|
||||
| `nanovllm/kvcache/offload_engine.py` | 新增属性+方法 | Graph 存储和访问 |
|
||||
| `nanovllm/kvcache/sparse/full_policy.py` | 修改方法 | 使用 graph replay |
|
||||
| `nanovllm/config.py` | 新增配置 | `enable_offload_graph` |
|
||||
|
||||
---
|
||||
|
||||
## 风险和注意事项
|
||||
|
||||
1. **Graph 捕获时机**: 需要在 KV cache 分配后、第一次 decode 前捕获
|
||||
2. **Chunk size 匹配**: Graph 的 chunk_size 必须和 block_size 一致
|
||||
3. **多 GPU**: Graph 需要在每个 GPU 上分别捕获
|
||||
4. **内存**: 2 个 graph 的额外内存开销很小
|
||||
|
||||
---
|
||||
|
||||
## 测试计划
|
||||
|
||||
1. **单元测试**: 验证 graph replay 结果正确
|
||||
2. **集成测试**: 运行 `test_needle.py --enable-offload --input-len 32768`
|
||||
3. **性能测试**: 对比 eager vs graph 的 decode 延迟
|
||||
|
||||
---
|
||||
|
||||
## 预期收益
|
||||
|
||||
- Decode 阶段 attention 计算加速(减少 kernel launch overhead)
|
||||
- 与现有 ring buffer pipeline 兼容
|
||||
- 内存开销极小(只有 2 个额外 graph)
|
||||
334
tests/test_xattn_bsa.py
Normal file
334
tests/test_xattn_bsa.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Test XAttention + BSA with RULER benchmark data.
|
||||
|
||||
Tests XAttention sparse attention correctness using RULER NIAH task.
|
||||
|
||||
Attention methods:
|
||||
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
|
||||
- Decode: FlashAttention (always, since q_len=1)
|
||||
|
||||
Usage (in compass conda env with BSA available):
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
|
||||
|
||||
# Test with XAttention + BSA for prefill (default)
|
||||
python tests/test_xattn_bsa.py --prefill-method xattn
|
||||
|
||||
# Test with FlashAttention for prefill (baseline)
|
||||
python tests/test_xattn_bsa.py --prefill-method flash
|
||||
|
||||
# Test specific sample(s)
|
||||
python tests/test_xattn_bsa.py --sample-id 0
|
||||
python tests/test_xattn_bsa.py --sample-ids 0,1,2
|
||||
|
||||
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
|
||||
and new `past_key_values` API).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
|
||||
# ============================================================
|
||||
# XAttention + BSA Functions
|
||||
# ============================================================
|
||||
|
||||
def expand_kv_for_gqa(key_states, value_states, num_heads):
|
||||
"""Expand KV for Grouped Query Attention."""
|
||||
num_kv_heads = key_states.shape[1]
|
||||
if num_heads == num_kv_heads:
|
||||
return key_states, value_states
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
|
||||
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
|
||||
"""Standard FlashAttention."""
|
||||
from flash_attn import flash_attn_func
|
||||
q = query_states.transpose(1, 2)
|
||||
k = key_states.transpose(1, 2)
|
||||
v = value_states.transpose(1, 2)
|
||||
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
|
||||
|
||||
|
||||
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
|
||||
"""XAttention + BSA sparse attention."""
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
k_len = key_states.shape[2]
|
||||
|
||||
_, mask = xattn_estimate(
|
||||
query_states, key_states,
|
||||
chunk_size=16384, block_size=128, threshold=threshold,
|
||||
use_triton=True, causal=True,
|
||||
)
|
||||
|
||||
q_block_num = (q_len + 127) // 128
|
||||
k_block_num = (k_len + 127) // 128
|
||||
|
||||
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
|
||||
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
output = block_sparse_attn_func(
|
||||
q, k, v,
|
||||
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
|
||||
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
|
||||
torch.ones(num_heads, dtype=torch.int32, device=q.device),
|
||||
None,
|
||||
mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
||||
q_len, k_len,
|
||||
p_dropout=0.0, deterministic=True, is_causal=True,
|
||||
)
|
||||
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
DEBUG = False # Set to True to enable debugging
|
||||
|
||||
def create_patched_forward(prefill_method="xattn", threshold=0.9):
|
||||
"""Create patched forward with configurable prefill method.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
|
||||
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
|
||||
|
||||
Note:
|
||||
- Prefill (q_len > 1): Uses specified prefill_method
|
||||
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
|
||||
"""
|
||||
call_count = [0] # Mutable to track calls across layers
|
||||
|
||||
def patched_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
position_embeddings=None,
|
||||
attention_mask=None,
|
||||
past_key_value=None, # Old API (transformers < 4.57)
|
||||
past_key_values=None, # New API (transformers >= 4.57)
|
||||
cache_position=None,
|
||||
**kwargs
|
||||
):
|
||||
# Handle both old and new transformers API
|
||||
kv_cache = past_key_values if past_key_values is not None else past_key_value
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
num_heads = self.config.num_attention_heads
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
|
||||
# Compute Q, K, V projections
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary position embedding
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Handle KV cache
|
||||
if kv_cache is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = kv_cache.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# Expand KV for GQA
|
||||
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
|
||||
|
||||
# Debug output
|
||||
if DEBUG and self.layer_idx == 0:
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 5:
|
||||
phase = "prefill" if q_len > 1 else "decode"
|
||||
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
|
||||
print(f" kv_cache is None: {kv_cache is None}")
|
||||
|
||||
# Choose attention method:
|
||||
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
|
||||
# - Decode (q_len = 1): Always use FlashAttention
|
||||
is_prefill = q_len > 1
|
||||
|
||||
if is_prefill and prefill_method == "xattn":
|
||||
# Prefill with XAttention + BSA (sparse)
|
||||
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
|
||||
else:
|
||||
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
|
||||
# Note: For decode (q_len=1), causal=False since single query attends to all KV
|
||||
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
|
||||
|
||||
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
|
||||
return attn_output, None
|
||||
|
||||
return patched_forward
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data & Evaluation
|
||||
# ============================================================
|
||||
|
||||
def load_samples(filepath, indices=None):
|
||||
"""Load samples from JSONL file."""
|
||||
samples = []
|
||||
with open(filepath) as f:
|
||||
for i, line in enumerate(f):
|
||||
if indices is None or i in indices:
|
||||
sample = json.loads(line)
|
||||
sample["_idx"] = i
|
||||
samples.append(sample)
|
||||
return samples
|
||||
|
||||
|
||||
def string_match_all(output_text, expected_list):
|
||||
"""RULER metric: fraction of expected values found in output."""
|
||||
output_lower = output_text.lower().replace('\n', ' ')
|
||||
if not expected_list:
|
||||
return 1.0
|
||||
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test
|
||||
# ============================================================
|
||||
|
||||
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
|
||||
"""Test attention methods using RULER data.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
|
||||
"""
|
||||
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
|
||||
|
||||
print("=" * 60)
|
||||
print("RULER NIAH Attention Test")
|
||||
print("=" * 60)
|
||||
print(f"Data: {data_file}")
|
||||
print(f"Samples: {sample_ids}")
|
||||
print(f"Prefill method: {prefill_desc}")
|
||||
print(f"Decode method: FlashAttention (always)")
|
||||
if prefill_method == "xattn":
|
||||
print(f"XAttention threshold: {threshold}")
|
||||
|
||||
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
|
||||
if not samples:
|
||||
print("No samples found!")
|
||||
return False
|
||||
print(f"Loaded {len(samples)} samples")
|
||||
|
||||
# Load model
|
||||
print(f"\nLoading model: {model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="cuda",
|
||||
attn_implementation="eager", # Will be patched
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Patch all layers
|
||||
print(f"Patching attention layers...")
|
||||
print(f" - Prefill: {prefill_desc}")
|
||||
print(f" - Decode: FlashAttention")
|
||||
for idx, layer in enumerate(model.model.layers):
|
||||
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
|
||||
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
|
||||
layer.self_attn, type(layer.self_attn)
|
||||
)
|
||||
|
||||
total_score = 0.0
|
||||
results = []
|
||||
|
||||
for sample in samples:
|
||||
idx = sample["_idx"]
|
||||
prompt = sample["input"]
|
||||
expected = sample["outputs"]
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
num_tokens = inputs["input_ids"].shape[1]
|
||||
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
|
||||
print(f"Expected: {expected}")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
inputs["input_ids"],
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
|
||||
score = string_match_all(output_text, expected)
|
||||
total_score += score
|
||||
|
||||
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
|
||||
print(f"Output: '{output_text[:100]}...'")
|
||||
print(f"Result: {status} (score={score:.2f})")
|
||||
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
|
||||
|
||||
avg_score = total_score / len(samples)
|
||||
passed = sum(1 for r in results if r["passed"])
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return avg_score >= 0.5
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
|
||||
)
|
||||
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
|
||||
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
|
||||
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
|
||||
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
|
||||
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
|
||||
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
|
||||
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=50)
|
||||
# Keep old option for backwards compatibility
|
||||
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model.replace("~", "/home/zijie")
|
||||
|
||||
# Handle deprecated --no-xattn option
|
||||
prefill_method = args.prefill_method
|
||||
if args.no_xattn:
|
||||
prefill_method = "flash"
|
||||
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
|
||||
|
||||
if args.sample_id is not None:
|
||||
sample_ids = [args.sample_id]
|
||||
elif args.sample_ids:
|
||||
sample_ids = [int(x) for x in args.sample_ids.split(",")]
|
||||
else:
|
||||
sample_ids = [0]
|
||||
|
||||
# Check BSA availability if using xattn
|
||||
if prefill_method == "xattn":
|
||||
try:
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
print("✓ BSA (Block Sparse Attention) available")
|
||||
except ImportError:
|
||||
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
|
||||
sys.exit(1)
|
||||
|
||||
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
|
||||
print("\ntest_xattn_bsa: PASSED")
|
||||
else:
|
||||
print("\ntest_xattn_bsa: FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
259
tests/test_xattn_chunked.py
Normal file
259
tests/test_xattn_chunked.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
|
||||
|
||||
Uses real QKV data captured from model inference.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096
|
||||
|
||||
# Default QKV data directory (relative to project root)
|
||||
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def load_qkv(path):
|
||||
"""Load saved QKV data."""
|
||||
data = torch.load(path, map_location="cpu", weights_only=False)
|
||||
print(f"Loaded: {path}")
|
||||
print(f" Query shape: {data['query'].shape}")
|
||||
print(f" Key shape: {data['key'].shape}")
|
||||
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
|
||||
return data
|
||||
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=q_start_pos,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
k_end = q_start_pos + q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_start_pos + q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_qkv(qkv_path):
|
||||
"""Test a single QKV file."""
|
||||
data = load_qkv(qkv_path)
|
||||
query = data["query"].cuda().to(torch.bfloat16)
|
||||
key = data["key"].cuda().to(torch.bfloat16)
|
||||
|
||||
seq_len = query.shape[2]
|
||||
print(f"\nTesting with seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
)
|
||||
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
|
||||
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
|
||||
help="Directory containing QKV files")
|
||||
args = parser.parse_args()
|
||||
|
||||
# QKV files to test
|
||||
qkv_files = [
|
||||
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
|
||||
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
|
||||
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
|
||||
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
|
||||
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
|
||||
]
|
||||
|
||||
available_files = [p for p in qkv_files if os.path.exists(p)]
|
||||
|
||||
if not available_files:
|
||||
print(f"No QKV file found in {args.qkv_dir}.")
|
||||
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(available_files)} QKV files to test")
|
||||
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
|
||||
print(f"Using Triton kernels")
|
||||
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for qkv_path in available_files:
|
||||
passed = test_single_qkv(qkv_path)
|
||||
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
|
||||
results.append((seq_len, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("test_xattn_chunked: PASSED")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("test_xattn_chunked: FAILED")
|
||||
sys.exit(1)
|
||||
129
tests/test_xattn_kernels.py
Normal file
129
tests/test_xattn_kernels.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Test: XAttention Triton kernels
|
||||
|
||||
演示 XAttention 的两个核心 Triton kernel:
|
||||
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
|
||||
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
|
||||
|
||||
数据流:
|
||||
Q [batch, heads, q_len, head_dim]
|
||||
K [batch, heads, kv_len, head_dim]
|
||||
↓ flat_group_gemm_fuse_reshape
|
||||
attn_scores [batch, heads, q_len/stride, kv_len/stride]
|
||||
↓ softmax_fuse_block_sum
|
||||
block_sums [batch, heads, q_blocks, k_blocks]
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
|
||||
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
|
||||
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
|
||||
q_len = 512
|
||||
kv_len = 2048
|
||||
head_dim = 128
|
||||
stride = 4
|
||||
block_size = 128 # softmax block size (in reshaped space)
|
||||
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
|
||||
|
||||
# ============================================================
|
||||
# 构造输入: 偶数位置=1, 奇数位置=2
|
||||
# ============================================================
|
||||
|
||||
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
|
||||
for i in range(q_len):
|
||||
if i % 2 == 0:
|
||||
Q[0, 0, i, :] = 1
|
||||
else:
|
||||
Q[0, 0, i, :] = 2
|
||||
|
||||
for i in range(kv_len):
|
||||
if i % 2 == 0:
|
||||
K[0, 0, i, :] = 1
|
||||
else:
|
||||
K[0, 0, i, :] = 2
|
||||
|
||||
# ============================================================
|
||||
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
|
||||
# ============================================================
|
||||
|
||||
q_reshaped_len = q_len // stride # 128
|
||||
kv_reshaped_len = kv_len // stride # 512
|
||||
|
||||
# 将 K 沿着长度维度分成多个 chunk
|
||||
k_chunk_size = 512 # 每个 chunk 512 tokens
|
||||
num_k_chunks = kv_len // k_chunk_size # 4 chunks
|
||||
|
||||
attn_scores_list = []
|
||||
for k_chunk_idx in range(num_k_chunks):
|
||||
k_start = k_chunk_idx * k_chunk_size
|
||||
k_end = k_start + k_chunk_size
|
||||
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
|
||||
|
||||
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
|
||||
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
|
||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
is_causal=False
|
||||
)
|
||||
attn_scores_list.append(attn_chunk)
|
||||
|
||||
# 拼接所有 K chunks 的结果
|
||||
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
|
||||
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
|
||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
||||
|
||||
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
|
||||
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
|
||||
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
|
||||
|
||||
# 验证: 反对角线求和
|
||||
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
|
||||
# 反对角线有 stride/2 对,再乘以 head_dim
|
||||
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
|
||||
actual_gemm = attn_scores[0, 0, 0, 0].item()
|
||||
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
|
||||
|
||||
# ============================================================
|
||||
# Step 2: softmax_fuse_block_sum
|
||||
# ============================================================
|
||||
|
||||
scale = 1.4426950408889634 # log2(e) for exp2
|
||||
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
block_size,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
real_q_len=q_reshaped_len,
|
||||
scale=scale,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# 验证 shape: [batch, heads, q_blocks, k_blocks]
|
||||
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
|
||||
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
|
||||
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
|
||||
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
|
||||
|
||||
# 验证: 每个 block 的 softmax 结果求和
|
||||
# 所有 attn_scores 相同 → softmax 均匀分布
|
||||
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
|
||||
# 每个 Q block 有 block_size 行
|
||||
# block_sum = block_size * (block_size / kv_reshaped_len)
|
||||
expected_sum = block_size * block_size / kv_reshaped_len
|
||||
actual_sum = block_sums[0, 0, 0, 0].item()
|
||||
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
|
||||
|
||||
print("test_xattn_kernels: PASSED")
|
||||
Reference in New Issue
Block a user