Compare commits
2 Commits
07f5220f40
...
690456dbf9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
690456dbf9 | ||
|
|
e440c45e73 |
@@ -14,6 +14,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
| [`docs/sparse_policy_architecture.md`](docs/sparse_policy_architecture.md) | SparsePolicy abstraction: prefill/decode delegation, pipeline modes, policy implementations |
|
| [`docs/sparse_policy_architecture.md`](docs/sparse_policy_architecture.md) | SparsePolicy abstraction: prefill/decode delegation, pipeline modes, policy implementations |
|
||||||
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
||||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||||
|
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
||||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||||
| [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition |
|
| [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition |
|
||||||
|
|||||||
@@ -50,30 +50,35 @@ output = block_sparse_attn_func(
|
|||||||
|
|
||||||
## Method 1: XAttention (xattn_estimate)
|
## Method 1: XAttention (xattn_estimate)
|
||||||
|
|
||||||
**Source**: `xattn/src/Xattention.py`
|
**Source**: `compass/src/Xattention.py`
|
||||||
|
|
||||||
|
**详细文档**: [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md)
|
||||||
|
|
||||||
### Core Idea
|
### Core Idea
|
||||||
|
|
||||||
Use **strided Q/K reshaping** to create coarse-grained representations, compute block-level attention scores, and select blocks above a threshold.
|
Use **stride interleaved reshape (inverse mode)** to efficiently estimate block-level attention importance, then use **BSA (Block Sparse Attention)** library for sparse computation.
|
||||||
|
|
||||||
### Algorithm
|
### Algorithm
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def xattn_estimate(query, key, block_size=64, stride=16):
|
def xattn_estimate(query, key, block_size=128, stride=8):
|
||||||
"""
|
"""
|
||||||
Estimate block importance using strided attention.
|
Estimate block importance using stride-interleaved attention.
|
||||||
|
|
||||||
1. Reshape Q: [batch, seq, heads, dim] -> [batch, num_blocks, stride, heads, dim]
|
1. K reshape (正向交错): concat([K[:,:,k::stride,:] for k in range(stride)])
|
||||||
Then take mean over stride dimension to get block-level Q
|
Q reshape (反向交错): concat([Q[:,:,(stride-1-q)::stride,:] for q])
|
||||||
|
结果: 序列长度 seq_len -> seq_len/stride, head_dim -> head_dim*stride
|
||||||
|
|
||||||
2. Reshape K: Same process to get block-level K
|
2. Triton kernel (flat_group_gemm_fuse_reshape):
|
||||||
|
融合 reshape + GEMM,计算 Q_reshaped @ K_reshaped^T
|
||||||
|
|
||||||
3. Compute block attention: softmax(block_Q @ block_K.T / sqrt(d))
|
3. Triton kernel (softmax_fuse_block_sum):
|
||||||
Result shape: [batch, heads, q_blocks, k_blocks]
|
在线 softmax + 按 block_size/stride 分组求和
|
||||||
|
输出: attn_sum [batch, heads, q_blocks, k_blocks]
|
||||||
|
|
||||||
4. Apply causal mask (upper triangle = 0)
|
4. find_blocks_chunked:
|
||||||
|
按 attn_sum 降序排序,累积到 threshold 的块标记为 True
|
||||||
5. Threshold: blocks with score > threshold are selected
|
对角块和 sink 块始终保留
|
||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -81,45 +86,60 @@ def xattn_estimate(query, key, block_size=64, stride=16):
|
|||||||
|
|
||||||
| Parameter | Default | Description |
|
| Parameter | Default | Description |
|
||||||
|-----------|---------|-------------|
|
|-----------|---------|-------------|
|
||||||
| `block_size` | 64 | Tokens per block |
|
| `block_size` | 128 | Tokens per block (BSA 要求固定 128) |
|
||||||
| `stride` | 16 | Stride for coarse Q/K computation |
|
| `stride` | 8 | Q/K 交错采样步长,越大估计越快但越粗糙 |
|
||||||
| `threshold` | 0.9 | Selection threshold (cumulative or direct) |
|
| `threshold` | 0.9 | 累积注意力阈值,选择累积权重达到此比例的块 |
|
||||||
|
| `chunk_size` | 16384 | 估计时的分块大小 |
|
||||||
|
|
||||||
### Computation Flow
|
### Computation Flow
|
||||||
|
|
||||||
```
|
```
|
||||||
query [B, S, H, D]
|
query [B, H, S, D]
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
Reshape to [B, num_blocks, stride, H, D]
|
Stride interleaved reshape (Triton fused)
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
Mean over stride -> block_q [B, num_blocks, H, D]
|
flat_group_gemm_fuse_reshape: Q_r @ K_r^T
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
Compute block attention scores [B, H, q_blocks, k_blocks]
|
softmax_fuse_block_sum: 在线 softmax + 块求和
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
Apply threshold -> block_mask [B, H, q_blocks, k_blocks]
|
attn_sum [B, H, q_blocks, k_blocks]
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
block_sparse_attn_func(q, k, v, block_mask)
|
find_blocks_chunked: 累积阈值选择
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
output [B, S, H, D]
|
simple_mask [B, H, q_blocks, k_blocks] (bool)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
block_sparse_attn_func(q, k, v, simple_mask) ← BSA 库
|
||||||
|
|
|
||||||
|
v
|
||||||
|
output [B, H, S, D]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
```python
|
||||||
|
from block_sparse_attn import block_sparse_attn_func # MIT-HAN-LAB BSA 库
|
||||||
|
import triton # Triton kernels for estimation
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from xattn.src.Xattention import Xattention_prefill
|
from compass.src.Xattention import Xattention_prefill
|
||||||
|
|
||||||
output = Xattention_prefill(
|
output = Xattention_prefill(
|
||||||
query_states, key_states, value_states,
|
query_states, key_states, value_states,
|
||||||
threshold=0.9,
|
threshold=0.9,
|
||||||
stride=16,
|
stride=8,
|
||||||
|
block_size=128,
|
||||||
|
use_triton=True,
|
||||||
)
|
)
|
||||||
```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
349
docs/xattention_algorithm_guide.md
Normal file
349
docs/xattention_algorithm_guide.md
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
# XAttention 算法实现指南
|
||||||
|
|
||||||
|
本文档详细描述 COMPASS 项目中 XAttention 的算法原理和实现细节。
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
XAttention 是一种基于 **stride reshape** 的块稀疏注意力方法,通过低成本估计识别重要块,然后使用 **BSA (Block Sparse Attention)** 库执行稀疏计算。
|
||||||
|
|
||||||
|
### 核心依赖
|
||||||
|
|
||||||
|
| 组件 | 来源 | 作用 |
|
||||||
|
|------|------|------|
|
||||||
|
| Triton Kernels | COMPASS 自研 | Q/K reshape + 块级估计 |
|
||||||
|
| BSA | MIT-HAN-LAB `block_sparse_attn` | 稀疏注意力计算 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 算法流程
|
||||||
|
|
||||||
|
```
|
||||||
|
输入: Q [batch, heads, q_len, head_dim]
|
||||||
|
K [batch, heads, k_len, head_dim]
|
||||||
|
V [batch, heads, k_len, head_dim]
|
||||||
|
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Phase 1: Stride Reshape (inverse 模式) │
|
||||||
|
│ │
|
||||||
|
│ K_reshaped = concat([K[:,:,k::stride,:] for k in stride]) │
|
||||||
|
│ Q_reshaped = concat([Q[:,:,(stride-1-q)::stride,:] for q]) │
|
||||||
|
│ │
|
||||||
|
│ 效果: 序列长度从 seq_len 缩短到 seq_len/stride │
|
||||||
|
│ head_dim 扩展到 head_dim * stride │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Phase 2: 块级注意力估计 (Triton 加速) │
|
||||||
|
│ │
|
||||||
|
│ 2a. flat_group_gemm_fuse_reshape: │
|
||||||
|
│ 计算 Q_reshaped @ K_reshaped^T │
|
||||||
|
│ 输出: attn_weights [batch, heads, q_len/stride, k_len/stride] │
|
||||||
|
│ │
|
||||||
|
│ 2b. softmax_fuse_block_sum: │
|
||||||
|
│ - 在线 softmax (数值稳定) │
|
||||||
|
│ - 按 block_size/stride 分组求和 │
|
||||||
|
│ 输出: attn_sum [batch, heads, q_blocks, k_blocks] │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Phase 3: 块选择 (find_blocks_chunked) │
|
||||||
|
│ │
|
||||||
|
│ 对每个 Q block: │
|
||||||
|
│ 1. 按 attn_sum 降序排序 K blocks │
|
||||||
|
│ 2. 累积求和直到 >= threshold * total_sum │
|
||||||
|
│ 3. 累积到的 blocks 标记为 True │
|
||||||
|
│ │
|
||||||
|
│ 特殊处理: │
|
||||||
|
│ - 对角块 (causal) 始终保留 │
|
||||||
|
│ - Sink 块 (block 0) 可选保留 │
|
||||||
|
│ │
|
||||||
|
│ 输出: simple_mask [batch, heads, q_blocks, k_blocks] (bool) │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Phase 4: 稀疏注意力计算 (BSA) │
|
||||||
|
│ │
|
||||||
|
│ attn_output = block_sparse_attn_func( │
|
||||||
|
│ Q, K, V, │
|
||||||
|
│ q_cu_seq_lens, # [0, q_len] │
|
||||||
|
│ k_cu_seq_lens, # [0, k_len] │
|
||||||
|
│ head_mask_type, # [num_heads] 全 1 │
|
||||||
|
│ None, # left_mask │
|
||||||
|
│ simple_mask, # 块稀疏 mask │
|
||||||
|
│ q_len, k_len, │
|
||||||
|
│ is_causal=True, │
|
||||||
|
│ ) │
|
||||||
|
│ │
|
||||||
|
│ 输出: attn_output [batch, heads, q_len, head_dim] │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stride Reshape 详解
|
||||||
|
|
||||||
|
### Inverse 模式
|
||||||
|
|
||||||
|
XAttention 默认使用 `select_mode="inverse"`,这是一种交错采样策略:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 原始: Q/K shape = [batch, heads, seq_len, head_dim]
|
||||||
|
# stride = 8
|
||||||
|
|
||||||
|
# K reshape: 正向交错
|
||||||
|
K_reshaped = concat([K[:, :, 0::8, :], # 位置 0, 8, 16, ...
|
||||||
|
K[:, :, 1::8, :], # 位置 1, 9, 17, ...
|
||||||
|
K[:, :, 2::8, :], # 位置 2, 10, 18, ...
|
||||||
|
...
|
||||||
|
K[:, :, 7::8, :]]) # 位置 7, 15, 23, ...
|
||||||
|
# 结果: [batch, heads, seq_len/8, head_dim * 8]
|
||||||
|
|
||||||
|
# Q reshape: 反向交错 (inverse)
|
||||||
|
Q_reshaped = concat([Q[:, :, 7::8, :], # 位置 7, 15, 23, ...
|
||||||
|
Q[:, :, 6::8, :], # 位置 6, 14, 22, ...
|
||||||
|
Q[:, :, 5::8, :], # 位置 5, 13, 21, ...
|
||||||
|
...
|
||||||
|
Q[:, :, 0::8, :]]) # 位置 0, 8, 16, ...
|
||||||
|
# 结果: [batch, heads, seq_len/8, head_dim * 8]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 为什么用 Inverse 模式?
|
||||||
|
|
||||||
|
当计算 `Q_reshaped @ K_reshaped^T` 时,inverse 模式使得:
|
||||||
|
- Q 的后半部分与 K 的前半部分对齐
|
||||||
|
- 这样可以近似捕获 **causal attention 的对角模式**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Triton Kernels 详解
|
||||||
|
|
||||||
|
### 1. flat_group_gemm_fuse_reshape
|
||||||
|
|
||||||
|
**文件**: `compass/src/kernels.py:198-235`
|
||||||
|
|
||||||
|
**功能**: 融合 stride reshape 和 GEMM,避免显式创建 reshape 后的大张量
|
||||||
|
|
||||||
|
```python
|
||||||
|
@triton.jit
|
||||||
|
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, ...):
|
||||||
|
# 关键: 不实际 reshape,而是通过指针算术模拟
|
||||||
|
Q_ptrs = Q + block_m * BLOCK_M * STRIDE * stride_qn
|
||||||
|
K_ptrs = K + block_n * BLOCK_N * STRIDE * stride_kn
|
||||||
|
|
||||||
|
# 对 stride 个位置累加
|
||||||
|
for iter in range(STRIDE):
|
||||||
|
q = tl.load(Q_ptrs - iter * stride_qn) # Q inverse 采样
|
||||||
|
k = tl.load(K_ptrs + iter * stride_kn) # K 正向采样
|
||||||
|
o += tl.dot(q, k)
|
||||||
|
```
|
||||||
|
|
||||||
|
**优势**:
|
||||||
|
- 内存节省: 不需要创建 `[batch, heads, seq_len/stride, head_dim*stride]` 的中间张量
|
||||||
|
- 计算融合: reshape + GEMM 一次完成
|
||||||
|
|
||||||
|
### 2. softmax_fuse_block_sum
|
||||||
|
|
||||||
|
**文件**: `compass/src/kernels.py:6-95`
|
||||||
|
|
||||||
|
**功能**: 在线 softmax + 块内求和
|
||||||
|
|
||||||
|
```python
|
||||||
|
@triton.jit
|
||||||
|
def softmax_fuse_block_sum_kernel_causal(In, Out, ...):
|
||||||
|
# Pass 1: 计算全局 max 和 sum (在线算法)
|
||||||
|
for iter in range(num_iters):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size) * scale
|
||||||
|
m_local = tl.max(X, 1)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
X = X - m_new[:, None]
|
||||||
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
# Pass 2: 归一化并按块求和
|
||||||
|
for iter in range(num_iters):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size) * scale
|
||||||
|
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] # softmax
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
X = tl.sum(X, 2).sum(0) # 块内求和
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X)
|
||||||
|
```
|
||||||
|
|
||||||
|
**输出含义**: `attn_sum[b, h, qi, ki]` = Q block qi 对 K block ki 的**归一化注意力权重之和**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 块选择算法 (find_blocks_chunked)
|
||||||
|
|
||||||
|
**文件**: `compass/src/utils.py:44-191`
|
||||||
|
|
||||||
|
### 算法步骤
|
||||||
|
|
||||||
|
```python
|
||||||
|
def find_blocks_chunked(input_tensor, current_index, threshold, ...):
|
||||||
|
"""
|
||||||
|
input_tensor: [batch, heads, q_blocks, k_blocks] - 块级注意力权重和
|
||||||
|
threshold: 0.9 - 累积阈值
|
||||||
|
"""
|
||||||
|
# 1. 计算每行总和
|
||||||
|
total_sum = input_tensor.sum(dim=-1, keepdim=True)
|
||||||
|
required_sum = total_sum * threshold # 需要达到的累积和
|
||||||
|
|
||||||
|
# 2. 特殊块始终保留
|
||||||
|
mask = zeros_like(input_tensor, dtype=bool)
|
||||||
|
mask[:, :, :, 0] = True # sink 块
|
||||||
|
mask[:, :, :, diagonal] = True # 对角块 (causal)
|
||||||
|
|
||||||
|
# 3. 对剩余块按权重排序
|
||||||
|
other_values = input_tensor.masked_fill(mask, 0)
|
||||||
|
sorted_values, index = sort(other_values, descending=True)
|
||||||
|
|
||||||
|
# 4. 累积求和直到达到阈值
|
||||||
|
cumsum = sorted_values.cumsum(dim=-1)
|
||||||
|
index_mask = cumsum < required_sum
|
||||||
|
|
||||||
|
# 5. 标记选中的块
|
||||||
|
mask[..., index[index_mask]] = True
|
||||||
|
|
||||||
|
return mask
|
||||||
|
```
|
||||||
|
|
||||||
|
### 示例
|
||||||
|
|
||||||
|
```
|
||||||
|
threshold = 0.9
|
||||||
|
attn_sum 某一行 = [0.05, 0.30, 0.40, 0.15, 0.10] (已 softmax, 和为 1.0)
|
||||||
|
required_sum = 0.9
|
||||||
|
|
||||||
|
排序后: [0.40, 0.30, 0.15, 0.10, 0.05]
|
||||||
|
累积和: [0.40, 0.70, 0.85, 0.95, 1.00]
|
||||||
|
↑ 达到 0.9
|
||||||
|
|
||||||
|
选中: 前 4 个块 (indices: 2, 1, 3, 4)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## BSA (Block Sparse Attention)
|
||||||
|
|
||||||
|
### 库来源
|
||||||
|
|
||||||
|
```python
|
||||||
|
from block_sparse_attn import block_sparse_attn_func
|
||||||
|
```
|
||||||
|
|
||||||
|
来自 MIT-HAN-LAB,提供基于块 mask 的高效稀疏 FlashAttention 实现。
|
||||||
|
|
||||||
|
### 接口
|
||||||
|
|
||||||
|
```python
|
||||||
|
attn_output = block_sparse_attn_func(
|
||||||
|
query_states, # [total_q, num_heads, head_dim]
|
||||||
|
key_states, # [total_k, num_heads, head_dim]
|
||||||
|
value_states, # [total_k, num_heads, head_dim]
|
||||||
|
q_cu_seq_lens, # [batch+1] cumulative sequence lengths
|
||||||
|
k_cu_seq_lens, # [batch+1]
|
||||||
|
head_mask_type, # [num_heads] int32, 1=causal, 0=full
|
||||||
|
left_mask, # Optional left padding mask
|
||||||
|
block_mask, # [batch, heads, q_blocks, k_blocks] bool
|
||||||
|
max_seqlen_q, # int
|
||||||
|
max_seqlen_k, # int
|
||||||
|
p_dropout=0.0,
|
||||||
|
deterministic=True,
|
||||||
|
is_causal=True, # 全局 causal flag
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 块大小要求
|
||||||
|
|
||||||
|
BSA 要求 **block_size = 128**(硬编码):
|
||||||
|
```python
|
||||||
|
assert block_size == 128 # Xattention.py:358
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 关键参数
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 范围 | 作用 |
|
||||||
|
|------|--------|------|------|
|
||||||
|
| `stride` | 8 | 4-16 | Q/K 交错采样步长,越大估计越快但越粗糙 |
|
||||||
|
| `threshold` | 0.9 | 0.7-0.99 | 累积注意力阈值,越高保留块越多 |
|
||||||
|
| `block_size` | 128 | 128 (固定) | BSA 块大小,不可调 |
|
||||||
|
| `chunk_size` | 16384 | 2048-131072 | 估计时的分块大小,影响内存使用 |
|
||||||
|
| `norm` | 1.0 | 0.5-2.0 | 注意力分数归一化系数 |
|
||||||
|
| `keep_sink` | False | bool | 是否始终保留第一个块 |
|
||||||
|
| `keep_recent` | False | bool | 是否始终保留对角块 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 计算复杂度
|
||||||
|
|
||||||
|
### 估计阶段
|
||||||
|
|
||||||
|
| 操作 | 复杂度 |
|
||||||
|
|------|--------|
|
||||||
|
| Stride reshape GEMM | O(seq_len/stride × seq_len/stride × head_dim × stride) = O(seq_len² × head_dim / stride) |
|
||||||
|
| Softmax + block sum | O(seq_len² / stride²) |
|
||||||
|
| Block selection | O(num_blocks² × log(num_blocks)) |
|
||||||
|
|
||||||
|
**估计阶段总复杂度**: O(seq_len² × head_dim / stride)
|
||||||
|
|
||||||
|
### 计算阶段 (BSA)
|
||||||
|
|
||||||
|
设选中块比例为 ρ (通常 0.3-0.5):
|
||||||
|
|
||||||
|
| 操作 | 复杂度 |
|
||||||
|
|------|--------|
|
||||||
|
| Block sparse attention | O(ρ × num_blocks² × block_size² × head_dim) = O(ρ × seq_len² × head_dim) |
|
||||||
|
|
||||||
|
**总复杂度**: O(seq_len² × head_dim × (1/stride + ρ))
|
||||||
|
|
||||||
|
当 stride=8, ρ=0.4 时,相比 full attention 节省约 **50%** 计算量。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 与 nano-vllm 集成注意事项
|
||||||
|
|
||||||
|
### 依赖要求
|
||||||
|
|
||||||
|
```
|
||||||
|
block_sparse_attn # pip install block-sparse-attn
|
||||||
|
triton >= 2.0 # Triton kernels
|
||||||
|
```
|
||||||
|
|
||||||
|
### CPU Offload 场景适配
|
||||||
|
|
||||||
|
XAttention 原始实现假设所有 KV 在 GPU 上。对于 CPU offload 场景,需要:
|
||||||
|
|
||||||
|
1. **估计阶段**: 仍需加载所有历史 KV 到 GPU 进行估计
|
||||||
|
2. **计算阶段**: 只加载选中的块
|
||||||
|
|
||||||
|
这可能需要修改为两阶段 pipeline:
|
||||||
|
- 先用采样数据估计重要块
|
||||||
|
- 再只加载重要块进行计算
|
||||||
|
|
||||||
|
### block_size 对齐
|
||||||
|
|
||||||
|
nano-vllm 的 `kvcache_block_size` 需要与 BSA 的 128 对齐:
|
||||||
|
- 如果 `kvcache_block_size = 1024`,则每个 kv block 包含 8 个 BSA blocks
|
||||||
|
- 块选择粒度需要相应调整
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 源文件索引
|
||||||
|
|
||||||
|
| 文件 | 位置 | 内容 |
|
||||||
|
|------|------|------|
|
||||||
|
| `Xattention.py` | `compass/src/Xattention.py` | 主入口: `xattn_estimate()`, `Xattention_prefill()` |
|
||||||
|
| `kernels.py` | `compass/src/kernels.py` | Triton 内核 |
|
||||||
|
| `utils.py` | `compass/src/utils.py` | `find_blocks_chunked()`, `create_causal_mask()` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 参考
|
||||||
|
|
||||||
|
- COMPASS 项目: `/home/zijie/Code/COMPASS/`
|
||||||
|
- BSA 库: MIT-HAN-LAB block_sparse_attn
|
||||||
|
- 测试报告: `docs/xattention_bsa_test_report.md`
|
||||||
@@ -255,7 +255,6 @@ class OffloadEngine:
|
|||||||
Clears:
|
Clears:
|
||||||
- GPU ring buffer slots (k_cache_gpu, v_cache_gpu)
|
- GPU ring buffer slots (k_cache_gpu, v_cache_gpu)
|
||||||
- Per-layer decode buffers (decode_k_buffer, decode_v_buffer)
|
- Per-layer decode buffers (decode_k_buffer, decode_v_buffer)
|
||||||
- Cross-layer pipeline buffers (layer_k/v_buffer_a/b)
|
|
||||||
- Per-layer prefill buffers (prefill_k/v_buffer)
|
- Per-layer prefill buffers (prefill_k/v_buffer)
|
||||||
- All pending async transfer events
|
- All pending async transfer events
|
||||||
"""
|
"""
|
||||||
@@ -267,12 +266,6 @@ class OffloadEngine:
|
|||||||
self.decode_k_buffer.zero_()
|
self.decode_k_buffer.zero_()
|
||||||
self.decode_v_buffer.zero_()
|
self.decode_v_buffer.zero_()
|
||||||
|
|
||||||
# Clear cross-layer pipeline buffers
|
|
||||||
self.layer_k_buffer_a.zero_()
|
|
||||||
self.layer_v_buffer_a.zero_()
|
|
||||||
self.layer_k_buffer_b.zero_()
|
|
||||||
self.layer_v_buffer_b.zero_()
|
|
||||||
|
|
||||||
# Clear per-layer prefill buffers
|
# Clear per-layer prefill buffers
|
||||||
self.prefill_k_buffer.zero_()
|
self.prefill_k_buffer.zero_()
|
||||||
self.prefill_v_buffer.zero_()
|
self.prefill_v_buffer.zero_()
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
||||||
@@ -222,7 +222,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
Attention output [batch_size, 1, num_heads, head_dim]
|
Attention output [batch_size, 1, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||||
@@ -319,7 +319,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
Loads one block at a time, computes attention, and merges results.
|
Loads one block at a time, computes attention, and merges results.
|
||||||
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
num_blocks = len(cpu_block_table)
|
num_blocks = len(cpu_block_table)
|
||||||
if num_blocks == 0:
|
if num_blocks == 0:
|
||||||
|
|||||||
19
nanovllm/ops/__init__.py
Normal file
19
nanovllm/ops/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Operators module for nano-vLLM.
|
||||||
|
|
||||||
|
This module contains low-level attention operators and kernels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nanovllm.ops.chunked_attention import (
|
||||||
|
flash_attn_with_lse,
|
||||||
|
merge_attention_outputs,
|
||||||
|
chunked_attention_varlen,
|
||||||
|
ChunkedPrefillState,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"flash_attn_with_lse",
|
||||||
|
"merge_attention_outputs",
|
||||||
|
"chunked_attention_varlen",
|
||||||
|
"ChunkedPrefillState",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user