Files
nano-vllm/docs/sparse_prefill_integration_plan.md
Zijie Tian 2771312565 [docs] Add sparse prefill integration plan from int-minference analysis
Consolidated analysis from int-minference-1/2/3 branches into a unified
integration plan for MInference, XAttention, and FlexPrefill strategies.

Key design decisions:
- Backward compatible: Keep existing SparsePolicy interface
- Unified BlockMask intermediate representation for new strategies
- XAttention/FlexPrefill use block_sparse_attn_func kernel
- MInference can optionally use block_sparse_attn (Phase 4)

Five-phase implementation plan:
1. BlockMask + block_sparse_attn wrapper
2. XAttention implementation
3. FlexPrefill implementation
4. Optional MInference refactoring
5. Integration and testing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-10 23:33:09 +08:00

15 KiB
Raw Blame History

Sparse Prefill Attention Integration Plan

Executive Summary

本文档整合了 int-minference-1/2/3 三个分支的分析提出统一的三种稀疏注意力策略MInference、XAttention、FlexPrefill集成方案。


Part 1: 现状分析

1.1 x-attention 仓库策略对比

策略 Pattern 类型 估计方法 Kernel Backend
MInference Vertical + Slash Last-64-Q attention → 列/对角线求和 vertical_slash_sparse_attention (minference lib)
XAttention Block Mask Stride-based Q/K 下采样 → block 分数 block_sparse_attn_func (MIT-HAN-LAB)
FlexPrefill Adaptive V+S Last-block attention + JS 散度自适应 triton_block_wise_attention (custom triton)

1.2 关键发现:两种 Kernel 接口

接口 A: Index-Based (minference)

# MInference 使用 vertical+slash indices
vertical_indices = [heads, vertical_size]  # 重要 K 列位置
slash_indices = [heads, slash_size]        # 对角线偏移
output = vertical_slash_sparse_attention(q, k, v, vertical_indices, slash_indices)

接口 B: Block Mask-Based (block_sparse_attn)

# XAttention/FlexPrefill 使用 boolean block mask
block_mask = torch.bool[batch, heads, q_blocks, k_blocks]  # True = 计算
output = block_sparse_attn_func(q, k, v, block_mask, ...)

1.3 当前 nanovllm MInference 实现

文件: nanovllm/kvcache/sparse/minference.py

已实现功能:

  • estimate_pattern(): 使用 last-64-Q 估计 vertical+slash pattern
  • sparse_prefill_attention(): 调用 minference kernel 执行稀疏注意力
  • 支持 GQA通过 K/V repeat_interleave
  • 支持 adaptive_budget 自适应预算

问题:

  1. 与 XAttention/FlexPrefill 使用不同 kernel无法统一接口
  2. sparse_prefill_attention() 将估计和执行耦合在一起
  3. 没有 BlockMask 中间表示,难以复用

Part 2: 架构设计

2.1 设计原则

  1. 向后兼容: 保持现有 SparsePolicy 接口不变
  2. 渐进式重构: 添加新功能而非替换
  3. 统一中间表示: 新策略使用 BlockMask 作为可选中间表示
  4. 可插拔 Kernel: 支持多种 attention kernel backend

2.2 架构图

┌──────────────────────────────────────────────────────────────────────────────┐
│                       Unified Sparse Prefill Framework                        │
├──────────────────────────────────────────────────────────────────────────────┤
│                                                                               │
│  ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐               │
│  │   MInference    │  │   XAttention    │  │   FlexPrefill   │  Strategies   │
│  │   Policy        │  │   Policy        │  │   Policy        │               │
│  └────────┬────────┘  └────────┬────────┘  └────────┬────────┘               │
│           │                    │                    │                         │
│           │ (indices)          │ (BlockMask)        │ (BlockMask)             │
│           │                    │                    │                         │
│           ▼                    └────────┬───────────┘                         │
│  ┌─────────────────┐                    ▼                                     │
│  │   minference    │  ┌─────────────────────────────────────────────────────┐│
│  │   kernel        │  │              BlockMask Container                    ││
│  └────────┬────────┘  │  [batch, num_heads, q_blocks, k_blocks] - boolean   ││
│           │           └─────────────────────────────────────────────────────┘│
│           │                              │                                    │
│           │                              ▼                                    │
│           │           ┌─────────────────────────────────────────────────────┐│
│           │           │         block_sparse_attn_func                      ││
│           │           │         (MIT-HAN-LAB kernel)                        ││
│           │           └─────────────────────────────────────────────────────┘│
│           │                              │                                    │
│           └──────────────────────────────┼────────────────────────────────── │
│                                          ▼                                    │
│  ┌─────────────────────────────────────────────────────────────────────────┐ │
│  │                        Attention Output                                  │ │
│  │                   [seq_len, num_heads, head_dim]                         │ │
│  └─────────────────────────────────────────────────────────────────────────┘ │
│                                                                               │
└──────────────────────────────────────────────────────────────────────────────┘

