Files
nano-vllm/docs/xattn_bsa_policy_design.md
Zijie Tian da5e13e2bb 📝 docs: update XAttention BSA Policy with benchmarks and memory management
Add new sections to xattn_bsa_policy_design.md:
- Performance benchmarks: 128K context comparison (Full vs XAttn BSA)
- Density trend analysis across chunks
- Memory leak issue and fix (64GB -> 4GB reduction)
- Memory monitoring guide with gpu-monitor agent
- Density statistics API documentation
- Known issues and optimization directions

Update CLAUDE.md description to reflect new content.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:35:18 +08:00

15 KiB
Raw Permalink Blame History

XAttention BSA Policy 设计文档

本文档描述 XAttentionBSAPolicy 的设计和实现,这是一个基于 XAttention 算法的稀疏注意力策略,用于 CPU offload 模式下的 chunked prefill。

概述

XAttentionBSAPolicy 实现了基于 XAttention 的块级稀疏注意力选择。核心思想是:

  1. 估计阶段:使用 XAttention kernels 快速估计每个 KV block 的重要性
  2. 选择阶段:基于阈值和 majority voting 选择重要的 blocks
  3. 计算阶段:只加载选中的 blocks 进行 attention 计算
┌─────────────────────────────────────────────────────────────┐
│                    XAttention BSA Policy                     │
├─────────────────────────────────────────────────────────────┤
│  select_blocks()                                             │
│  ┌─────────────┐   ┌──────────────────┐   ┌──────────────┐  │
│  │ Load K      │──>│ flat_group_gemm  │──>│ softmax_fuse │  │
│  │ blocks      │   │ _fuse_reshape    │   │ _block_sum   │  │
│  └─────────────┘   └──────────────────┘   └──────────────┘  │
│         │                   │                    │           │
│         v                   v                    v           │
│  ┌─────────────┐   ┌──────────────────┐   ┌──────────────┐  │
│  │ K: [B,H,L,D]│   │ attn_scores:     │   │ block_sums:  │  │
│  │             │   │ [B,H,Q/s,K/s]    │   │ [B,H,Qb,Kb]  │  │
│  └─────────────┘   └──────────────────┘   └──────────────┘  │
│                                                  │           │
│                           ┌──────────────────────┘           │
│                           v                                  │
│                    ┌──────────────┐                          │
│                    │find_blocks   │                          │
│                    │_chunked      │                          │
│                    └──────────────┘                          │
│                           │                                  │
│                           v                                  │
│                    ┌──────────────┐                          │
│                    │ GQA-aware    │                          │
│                    │ aggregation  │                          │
│                    │ + majority   │                          │
│                    │ voting       │                          │
│                    └──────────────┘                          │
│                           │                                  │
│                           v                                  │
│                    selected_block_ids                        │
├─────────────────────────────────────────────────────────────┤
│  compute_chunked_prefill()                                   │
│  ┌─────────────┐   ┌──────────────────┐   ┌──────────────┐  │
│  │ Ring buffer │──>│ flash_attn_      │──>│ merge_       │  │
│  │ pipeline    │   │ with_lse         │   │ attention    │  │
│  └─────────────┘   └──────────────────┘   └──────────────┘  │
└─────────────────────────────────────────────────────────────┘

文件位置

主文件: nanovllm/kvcache/sparse/xattn_bsa.py

依赖的 XAttention kernels: nanovllm/ops/xattn.py

  • flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores
  • softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
  • find_blocks_chunked: 基于阈值选择 blocks

核心算法

1. select_blocks: 块选择算法

def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]:

Step 1: 加载 K blocks 并计算 attention scores

对每个 CPU block加载 K 到 GPU 并使用 flat_group_gemm_fuse_reshape 计算:

for cpu_block_id in available_blocks:
    # 加载 K block: [1, block_size, num_kv_heads, head_dim]
    offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
    k_block, _ = offload_engine.get_kv_for_slot(slot)

    # 转换为 [batch, heads, k_len, head_dim]
    K_chunk = k_block.transpose(1, 2)

    # GQA: 扩展 K heads 匹配 Q heads
    if num_heads != num_kv_heads:
        K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)

    # 计算 attention scores
    attn_chunk = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...)
    attn_scores_list.append(attn_chunk)

