Files
nano-vllm/docs/xattention_algorithm_guide.md
Zijie Tian e440c45e73 📝 docs: add XAttention algorithm guide based on COMPASS implementation
- Create docs/xattention_algorithm_guide.md with detailed algorithm explanation
  - Stride reshape (inverse mode) for Q/K interleaved sampling
  - Triton kernels: flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
  - Block selection via find_blocks_chunked with cumulative threshold
  - BSA (block_sparse_attn) dependency for sparse computation
- Update docs/sparse_attention_guide.md XAttention section with accurate description
- Add documentation index entry in CLAUDE.md

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 02:50:03 +08:00

13 KiB
Raw Permalink Blame History

XAttention 算法实现指南

本文档详细描述 COMPASS 项目中 XAttention 的算法原理和实现细节。

概述

XAttention 是一种基于 stride reshape 的块稀疏注意力方法,通过低成本估计识别重要块,然后使用 BSA (Block Sparse Attention) 库执行稀疏计算。

核心依赖

组件 来源 作用
Triton Kernels COMPASS 自研 Q/K reshape + 块级估计
BSA MIT-HAN-LAB block_sparse_attn 稀疏注意力计算

算法流程

输入: Q [batch, heads, q_len, head_dim]
      K [batch, heads, k_len, head_dim]
      V [batch, heads, k_len, head_dim]

┌─────────────────────────────────────────────────────────────┐
│ Phase 1: Stride Reshape (inverse 模式)                       │
│                                                              │
│ K_reshaped = concat([K[:,:,k::stride,:] for k in stride])   │
│ Q_reshaped = concat([Q[:,:,(stride-1-q)::stride,:] for q])  │
│                                                              │
│ 效果: 序列长度从 seq_len 缩短到 seq_len/stride               │
│       head_dim 扩展到 head_dim * stride                      │
└─────────────────────────────────────────────────────────────┘
                          ↓
┌─────────────────────────────────────────────────────────────┐
│ Phase 2: 块级注意力估计 (Triton 加速)                         │
│                                                              │
│ 2a. flat_group_gemm_fuse_reshape:                           │
│     计算 Q_reshaped @ K_reshaped^T                          │
│     输出: attn_weights [batch, heads, q_len/stride, k_len/stride] │
│                                                              │
│ 2b. softmax_fuse_block_sum:                                 │
│     - 在线 softmax (数值稳定)                                │
│     - 按 block_size/stride 分组求和                          │
│     输出: attn_sum [batch, heads, q_blocks, k_blocks]        │
└─────────────────────────────────────────────────────────────┘
                          ↓
┌─────────────────────────────────────────────────────────────┐
│ Phase 3: 块选择 (find_blocks_chunked)                        │
│                                                              │
│ 对每个 Q block:                                              │
│   1. 按 attn_sum 降序排序 K blocks                           │
│   2. 累积求和直到 >= threshold * total_sum                   │
│   3. 累积到的 blocks 标记为 True                             │
│                                                              │
│ 特殊处理:                                                    │
│   - 对角块 (causal) 始终保留                                 │
│   - Sink 块 (block 0) 可选保留                               │
│                                                              │
│ 输出: simple_mask [batch, heads, q_blocks, k_blocks] (bool)  │
└─────────────────────────────────────────────────────────────┘
                          ↓
┌─────────────────────────────────────────────────────────────┐
│ Phase 4: 稀疏注意力计算 (BSA)                                 │
│                                                              │
│ attn_output = block_sparse_attn_func(                       │
│     Q, K, V,                                                 │
│     q_cu_seq_lens,      # [0, q_len]                        │
│     k_cu_seq_lens,      # [0, k_len]                        │
│     head_mask_type,     # [num_heads] 全 1                   │
│     None,               # left_mask                          │
│     simple_mask,        # 块稀疏 mask                        │
│     q_len, k_len,                                            │
│     is_causal=True,                                          │
│ )                                                            │
│                                                              │
│ 输出: attn_output [batch, heads, q_len, head_dim]            │
└─────────────────────────────────────────────────────────────┘

Stride Reshape 详解

Inverse 模式

XAttention 默认使用 select_mode="inverse",这是一种交错采样策略:

# 原始: Q/K shape = [batch, heads, seq_len, head_dim]
# stride = 8

