Add comprehensive documentation for the MIT-Han-Lab Block-Sparse-Attention library (3rdparty submodule, branch: tzj/minference). The new document covers: - Four sparse attention modes (dense, token/block streaming, block sparse) - Hybrid mask support (different patterns per head) - Complete API reference for all three functions - Performance benchmarks (up to 3-4x speedup on A100) - Integration considerations for nano-vllm Co-Authored-By: Claude <noreply@anthropic.com>
5.7 KiB
5.7 KiB
Block-Sparse-Attention Library Reference
MIT Han Lab 的块稀疏注意力内核库,基于 FlashAttention 2.4.2 修改,支持多种稀疏注意力模式。
库信息
- 来源: MIT-Han-Lab/Block-Sparse-Attention
- 本地路径:
3rdparty/Block-Sparse-Attention(submodule, branch:tzj/minference) - 基于: FlashAttention 2.4.2
- 安装位置:
site-packages/block_sparse_attn
支持的稀疏模式
1. Dense Attention
计算完整注意力矩阵,无稀疏化。
2. Token Streaming (token granularity)
固定数量的 sink tokens + local tokens,参考 StreamingLLM。
适用场景: 需要精确保留部分关键 token 的长上下文推理
3. Block Streaming (block granularity)
Block 粒度的 streaming attention,block_size = 128。
适用场景: 长序列推理,牺牲少量精度换取更大加速
4. Block Sparse
基于自定义 block mask 的稀疏注意力。
适用场景: 已知特定 attention 模式的工作负载
混合模式
关键特性: 支持不同 head 使用不同稀疏模式
# 8 个 heads 的混合配置示例
head_mask_type = [1, 1, 0, 0, 0, -1, 0, -1]
# 含义:
# - head 0,1: blocksparse (使用 basemask[0])
# - head 2-4,6: dense
# - head 5,7: streaming
Mask 类型编码:
0= Dense attention-1= Streaming attention1, 2, ...= Block sparse (使用 basemask[mask_type - 1])
API 参考
block_sparse_attn_func
通用块稀疏注意力函数,支持所有模式。
from block_sparse_attn import block_sparse_attn_func
output = block_sparse_attn_func(
q, k, v, # [total_tokens, heads, head_dim] unpadded
cu_seqlens_q, cu_seqlens_k, # cumulative sequence lengths
head_mask_type, # [heads] tensor, 每个头的模式
streaming_info, # streaming 配置 (sink/local 数量)
base_blockmask, # [q_blocks, k_blocks, n_masks] bool tensor
max_seqlen_q, max_seqlen_k, # 最大序列长度
p_dropout, # dropout 概率 (推理时设为 0.0)
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False, # True=token streaming, False=block streaming
return_attn_probs=False,
)
关键参数:
| 参数 | 类型 | 说明 |
|---|---|---|
head_mask_type |
Tensor[heads] | 每个头的稀疏模式,0=dense, -1=streaming, 1+=blocksparse |
streaming_info |
Tensor | [sink_blocks, local_blocks] 或 [sink_tokens, local_tokens] |
base_blockmask |
Tensor | Block mask,形状 [q_blocks, k_blocks, n_masks] |
exact_streaming |
bool | True=token 粒度,False=block 粒度 streaming |
block_streaming_attn_func
Block 粒度 streaming attention(block_size=128)。
from block_sparse_attn import block_streaming_attn_func
output = block_streaming_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info, # [sink_blocks, local_blocks]
max_seqlen_q, max_seqlen_k,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=True,
return_attn_probs=False,
)
token_streaming_attn_func
Token 粒度 streaming attention。
注意: 不支持反向传播(仅推理)。
from block_sparse_attn import token_streaming_attn_func
output = token_streaming_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info, # [sink_tokens, local_tokens]
max_seqlen_q, max_seqlen_k,
deterministic=False,
softmax_scale=None,
return_attn_probs=False,
)
技术规格
| 特性 | 支持情况 |
|---|---|
| 数据类型 | fp16, bf16 (bf16 需要 Ampere/Ada/Hopper GPU) |
| Head 维度 | 32, 64, 128 |
| Block Size | 128 (固定) |
| CUDA 要求 | 11.6+ |
| PyTorch 要求 | 1.12+ |
性能参考
测试环境: A100 GPU, head_dim=128, 32 heads, batch_size=1
Block Sparse 加速比
- 相比 FlashAttention2: 最高 3-4x 加速
- 加速随序列长度增加而提升
Streaming 混合模式加速比
- Token streaming: 64 sink + 256 local tokens
- Block streaming: 1 sink block + 3 local blocks
- 50% Dense + 50% Streaming: 最高 2x 加速
与 nano-vllm 的集成考虑
潜在集成点
-
长上下文推理优化
- 使用 block streaming 减少计算量
- 在 CPU offload 模式下减少 GPU-CPU 传输
-
混合注意力策略
- 部分 head 使用 streaming(减少计算)
- 部分 head 使用 dense(保持精度)
- 参考 Duo Attention 论文的混合模式
-
稀疏 offload
- 只 offload 重要 blocks 的 KV cache
- 结合
requires_block_selection接口
实现注意事项
- 输入格式: 库使用 unpadded 格式(
cu_seqlens),需要与 nano-vllm 的 padded 格式转换 - Block size 固定: 库固定 block_size=128,需要适配
- Streaming info 配置: 需要根据模型特性调整 sink/local 数量
相关工作
- FlashAttention - 基础实现
- StreamingLLM - Streaming attention 理论基础
- Duo Attention - 混合 dense/streaming 模式
- MInference - 混合 mask 方法
测试
库自带测试位于 3rdparty/Block-Sparse-Attention/block_sparse_tests/:
# 正确性测试
cd 3rdparty/Block-Sparse-Attention/block_sparse_tests/fwd/test_correctness
pytest full_test.py
# 性能测试
cd 3rdparty/Block-Sparse-Attention/block_sparse_tests/fwd/test_performance
python token_streaming.py
python blocksparse.py