From ad361c2c3b45c58dac95b83c139252a0a5f95b15 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Fri, 23 Jan 2026 08:36:56 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20docs:=20add=20XAttention=20BSA?= =?UTF-8?q?=20Policy=20design=20documentation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create docs/xattn_bsa_policy_design.md with: - Algorithm overview and data flow diagram - select_blocks implementation details - GQA-aware aggregation and majority voting - compute_chunked_prefill ring buffer pipeline - Parameter configuration and usage examples - Performance characteristics and limitations - Update CLAUDE.md documentation index Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 1 + docs/xattn_bsa_policy_design.md | 294 ++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+) create mode 100644 docs/xattn_bsa_policy_design.md diff --git a/CLAUDE.md b/CLAUDE.md index 02b44d1..a908d60 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,6 +17,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L | [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 | | [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) | | [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 | +| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy 设计: select_blocks 算法、majority voting、compute_chunked_prefill | | [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 | | [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling | | [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) | diff --git a/docs/xattn_bsa_policy_design.md b/docs/xattn_bsa_policy_design.md new file mode 100644 index 0000000..0632aee --- /dev/null +++ b/docs/xattn_bsa_policy_design.md @@ -0,0 +1,294 @@ +# XAttention BSA Policy 设计文档 + +本文档描述 `XAttentionBSAPolicy` 的设计和实现,这是一个基于 XAttention 算法的稀疏注意力策略,用于 CPU offload 模式下的 chunked prefill。 + +## 概述 + +`XAttentionBSAPolicy` 实现了基于 XAttention 的块级稀疏注意力选择。核心思想是: + +1. **估计阶段**:使用 XAttention kernels 快速估计每个 KV block 的重要性 +2. **选择阶段**:基于阈值和 majority voting 选择重要的 blocks +3. **计算阶段**:只加载选中的 blocks 进行 attention 计算 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ XAttention BSA Policy │ +├─────────────────────────────────────────────────────────────┤ +│ select_blocks() │ +│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │ +│ │ Load K │──>│ flat_group_gemm │──>│ softmax_fuse │ │ +│ │ blocks │ │ _fuse_reshape │ │ _block_sum │ │ +│ └─────────────┘ └──────────────────┘ └──────────────┘ │ +│ │ │ │ │ +│ v v v │ +│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │ +│ │ K: [B,H,L,D]│ │ attn_scores: │ │ block_sums: │ │ +│ │ │ │ [B,H,Q/s,K/s] │ │ [B,H,Qb,Kb] │ │ +│ └─────────────┘ └──────────────────┘ └──────────────┘ │ +│ │ │ +│ ┌──────────────────────┘ │ +│ v │ +│ ┌──────────────┐ │ +│ │find_blocks │ │ +│ │_chunked │ │ +│ └──────────────┘ │ +│ │ │ +│ v │ +│ ┌──────────────┐ │ +│ │ GQA-aware │ │ +│ │ aggregation │ │ +│ │ + majority │ │ +│ │ voting │ │ +│ └──────────────┘ │ +│ │ │ +│ v │ +│ selected_block_ids │ +├─────────────────────────────────────────────────────────────┤ +│ compute_chunked_prefill() │ +│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │ +│ │ Ring buffer │──>│ flash_attn_ │──>│ merge_ │ │ +│ │ pipeline │ │ with_lse │ │ attention │ │ +│ └─────────────┘ └──────────────────┘ └──────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## 文件位置 + +**主文件**: `nanovllm/kvcache/sparse/xattn_bsa.py` + +**依赖的 XAttention kernels**: `nanovllm/ops/xattn.py` +- `flat_group_gemm_fuse_reshape`: 计算 stride reshape 后的 attention scores +- `softmax_fuse_block_sum`: 对 attention scores 做 softmax 后按 block 求和 +- `find_blocks_chunked`: 基于阈值选择 blocks + +--- + +## 核心算法 + +### 1. select_blocks: 块选择算法 + +```python +def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]: +``` + +#### Step 1: 加载 K blocks 并计算 attention scores + +对每个 CPU block,加载 K 到 GPU 并使用 `flat_group_gemm_fuse_reshape` 计算: + +```python +for cpu_block_id in available_blocks: + # 加载 K block: [1, block_size, num_kv_heads, head_dim] + offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id) + k_block, _ = offload_engine.get_kv_for_slot(slot) + + # 转换为 [batch, heads, k_len, head_dim] + K_chunk = k_block.transpose(1, 2) + + # GQA: 扩展 K heads 匹配 Q heads + if num_heads != num_kv_heads: + K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) + + # 计算 attention scores + attn_chunk = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...) + attn_scores_list.append(attn_chunk) + +# 拼接所有 K chunks: [1, heads, q_reshaped_len, total_k_reshaped_len] +attn_scores = torch.cat(attn_scores_list, dim=-1) +``` + +#### Step 2: 聚合到 block 级别 + +使用 `softmax_fuse_block_sum` 将 attention scores 聚合到 block 级别: + +```python +# reshaped_block_size = block_size / stride = 1024 / 8 = 128 +block_sums = softmax_fuse_block_sum( + attn_scores, + reshaped_block_size, # 1:1 对应 CPU blocks + segment_size, + chunk_start=0, + chunk_end=q_reshaped_len, + real_q_len=q_reshaped_len, + scale=scale, + is_causal=False, +) +# block_sums: [batch, heads, q_blocks, k_blocks] +``` + +**关键点**: `reshaped_block_size` 必须与 CPU block 对齐,确保输出的 `k_blocks` 维度 1:1 对应 `available_blocks`。 + +#### Step 3: 阈值选择 + +使用 `find_blocks_chunked` 基于累积注意力阈值选择 blocks: + +```python +mask = find_blocks_chunked( + block_sums, + current_index=0, + threshold=self.threshold, # e.g., 0.95 + num_to_choose=None, + decoding=False, + mode="prefill", + causal=False, +) +# mask: [batch, num_heads, q_blocks, k_blocks] - boolean +``` + +#### Step 4: GQA-aware 聚合 + Majority Voting + +```python +# GQA: 在同一个 KV head group 内,任一 Q head 选择即选择 +if num_groups > 1: + mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks) + mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks] + +# Majority voting: 跨 KV heads 和 q_blocks 投票 +vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks] +total_votes = num_kv_heads * q_blocks +vote_ratio = vote_count / total_votes + +# 选择 >50% 投票的 blocks +vote_threshold = 0.5 +block_selected = vote_ratio > vote_threshold +selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel] + +# 安全措施: 始终包含第一个 (sink) 和最后一个 block +if available_blocks[0] not in selected_block_ids: + selected_block_ids.insert(0, available_blocks[0]) +if available_blocks[-1] not in selected_block_ids: + selected_block_ids.append(available_blocks[-1]) +``` + +**为什么使用 Majority Voting?** + +| 聚合方式 | 问题 | +|---------|------| +| `any()` 跨所有 heads | 密度接近 100%,失去稀疏性 | +| `all()` | 太激进,可能丢失重要 blocks | +| **Majority voting (>50%)** | 平衡稀疏性和准确性 | + +实验结果显示: +- 每 head 密度: 20-35% +- `any()` 聚合后: ~100% +- **Majority voting 后: ~45%** + +--- + +### 2. compute_chunked_prefill: 注意力计算 + +复用 `FullAttentionPolicy` 的 ring buffer pipeline 实现: + +```python +def compute_chunked_prefill(self, q, k, v, layer_id, softmax_scale, + offload_engine, kvcache_manager, + current_chunk_idx, seq, num_tokens, + selected_blocks) -> torch.Tensor: +``` + +#### 计算流程 + +1. **加载历史 blocks** (使用 selected_blocks): + ```python + for block_idx in range(num_blocks): + # Ring buffer pipeline: load -> wait -> compute -> next + offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id) + offload_engine.wait_slot_layer(slot) + + prev_k, prev_v = offload_engine.get_kv_for_slot(slot) + prev_o, prev_lse = flash_attn_with_lse(q, prev_k, prev_v, causal=False) + + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + ``` + +2. **计算当前 chunk** (causal mask): + ```python + k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) + current_o, current_lse = flash_attn_with_lse(q, k_curr, v_curr, causal=True) + ``` + +3. **合并结果**: + ```python + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + ``` + +--- + +## 参数配置 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `threshold` | 0.95 | 累积注意力阈值 (tau),越高越保守 | +| `stride` | 8 | XAttention stride reshape 参数 | +| `chunk_size` | 16384 | 估计时的处理 chunk size | +| `block_size` | 128 | BSA block size (固定值) | + +### 使用方式 + +```python +# 在 config 中设置 +config.sparse_policy = SparsePolicyType.XATTN_BSA +config.sparse_threshold = 0.95 + +# 或通过命令行 +python tests/test_needle.py \ + --enable-offload \ + --enable-xattn-bsa \ + --sparse-threshold 9 # 会被除以 10 变为 0.9 +``` + +--- + +## 性能特性 + +| 特性 | 说明 | +|------|------| +| **Prefill 支持** | ✅ 完整支持 | +| **Decode 支持** | ❌ 不支持(使用 FullAttentionPolicy) | +| **稀疏度** | ~45-55%(threshold=0.95,majority voting) | +| **准确性** | RULER NIAH 100% 通过 | + +### 限制 + +1. **Decode 不支持**: XAttention 估计需要足够长的 Q 序列,单 token decode 不适用 +2. **估计开销**: `select_blocks` 需要加载所有 K blocks 进行估计 +3. **Triton 对齐**: Q/K 长度必须满足 `stride * BLOCK_M/N` 对齐要求 + +--- + +## 与其他 Policy 的对比 + +| Policy | select_blocks | 稀疏度 | Decode 支持 | +|--------|--------------|--------|-------------| +| FullAttentionPolicy | 返回所有 blocks | 0% | ✅ | +| QuestPolicy | 基于 min/max key | ~50% | ✅ | +| **XAttentionBSAPolicy** | XAttention + majority voting | ~45-55% | ❌ | + +--- + +## 测试验证 + +```bash +# Needle test (32K) +CUDA_VISIBLE_DEVICES=0 python tests/test_needle.py \ + --model ~/models/Llama-3.1-8B-Instruct \ + --enable-offload \ + --enable-xattn-bsa \ + --input-len 32768 + +# RULER benchmark +CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \ + --model ~/models/Llama-3.1-8B-Instruct \ + --enable-offload \ + --sparse-policy XATTN_BSA \ + --sparse-threshold 0.95 \ + --data-dir tests/data/ruler_niah +``` + +--- + +## 相关文档 + +- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解 +- [`docs/xattn_kernels_guide.md`](xattn_kernels_guide.md): Triton kernels 实现 +- [`docs/sparse_policy_architecture.md`](sparse_policy_architecture.md): SparsePolicy 架构 +- [`docs/sparse_policy_implementation_guide.md`](sparse_policy_implementation_guide.md): 实现指南