2.3 新增类设计

# nanovllm/kvcache/sparse/block_mask.py

@dataclass
class BlockMask:
    """Block-level attention mask container."""
    mask: torch.Tensor      # [batch, heads, q_blocks, k_blocks]
    block_size: int
    seq_len: int
    num_q_blocks: int
    num_k_blocks: int

    def sparsity_ratio(self) -> float:
        """Fraction of blocks masked out."""
        return 1.0 - self.mask.float().mean().item()

    def to_flat_indices(self, head_idx: int) -> torch.Tensor:
        """Convert to flattened block indices for a given head."""
        pass

    @classmethod
    def from_vertical_slash(
        cls,
        vertical_idx: torch.Tensor,
        slash_idx: torch.Tensor,
        seq_len: int,
        block_size: int,
    ) -> "BlockMask":
        """Convert MInference-style indices to block mask."""
        pass

    def apply_causal(self) -> "BlockMask":
        """Apply causal constraint (lower triangular)."""
        pass
# nanovllm/kvcache/sparse/kernels/block_sparse.py

def block_sparse_attention(
    q: torch.Tensor,           # [seq_len, num_heads, head_dim]
    k: torch.Tensor,           # [seq_len, num_kv_heads, head_dim]
    v: torch.Tensor,           # [seq_len, num_kv_heads, head_dim]
    block_mask: BlockMask,
) -> torch.Tensor:
    """
    Execute block sparse attention using MIT-HAN-LAB kernel.

    Handles:
    - GQA expansion (K/V heads < Q heads)
    - Tensor format conversion
    - Causal masking
    """
    from block_sparse_attn import block_sparse_attn_func
    # ... implementation

Part 3: 实现计划

Phase 1: 基础设施 (新增文件)

目标: 添加 BlockMask 和 block_sparse_attn 封装

文件:

  • nanovllm/kvcache/sparse/block_mask.py (NEW)
  • nanovllm/kvcache/sparse/kernels/__init__.py (NEW)
  • nanovllm/kvcache/sparse/kernels/block_sparse.py (NEW)

任务:

  1. 实现 BlockMask 数据类
  2. 实现 block_sparse_attention() 封装函数
  3. 处理 GQA 和 tensor 格式转换
  4. 测试:使用全 True 的 block mask 验证输出正确

Phase 2: XAttention 实现

目标: 移植 x-attention 的 XAttention 策略

文件:

  • nanovllm/kvcache/sparse/xattention.py (NEW)
  • nanovllm/config.py (添加 XATTENTION 枚举)
  • nanovllm/kvcache/sparse/__init__.py (更新工厂函数)

关键函数移植:

# From x-attention/xattn/src/Xattention.py
def xattn_estimate(q, k, block_size, stride, threshold, ...):
    # 1. Stride-based Q/K downsampling
    reshaped_k = cat([k[:, :, i::stride, :] for i in range(stride)], dim=-1)
    reshaped_q = cat([q[:, :, stride-1-i::stride, :] for i in range(stride)], dim=-1)

    # 2. Block-level attention scores
    attn_weights = matmul(reshaped_q, reshaped_k.T) / sqrt(d) / stride

    # 3. Threshold selection
    block_mask = find_blocks_chunked(attn_sum, threshold)
    return block_mask

配置参数:

xattention_stride: int = 16           # Q/K 下采样步长
xattention_threshold: float = 0.9     # 累积分数阈值
xattention_block_size: int = 128      # Block 大小

测试: python tests/test_needle.py --input-len 32768 --enable-xattention

Phase 3: FlexPrefill 实现

目标: 移植 x-attention 的 FlexPrefill 策略

文件:

  • nanovllm/kvcache/sparse/flexprefill.py (NEW)
  • nanovllm/config.py (添加 FLEXPREFILL 枚举)

关键函数移植:

# From x-attention/xattn/src/Flexprefill.py
def get_active_blocks(q, k, gamma, tau, block_size, ...):
    # 1. Last-block attention analysis
    last_q = q[:, -block_size:, :, :]
    qk = einsum('bihd,bjhd->bhij', last_q, k)

    # 2. Vertical + slash pattern detection
    vertical = qk.mean(-2)  # Column importance
    slash = sum_all_diagonal_matrix(qk)  # Diagonal importance

    # 3. JS divergence for adaptive budget
    kl_div = js_divergence(avg_qk, vertical_pooled)
    is_sparse_head = kl_div > tau
    budget = gamma if is_sparse_head else 1.0

    # 4. Select blocks
    block_idx = transform_vertical_slash_idx(...)
    return block_mask

配置参数:

flexprefill_gamma: float = 0.9        # 基础覆盖率
flexprefill_tau: float = 0.1          # JS 散度阈值
flexprefill_min_budget: int = 128     # 最小 token 预算
flexprefill_block_size: int = 128     # Block 大小

测试: python tests/test_needle.py --input-len 32768 --enable-flexprefill

Phase 4: MInference 可选重构

目标: (可选) 让 MInference 也可以使用 block_sparse_attn

修改文件:

  • nanovllm/kvcache/sparse/minference.py

新增方法:

class MInferencePolicy(SparsePolicy):
    def __init__(self, ..., use_block_sparse: bool = False):
        self.use_block_sparse = use_block_sparse

    def estimate_block_mask(self, q, k, layer_id) -> BlockMask:
        """Convert vertical+slash indices to BlockMask."""
        vertical_idx, slash_idx = self.estimate_pattern(q, k, layer_id)
        return BlockMask.from_vertical_slash(vertical_idx, slash_idx, ...)

    def sparse_prefill_attention(self, q, k, v, layer_id):
        if self.use_block_sparse:
            block_mask = self.estimate_block_mask(q, k, layer_id)
            return block_sparse_attention(q, k, v, block_mask)
        else:
            # 使用原有 minference kernel
            return self._minference_kernel_attention(q, k, v, layer_id)

Phase 5: 集成和测试

任务:

  1. 更新 __init__.py 工厂函数支持所有策略
  2. 更新 Config 添加所有配置参数
  3. 添加性能基准测试脚本
  4. 更新文档

Part 4: 依赖管理

必需依赖

# requirements.txt 新增
block-sparse-attn    # MIT-HAN-LAB block sparse kernel
triton>=2.0          # FlexPrefill Triton kernels

安装说明

# block_sparse_attn from MIT-HAN-LAB
pip install git+https://github.com/mit-han-lab/Block-Sparse-Attention.git

# 或从本地安装(如果有)
cd /home/zijie/Code/x-attention/Block-Sparse-Attention
pip install -e .

Part 5: 配置参数汇总

SparsePolicyType 枚举

class SparsePolicyType(str, Enum):
    FULL = "full"              # 全注意力(无稀疏)
    QUEST = "quest"            # Decode-only Top-K
    MINFERENCE = "minference"  # Prefill vertical+slash
    XATTENTION = "xattention"  # Prefill stride-based block
    FLEXPREFILL = "flexprefill"  # Prefill adaptive JS-divergence

策略参数对照表

策略 参数 默认值 说明
MInference adaptive_budget 0.3 预算占 seq_len 比例
MInference vertical_size 1000 固定 vertical 大小
MInference slash_size 6096 固定 slash 大小
XAttention stride 16 Q/K 下采样步长
XAttention threshold 0.9 累积分数阈值
XAttention block_size 128 Block 大小
FlexPrefill gamma 0.9 基础覆盖率
FlexPrefill tau 0.1 JS 散度阈值
FlexPrefill min_budget 128 最小 token 预算
FlexPrefill block_size 128 Block 大小

Part 6: 成功标准

  1. 正确性: 所有三种策略通过 32K+ needle-in-haystack 测试
  2. 性能: 稀疏 prefill 比全注意力快 (>1.5x speedup at 64K)
  3. 统一接口: XAttention/FlexPrefill 使用 BlockMask + block_sparse_attn
  4. 向后兼容: 现有 MInference 配置继续工作
  5. 可配置: 所有策略参数可通过 LLM 配置设置

Part 7: 风险评估

风险 影响 可能性 缓解措施
block_sparse_attn 硬件兼容性 测试目标硬件fallback 到 flash_attn
MInference → block mask 精度损失 对比测试输出差异
Triton kernel 移植问题 使用非 Triton fallback
内存开销增加 block_size=128 → 1KB/head for 128K

References

  • x-attention repo: /home/zijie/Code/x-attention
  • MIT-HAN-LAB Block-Sparse-Attention: https://github.com/mit-han-lab/Block-Sparse-Attention
  • MInference paper: https://arxiv.org/abs/2407.02490
  • Current nanovllm sparse implementation: nanovllm/kvcache/sparse/