# K reshape: 正向交错
K_reshaped = concat([K[:, :, 0::8, :],   # 位置 0, 8, 16, ...
                     K[:, :, 1::8, :],   # 位置 1, 9, 17, ...
                     K[:, :, 2::8, :],   # 位置 2, 10, 18, ...
                     ...
                     K[:, :, 7::8, :]])  # 位置 7, 15, 23, ...
# 结果: [batch, heads, seq_len/8, head_dim * 8]

# Q reshape: 反向交错 (inverse)
Q_reshaped = concat([Q[:, :, 7::8, :],   # 位置 7, 15, 23, ...
                     Q[:, :, 6::8, :],   # 位置 6, 14, 22, ...
                     Q[:, :, 5::8, :],   # 位置 5, 13, 21, ...
                     ...
                     Q[:, :, 0::8, :]])  # 位置 0, 8, 16, ...
# 结果: [batch, heads, seq_len/8, head_dim * 8]

为什么用 Inverse 模式?

当计算 Q_reshaped @ K_reshaped^Tinverse 模式使得:

  • Q 的后半部分与 K 的前半部分对齐
  • 这样可以近似捕获 causal attention 的对角模式

Triton Kernels 详解

1. flat_group_gemm_fuse_reshape

文件: compass/src/kernels.py:198-235

功能: 融合 stride reshape 和 GEMM避免显式创建 reshape 后的大张量

@triton.jit
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, ...):
    # 关键: 不实际 reshape而是通过指针算术模拟
    Q_ptrs = Q + block_m * BLOCK_M * STRIDE * stride_qn
    K_ptrs = K + block_n * BLOCK_N * STRIDE * stride_kn

    # 对 stride 个位置累加
    for iter in range(STRIDE):
        q = tl.load(Q_ptrs - iter * stride_qn)  # Q inverse 采样
        k = tl.load(K_ptrs + iter * stride_kn)  # K 正向采样
        o += tl.dot(q, k)

优势:

  • 内存节省: 不需要创建 [batch, heads, seq_len/stride, head_dim*stride] 的中间张量
  • 计算融合: reshape + GEMM 一次完成

2. softmax_fuse_block_sum

文件: compass/src/kernels.py:6-95

功能: 在线 softmax + 块内求和

