From e440c45e73061827089bfbbdd71403a724f037ac Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Tue, 20 Jan 2026 02:50:03 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20docs:=20add=20XAttention=20algor?= =?UTF-8?q?ithm=20guide=20based=20on=20COMPASS=20implementation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create docs/xattention_algorithm_guide.md with detailed algorithm explanation - Stride reshape (inverse mode) for Q/K interleaved sampling - Triton kernels: flat_group_gemm_fuse_reshape, softmax_fuse_block_sum - Block selection via find_blocks_chunked with cumulative threshold - BSA (block_sparse_attn) dependency for sparse computation - Update docs/sparse_attention_guide.md XAttention section with accurate description - Add documentation index entry in CLAUDE.md Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 1 + docs/sparse_attention_guide.md | 70 +++--- docs/xattention_algorithm_guide.md | 349 +++++++++++++++++++++++++++++ 3 files changed, 395 insertions(+), 25 deletions(-) create mode 100644 docs/xattention_algorithm_guide.md diff --git a/CLAUDE.md b/CLAUDE.md index 66eb8fb..b2c64a5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -14,6 +14,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L | [`docs/sparse_policy_architecture.md`](docs/sparse_policy_architecture.md) | SparsePolicy abstraction: prefill/decode delegation, pipeline modes, policy implementations | | [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern | | [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms | +| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 | | [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling | | [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) | | [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition | diff --git a/docs/sparse_attention_guide.md b/docs/sparse_attention_guide.md index bda5b13..09e6e2d 100644 --- a/docs/sparse_attention_guide.md +++ b/docs/sparse_attention_guide.md @@ -50,30 +50,35 @@ output = block_sparse_attn_func( ## Method 1: XAttention (xattn_estimate) -**Source**: `xattn/src/Xattention.py` +**Source**: `compass/src/Xattention.py` + +**详细文档**: [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md) ### Core Idea -Use **strided Q/K reshaping** to create coarse-grained representations, compute block-level attention scores, and select blocks above a threshold. +Use **stride interleaved reshape (inverse mode)** to efficiently estimate block-level attention importance, then use **BSA (Block Sparse Attention)** library for sparse computation. ### Algorithm ```python -def xattn_estimate(query, key, block_size=64, stride=16): +def xattn_estimate(query, key, block_size=128, stride=8): """ - Estimate block importance using strided attention. + Estimate block importance using stride-interleaved attention. - 1. Reshape Q: [batch, seq, heads, dim] -> [batch, num_blocks, stride, heads, dim] - Then take mean over stride dimension to get block-level Q + 1. K reshape (正向交错): concat([K[:,:,k::stride,:] for k in range(stride)]) + Q reshape (反向交错): concat([Q[:,:,(stride-1-q)::stride,:] for q]) + 结果: 序列长度 seq_len -> seq_len/stride, head_dim -> head_dim*stride - 2. Reshape K: Same process to get block-level K + 2. Triton kernel (flat_group_gemm_fuse_reshape): + 融合 reshape + GEMM,计算 Q_reshaped @ K_reshaped^T - 3. Compute block attention: softmax(block_Q @ block_K.T / sqrt(d)) - Result shape: [batch, heads, q_blocks, k_blocks] + 3. Triton kernel (softmax_fuse_block_sum): + 在线 softmax + 按 block_size/stride 分组求和 + 输出: attn_sum [batch, heads, q_blocks, k_blocks] - 4. Apply causal mask (upper triangle = 0) - - 5. Threshold: blocks with score > threshold are selected + 4. find_blocks_chunked: + 按 attn_sum 降序排序,累积到 threshold 的块标记为 True + 对角块和 sink 块始终保留 """ ``` @@ -81,45 +86,60 @@ def xattn_estimate(query, key, block_size=64, stride=16): | Parameter | Default | Description | |-----------|---------|-------------| -| `block_size` | 64 | Tokens per block | -| `stride` | 16 | Stride for coarse Q/K computation | -| `threshold` | 0.9 | Selection threshold (cumulative or direct) | +| `block_size` | 128 | Tokens per block (BSA 要求固定 128) | +| `stride` | 8 | Q/K 交错采样步长,越大估计越快但越粗糙 | +| `threshold` | 0.9 | 累积注意力阈值,选择累积权重达到此比例的块 | +| `chunk_size` | 16384 | 估计时的分块大小 | ### Computation Flow ``` -query [B, S, H, D] +query [B, H, S, D] | v -Reshape to [B, num_blocks, stride, H, D] +Stride interleaved reshape (Triton fused) | v -Mean over stride -> block_q [B, num_blocks, H, D] +flat_group_gemm_fuse_reshape: Q_r @ K_r^T | v -Compute block attention scores [B, H, q_blocks, k_blocks] +softmax_fuse_block_sum: 在线 softmax + 块求和 | v -Apply threshold -> block_mask [B, H, q_blocks, k_blocks] +attn_sum [B, H, q_blocks, k_blocks] | v -block_sparse_attn_func(q, k, v, block_mask) +find_blocks_chunked: 累积阈值选择 | v -output [B, S, H, D] +simple_mask [B, H, q_blocks, k_blocks] (bool) + | + v +block_sparse_attn_func(q, k, v, simple_mask) ← BSA 库 + | + v +output [B, H, S, D] +``` + +### Dependencies + +```python +from block_sparse_attn import block_sparse_attn_func # MIT-HAN-LAB BSA 库 +import triton # Triton kernels for estimation ``` ### Usage ```python -from xattn.src.Xattention import Xattention_prefill +from compass.src.Xattention import Xattention_prefill output = Xattention_prefill( query_states, key_states, value_states, threshold=0.9, - stride=16, + stride=8, + block_size=128, + use_triton=True, ) -``` --- diff --git a/docs/xattention_algorithm_guide.md b/docs/xattention_algorithm_guide.md new file mode 100644 index 0000000..5cca74c --- /dev/null +++ b/docs/xattention_algorithm_guide.md @@ -0,0 +1,349 @@ +# XAttention 算法实现指南 + +本文档详细描述 COMPASS 项目中 XAttention 的算法原理和实现细节。 + +## 概述 + +XAttention 是一种基于 **stride reshape** 的块稀疏注意力方法,通过低成本估计识别重要块,然后使用 **BSA (Block Sparse Attention)** 库执行稀疏计算。 + +### 核心依赖 + +| 组件 | 来源 | 作用 | +|------|------|------| +| Triton Kernels | COMPASS 自研 | Q/K reshape + 块级估计 | +| BSA | MIT-HAN-LAB `block_sparse_attn` | 稀疏注意力计算 | + +--- + +## 算法流程 + +``` +输入: Q [batch, heads, q_len, head_dim] + K [batch, heads, k_len, head_dim] + V [batch, heads, k_len, head_dim] + +┌─────────────────────────────────────────────────────────────┐ +│ Phase 1: Stride Reshape (inverse 模式) │ +│ │ +│ K_reshaped = concat([K[:,:,k::stride,:] for k in stride]) │ +│ Q_reshaped = concat([Q[:,:,(stride-1-q)::stride,:] for q]) │ +│ │ +│ 效果: 序列长度从 seq_len 缩短到 seq_len/stride │ +│ head_dim 扩展到 head_dim * stride │ +└─────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ Phase 2: 块级注意力估计 (Triton 加速) │ +│ │ +│ 2a. flat_group_gemm_fuse_reshape: │ +│ 计算 Q_reshaped @ K_reshaped^T │ +│ 输出: attn_weights [batch, heads, q_len/stride, k_len/stride] │ +│ │ +│ 2b. softmax_fuse_block_sum: │ +│ - 在线 softmax (数值稳定) │ +│ - 按 block_size/stride 分组求和 │ +│ 输出: attn_sum [batch, heads, q_blocks, k_blocks] │ +└─────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ Phase 3: 块选择 (find_blocks_chunked) │ +│ │ +│ 对每个 Q block: │ +│ 1. 按 attn_sum 降序排序 K blocks │ +│ 2. 累积求和直到 >= threshold * total_sum │ +│ 3. 累积到的 blocks 标记为 True │ +│ │ +│ 特殊处理: │ +│ - 对角块 (causal) 始终保留 │ +│ - Sink 块 (block 0) 可选保留 │ +│ │ +│ 输出: simple_mask [batch, heads, q_blocks, k_blocks] (bool) │ +└─────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ Phase 4: 稀疏注意力计算 (BSA) │ +│ │ +│ attn_output = block_sparse_attn_func( │ +│ Q, K, V, │ +│ q_cu_seq_lens, # [0, q_len] │ +│ k_cu_seq_lens, # [0, k_len] │ +│ head_mask_type, # [num_heads] 全 1 │ +│ None, # left_mask │ +│ simple_mask, # 块稀疏 mask │ +│ q_len, k_len, │ +│ is_causal=True, │ +│ ) │ +│ │ +│ 输出: attn_output [batch, heads, q_len, head_dim] │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## Stride Reshape 详解 + +### Inverse 模式 + +XAttention 默认使用 `select_mode="inverse"`,这是一种交错采样策略: + +```python +# 原始: Q/K shape = [batch, heads, seq_len, head_dim] +# stride = 8 + +# K reshape: 正向交错 +K_reshaped = concat([K[:, :, 0::8, :], # 位置 0, 8, 16, ... + K[:, :, 1::8, :], # 位置 1, 9, 17, ... + K[:, :, 2::8, :], # 位置 2, 10, 18, ... + ... + K[:, :, 7::8, :]]) # 位置 7, 15, 23, ... +# 结果: [batch, heads, seq_len/8, head_dim * 8] + +# Q reshape: 反向交错 (inverse) +Q_reshaped = concat([Q[:, :, 7::8, :], # 位置 7, 15, 23, ... + Q[:, :, 6::8, :], # 位置 6, 14, 22, ... + Q[:, :, 5::8, :], # 位置 5, 13, 21, ... + ... + Q[:, :, 0::8, :]]) # 位置 0, 8, 16, ... +# 结果: [batch, heads, seq_len/8, head_dim * 8] +``` + +### 为什么用 Inverse 模式? + +当计算 `Q_reshaped @ K_reshaped^T` 时,inverse 模式使得: +- Q 的后半部分与 K 的前半部分对齐 +- 这样可以近似捕获 **causal attention 的对角模式** + +--- + +## Triton Kernels 详解 + +### 1. flat_group_gemm_fuse_reshape + +**文件**: `compass/src/kernels.py:198-235` + +**功能**: 融合 stride reshape 和 GEMM,避免显式创建 reshape 后的大张量 + +```python +@triton.jit +def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, ...): + # 关键: 不实际 reshape,而是通过指针算术模拟 + Q_ptrs = Q + block_m * BLOCK_M * STRIDE * stride_qn + K_ptrs = K + block_n * BLOCK_N * STRIDE * stride_kn + + # 对 stride 个位置累加 + for iter in range(STRIDE): + q = tl.load(Q_ptrs - iter * stride_qn) # Q inverse 采样 + k = tl.load(K_ptrs + iter * stride_kn) # K 正向采样 + o += tl.dot(q, k) +``` + +**优势**: +- 内存节省: 不需要创建 `[batch, heads, seq_len/stride, head_dim*stride]` 的中间张量 +- 计算融合: reshape + GEMM 一次完成 + +### 2. softmax_fuse_block_sum + +**文件**: `compass/src/kernels.py:6-95` + +**功能**: 在线 softmax + 块内求和 + +```python +@triton.jit +def softmax_fuse_block_sum_kernel_causal(In, Out, ...): + # Pass 1: 计算全局 max 和 sum (在线算法) + for iter in range(num_iters): + X = tl.load(input_ptr + iter * segment_size) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + m_i = m_new + + # Pass 2: 归一化并按块求和 + for iter in range(num_iters): + X = tl.load(input_ptr + iter * segment_size) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] # softmax + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2).sum(0) # 块内求和 + tl.store(output_ptr + iter * segment_size // block_size, X) +``` + +**输出含义**: `attn_sum[b, h, qi, ki]` = Q block qi 对 K block ki 的**归一化注意力权重之和** + +--- + +## 块选择算法 (find_blocks_chunked) + +**文件**: `compass/src/utils.py:44-191` + +### 算法步骤 + +```python +def find_blocks_chunked(input_tensor, current_index, threshold, ...): + """ + input_tensor: [batch, heads, q_blocks, k_blocks] - 块级注意力权重和 + threshold: 0.9 - 累积阈值 + """ + # 1. 计算每行总和 + total_sum = input_tensor.sum(dim=-1, keepdim=True) + required_sum = total_sum * threshold # 需要达到的累积和 + + # 2. 特殊块始终保留 + mask = zeros_like(input_tensor, dtype=bool) + mask[:, :, :, 0] = True # sink 块 + mask[:, :, :, diagonal] = True # 对角块 (causal) + + # 3. 对剩余块按权重排序 + other_values = input_tensor.masked_fill(mask, 0) + sorted_values, index = sort(other_values, descending=True) + + # 4. 累积求和直到达到阈值 + cumsum = sorted_values.cumsum(dim=-1) + index_mask = cumsum < required_sum + + # 5. 标记选中的块 + mask[..., index[index_mask]] = True + + return mask +``` + +### 示例 + +``` +threshold = 0.9 +attn_sum 某一行 = [0.05, 0.30, 0.40, 0.15, 0.10] (已 softmax, 和为 1.0) +required_sum = 0.9 + +排序后: [0.40, 0.30, 0.15, 0.10, 0.05] +累积和: [0.40, 0.70, 0.85, 0.95, 1.00] + ↑ 达到 0.9 + +选中: 前 4 个块 (indices: 2, 1, 3, 4) +``` + +--- + +## BSA (Block Sparse Attention) + +### 库来源 + +```python +from block_sparse_attn import block_sparse_attn_func +``` + +来自 MIT-HAN-LAB,提供基于块 mask 的高效稀疏 FlashAttention 实现。 + +### 接口 + +```python +attn_output = block_sparse_attn_func( + query_states, # [total_q, num_heads, head_dim] + key_states, # [total_k, num_heads, head_dim] + value_states, # [total_k, num_heads, head_dim] + q_cu_seq_lens, # [batch+1] cumulative sequence lengths + k_cu_seq_lens, # [batch+1] + head_mask_type, # [num_heads] int32, 1=causal, 0=full + left_mask, # Optional left padding mask + block_mask, # [batch, heads, q_blocks, k_blocks] bool + max_seqlen_q, # int + max_seqlen_k, # int + p_dropout=0.0, + deterministic=True, + is_causal=True, # 全局 causal flag +) +``` + +### 块大小要求 + +BSA 要求 **block_size = 128**(硬编码): +```python +assert block_size == 128 # Xattention.py:358 +``` + +--- + +## 关键参数 + +| 参数 | 默认值 | 范围 | 作用 | +|------|--------|------|------| +| `stride` | 8 | 4-16 | Q/K 交错采样步长,越大估计越快但越粗糙 | +| `threshold` | 0.9 | 0.7-0.99 | 累积注意力阈值,越高保留块越多 | +| `block_size` | 128 | 128 (固定) | BSA 块大小,不可调 | +| `chunk_size` | 16384 | 2048-131072 | 估计时的分块大小,影响内存使用 | +| `norm` | 1.0 | 0.5-2.0 | 注意力分数归一化系数 | +| `keep_sink` | False | bool | 是否始终保留第一个块 | +| `keep_recent` | False | bool | 是否始终保留对角块 | + +--- + +## 计算复杂度 + +### 估计阶段 + +| 操作 | 复杂度 | +|------|--------| +| Stride reshape GEMM | O(seq_len/stride × seq_len/stride × head_dim × stride) = O(seq_len² × head_dim / stride) | +| Softmax + block sum | O(seq_len² / stride²) | +| Block selection | O(num_blocks² × log(num_blocks)) | + +**估计阶段总复杂度**: O(seq_len² × head_dim / stride) + +### 计算阶段 (BSA) + +设选中块比例为 ρ (通常 0.3-0.5): + +| 操作 | 复杂度 | +|------|--------| +| Block sparse attention | O(ρ × num_blocks² × block_size² × head_dim) = O(ρ × seq_len² × head_dim) | + +**总复杂度**: O(seq_len² × head_dim × (1/stride + ρ)) + +当 stride=8, ρ=0.4 时,相比 full attention 节省约 **50%** 计算量。 + +--- + +## 与 nano-vllm 集成注意事项 + +### 依赖要求 + +``` +block_sparse_attn # pip install block-sparse-attn +triton >= 2.0 # Triton kernels +``` + +### CPU Offload 场景适配 + +XAttention 原始实现假设所有 KV 在 GPU 上。对于 CPU offload 场景,需要: + +1. **估计阶段**: 仍需加载所有历史 KV 到 GPU 进行估计 +2. **计算阶段**: 只加载选中的块 + +这可能需要修改为两阶段 pipeline: +- 先用采样数据估计重要块 +- 再只加载重要块进行计算 + +### block_size 对齐 + +nano-vllm 的 `kvcache_block_size` 需要与 BSA 的 128 对齐: +- 如果 `kvcache_block_size = 1024`,则每个 kv block 包含 8 个 BSA blocks +- 块选择粒度需要相应调整 + +--- + +## 源文件索引 + +| 文件 | 位置 | 内容 | +|------|------|------| +| `Xattention.py` | `compass/src/Xattention.py` | 主入口: `xattn_estimate()`, `Xattention_prefill()` | +| `kernels.py` | `compass/src/kernels.py` | Triton 内核 | +| `utils.py` | `compass/src/utils.py` | `find_blocks_chunked()`, `create_causal_mask()` | + +--- + +## 参考 + +- COMPASS 项目: `/home/zijie/Code/COMPASS/` +- BSA 库: MIT-HAN-LAB block_sparse_attn +- 测试报告: `docs/xattention_bsa_test_report.md`