# 拼接所有 K chunks: [1, heads, q_reshaped_len, total_k_reshaped_len]
attn_scores = torch.cat(attn_scores_list, dim=-1)

Step 2: 聚合到 block 级别

使用 softmax_fuse_block_sum 将 attention scores 聚合到 block 级别:

# reshaped_block_size = block_size / stride = 1024 / 8 = 128
block_sums = softmax_fuse_block_sum(
    attn_scores,
    reshaped_block_size,  # 1:1 对应 CPU blocks
    segment_size,
    chunk_start=0,
    chunk_end=q_reshaped_len,
    real_q_len=q_reshaped_len,
    scale=scale,
    is_causal=False,
)
# block_sums: [batch, heads, q_blocks, k_blocks]

关键点: reshaped_block_size 必须与 CPU block 对齐,确保输出的 k_blocks 维度 1:1 对应 available_blocks

Step 3: 阈值选择

使用 find_blocks_chunked 基于累积注意力阈值选择 blocks

mask = find_blocks_chunked(
    block_sums,
    current_index=0,
    threshold=self.threshold,  # e.g., 0.95
    num_to_choose=None,
    decoding=False,
    mode="prefill",
    causal=False,
)
# mask: [batch, num_heads, q_blocks, k_blocks] - boolean

Step 4: GQA-aware 聚合 + Majority Voting

# GQA: 在同一个 KV head group 内,任一 Q head 选择即选择
if num_groups > 1:
    mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks)
    mask_per_kv_head = mask_gqa.any(dim=2)  # [batch, num_kv_heads, q_blocks, k_blocks]

# Majority voting: 跨 KV heads 和 q_blocks 投票
vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0)  # [k_blocks]
total_votes = num_kv_heads * q_blocks
vote_ratio = vote_count / total_votes

# 选择 >50% 投票的 blocks
vote_threshold = 0.5
block_selected = vote_ratio > vote_threshold
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]

# 安全措施: 始终包含第一个 (sink) 和最后一个 block
if available_blocks[0] not in selected_block_ids:
    selected_block_ids.insert(0, available_blocks[0])
if available_blocks[-1] not in selected_block_ids:
    selected_block_ids.append(available_blocks[-1])

为什么使用 Majority Voting?

聚合方式 问题
any() 跨所有 heads 密度接近 100%,失去稀疏性
all() 太激进,可能丢失重要 blocks
Majority voting (>50%) 平衡稀疏性和准确性

实验结果显示:

  • 每 head 密度: 20-35%
  • any() 聚合后: ~100%
  • Majority voting 后: ~45%

2. compute_chunked_prefill: 注意力计算

复用 FullAttentionPolicy 的 ring buffer pipeline 实现:

def compute_chunked_prefill(self, q, k, v, layer_id, softmax_scale,
                            offload_engine, kvcache_manager,
                            current_chunk_idx, seq, num_tokens,
                            selected_blocks) -> torch.Tensor:

计算流程

  1. 加载历史 blocks (使用 selected_blocks):

    for block_idx in range(num_blocks):
        # Ring buffer pipeline: load -> wait -> compute -> next
        offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
        offload_engine.wait_slot_layer(slot)
    
        prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
        prev_o, prev_lse = flash_attn_with_lse(q, prev_k, prev_v, causal=False)
    
        o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
    
  2. 计算当前 chunk (causal mask):

    k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
    current_o, current_lse = flash_attn_with_lse(q, k_curr, v_curr, causal=True)
    
  3. 合并结果:

    final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
    

参数配置

参数 默认值 说明
threshold 0.95 累积注意力阈值 (tau),越高越保守
stride 8 XAttention stride reshape 参数
chunk_size 16384 估计时的处理 chunk size
block_size 128 BSA block size (固定值)

使用方式

# 在 config 中设置
config.sparse_policy = SparsePolicyType.XATTN_BSA
config.sparse_threshold = 0.95

# 或通过命令行
python tests/test_needle.py \
    --enable-offload \
    --enable-xattn-bsa \
    --sparse-threshold 9  # 会被除以 10 变为 0.9

性能特性

