- Create docs/xattention_algorithm_guide.md with detailed algorithm explanation - Stride reshape (inverse mode) for Q/K interleaved sampling - Triton kernels: flat_group_gemm_fuse_reshape, softmax_fuse_block_sum - Block selection via find_blocks_chunked with cumulative threshold - BSA (block_sparse_attn) dependency for sparse computation - Update docs/sparse_attention_guide.md XAttention section with accurate description - Add documentation index entry in CLAUDE.md Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
350 lines
13 KiB
Markdown
350 lines
13 KiB
Markdown
# XAttention 算法实现指南
|
||
|
||
本文档详细描述 COMPASS 项目中 XAttention 的算法原理和实现细节。
|
||
|
||
## 概述
|
||
|
||
XAttention 是一种基于 **stride reshape** 的块稀疏注意力方法,通过低成本估计识别重要块,然后使用 **BSA (Block Sparse Attention)** 库执行稀疏计算。
|
||
|
||
### 核心依赖
|
||
|
||
| 组件 | 来源 | 作用 |
|
||
|------|------|------|
|
||
| Triton Kernels | COMPASS 自研 | Q/K reshape + 块级估计 |
|
||
| BSA | MIT-HAN-LAB `block_sparse_attn` | 稀疏注意力计算 |
|
||
|
||
---
|
||
|
||
## 算法流程
|
||
|
||
```
|
||
输入: Q [batch, heads, q_len, head_dim]
|
||
K [batch, heads, k_len, head_dim]
|
||
V [batch, heads, k_len, head_dim]
|
||
|
||
┌─────────────────────────────────────────────────────────────┐
|
||
│ Phase 1: Stride Reshape (inverse 模式) │
|
||
│ │
|
||
│ K_reshaped = concat([K[:,:,k::stride,:] for k in stride]) │
|
||
│ Q_reshaped = concat([Q[:,:,(stride-1-q)::stride,:] for q]) │
|
||
│ │
|
||
│ 效果: 序列长度从 seq_len 缩短到 seq_len/stride │
|
||
│ head_dim 扩展到 head_dim * stride │
|
||
└─────────────────────────────────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────────┐
|
||
│ Phase 2: 块级注意力估计 (Triton 加速) │
|
||
│ │
|
||
│ 2a. flat_group_gemm_fuse_reshape: │
|
||
│ 计算 Q_reshaped @ K_reshaped^T │
|
||
│ 输出: attn_weights [batch, heads, q_len/stride, k_len/stride] │
|
||
│ │
|
||
│ 2b. softmax_fuse_block_sum: │
|
||
│ - 在线 softmax (数值稳定) │
|
||
│ - 按 block_size/stride 分组求和 │
|
||
│ 输出: attn_sum [batch, heads, q_blocks, k_blocks] │
|
||
└─────────────────────────────────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────────┐
|
||
│ Phase 3: 块选择 (find_blocks_chunked) │
|
||
│ │
|
||
│ 对每个 Q block: │
|
||
│ 1. 按 attn_sum 降序排序 K blocks │
|
||
│ 2. 累积求和直到 >= threshold * total_sum │
|
||
│ 3. 累积到的 blocks 标记为 True │
|
||
│ │
|
||
│ 特殊处理: │
|
||
│ - 对角块 (causal) 始终保留 │
|
||
│ - Sink 块 (block 0) 可选保留 │
|
||
│ │
|
||
│ 输出: simple_mask [batch, heads, q_blocks, k_blocks] (bool) │
|
||
└─────────────────────────────────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────────┐
|
||
│ Phase 4: 稀疏注意力计算 (BSA) │
|
||
│ │
|
||
│ attn_output = block_sparse_attn_func( │
|
||
│ Q, K, V, │
|
||
│ q_cu_seq_lens, # [0, q_len] │
|
||
│ k_cu_seq_lens, # [0, k_len] │
|
||
│ head_mask_type, # [num_heads] 全 1 │
|
||
│ None, # left_mask │
|
||
│ simple_mask, # 块稀疏 mask │
|
||
│ q_len, k_len, │
|
||
│ is_causal=True, │
|
||
│ ) │
|
||
│ │
|
||
│ 输出: attn_output [batch, heads, q_len, head_dim] │
|
||
└─────────────────────────────────────────────────────────────┘
|
||
```
|
||
|
||
---
|
||
|
||
## Stride Reshape 详解
|
||
|
||
### Inverse 模式
|
||
|
||
XAttention 默认使用 `select_mode="inverse"`,这是一种交错采样策略:
|
||
|
||
```python
|
||
# 原始: Q/K shape = [batch, heads, seq_len, head_dim]
|
||
# stride = 8
|
||
|
||
# K reshape: 正向交错
|
||
K_reshaped = concat([K[:, :, 0::8, :], # 位置 0, 8, 16, ...
|
||
K[:, :, 1::8, :], # 位置 1, 9, 17, ...
|
||
K[:, :, 2::8, :], # 位置 2, 10, 18, ...
|
||
...
|
||
K[:, :, 7::8, :]]) # 位置 7, 15, 23, ...
|
||
# 结果: [batch, heads, seq_len/8, head_dim * 8]
|
||
|
||
# Q reshape: 反向交错 (inverse)
|
||
Q_reshaped = concat([Q[:, :, 7::8, :], # 位置 7, 15, 23, ...
|
||
Q[:, :, 6::8, :], # 位置 6, 14, 22, ...
|
||
Q[:, :, 5::8, :], # 位置 5, 13, 21, ...
|
||
...
|
||
Q[:, :, 0::8, :]]) # 位置 0, 8, 16, ...
|
||
# 结果: [batch, heads, seq_len/8, head_dim * 8]
|
||
```
|
||
|
||
### 为什么用 Inverse 模式?
|
||
|
||
当计算 `Q_reshaped @ K_reshaped^T` 时,inverse 模式使得:
|
||
- Q 的后半部分与 K 的前半部分对齐
|
||
- 这样可以近似捕获 **causal attention 的对角模式**
|
||
|
||
---
|
||
|
||
## Triton Kernels 详解
|
||
|
||
### 1. flat_group_gemm_fuse_reshape
|
||
|
||
**文件**: `compass/src/kernels.py:198-235`
|
||
|
||
**功能**: 融合 stride reshape 和 GEMM,避免显式创建 reshape 后的大张量
|
||
|
||
```python
|
||
@triton.jit
|
||
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, ...):
|
||
# 关键: 不实际 reshape,而是通过指针算术模拟
|
||
Q_ptrs = Q + block_m * BLOCK_M * STRIDE * stride_qn
|
||
K_ptrs = K + block_n * BLOCK_N * STRIDE * stride_kn
|
||
|
||
# 对 stride 个位置累加
|
||
for iter in range(STRIDE):
|
||
q = tl.load(Q_ptrs - iter * stride_qn) # Q inverse 采样
|
||
k = tl.load(K_ptrs + iter * stride_kn) # K 正向采样
|
||
o += tl.dot(q, k)
|
||
```
|
||
|
||
**优势**:
|
||
- 内存节省: 不需要创建 `[batch, heads, seq_len/stride, head_dim*stride]` 的中间张量
|
||
- 计算融合: reshape + GEMM 一次完成
|
||
|
||
### 2. softmax_fuse_block_sum
|
||
|
||
**文件**: `compass/src/kernels.py:6-95`
|
||
|
||
**功能**: 在线 softmax + 块内求和
|
||
|
||
```python
|
||
@triton.jit
|
||
def softmax_fuse_block_sum_kernel_causal(In, Out, ...):
|
||
# Pass 1: 计算全局 max 和 sum (在线算法)
|
||
for iter in range(num_iters):
|
||
X = tl.load(input_ptr + iter * segment_size) * scale
|
||
m_local = tl.max(X, 1)
|
||
m_new = tl.maximum(m_i, m_local)
|
||
alpha = tl.math.exp2(m_i - m_new)
|
||
X = X - m_new[:, None]
|
||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||
l_i = l_i * alpha + l_local
|
||
m_i = m_new
|
||
|
||
# Pass 2: 归一化并按块求和
|
||
for iter in range(num_iters):
|
||
X = tl.load(input_ptr + iter * segment_size) * scale
|
||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] # softmax
|
||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||
X = tl.sum(X, 2).sum(0) # 块内求和
|
||
tl.store(output_ptr + iter * segment_size // block_size, X)
|
||
```
|
||
|
||
**输出含义**: `attn_sum[b, h, qi, ki]` = Q block qi 对 K block ki 的**归一化注意力权重之和**
|
||
|
||
---
|
||
|
||
## 块选择算法 (find_blocks_chunked)
|
||
|
||
**文件**: `compass/src/utils.py:44-191`
|
||
|
||
### 算法步骤
|
||
|
||
```python
|
||
def find_blocks_chunked(input_tensor, current_index, threshold, ...):
|
||
"""
|
||
input_tensor: [batch, heads, q_blocks, k_blocks] - 块级注意力权重和
|
||
threshold: 0.9 - 累积阈值
|
||
"""
|
||
# 1. 计算每行总和
|
||
total_sum = input_tensor.sum(dim=-1, keepdim=True)
|
||
required_sum = total_sum * threshold # 需要达到的累积和
|
||
|
||
# 2. 特殊块始终保留
|
||
mask = zeros_like(input_tensor, dtype=bool)
|
||
mask[:, :, :, 0] = True # sink 块
|
||
mask[:, :, :, diagonal] = True # 对角块 (causal)
|
||
|
||
# 3. 对剩余块按权重排序
|
||
other_values = input_tensor.masked_fill(mask, 0)
|
||
sorted_values, index = sort(other_values, descending=True)
|
||
|
||
# 4. 累积求和直到达到阈值
|
||
cumsum = sorted_values.cumsum(dim=-1)
|
||
index_mask = cumsum < required_sum
|
||
|
||
# 5. 标记选中的块
|
||
mask[..., index[index_mask]] = True
|
||
|
||
return mask
|
||
```
|
||
|
||
### 示例
|
||
|
||
```
|
||
threshold = 0.9
|
||
attn_sum 某一行 = [0.05, 0.30, 0.40, 0.15, 0.10] (已 softmax, 和为 1.0)
|
||
required_sum = 0.9
|
||
|
||
排序后: [0.40, 0.30, 0.15, 0.10, 0.05]
|
||
累积和: [0.40, 0.70, 0.85, 0.95, 1.00]
|
||
↑ 达到 0.9
|
||
|
||
选中: 前 4 个块 (indices: 2, 1, 3, 4)
|
||
```
|
||
|
||
---
|
||
|
||
## BSA (Block Sparse Attention)
|
||
|
||
### 库来源
|
||
|
||
```python
|
||
from block_sparse_attn import block_sparse_attn_func
|
||
```
|
||
|
||
来自 MIT-HAN-LAB,提供基于块 mask 的高效稀疏 FlashAttention 实现。
|
||
|
||
### 接口
|
||
|
||
```python
|
||
attn_output = block_sparse_attn_func(
|
||
query_states, # [total_q, num_heads, head_dim]
|
||
key_states, # [total_k, num_heads, head_dim]
|
||
value_states, # [total_k, num_heads, head_dim]
|
||
q_cu_seq_lens, # [batch+1] cumulative sequence lengths
|
||
k_cu_seq_lens, # [batch+1]
|
||
head_mask_type, # [num_heads] int32, 1=causal, 0=full
|
||
left_mask, # Optional left padding mask
|
||
block_mask, # [batch, heads, q_blocks, k_blocks] bool
|
||
max_seqlen_q, # int
|
||
max_seqlen_k, # int
|
||
p_dropout=0.0,
|
||
deterministic=True,
|
||
is_causal=True, # 全局 causal flag
|
||
)
|
||
```
|
||
|
||
### 块大小要求
|
||
|
||
BSA 要求 **block_size = 128**(硬编码):
|
||
```python
|
||
assert block_size == 128 # Xattention.py:358
|
||
```
|
||
|
||
---
|
||
|
||
## 关键参数
|
||
|
||
| 参数 | 默认值 | 范围 | 作用 |
|
||
|------|--------|------|------|
|
||
| `stride` | 8 | 4-16 | Q/K 交错采样步长,越大估计越快但越粗糙 |
|
||
| `threshold` | 0.9 | 0.7-0.99 | 累积注意力阈值,越高保留块越多 |
|
||
| `block_size` | 128 | 128 (固定) | BSA 块大小,不可调 |
|
||
| `chunk_size` | 16384 | 2048-131072 | 估计时的分块大小,影响内存使用 |
|
||
| `norm` | 1.0 | 0.5-2.0 | 注意力分数归一化系数 |
|
||
| `keep_sink` | False | bool | 是否始终保留第一个块 |
|
||
| `keep_recent` | False | bool | 是否始终保留对角块 |
|
||
|
||
---
|
||
|
||
## 计算复杂度
|
||
|
||
### 估计阶段
|
||
|
||
| 操作 | 复杂度 |
|
||
|------|--------|
|
||
| Stride reshape GEMM | O(seq_len/stride × seq_len/stride × head_dim × stride) = O(seq_len² × head_dim / stride) |
|
||
| Softmax + block sum | O(seq_len² / stride²) |
|
||
| Block selection | O(num_blocks² × log(num_blocks)) |
|
||
|
||
**估计阶段总复杂度**: O(seq_len² × head_dim / stride)
|
||
|
||
### 计算阶段 (BSA)
|
||
|
||
设选中块比例为 ρ (通常 0.3-0.5):
|
||
|
||
| 操作 | 复杂度 |
|
||
|------|--------|
|
||
| Block sparse attention | O(ρ × num_blocks² × block_size² × head_dim) = O(ρ × seq_len² × head_dim) |
|
||
|
||
**总复杂度**: O(seq_len² × head_dim × (1/stride + ρ))
|
||
|
||
当 stride=8, ρ=0.4 时,相比 full attention 节省约 **50%** 计算量。
|
||
|
||
---
|
||
|
||
## 与 nano-vllm 集成注意事项
|
||
|
||
### 依赖要求
|
||
|
||
```
|
||
block_sparse_attn # pip install block-sparse-attn
|
||
triton >= 2.0 # Triton kernels
|
||
```
|
||
|
||
### CPU Offload 场景适配
|
||
|
||
XAttention 原始实现假设所有 KV 在 GPU 上。对于 CPU offload 场景,需要:
|
||
|
||
1. **估计阶段**: 仍需加载所有历史 KV 到 GPU 进行估计
|
||
2. **计算阶段**: 只加载选中的块
|
||
|
||
这可能需要修改为两阶段 pipeline:
|
||
- 先用采样数据估计重要块
|
||
- 再只加载重要块进行计算
|
||
|
||
### block_size 对齐
|
||
|
||
nano-vllm 的 `kvcache_block_size` 需要与 BSA 的 128 对齐:
|
||
- 如果 `kvcache_block_size = 1024`,则每个 kv block 包含 8 个 BSA blocks
|
||
- 块选择粒度需要相应调整
|
||
|
||
---
|
||
|
||
## 源文件索引
|
||
|
||
| 文件 | 位置 | 内容 |
|
||
|------|------|------|
|
||
| `Xattention.py` | `compass/src/Xattention.py` | 主入口: `xattn_estimate()`, `Xattention_prefill()` |
|
||
| `kernels.py` | `compass/src/kernels.py` | Triton 内核 |
|
||
| `utils.py` | `compass/src/utils.py` | `find_blocks_chunked()`, `create_causal_mask()` |
|
||
|
||
---
|
||
|
||
## 参考
|
||
|
||
- COMPASS 项目: `/home/zijie/Code/COMPASS/`
|
||
- BSA 库: MIT-HAN-LAB block_sparse_attn
|
||
- 测试报告: `docs/xattention_bsa_test_report.md`
|