📝 docs: add XAttention algorithm guide based on COMPASS implementation

- Create docs/xattention_algorithm_guide.md with detailed algorithm explanation
  - Stride reshape (inverse mode) for Q/K interleaved sampling
  - Triton kernels: flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
  - Block selection via find_blocks_chunked with cumulative threshold
  - BSA (block_sparse_attn) dependency for sparse computation
- Update docs/sparse_attention_guide.md XAttention section with accurate description
- Add documentation index entry in CLAUDE.md

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-20 02:50:03 +08:00
parent 07f5220f40
commit e440c45e73
3 changed files with 395 additions and 25 deletions

View File

@@ -50,30 +50,35 @@ output = block_sparse_attn_func(
## Method 1: XAttention (xattn_estimate)
**Source**: `xattn/src/Xattention.py`
**Source**: `compass/src/Xattention.py`
**详细文档**: [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md)
### Core Idea
Use **strided Q/K reshaping** to create coarse-grained representations, compute block-level attention scores, and select blocks above a threshold.
Use **stride interleaved reshape (inverse mode)** to efficiently estimate block-level attention importance, then use **BSA (Block Sparse Attention)** library for sparse computation.
### Algorithm
```python
def xattn_estimate(query, key, block_size=64, stride=16):
def xattn_estimate(query, key, block_size=128, stride=8):
"""
Estimate block importance using strided attention.
Estimate block importance using stride-interleaved attention.
1. Reshape Q: [batch, seq, heads, dim] -> [batch, num_blocks, stride, heads, dim]
Then take mean over stride dimension to get block-level Q
1. K reshape (正向交错): concat([K[:,:,k::stride,:] for k in range(stride)])
Q reshape (反向交错): concat([Q[:,:,(stride-1-q)::stride,:] for q])
结果: 序列长度 seq_len -> seq_len/stride, head_dim -> head_dim*stride
2. Reshape K: Same process to get block-level K
2. Triton kernel (flat_group_gemm_fuse_reshape):
融合 reshape + GEMM计算 Q_reshaped @ K_reshaped^T
3. Compute block attention: softmax(block_Q @ block_K.T / sqrt(d))
Result shape: [batch, heads, q_blocks, k_blocks]
3. Triton kernel (softmax_fuse_block_sum):
在线 softmax + 按 block_size/stride 分组求和
输出: attn_sum [batch, heads, q_blocks, k_blocks]
4. Apply causal mask (upper triangle = 0)
5. Threshold: blocks with score > threshold are selected
4. find_blocks_chunked:
按 attn_sum 降序排序,累积到 threshold 的块标记为 True
对角块和 sink 块始终保留
"""
```
@@ -81,45 +86,60 @@ def xattn_estimate(query, key, block_size=64, stride=16):
| Parameter | Default | Description |
|-----------|---------|-------------|
| `block_size` | 64 | Tokens per block |
| `stride` | 16 | Stride for coarse Q/K computation |
| `threshold` | 0.9 | Selection threshold (cumulative or direct) |
| `block_size` | 128 | Tokens per block (BSA 要求固定 128) |
| `stride` | 8 | Q/K 交错采样步长,越大估计越快但越粗糙 |
| `threshold` | 0.9 | 累积注意力阈值,选择累积权重达到此比例的块 |
| `chunk_size` | 16384 | 估计时的分块大小 |
### Computation Flow
```
query [B, S, H, D]
query [B, H, S, D]
|
v
Reshape to [B, num_blocks, stride, H, D]
Stride interleaved reshape (Triton fused)
|
v
Mean over stride -> block_q [B, num_blocks, H, D]
flat_group_gemm_fuse_reshape: Q_r @ K_r^T
|
v
Compute block attention scores [B, H, q_blocks, k_blocks]
softmax_fuse_block_sum: 在线 softmax + 块求和
|
v
Apply threshold -> block_mask [B, H, q_blocks, k_blocks]
attn_sum [B, H, q_blocks, k_blocks]
|
v
block_sparse_attn_func(q, k, v, block_mask)
find_blocks_chunked: 累积阈值选择
|
v
output [B, S, H, D]
simple_mask [B, H, q_blocks, k_blocks] (bool)
|
v
block_sparse_attn_func(q, k, v, simple_mask) ← BSA 库
|
v
output [B, H, S, D]
```
### Dependencies
```python
from block_sparse_attn import block_sparse_attn_func # MIT-HAN-LAB BSA 库
import triton # Triton kernels for estimation
```
### Usage
```python
from xattn.src.Xattention import Xattention_prefill
from compass.src.Xattention import Xattention_prefill
output = Xattention_prefill(
query_states, key_states, value_states,
threshold=0.9,
stride=16,
stride=8,
block_size=128,
use_triton=True,
)
```
---