📝 docs: add XAttention algorithm guide based on COMPASS implementation
- Create docs/xattention_algorithm_guide.md with detailed algorithm explanation - Stride reshape (inverse mode) for Q/K interleaved sampling - Triton kernels: flat_group_gemm_fuse_reshape, softmax_fuse_block_sum - Block selection via find_blocks_chunked with cumulative threshold - BSA (block_sparse_attn) dependency for sparse computation - Update docs/sparse_attention_guide.md XAttention section with accurate description - Add documentation index entry in CLAUDE.md Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -50,30 +50,35 @@ output = block_sparse_attn_func(
|
||||
|
||||
## Method 1: XAttention (xattn_estimate)
|
||||
|
||||
**Source**: `xattn/src/Xattention.py`
|
||||
**Source**: `compass/src/Xattention.py`
|
||||
|
||||
**详细文档**: [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md)
|
||||
|
||||
### Core Idea
|
||||
|
||||
Use **strided Q/K reshaping** to create coarse-grained representations, compute block-level attention scores, and select blocks above a threshold.
|
||||
Use **stride interleaved reshape (inverse mode)** to efficiently estimate block-level attention importance, then use **BSA (Block Sparse Attention)** library for sparse computation.
|
||||
|
||||
### Algorithm
|
||||
|
||||
```python
|
||||
def xattn_estimate(query, key, block_size=64, stride=16):
|
||||
def xattn_estimate(query, key, block_size=128, stride=8):
|
||||
"""
|
||||
Estimate block importance using strided attention.
|
||||
Estimate block importance using stride-interleaved attention.
|
||||
|
||||
1. Reshape Q: [batch, seq, heads, dim] -> [batch, num_blocks, stride, heads, dim]
|
||||
Then take mean over stride dimension to get block-level Q
|
||||
1. K reshape (正向交错): concat([K[:,:,k::stride,:] for k in range(stride)])
|
||||
Q reshape (反向交错): concat([Q[:,:,(stride-1-q)::stride,:] for q])
|
||||
结果: 序列长度 seq_len -> seq_len/stride, head_dim -> head_dim*stride
|
||||
|
||||
2. Reshape K: Same process to get block-level K
|
||||
2. Triton kernel (flat_group_gemm_fuse_reshape):
|
||||
融合 reshape + GEMM,计算 Q_reshaped @ K_reshaped^T
|
||||
|
||||
3. Compute block attention: softmax(block_Q @ block_K.T / sqrt(d))
|
||||
Result shape: [batch, heads, q_blocks, k_blocks]
|
||||
3. Triton kernel (softmax_fuse_block_sum):
|
||||
在线 softmax + 按 block_size/stride 分组求和
|
||||
输出: attn_sum [batch, heads, q_blocks, k_blocks]
|
||||
|
||||
4. Apply causal mask (upper triangle = 0)
|
||||
|
||||
5. Threshold: blocks with score > threshold are selected
|
||||
4. find_blocks_chunked:
|
||||
按 attn_sum 降序排序,累积到 threshold 的块标记为 True
|
||||
对角块和 sink 块始终保留
|
||||
"""
|
||||
```
|
||||
|
||||
@@ -81,45 +86,60 @@ def xattn_estimate(query, key, block_size=64, stride=16):
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `block_size` | 64 | Tokens per block |
|
||||
| `stride` | 16 | Stride for coarse Q/K computation |
|
||||
| `threshold` | 0.9 | Selection threshold (cumulative or direct) |
|
||||
| `block_size` | 128 | Tokens per block (BSA 要求固定 128) |
|
||||
| `stride` | 8 | Q/K 交错采样步长,越大估计越快但越粗糙 |
|
||||
| `threshold` | 0.9 | 累积注意力阈值,选择累积权重达到此比例的块 |
|
||||
| `chunk_size` | 16384 | 估计时的分块大小 |
|
||||
|
||||
### Computation Flow
|
||||
|
||||
```
|
||||
query [B, S, H, D]
|
||||
query [B, H, S, D]
|
||||
|
|
||||
v
|
||||
Reshape to [B, num_blocks, stride, H, D]
|
||||
Stride interleaved reshape (Triton fused)
|
||||
|
|
||||
v
|
||||
Mean over stride -> block_q [B, num_blocks, H, D]
|
||||
flat_group_gemm_fuse_reshape: Q_r @ K_r^T
|
||||
|
|
||||
v
|
||||
Compute block attention scores [B, H, q_blocks, k_blocks]
|
||||
softmax_fuse_block_sum: 在线 softmax + 块求和
|
||||
|
|
||||
v
|
||||
Apply threshold -> block_mask [B, H, q_blocks, k_blocks]
|
||||
attn_sum [B, H, q_blocks, k_blocks]
|
||||
|
|
||||
v
|
||||
block_sparse_attn_func(q, k, v, block_mask)
|
||||
find_blocks_chunked: 累积阈值选择
|
||||
|
|
||||
v
|
||||
output [B, S, H, D]
|
||||
simple_mask [B, H, q_blocks, k_blocks] (bool)
|
||||
|
|
||||
v
|
||||
block_sparse_attn_func(q, k, v, simple_mask) ← BSA 库
|
||||
|
|
||||
v
|
||||
output [B, H, S, D]
|
||||
```
|
||||
|
||||
### Dependencies
|
||||
|
||||
```python
|
||||
from block_sparse_attn import block_sparse_attn_func # MIT-HAN-LAB BSA 库
|
||||
import triton # Triton kernels for estimation
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from xattn.src.Xattention import Xattention_prefill
|
||||
from compass.src.Xattention import Xattention_prefill
|
||||
|
||||
output = Xattention_prefill(
|
||||
query_states, key_states, value_states,
|
||||
threshold=0.9,
|
||||
stride=16,
|
||||
stride=8,
|
||||
block_size=128,
|
||||
use_triton=True,
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user