@triton.jit
def softmax_fuse_block_sum_kernel_causal(In, Out, ...):
    # Pass 1: 计算全局 max 和 sum (在线算法)
    for iter in range(num_iters):
        X = tl.load(input_ptr + iter * segment_size) * scale
        m_local = tl.max(X, 1)
        m_new = tl.maximum(m_i, m_local)
        alpha = tl.math.exp2(m_i - m_new)
        X = X - m_new[:, None]
        l_local = tl.sum(tl.math.exp2(X), 1)
        l_i = l_i * alpha + l_local
        m_i = m_new

    # Pass 2: 归一化并按块求和
    for iter in range(num_iters):
        X = tl.load(input_ptr + iter * segment_size) * scale
        X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]  # softmax
        X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
        X = tl.sum(X, 2).sum(0)  # 块内求和
        tl.store(output_ptr + iter * segment_size // block_size, X)

输出含义: attn_sum[b, h, qi, ki] = Q block qi 对 K block ki 的归一化注意力权重之和


块选择算法 (find_blocks_chunked)

文件: compass/src/utils.py:44-191

算法步骤

def find_blocks_chunked(input_tensor, current_index, threshold, ...):
    """
    input_tensor: [batch, heads, q_blocks, k_blocks] - 块级注意力权重和
    threshold: 0.9 - 累积阈值
    """
    # 1. 计算每行总和
    total_sum = input_tensor.sum(dim=-1, keepdim=True)
    required_sum = total_sum * threshold  # 需要达到的累积和

    # 2. 特殊块始终保留
    mask = zeros_like(input_tensor, dtype=bool)
    mask[:, :, :, 0] = True              # sink 块
    mask[:, :, :, diagonal] = True       # 对角块 (causal)

    # 3. 对剩余块按权重排序
    other_values = input_tensor.masked_fill(mask, 0)
    sorted_values, index = sort(other_values, descending=True)

    # 4. 累积求和直到达到阈值
    cumsum = sorted_values.cumsum(dim=-1)
    index_mask = cumsum < required_sum

    # 5. 标记选中的块
    mask[..., index[index_mask]] = True

    return mask

示例

threshold = 0.9
attn_sum 某一行 = [0.05, 0.30, 0.40, 0.15, 0.10]  (已 softmax, 和为 1.0)
required_sum = 0.9

排序后: [0.40, 0.30, 0.15, 0.10, 0.05]
累积和: [0.40, 0.70, 0.85, 0.95, 1.00]
                            ↑ 达到 0.9

选中: 前 4 个块 (indices: 2, 1, 3, 4)

BSA (Block Sparse Attention)

库来源

from block_sparse_attn import block_sparse_attn_func

来自 MIT-HAN-LAB提供基于块 mask 的高效稀疏 FlashAttention 实现。

接口

attn_output = block_sparse_attn_func(
    query_states,         # [total_q, num_heads, head_dim]
    key_states,           # [total_k, num_heads, head_dim]
    value_states,         # [total_k, num_heads, head_dim]
    q_cu_seq_lens,        # [batch+1] cumulative sequence lengths
    k_cu_seq_lens,        # [batch+1]
    head_mask_type,       # [num_heads] int32, 1=causal, 0=full
    left_mask,            # Optional left padding mask
    block_mask,           # [batch, heads, q_blocks, k_blocks] bool
    max_seqlen_q,         # int
    max_seqlen_k,         # int
    p_dropout=0.0,
    deterministic=True,
    is_causal=True,       # 全局 causal flag
)

块大小要求

BSA 要求 block_size = 128(硬编码):

assert block_size == 128  # Xattention.py:358

关键参数

参数 默认值 范围 作用
stride 8 4-16 Q/K 交错采样步长,越大估计越快但越粗糙
threshold 0.9 0.7-0.99 累积注意力阈值,越高保留块越多
block_size 128 128 (固定) BSA 块大小,不可调
chunk_size 16384 2048-131072 估计时的分块大小,影响内存使用
norm 1.0 0.5-2.0 注意力分数归一化系数
keep_sink False bool 是否始终保留第一个块
keep_recent False bool 是否始终保留对角块

计算复杂度

估计阶段

操作 复杂度
Stride reshape GEMM O(seq_len/stride × seq_len/stride × head_dim × stride) = O(seq_len² × head_dim / stride)
Softmax + block sum O(seq_len² / stride²)
Block selection O(num_blocks² × log(num_blocks))

估计阶段总复杂度: O(seq_len² × head_dim / stride)

计算阶段 (BSA)

设选中块比例为 ρ (通常 0.3-0.5):

操作 复杂度
Block sparse attention O(ρ × num_blocks² × block_size² × head_dim) = O(ρ × seq_len² × head_dim)

总复杂度: O(seq_len² × head_dim × (1/stride + ρ))

当 stride=8, ρ=0.4 时,相比 full attention 节省约 50% 计算量。


与 nano-vllm 集成注意事项

依赖要求

block_sparse_attn  # pip install block-sparse-attn
triton >= 2.0      # Triton kernels

CPU Offload 场景适配

XAttention 原始实现假设所有 KV 在 GPU 上。对于 CPU offload 场景,需要:

  1. 估计阶段: 仍需加载所有历史 KV 到 GPU 进行估计
  2. 计算阶段: 只加载选中的块

这可能需要修改为两阶段 pipeline:

  • 先用采样数据估计重要块
  • 再只加载重要块进行计算

block_size 对齐

nano-vllm 的 kvcache_block_size 需要与 BSA 的 128 对齐:

  • 如果 kvcache_block_size = 1024,则每个 kv block 包含 8 个 BSA blocks
  • 块选择粒度需要相应调整

源文件索引

文件 位置 内容
Xattention.py compass/src/Xattention.py 主入口: xattn_estimate(), Xattention_prefill()
kernels.py compass/src/kernels.py Triton 内核
utils.py compass/src/utils.py find_blocks_chunked(), create_causal_mask()

参考

  • COMPASS 项目: /home/zijie/Code/COMPASS/
  • BSA 库: MIT-HAN-LAB block_sparse_attn
  • 测试报告: docs/xattention_bsa_test_report.md