特性 说明
Prefill 支持 完整支持
Decode 支持 不支持(使用 FullAttentionPolicy
稀疏度 ~45-55%threshold=0.95majority voting
准确性 RULER NIAH 100% 通过

限制

  1. Decode 不支持: XAttention 估计需要足够长的 Q 序列,单 token decode 不适用
  2. 估计开销: select_blocks 需要加载所有 K blocks 进行估计
  3. Triton 对齐: Q/K 长度必须满足 stride * BLOCK_M/N 对齐要求

与其他 Policy 的对比

Policy select_blocks 稀疏度 Decode 支持
FullAttentionPolicy 返回所有 blocks 0%
QuestPolicy 基于 min/max key ~50%
XAttentionBSAPolicy XAttention + majority voting ~45-55%

测试验证

# Needle test (32K)
CUDA_VISIBLE_DEVICES=0 python tests/test_needle.py \
    --model ~/models/Llama-3.1-8B-Instruct \
    --enable-offload \
    --enable-xattn-bsa \
    --input-len 32768

# RULER benchmark
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \
    --model ~/models/Llama-3.1-8B-Instruct \
    --enable-offload \
    --sparse-policy XATTN_BSA \
    --sparse-threshold 0.95 \
    --data-dir tests/data/ruler_niah

性能基准测试

128K 上下文对比 (Llama-3.1-8B, A100 80GB)

Policy Density 时间 内存峰值 准确率
Full 100% 120.9s 16.4GB (稳定) 100%
XAttn BSA ~52% 152.3s 19.8GB 100%

Density 变化趋势

Chunk Full XAttn BSA
10 100% 90%
30 100% 73%
60 100% 50%
100 100% 50%
126 100% 52%

观察XAttn BSA 的 density 随 chunks 增加而下降,最终稳定在 ~50%。

性能分析

当前问题XAttn BSA 虽然 density 只有 ~52%,但时间反而比 Full 更长152s vs 121s

原因select_blocks 需要加载所有 K blocks 来估计 attention scores导致每个 block 被加载两次:

  1. 估计阶段:加载 K 计算 attention scores
  2. 计算阶段:加载选中的 K/V 进行实际计算

优化方向

  1. 跨层共享估计结果layer 0 估计,其他层复用)
  2. 采样估计(只用部分 K blocks 估计)
  3. 缓存估计结果避免重复计算

内存管理

内存泄漏问题 (已修复)

问题128K prefill 时 GPU 内存从 16GB 增长到 80GB。

根因

# 问题代码:累积存储但从未使用
self.sparse_metadata[layer_id] = attn_scores

每个 chunk 的每个 layer 都存储 attn_scores,导致内存持续增长。

修复方法

# 1. 删除无用的 sparse_metadata 存储

# 2. 立即释放中间变量
del attn_scores_list
del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected

修复效果

版本 内存增长 峰值
修复前 +64GB 80GB
修复后 +4GB 19.8GB

内存监控

使用 gpu-monitor agent 监控内存:

# 启动监控
# 在 Claude Code 中使用 Task tool 启动 gpu-monitor agent

# 或手动监控
watch -n 1 'nvidia-smi --query-gpu=memory.used --format=csv,noheader -i 0'

Density 统计 API

启用统计

# 统计自动在 select_blocks 中更新(仅 layer 0
# 使用 logger.debug 输出每 chunk 的 density

获取统计

policy = XAttentionBSAPolicy(threshold=0.95)

# 运行 prefill 后...

# 获取统计
stats = policy.get_density_stats()
# {
#     "total_available_blocks": 8001,
#     "total_selected_blocks": 4160,
#     "num_chunks": 126,
#     "overall_density": 0.52
# }

# 打印统计
policy.print_density_stats()

# 重置统计
policy.reset_stats()

启用 DEBUG 日志

# 在 test_ruler.py 中
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"

# 输出示例:
# [XAttn] chunk=30, available=30, selected=22, chunk_density=73.3%

已知问题

问题 状态 说明
估计开销过大 🟡 待优化 select_blocks 需要加载所有 K blocks
时间比 Full 更长 🟡 待优化 128K 场景 152s vs 121s
小幅内存增长 🟢 可接受 ~4GB可能来自 Triton 缓存
Decode 不支持 设计如此 使用 FullAttentionPolicy

相关文档