Compare commits

2 Commits

Author SHA1 Message Date
Zijie Tian
f3e4611e3b 📝 docs: add XAttention performance analysis documentation
Add comprehensive performance analysis for XAttention:
- NVTX marker locations and usage
- Block size impact on offload mode (4096 vs 1024)
- Detailed timing breakdown for estimate vs compute phases
- softmax_fuse_block_sum_kernel analysis
- Optimization recommendations

Key findings:
- block_size=4096 is 2x faster than 1024 for 64K context
- find_blocks_chunked is bottleneck (40%) at block_size=4096
- estimate_gemm becomes bottleneck (24%) at block_size=1024

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 00:57:20 +08:00
Zijie Tian
7b5d3b34eb 📈 feat: add NVTX markers to XAttention for profiling
Add NVTX range markers to track XAttention performance:
- GPU-only: xattn_estimate, xattn_bsa_compute
- Offload: xattn_estimate_gemm, xattn_estimate_softmax,
  xattn_estimate_find_blocks, xattn_compute_historical,
  xattn_compute_current, xattn_compute_merge

These markers enable detailed nsys profiling to identify
performance bottlenecks in estimate vs compute phases.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 00:57:11 +08:00
3 changed files with 313 additions and 133 deletions

View File

@@ -30,6 +30,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/bench_offload_results.md`](docs/bench_offload_results.md) | 📊 BENCH: CPU offload 性能测试结果Full vs XAttention 对比 (32K/128K) | | [`docs/bench_offload_results.md`](docs/bench_offload_results.md) | 📊 BENCH: CPU offload 性能测试结果Full vs XAttention 对比 (32K/128K) |
| [`docs/cpu_offload_optimization_strategies.md`](docs/cpu_offload_optimization_strategies.md) | 🚀 OPT: CPU offload 优化策略chunk size、CUDA Graph、前沿研究(InfiniGen/ShadowKV) | | [`docs/cpu_offload_optimization_strategies.md`](docs/cpu_offload_optimization_strategies.md) | 🚀 OPT: CPU offload 优化策略chunk size、CUDA Graph、前沿研究(InfiniGen/ShadowKV) |
| [`docs/gpu_only_xattn_guide.md`](docs/gpu_only_xattn_guide.md) | 🚀 GPU-Only XAttention: 内存预分配、性能分析 (32K +15%, 64K +41%)、CUDA Graph 限制 | | [`docs/gpu_only_xattn_guide.md`](docs/gpu_only_xattn_guide.md) | 🚀 GPU-Only XAttention: 内存预分配、性能分析 (32K +15%, 64K +41%)、CUDA Graph 限制 |
| [`docs/xattn_performance_analysis.md`](docs/xattn_performance_analysis.md) | 📊 XAttention 性能分析: NVTX 标记、block size 影响、estimate vs compute 耗时对比 |
## Rules Index ## Rules Index

View File

@@ -0,0 +1,170 @@
# XAttention Performance Analysis
本文档记录 XAttention 在不同配置下的性能分析结果,包括 NVTX 标记位置、block size 影响和性能瓶颈。
## NVTX 标记
XAttention 代码中添加了 NVTX 标记用于 nsys profiling便于分析 estimate 和 compute 阶段的性能。
### 标记位置
| 模式 | 标记名称 | 文件位置 | 说明 |
|------|---------|---------|------|
| GPU-only | `xattn_estimate` | `xattn_bsa.py:compute_prefill` | xattn_estimate 调用 |
| GPU-only | `xattn_bsa_compute` | `xattn_bsa.py:compute_prefill` | BSA kernel 调用 |
| Offload | `xattn_estimate_gemm` | `xattn_bsa.py:select_blocks` | flat_group_gemm 循环 |
| Offload | `xattn_estimate_softmax` | `xattn_bsa.py:select_blocks` | softmax_fuse_block_sum |
| Offload | `xattn_estimate_find_blocks` | `xattn_bsa.py:select_blocks` | find_blocks_chunked |
| Offload | `xattn_compute_historical` | `xattn_bsa.py:compute_chunked_prefill` | 历史 chunks attention |
| Offload | `xattn_compute_current` | `xattn_bsa.py:compute_chunked_prefill` | 当前 chunk attention |
| Offload | `xattn_compute_merge` | `xattn_bsa.py:compute_chunked_prefill` | merge 操作 |
### 查看 NVTX 统计
```bash
# 生成 profile
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 4096 --gpu 0
# 查看 NVTX 统计
nsys stats --report nvtx_pushpop_sum results/nsys/<filename>.nsys-rep
```
## Block Size 对 Offload 模式的影响
### 测试配置
- Model: Llama-3.1-8B-Instruct
- Context: 64K tokens
- Mode: xattn + offload
- GPU: A100 40GB
### 性能对比
| 指标 | block_size=4096 | block_size=1024 | 变化 |
|------|----------------|-----------------|------|
| **总时间** | 27.7s | 55.5s | **2x 慢** |
| **Chunks 数量** | 16 | 64 | 4x |
| **CPU blocks** | 18 | 71 | ~4x |
### 各阶段耗时分布
#### block_size=4096
| 阶段 | 占比 | 总时间 | 平均时间 | 调用次数 |
|-----|------|--------|---------|---------|
| **xattn_estimate_find_blocks** | **39.7%** | 18.0s | 37.6ms | 480 |
| xattn_compute_historical | 4.4% | 2.0s | 4.2ms | 480 |
| xattn_estimate_gemm | 3.4% | 1.5s | 3.2ms | 480 |
| xattn_compute_current | 0.2% | 113ms | 0.22ms | 512 |
| xattn_compute_merge | 0.2% | 96ms | 0.19ms | 512 |
| xattn_estimate_softmax | 0.2% | 88ms | 0.18ms | 480 |
#### block_size=1024
| 阶段 | 占比 | 总时间 | 平均时间 | 调用次数 |
|-----|------|--------|---------|---------|
| **xattn_estimate_gemm** | **23.6%** | 22.6s | 11.4ms | 1984 |
| **xattn_compute_historical** | **16.9%** | 16.2s | 8.0ms | 2016 |
| xattn_estimate_find_blocks | 1.4% | 1.3s | 0.66ms | 1984 |
| xattn_compute_current | 0.5% | 433ms | 0.21ms | 2048 |
| xattn_compute_merge | 0.4% | 373ms | 0.18ms | 2048 |
| xattn_estimate_softmax | 0.2% | 222ms | 0.11ms | 1984 |
### 关键发现
1. **Block size 对性能影响显著**
- block_size=1024 比 4096 慢约 2x
- 更小的 block size 导致更多的 chunks增加调用次数
2. **性能瓶颈随 block size 变化**
- **block_size=4096**: 瓶颈是 `find_blocks_chunked` (39.7%)
- **block_size=1024**: 瓶颈转移到 `estimate_gemm` (23.6%) 和 `compute_historical` (16.9%)
3. **Amortization 效应**
- 大 block size 虽然单次 `find_blocks` 更慢 (37.6ms vs 0.66ms)
- 但调用次数少 (480 vs 1984),总时间反而更少
4. **find_blocks_chunked 的特殊性**
- 该函数主要在 CPU 上执行 block 选择逻辑
- 处理更大的数据量时开销显著增加
- block_size=4096 时占用 40% 时间,是主要优化目标
## softmax_fuse_block_sum_kernel 性能分析
`softmax_fuse_block_sum_kernel_non_causal` 是 XAttention 估计阶段的核心 Triton kernel。
### Kernel 结构
```python
# 每个 thread block 处理的数据形状
工作负载: [block_size, segment_size] # 单个 Q block 对所有 K 的注意力
# Pass 1: 计算全局 softmax 参数 (m_i, l_i)
for iter in range(num_iters): # num_iters = k_len / segment_size
X = load [block_size, segment_size]
compute max, sum for softmax normalization
# Pass 2: Normalize + Block Sum
for iter in range(num_iters):
X = load [block_size, segment_size]
X = softmax(X)
X = reshape(X, [block_size, segment_size/block_size, block_size])
X = sum(X, axis=2) # → [block_size, segment_size/block_size]
X = sum(X, axis=0) # → [segment_size/block_size]
store output
```
### 性能随 block_size 变化的因素
| 因素 | 小 block_size (64) | 大 block_size (256) |
|------|-------------------|---------------------|
| Grid 并行度 | 高 (更多 blocks) | 低 (更少 blocks) |
| 寄存器使用 | 低 | 高 (可能 spill) |
| L2 Cache 复用 | 差 | 好 |
| 输出大小 | 大 | 小 |
### 典型性能曲线
```
Performance
│ ┌─────┐
│ / \
│ / \
│ / \
│ / \
└────/───────────────\────────→ block_size
64 128 256 512
最优点通常在 128-256 之间
```
## 优化建议
1. **优先使用 block_size=4096**
- 减少 chunk 数量,降低调度开销
- 更好的 amortization 效果
2. **优化 find_blocks_chunked**
- 当前是 block_size=4096 的主要瓶颈
- 考虑 GPU 加速或批量处理
3. **Pipeline 优化**
- 利用多 slot 的 ring buffer 实现计算和传输 overlap
- 当前已实现,但 find_blocks 是 CPU 操作,无法 overlap
## 测试命令
```bash
# GPU-only 模式 (需要 40GB+ VRAM)
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --no-offload --gpu 0
# Offload 模式block_size=4096
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 4096 --gpu 0
# Offload 模式block_size=1024
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 1024 --gpu 0
# 128K context
bash scripts/profile_offload.sh --policy xattn --ctx-len 128k --block-size 4096 --gpu 0
```

View File

@@ -13,6 +13,7 @@ Note: Decode phase is not supported - use FullAttentionPolicy for decode.
import logging import logging
import torch import torch
import torch.cuda.nvtx as nvtx
from typing import List, Tuple, TYPE_CHECKING from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
@@ -304,6 +305,7 @@ class XAttentionBSAPolicy(SparsePolicy):
K_exp, V_exp = K, V K_exp, V_exp = K, V
# Estimate block importance and get sparse mask # Estimate block importance and get sparse mask
with nvtx.range("xattn_estimate"):
_, mask = xattn_estimate( _, mask = xattn_estimate(
Q, K_exp, Q, K_exp,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
@@ -339,6 +341,7 @@ class XAttentionBSAPolicy(SparsePolicy):
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous() mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
# Compute sparse attention using BSA # Compute sparse attention using BSA
with nvtx.range("xattn_bsa_compute"):
output = block_sparse_attn_func( output = block_sparse_attn_func(
q_bsa, k_bsa, v_bsa, q_bsa, k_bsa, v_bsa,
cu_seqlens_q_bsa, cu_seqlens_q_bsa,
@@ -453,6 +456,7 @@ class XAttentionBSAPolicy(SparsePolicy):
block_size = ctx.block_size # tokens per CPU block (e.g., 1024) block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128 reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
with nvtx.range("xattn_estimate_gemm"):
for cpu_block_id in available_blocks: for cpu_block_id in available_blocks:
# Load K block from CPU to GPU (cpu_block_id is chunk index) # Load K block from CPU to GPU (cpu_block_id is chunk index)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
@@ -510,6 +514,7 @@ class XAttentionBSAPolicy(SparsePolicy):
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
segment_size = min(4096, reshaped_block_size) segment_size = min(4096, reshaped_block_size)
with nvtx.range("xattn_estimate_softmax"):
block_sums = softmax_fuse_block_sum( block_sums = softmax_fuse_block_sum(
attn_scores, attn_scores,
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128) reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
@@ -525,6 +530,7 @@ class XAttentionBSAPolicy(SparsePolicy):
# Step 3: Use find_blocks_chunked to get selection mask # Step 3: Use find_blocks_chunked to get selection mask
# current_index = 0 since we're looking at historical blocks only # current_index = 0 since we're looking at historical blocks only
with nvtx.range("xattn_estimate_find_blocks"):
mask = find_blocks_chunked( mask = find_blocks_chunked(
block_sums, block_sums,
current_index=0, current_index=0,
@@ -639,6 +645,7 @@ class XAttentionBSAPolicy(SparsePolicy):
cpu_block_table = selected_blocks cpu_block_table = selected_blocks
if cpu_block_table: if cpu_block_table:
with nvtx.range("xattn_compute_historical"):
load_slots = list(range(offload_engine.num_ring_slots)) load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table) num_blocks = len(cpu_block_table)
@@ -697,6 +704,7 @@ class XAttentionBSAPolicy(SparsePolicy):
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Compute attention to current chunk (causal mask) # Compute attention to current chunk (causal mask)
with nvtx.range("xattn_compute_current"):
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse( current_o, current_lse = flash_attn_with_lse(
@@ -706,6 +714,7 @@ class XAttentionBSAPolicy(SparsePolicy):
) )
# Merge historical and current attention # Merge historical and current attention
with nvtx.range("xattn_compute_merge"):
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
if o_acc is None: if o_acc is None:
final_o = current_o final_o = current_o