Compare commits
6 Commits
2e96d1d97d
...
6e34efd58a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e34efd58a | ||
|
|
5acd5558d6 | ||
|
|
193ef55d18 | ||
|
|
f173a3f7f5 | ||
|
|
8035e4db3d | ||
|
|
8ab53e7331 |
@@ -16,6 +16,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
||||
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
|
||||
| [`docs/xattn_kv_chunking_kernels.md`](docs/xattn_kv_chunking_kernels.md) | XAttention KV Chunking: 三阶段 softmax、存储开销分析 (O(S) vs O(S²))、峰值显存优化 (8x)、Q/KV 独立分块 |
|
||||
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
|
||||
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
|
||||
| [`docs/xattn_density_benchmark.md`](docs/xattn_density_benchmark.md) | 📊 XAttention Density Benchmark: 4K-32K context、stride 参数、per-layer density 分析 |
|
||||
@@ -38,6 +39,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL),≤10B 推荐模型 |
|
||||
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
|
||||
| [`docs/xattn_density_alignment_analysis.md`](docs/xattn_density_alignment_analysis.md) | 📊 ANALYSIS: GPU-only vs Offload 模式 density 对齐分析,chunked softmax 边界效应,5-7% 差异根因 |
|
||||
| [`docs/xattn_kv_chunking_density_test.md`](docs/xattn_kv_chunking_density_test.md) | 🧪 TEST: XAttention KV chunking density 验证,threshold=1.0 对齐,threshold<1.0 差异 10-13% |
|
||||
|
||||
## Rules Index
|
||||
|
||||
|
||||
122
docs/xattn_kv_chunking_density_test.md
Normal file
122
docs/xattn_kv_chunking_density_test.md
Normal file
@@ -0,0 +1,122 @@
|
||||
# XAttention KV Chunking Density 验证测试
|
||||
|
||||
## 背景
|
||||
|
||||
验证 XAttention Triton kernel 是否只能沿 Q 轴分 chunk,不能沿 KV 轴分 chunk。
|
||||
|
||||
**假设**:`softmax_fuse_block_sum` 需要完整的 K 来计算正确的归一化分母,分 chunk 后的 attention 分布与完整序列不同。
|
||||
|
||||
## 测试方法
|
||||
|
||||
1. **GPU-only 模式**:一次性对完整序列调用 `xattn_estimate`,记录 Layer 0 的 density
|
||||
2. **Offload DEBUG 模式**:分 chunk 调用 `xattn_estimate`,累积 selected/total counts,计算最终 density
|
||||
3. 使用相同的 `_debug_k_full` buffer 收集完整 K cache,确保输入数据一致
|
||||
|
||||
### 关键代码逻辑
|
||||
|
||||
```python
|
||||
# Offload DEBUG: 每个 chunk 累积 selected/total
|
||||
for each chunk:
|
||||
K_full = _debug_k_full[:, :, :total_k_len, :] # 累积的 K
|
||||
_, mask_chunk = xattn_estimate(Q_chunk, K_full, threshold=threshold, causal=True)
|
||||
|
||||
# 裁剪到有效区域,计算正确的 causal mask (考虑 Q 偏移量)
|
||||
q_offset_blocks = k_blocks - q_blocks
|
||||
causal_mask = indices <= (q_indices + q_offset_blocks)
|
||||
|
||||
selected += (mask_valid & causal_mask).sum()
|
||||
total += causal_mask.sum()
|
||||
|
||||
density = selected / total
|
||||
```
|
||||
|
||||
## 测试结果
|
||||
|
||||
### 64K 序列 (niah_single_1, 序列长度 64891)
|
||||
|
||||
| threshold | GPU-only selected | Offload selected | GPU-only density | Offload density | 差异 (selected) |
|
||||
|-----------|------------------|------------------|------------------|-----------------|-----------------|
|
||||
| **0.90** | 1,524,617 | 1,330,506 | **0.3700** | **0.3229** | 194,111 (12.7%) |
|
||||
| **0.95** | 1,955,015 | 1,747,585 | **0.4744** | **0.4241** | 207,430 (10.6%) |
|
||||
| **1.00** | 4,118,719 | 4,118,896 | **0.9995** | **0.9995** | -177 (~0%) |
|
||||
|
||||
- **total**: 4,120,896 (两种模式一致)
|
||||
|
||||
### 32K 序列 (niah_single_1, 序列长度 32485)
|
||||
|
||||
| threshold | GPU-only selected | Offload selected | GPU-only density | Offload density | 差异 (selected) |
|
||||
|-----------|------------------|------------------|------------------|-----------------|-----------------|
|
||||
| **0.90** | 520,314 | 466,937 | **0.5021** | **0.4506** | 53,377 (10.3%) |
|
||||
| **0.95** | 647,765 | 602,953 | **0.6251** | **0.5818** | 44,812 (6.9%) |
|
||||
| **1.00** | 1,036,295 | 1,036,264 | **0.9999** | **0.9999** | 31 (~0%) |
|
||||
|
||||
- **total**: 1,036,320 (两种模式一致)
|
||||
|
||||
### 汇总对比
|
||||
|
||||
| 序列长度 | threshold | GPU-only density | Offload density | density 差异 |
|
||||
|---------|-----------|------------------|-----------------|--------------|
|
||||
| 32K | 0.90 | 0.5021 | 0.4506 | 5.2% |
|
||||
| 64K | 0.90 | 0.3700 | 0.3229 | 4.7% |
|
||||
| 32K | 0.95 | 0.6251 | 0.5818 | 4.3% |
|
||||
| 64K | 0.95 | 0.4744 | 0.4241 | 5.0% |
|
||||
| 32K | 1.00 | 0.9999 | 0.9999 | ~0% |
|
||||
| 64K | 1.00 | 0.9995 | 0.9995 | ~0% |
|
||||
|
||||
## 结论
|
||||
|
||||
### 1. Softmax 归一化本身是正确的
|
||||
|
||||
当 `threshold=1.0`(选择所有 blocks)时,GPU-only 和 Offload 模式的 density 几乎完全对齐(差异 < 0.01%)。
|
||||
|
||||
这说明:
|
||||
- `_debug_k_full` 正确收集了完整的 K cache
|
||||
- 分 chunk 调用 `xattn_estimate` 时,softmax 归一化在正确的 K 序列上计算
|
||||
- causal mask 的 Q 偏移量处理正确
|
||||
|
||||
### 2. 问题在于 threshold 的应用方式
|
||||
|
||||
当 `threshold < 1.0` 时,差异显著(10-13%):
|
||||
|
||||
- **GPU-only**:对完整序列一次性应用 threshold,选择 cumulative attention >= threshold 的 blocks
|
||||
- **Offload**:每个 chunk 独立应用 threshold,累积 selected counts
|
||||
|
||||
每个 chunk 独立应用 threshold 会导致:
|
||||
- 某些在 GPU-only 中被选中的 blocks,在分 chunk 时因 attention 分布不同而未被选中
|
||||
- 累积的 selected 比一次性计算的要少
|
||||
|
||||
### 3. XAttention Triton kernel 的 KV chunking 限制
|
||||
|
||||
**验证结论**:XAttention 的 `xattn_estimate` 可以正确处理 KV chunking(softmax 归一化正确),但 **threshold-based block selection 不能简单累积**。
|
||||
|
||||
如果要在 Offload 模式下获得与 GPU-only 一致的 block selection:
|
||||
1. 需要先累积所有 chunks 的 attention scores
|
||||
2. 最后一次性应用 threshold 选择 blocks
|
||||
|
||||
或者接受 10-13% 的 density 差异,这对实际推理准确性的影响需要进一步评估。
|
||||
|
||||
## 测试命令
|
||||
|
||||
```bash
|
||||
# GPU-only 模式
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||
--sparse-policy xattn_bsa --sparse-threshold 0.9
|
||||
|
||||
# Offload 模式 (64K)
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||
--sparse-policy xattn_bsa --sparse-threshold 0.9 --enable-offload
|
||||
|
||||
# Offload 模式 (32K)
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||
--sparse-policy xattn_bsa --sparse-threshold 0.9 --enable-offload \
|
||||
--data-dir /home/zijie/Code/nano-vllm/tests/data/ruler_32k --max-model-len 34000
|
||||
```
|
||||
|
||||
## 相关文件
|
||||
|
||||
- `nanovllm/kvcache/sparse/xattn_bsa.py`: DEBUG 代码位置
|
||||
- `nanovllm/ops/xattn.py`: `xattn_estimate` 实现
|
||||
- `nanovllm/utils/density_observer.py`: DensityObserver 实现
|
||||
400
docs/xattn_kv_chunking_kernels.md
Normal file
400
docs/xattn_kv_chunking_kernels.md
Normal file
@@ -0,0 +1,400 @@
|
||||
# XAttention KV Chunking Kernels
|
||||
|
||||
## 概述
|
||||
|
||||
本文档描述了支持 KV 维度分 chunk 的 softmax kernels 实现。这些 kernels 允许在 CPU offload 场景下,沿 KV 维度分块计算 sparse attention estimation,而不需要在 GPU 上保存完整的 raw attention scores。
|
||||
|
||||
## 背景
|
||||
|
||||
原始的 `softmax_fuse_block_sum` kernel 需要完整的 K 序列来计算正确的 softmax 归一化分母:
|
||||
|
||||
```
|
||||
softmax(x_i) = exp(x_i) / Σ_j exp(x_j)
|
||||
```
|
||||
|
||||
如果只有部分 K (KV chunk),分母 `Σ_j exp(x_j)` 不完整,导致归一化错误。
|
||||
|
||||
## 解决方案:三阶段计算
|
||||
|
||||
通过将 softmax 计算拆分为三个阶段,实现正确的 KV chunking:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ 三阶段 Pipeline │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ KV Chunk 0 │ │ KV Chunk 1 │ │ KV Chunk N │ │
|
||||
│ │ attn_scores │ │ attn_scores │ │ attn_scores │ │
|
||||
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
|
||||
│ │ │ │ │
|
||||
│ ▼ ▼ ▼ │
|
||||
│ ┌─────────────────────────────────────────────────┐ │
|
||||
│ │ 阶段 1: softmax_compute_partial_stats │ │
|
||||
│ │ 计算每个 chunk 的 (m_partial, l_partial) │ │
|
||||
│ └─────────────────────────────────────────────────┘ │
|
||||
│ │ │ │ │
|
||||
│ ▼ ▼ ▼ │
|
||||
│ (m_0, l_0) (m_1, l_1) (m_N, l_N) │
|
||||
│ │ │ │ │
|
||||
│ └────────────────┬┴─────────────────┘ │
|
||||
│ ▼ │
|
||||
│ ┌─────────────────────────────────────────────────┐ │
|
||||
│ │ 阶段 2: merge_softmax_stats │ │
|
||||
│ │ Host 端合并 → (m_global, l_global) │ │
|
||||
│ └─────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────────┼────────────────┐ │
|
||||
│ ▼ ▼ ▼ │
|
||||
│ ┌─────────────────────────────────────────────────┐ │
|
||||
│ │ 阶段 3: softmax_normalize_and_block_sum │ │
|
||||
│ │ 使用全局 stats 归一化并计算 block sums │ │
|
||||
│ └─────────────────────────────────────────────────┘ │
|
||||
│ │ │ │ │
|
||||
│ ▼ ▼ ▼ │
|
||||
│ block_sums_0 block_sums_1 block_sums_N │
|
||||
│ │ │ │ │
|
||||
│ └────────────────┴────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ torch.cat → final mask │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 阶段 1: `softmax_compute_partial_stats`
|
||||
|
||||
计算每个 KV chunk 的 partial statistics:
|
||||
- `m_partial`: 该 chunk 内的最大值 (per query row)
|
||||
- `l_partial`: 该 chunk 内的 partial sum = Σ exp(x - m_partial)
|
||||
|
||||
```python
|
||||
m_partial, l_partial = softmax_compute_partial_stats(
|
||||
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset, # KV chunk 在完整序列中的偏移
|
||||
is_causal=True,
|
||||
)
|
||||
# 输出: m_partial, l_partial 形状为 [batch, heads, q_len]
|
||||
```
|
||||
|
||||
### 阶段 2: `merge_softmax_stats`
|
||||
|
||||
Host 端合并所有 KV chunks 的 statistics:
|
||||
|
||||
```python
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
```
|
||||
|
||||
合并公式 (Online Softmax):
|
||||
```
|
||||
m_new = max(m_global, m_chunk)
|
||||
l_new = l_global * exp(m_global - m_new) + l_chunk * exp(m_chunk - m_new)
|
||||
```
|
||||
|
||||
### 阶段 3: `softmax_normalize_and_block_sum`
|
||||
|
||||
使用全局 statistics 归一化并计算 block sums:
|
||||
|
||||
```python
|
||||
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
|
||||
m_global, # [batch, heads, q_len]
|
||||
l_global, # [batch, heads, q_len]
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=real_q_len,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset,
|
||||
is_causal=True,
|
||||
)
|
||||
# 输出: [batch, heads, q_blocks, k_chunk_blocks]
|
||||
```
|
||||
|
||||
## 数学等价性证明
|
||||
|
||||
原始 softmax 计算 (完整 K):
|
||||
```
|
||||
softmax(x_i) = exp(x_i - m) / Σ_j exp(x_j - m)
|
||||
```
|
||||
|
||||
分 KV chunk 计算:
|
||||
```
|
||||
Chunk 0: m_0 = max(x[0:N/2]), l_0 = Σ exp(x[0:N/2] - m_0)
|
||||
Chunk 1: m_1 = max(x[N/2:N]), l_1 = Σ exp(x[N/2:N] - m_1)
|
||||
|
||||
合并:
|
||||
m_global = max(m_0, m_1)
|
||||
l_global = l_0 * exp(m_0 - m_global) + l_1 * exp(m_1 - m_global)
|
||||
= Σ exp(x[0:N] - m_global) # 等于全局 sum
|
||||
|
||||
归一化:
|
||||
softmax(x_i) = exp(x_i - m_global) / l_global # 正确!
|
||||
```
|
||||
|
||||
## Causal Mask 处理
|
||||
|
||||
两个 kernel 都正确处理了 causal attention:
|
||||
|
||||
1. **`softmax_partial_stats_kernel`**: 通过 `kv_offset` 参数确定当前 KV chunk 在完整序列中的位置,正确计算 causal boundary
|
||||
|
||||
2. **`softmax_normalize_block_sum_kernel`**: 同样使用 `kv_offset`,对 causal boundary 之后的位置输出 0
|
||||
|
||||
## 存储开销分析
|
||||
|
||||
### 符号定义
|
||||
|
||||
| 符号 | 含义 | 典型值 |
|
||||
|------|------|--------|
|
||||
| S | seq_len | 64K |
|
||||
| B | batch_size | 1 |
|
||||
| H | num_heads | 32 |
|
||||
| D | head_dim | 128 |
|
||||
| T | stride | 4-8 |
|
||||
| C | chunk_size | 16K |
|
||||
| n | num_kv_chunks = ceil(S/C) | 4 |
|
||||
|
||||
### 原始方式 (无 KV chunking)
|
||||
|
||||
**attn_weights 峰值内存**:
|
||||
```
|
||||
[B, H, S/T, S/T] × 4 bytes = B × H × (S/T)² × 4
|
||||
|
||||
例: S=64K, T=4, B=1, H=32
|
||||
= 1 × 32 × 16384² × 4 = 32 GB
|
||||
```
|
||||
|
||||
### KV Chunking 方式的额外存储
|
||||
|
||||
#### 1. Partial Stats (每个 KV chunk)
|
||||
|
||||
```
|
||||
m_partial: [B, H, C/T] × 4 bytes
|
||||
l_partial: [B, H, C/T] × 4 bytes
|
||||
|
||||
单个 chunk = 2 × B × H × (C/T) × 4
|
||||
= 2 × 1 × 32 × 4096 × 4 = 1 MB
|
||||
```
|
||||
|
||||
#### 2. Global Stats
|
||||
|
||||
```
|
||||
m_global: [B, H, S/T] × 4 bytes
|
||||
l_global: [B, H, S/T] × 4 bytes
|
||||
|
||||
= 2 × B × H × (S/T) × 4
|
||||
= 2 × 1 × 32 × 16384 × 4 = 4 MB
|
||||
```
|
||||
|
||||
#### 3. 总额外开销
|
||||
|
||||
```
|
||||
total_extra = n × partial_stats + global_stats
|
||||
= 4 × 1MB + 4MB = 8 MB
|
||||
```
|
||||
|
||||
### 存储开销随 seqlen 变化
|
||||
|
||||
| seqlen | num_chunks | 原始 attn_weights | 额外 stats | 比例 |
|
||||
|--------|------------|-------------------|------------|------|
|
||||
| 16K | 1 | 2 GB | 2 MB | 0.1% |
|
||||
| 32K | 2 | 8 GB | 4 MB | 0.05% |
|
||||
| 64K | 4 | 32 GB | 8 MB | 0.025% |
|
||||
| 128K | 8 | 128 GB | 16 MB | 0.012% |
|
||||
|
||||
### 复杂度分析
|
||||
|
||||
| 存储组件 | 复杂度 | 说明 |
|
||||
|----------|--------|------|
|
||||
| 原始 attn_weights | O(S²) | 二次增长 |
|
||||
| Partial/Global stats | O(S) | 线性增长 |
|
||||
| **相对开销** | O(1/S) | **随 seqlen 递减** |
|
||||
|
||||
### 峰值显存优化
|
||||
|
||||
KV chunking 的主要收益是**峰值显存**从 O(S²) 降到 O(S×C):
|
||||
|
||||
```
|
||||
原始: O(B × H × (S/T)²) # 完整 attn_weights
|
||||
KV chunking: O(B × H × (S/T) × (C/T)) # 一次只处理一个 chunk
|
||||
```
|
||||
|
||||
以 S=128K, C=16K 为例:
|
||||
- 原始峰值: ~128 GB
|
||||
- KV chunking 峰值: ~16 GB (降低 **8 倍**)
|
||||
|
||||
## 支持不同 Q/KV Chunk Size
|
||||
|
||||
三阶段 pipeline 支持 Q 和 KV 使用不同的 chunk size:
|
||||
|
||||
```python
|
||||
q_chunk_size = 8192 # Q 分块大小
|
||||
kv_chunk_size = 16384 # KV 分块大小
|
||||
|
||||
for q_chunk_idx in range(q_chunk_num):
|
||||
Q_chunk = Q[:, :, q_start:q_end, :] # [B, H, q_chunk_size, D]
|
||||
|
||||
for kv_chunk_idx in range(kv_chunk_num):
|
||||
K_chunk = K[:, :, kv_start:kv_end, :] # [B, H, kv_chunk_size, D]
|
||||
# ... 三阶段处理
|
||||
```
|
||||
|
||||
### 测试验证结果
|
||||
|
||||
| Config | seq_len | Q chunks | KV chunks | density | 对齐 |
|
||||
|--------|---------|----------|-----------|---------|------|
|
||||
| Q=16K, KV=16K | 64891 | 4 | 4 | 0.1117 | ✓ 100% |
|
||||
| Q=8K, KV=16K | 64891 | 8 | 4 | 0.1112 | ✓ 100% |
|
||||
| Q=16K, KV=8K | 64891 | 4 | 8 | 0.1117 | ✓ 100% |
|
||||
| Q=8K, KV=8K | 64891 | 8 | 8 | 0.1112 | ✓ 100% |
|
||||
| Q=4K, KV=16K | 64891 | 16 | 4 | 0.1109 | ✓ 100% |
|
||||
|
||||
## API 参考
|
||||
|
||||
### `softmax_compute_partial_stats`
|
||||
|
||||
```python
|
||||
def softmax_compute_partial_stats(
|
||||
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
|
||||
reshaped_block_size: int,
|
||||
segment_size: int,
|
||||
scale: float,
|
||||
chunk_start: int = 0, # Q chunk 起始位置 (reshaped space)
|
||||
kv_offset: int = 0, # KV chunk 偏移 (reshaped space)
|
||||
is_causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""返回 (m, l) partial stats"""
|
||||
```
|
||||
|
||||
### `merge_softmax_stats`
|
||||
|
||||
```python
|
||||
def merge_softmax_stats(
|
||||
m_chunks: list, # List of [batch, heads, q_len] tensors
|
||||
l_chunks: list, # List of [batch, heads, q_len] tensors
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""返回 (m_global, l_global)"""
|
||||
```
|
||||
|
||||
### `softmax_normalize_and_block_sum`
|
||||
|
||||
```python
|
||||
def softmax_normalize_and_block_sum(
|
||||
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
|
||||
m_global: torch.Tensor, # [batch, heads, q_len]
|
||||
l_global: torch.Tensor, # [batch, heads, q_len]
|
||||
reshaped_block_size: int,
|
||||
segment_size: int,
|
||||
chunk_start: int,
|
||||
real_q_len: int,
|
||||
scale: float,
|
||||
kv_offset: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""返回 block sums [batch, heads, q_blocks, k_chunk_blocks]"""
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from nanovllm.ops.xattn import (
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_compute_partial_stats,
|
||||
softmax_normalize_and_block_sum,
|
||||
merge_softmax_stats,
|
||||
find_blocks_chunked,
|
||||
)
|
||||
|
||||
# 对每个 Q chunk
|
||||
for q_chunk_idx in range(q_chunk_num):
|
||||
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||
|
||||
# 阶段 1: 每个 KV chunk 计算 partial stats
|
||||
m_chunks, l_chunks = [], []
|
||||
attn_weights_chunks = []
|
||||
|
||||
for kv_chunk_idx in range(kv_chunk_num):
|
||||
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
|
||||
|
||||
# 计算 raw scores
|
||||
attn_weights = flat_group_gemm_fuse_reshape(
|
||||
Q_chunk, K_chunk, STRIDE,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False, # K 不完整
|
||||
)
|
||||
attn_weights_chunks.append(attn_weights)
|
||||
|
||||
# 计算 partial stats
|
||||
m, l = softmax_compute_partial_stats(
|
||||
attn_weights, block_size, segment_size, scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m)
|
||||
l_chunks.append(l)
|
||||
|
||||
# 阶段 2: 合并 stats
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
|
||||
# 阶段 3: 归一化并计算 block sums
|
||||
block_sums_list = []
|
||||
for kv_chunk_idx, attn_weights in enumerate(attn_weights_chunks):
|
||||
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
|
||||
block_sums = softmax_normalize_and_block_sum(
|
||||
attn_weights, m_global, l_global,
|
||||
block_size, segment_size, chunk_start, real_q_len, scale,
|
||||
kv_offset=kv_offset,
|
||||
is_causal=True,
|
||||
)
|
||||
block_sums_list.append(block_sums)
|
||||
|
||||
# 拼接并选择 blocks
|
||||
attn_sum = torch.cat(block_sums_list, dim=-1)
|
||||
mask = find_blocks_chunked(attn_sum, ...)
|
||||
```
|
||||
|
||||
## 性能对比
|
||||
|
||||
| 方面 | 原始实现 | KV Chunking 实现 |
|
||||
|------|---------|-----------------|
|
||||
| Kernel 数量 | 1 | 2 (stats + normalize) |
|
||||
| Raw scores 读取次数 | 2 | 2 |
|
||||
| 额外内存 | 0 | O(B × H × S/T × 2) for (m, l) |
|
||||
| Host 计算 | 无 | merge stats (轻量) |
|
||||
| **峰值显存** | O(S²) | **O(S × C)** |
|
||||
|
||||
## 验证测试
|
||||
|
||||
### 批量测试结果
|
||||
|
||||
测试脚本 `tests/test_xattn_kv_chunking_batch.py` 验证了不同 seqlen 下的一致性:
|
||||
|
||||
```
|
||||
| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |
|
||||
|---------|--------|-----------|-----------|-------------|------------|----------|-----------|--------|
|
||||
| 3688 | 4 | 0.90 | 1 | 0.383405 | 0.383405 | 0.000000 | 0.0000% | PASS |
|
||||
| 7888 | 4 | 0.90 | 1 | 0.290611 | 0.290611 | 0.000000 | 0.0000% | PASS |
|
||||
| 15685 | 4 | 0.90 | 1 | 0.197724 | 0.197724 | 0.000000 | 0.0000% | PASS |
|
||||
| 32485 | 4 | 0.90 | 2 | 0.159023 | 0.159023 | 0.000000 | 0.0000% | PASS |
|
||||
| 64891 | 4 | 0.90 | 4 | 0.111656 | 0.111656 | 0.000000 | 0.0000% | PASS |
|
||||
```
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **数学等价性**: density_diff = 0.000000 对于所有测试
|
||||
2. **Mask 完全对齐**: mask_diff = 0.0000% 对于所有测试
|
||||
3. **支持任意 Q/KV chunk size 组合**
|
||||
|
||||
## 相关文件
|
||||
|
||||
- `nanovllm/ops/xattn.py`: Kernel 实现
|
||||
- `tests/test_xattn_estimate_alignment.py`: 单文件验证测试
|
||||
- `tests/test_xattn_kv_chunking_batch.py`: 批量验证测试
|
||||
- `docs/xattn_kernels_guide.md`: 原始 kernel 文档
|
||||
@@ -147,8 +147,10 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
self._selected_cpu_indices: List[int] = []
|
||||
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
|
||||
|
||||
#> Debug: store all K cache
|
||||
#> Debug: store all K cache and density counts
|
||||
self._debug_k_full: torch.Tensor | None = None
|
||||
self._debug_selected: int = 0 # 累积的 selected blocks
|
||||
self._debug_total: int = 0 # 累积的 total blocks
|
||||
|
||||
def alloc_policy_metadata(
|
||||
self,
|
||||
@@ -202,8 +204,10 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
|
||||
|
||||
#DEBUG : buffer for save all K cache.
|
||||
#DEBUG : buffer for save all K cache
|
||||
self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device)
|
||||
self._debug_selected = 0
|
||||
self._debug_total = 0
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
@@ -395,6 +399,15 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
)
|
||||
|
||||
# Record density for all layers via DensityObserver
|
||||
if layer_id == 0:
|
||||
# DEBUG: 打印 GPU-only Layer 0 的 mask 详情
|
||||
q_bk = mask_trimmed.shape[2]
|
||||
k_bk = mask_trimmed.shape[3]
|
||||
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
|
||||
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
|
||||
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, "
|
||||
f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}")
|
||||
DensityObserver.record(layer_id, mask_trimmed, causal=True)
|
||||
|
||||
return output
|
||||
@@ -567,9 +580,67 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim]
|
||||
|
||||
self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# DEBUG: 累积 selected/total counts (仅 layer 0)
|
||||
# 使用完整 K 调用 xattn_estimate,与 GPU-only 逻辑一致
|
||||
# ============================================================
|
||||
if layer_id == 0:
|
||||
__import__('pdb').set_trace()
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
total_k_len = historical_k_len + q_len
|
||||
K_full = self._debug_k_full[:, :, :total_k_len, :]
|
||||
|
||||
# 用当前 Q chunk 和累积的 K 调用 xattn_estimate
|
||||
# 设置 chunk_size 为 q_len 的最小对齐值 (stride * BLOCK_M = 8 * 128 = 1024)
|
||||
alignment = self.stride * 128
|
||||
aligned_chunk_size = ((q_len + alignment - 1) // alignment) * alignment
|
||||
# DEBUG: 使用固定 threshold 测试
|
||||
_, mask_chunk = xattn_estimate(
|
||||
Q[:, :, :q_len, :], # 当前 Q chunk
|
||||
K_full, # 累积的 K
|
||||
block_size=self.BSA_BLOCK_SIZE,
|
||||
stride=self.stride,
|
||||
threshold=self.threshold, # DEBUG: 使用传入的 threshold
|
||||
chunk_size=aligned_chunk_size, # 对齐的 chunk_size
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# 计算有效的 block 数量(排除 padding)
|
||||
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
valid_k_blocks = (total_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
|
||||
# 裁剪 mask 到有效区域
|
||||
mask_valid = mask_chunk[:, :, :valid_q_blocks, :valid_k_blocks]
|
||||
|
||||
# 计算当前 chunk 的 selected/total (考虑 causal,考虑 Q 偏移量)
|
||||
q_blocks = valid_q_blocks
|
||||
k_blocks = valid_k_blocks
|
||||
# Q 从位置 (k_blocks - q_blocks) 开始,所以 Q block i 实际位置是 i + offset
|
||||
# Q block i (实际位置 i+offset) 可以看到 K block 0 到 i+offset
|
||||
q_offset_blocks = k_blocks - q_blocks
|
||||
indices = torch.arange(k_blocks, device=mask_valid.device).unsqueeze(0) # [1, k_blocks]
|
||||
q_indices = torch.arange(q_blocks, device=mask_valid.device).unsqueeze(1) # [q_blocks, 1]
|
||||
causal_mask = indices <= (q_indices + q_offset_blocks) # [q_blocks, k_blocks]
|
||||
chunk_total = causal_mask.sum().item() * mask_valid.shape[0] * mask_valid.shape[1]
|
||||
chunk_selected = (mask_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
|
||||
# 累积
|
||||
self._debug_selected += chunk_selected
|
||||
self._debug_total += chunk_total
|
||||
|
||||
# 打印当前累积的 density
|
||||
if self._debug_total > 0:
|
||||
density = self._debug_selected / self._debug_total
|
||||
logger.info(f"[DEBUG Offload Layer0] 累积 density: {density:.4f} "
|
||||
f"(selected={self._debug_selected}, total={self._debug_total}, k_len={total_k_len}, "
|
||||
f"mask_shape={mask_chunk.shape}, q_offset={q_offset_blocks})")
|
||||
|
||||
# DEBUG: 跳过正常 offload 逻辑,直接返回所有 blocks
|
||||
return available_blocks
|
||||
else:
|
||||
# DEBUG: 非 Layer 0 也跳过正常 offload 逻辑
|
||||
return available_blocks
|
||||
|
||||
# ============================================================
|
||||
# Step 3: Get current chunk K and compute its attn_scores
|
||||
@@ -656,14 +727,16 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K
|
||||
|
||||
with nvtx.range("xattn_find_blocks"):
|
||||
# 对于历史 K 的选择,使用 causal=False 因为历史 K 都在当前 Q 之前
|
||||
# current_index=0 避免超出 block_sums 的 K 维度
|
||||
mask = find_blocks_chunked(
|
||||
block_sums,
|
||||
current_index=q_start_bsa_block, # Q's position in BSA blocks
|
||||
current_index=0,
|
||||
threshold=self.threshold,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True, # Causal for block-level mask
|
||||
mode="both",
|
||||
causal=False,
|
||||
)
|
||||
# mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
|
||||
|
||||
@@ -676,47 +749,13 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
|
||||
# 7a: Record historical blocks density
|
||||
# IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation!
|
||||
# Q block i (global position = q_start_bsa_block + i) can see historical K block j
|
||||
# only if j <= q_start_bsa_block + i (causal constraint)
|
||||
mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks]
|
||||
# 7a: Record historical blocks density (暂时禁用,使用 DEBUG 输出代替)
|
||||
# if historical_k_bsa_blocks > 0:
|
||||
# ... DensityObserver.record_counts ...
|
||||
|
||||
if historical_k_bsa_blocks > 0:
|
||||
# Create causal mask for historical blocks
|
||||
# Q_global[i] = q_start_bsa_block + i, K[j] = j
|
||||
# Causal: j <= Q_global[i] => j <= q_start_bsa_block + i
|
||||
q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block
|
||||
k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device)
|
||||
# Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i]
|
||||
causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks]
|
||||
|
||||
# Count positions within causal mask only
|
||||
total_historical_causal = causal_mask_historical.sum().item() * B * H
|
||||
selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
|
||||
if total_historical_causal > 0:
|
||||
DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal)
|
||||
|
||||
# 7b: Record current chunk density (causal, to align with GPU-only mode)
|
||||
# Current chunk is the portion after historical blocks
|
||||
if valid_curr_k_bsa > 0:
|
||||
# Extract current chunk mask (only valid portion, not padded)
|
||||
mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa]
|
||||
|
||||
q_dim = mask_current.shape[2]
|
||||
k_dim = mask_current.shape[3]
|
||||
|
||||
# Create causal mask (lower triangular)
|
||||
# For current chunk: Q[i] can see K[j] where j <= i (standard causal)
|
||||
causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool))
|
||||
|
||||
# Count positions within causal mask only
|
||||
total_current_causal = causal_mask.sum().item() * B * H
|
||||
selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
|
||||
if total_current_causal > 0:
|
||||
DensityObserver.record_counts(layer_id, selected_current, total_current_causal)
|
||||
# 7b: Record current chunk density (暂时禁用)
|
||||
# if valid_curr_k_bsa > 0:
|
||||
# ... DensityObserver.record_counts ...
|
||||
|
||||
# Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill
|
||||
# Use full Q_bsa (padded) for buffer, not valid_q_bsa
|
||||
|
||||
@@ -218,6 +218,209 @@ def softmax_fuse_block_sum_kernel_non_causal(
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
|
||||
# ============================================================
|
||||
# KV Chunking Support Kernels
|
||||
# ============================================================
|
||||
|
||||
@triton.jit
|
||||
def softmax_partial_stats_kernel(
|
||||
In,
|
||||
M_out, # max per row
|
||||
L_out, # sum per row (normalized by M_out)
|
||||
scale,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
stats_stride_0,
|
||||
stats_stride_1,
|
||||
k_len,
|
||||
chunk_start, # Q start position (for causal)
|
||||
kv_offset, # KV chunk offset (for causal)
|
||||
segment_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute partial softmax statistics for a KV chunk.
|
||||
|
||||
For each query row, computes:
|
||||
- m: max value in this chunk
|
||||
- l: sum of exp(x - m) in this chunk
|
||||
|
||||
These can be merged across chunks using online softmax formula.
|
||||
|
||||
Input shape: [batch, heads, q_len, k_chunk_len]
|
||||
Output shapes: M[batch, heads, q_len], L[batch, heads, q_len]
|
||||
"""
|
||||
block_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
batch_id = tl.program_id(2)
|
||||
|
||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||
offs_k = tl.arange(0, segment_size)
|
||||
|
||||
num_iters = k_len // segment_size
|
||||
|
||||
# For causal: compute boundary
|
||||
if is_causal:
|
||||
# causal boundary: Q position where this KV chunk starts to be valid
|
||||
# Q[i] can attend K[j] if i >= j
|
||||
# For KV chunk at kv_offset, Q[i] can attend if i >= kv_offset
|
||||
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
|
||||
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
|
||||
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
|
||||
else:
|
||||
num_iters_before_causal = num_iters
|
||||
|
||||
# Online softmax state
|
||||
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([block_size], dtype=tl.float32)
|
||||
|
||||
# Input pointer
|
||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||
|
||||
# Compute max and sum (before causal boundary)
|
||||
for iter in range(0, num_iters_before_causal):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
# Handle causal boundary
|
||||
if is_causal:
|
||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||
if iter < num_iters:
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
# causal mask: Q[i] >= K[j] + kv_offset
|
||||
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
|
||||
X = tl.where(mask, X, -1.0e6)
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
# Output pointers
|
||||
m_ptr = M_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||
l_ptr = L_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||
|
||||
offs = tl.arange(0, block_size)
|
||||
tl.store(m_ptr + offs, m_i.to(M_out.type.element_ty))
|
||||
tl.store(l_ptr + offs, l_i.to(L_out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_normalize_block_sum_kernel(
|
||||
In,
|
||||
Out,
|
||||
M_global, # global max per row
|
||||
L_global, # global sum per row
|
||||
scale,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
output_stride_2,
|
||||
stats_stride_0,
|
||||
stats_stride_1,
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
kv_offset, # KV chunk offset (for causal)
|
||||
segment_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Normalize with global stats and compute block sums for a KV chunk.
|
||||
|
||||
Uses pre-computed global m and l to correctly normalize softmax
|
||||
across all KV chunks.
|
||||
|
||||
Input shape: [batch, heads, q_len, k_chunk_len]
|
||||
Output shape: [batch, heads, q_blocks, k_chunk_blocks]
|
||||
"""
|
||||
block_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
batch_id = tl.program_id(2)
|
||||
|
||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||
offs_k = tl.arange(0, segment_size)
|
||||
|
||||
num_iters = k_len // segment_size
|
||||
|
||||
# For causal: compute boundary
|
||||
if is_causal:
|
||||
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
|
||||
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
|
||||
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
|
||||
else:
|
||||
num_iters_before_causal = num_iters
|
||||
|
||||
# Load global stats
|
||||
m_ptr = M_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||
l_ptr = L_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||
|
||||
offs = tl.arange(0, block_size)
|
||||
m_global = tl.load(m_ptr + offs).to(tl.float32)
|
||||
l_global = tl.load(l_ptr + offs).to(tl.float32)
|
||||
# Handle l_global = 0 (when all positions are masked)
|
||||
l_global_safe = tl.where(l_global > 0, l_global, 1.0)
|
||||
l_global_inv = 1.0 / l_global_safe
|
||||
|
||||
# Input pointer
|
||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||
|
||||
# Output pointer
|
||||
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||
|
||||
sum_mask = offs_q[:, None] < real_q_len
|
||||
|
||||
# Normalize and compute block sums (before causal boundary)
|
||||
for iter in range(0, num_iters_before_causal):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
# Handle causal boundary
|
||||
if is_causal:
|
||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||
if iter < num_iters:
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
# causal mask: Q[i] >= K[j] + kv_offset
|
||||
causal_mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
|
||||
X = tl.where(causal_mask, X, -1.0e6)
|
||||
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
# Zero out future blocks
|
||||
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def flat_group_gemm_fuse_reshape_kernel(
|
||||
Q, K, Out,
|
||||
@@ -380,6 +583,194 @@ def softmax_fuse_block_sum(
|
||||
return output
|
||||
|
||||
|
||||
def softmax_compute_partial_stats(
|
||||
attn_weights_slice: torch.Tensor,
|
||||
reshaped_block_size: int,
|
||||
segment_size: int,
|
||||
scale: float,
|
||||
chunk_start: int = 0,
|
||||
kv_offset: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute partial softmax statistics for a KV chunk.
|
||||
|
||||
This is the first step for KV-chunked softmax computation.
|
||||
For each query row, computes:
|
||||
- m: max value in this chunk
|
||||
- l: sum of exp(x - m) in this chunk
|
||||
|
||||
These partial stats can be merged across KV chunks using
|
||||
`merge_softmax_stats()`, then used with `softmax_normalize_and_block_sum()`.
|
||||
|
||||
Args:
|
||||
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
|
||||
reshaped_block_size: Block size in reshaped space
|
||||
segment_size: Processing segment size
|
||||
scale: Softmax scale factor
|
||||
chunk_start: Q chunk start position (in reshaped space)
|
||||
kv_offset: KV chunk offset (in reshaped space, for causal masking)
|
||||
is_causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
Tuple of (m, l) where:
|
||||
- m: [batch, heads, q_len] max values per row
|
||||
- l: [batch, heads, q_len] partial sums per row
|
||||
"""
|
||||
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||
|
||||
assert q_len % reshaped_block_size == 0
|
||||
assert k_len % segment_size == 0
|
||||
assert attn_weights_slice.stride(-1) == 1
|
||||
|
||||
m_out = torch.empty(
|
||||
(batch_size, num_heads, q_len),
|
||||
dtype=torch.float32,
|
||||
device=attn_weights_slice.device
|
||||
)
|
||||
l_out = torch.empty(
|
||||
(batch_size, num_heads, q_len),
|
||||
dtype=torch.float32,
|
||||
device=attn_weights_slice.device
|
||||
)
|
||||
|
||||
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||
|
||||
softmax_partial_stats_kernel[grid](
|
||||
attn_weights_slice,
|
||||
m_out,
|
||||
l_out,
|
||||
scale,
|
||||
attn_weights_slice.stride(0),
|
||||
attn_weights_slice.stride(1),
|
||||
attn_weights_slice.stride(2),
|
||||
m_out.stride(0),
|
||||
m_out.stride(1),
|
||||
k_len,
|
||||
chunk_start,
|
||||
kv_offset,
|
||||
segment_size,
|
||||
reshaped_block_size,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
return m_out, l_out
|
||||
|
||||
|
||||
def merge_softmax_stats(
|
||||
m_chunks: list,
|
||||
l_chunks: list,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Merge partial softmax statistics from multiple KV chunks.
|
||||
|
||||
Uses the online softmax merging formula:
|
||||
m_new = max(m1, m2)
|
||||
l_new = l1 * exp(m1 - m_new) + l2 * exp(m2 - m_new)
|
||||
|
||||
Args:
|
||||
m_chunks: List of max tensors [batch, heads, q_len] from each chunk
|
||||
l_chunks: List of sum tensors [batch, heads, q_len] from each chunk
|
||||
|
||||
Returns:
|
||||
Tuple of (m_global, l_global) with same shape as inputs
|
||||
"""
|
||||
assert len(m_chunks) == len(l_chunks)
|
||||
assert len(m_chunks) > 0
|
||||
|
||||
# Use log2 scale to match kernel (exp2)
|
||||
LOG2E = 1.4426950408889634
|
||||
|
||||
m_global = m_chunks[0].clone()
|
||||
l_global = l_chunks[0].clone()
|
||||
|
||||
for i in range(1, len(m_chunks)):
|
||||
m_chunk = m_chunks[i]
|
||||
l_chunk = l_chunks[i]
|
||||
|
||||
m_new = torch.maximum(m_global, m_chunk)
|
||||
# exp2(m - m_new) = 2^(m - m_new)
|
||||
l_global = l_global * torch.pow(2.0, m_global - m_new) + l_chunk * torch.pow(2.0, m_chunk - m_new)
|
||||
m_global = m_new
|
||||
|
||||
return m_global, l_global
|
||||
|
||||
|
||||
def softmax_normalize_and_block_sum(
|
||||
attn_weights_slice: torch.Tensor,
|
||||
m_global: torch.Tensor,
|
||||
l_global: torch.Tensor,
|
||||
reshaped_block_size: int,
|
||||
segment_size: int,
|
||||
chunk_start: int,
|
||||
real_q_len: int,
|
||||
scale: float,
|
||||
kv_offset: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Normalize with global stats and compute block sums for a KV chunk.
|
||||
|
||||
This is the second step for KV-chunked softmax computation.
|
||||
Uses pre-computed global m and l (from `merge_softmax_stats()`)
|
||||
to correctly normalize softmax values and compute block sums.
|
||||
|
||||
Args:
|
||||
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
|
||||
m_global: Global max values [batch, heads, q_len]
|
||||
l_global: Global sum values [batch, heads, q_len]
|
||||
reshaped_block_size: Block size in reshaped space
|
||||
segment_size: Processing segment size
|
||||
chunk_start: Start position for this chunk (for masking)
|
||||
real_q_len: Actual Q length (before padding)
|
||||
scale: Softmax scale factor
|
||||
kv_offset: KV chunk offset (in reshaped space, for causal masking)
|
||||
is_causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
Block-level attention sums [batch, heads, q_blocks, k_chunk_blocks]
|
||||
"""
|
||||
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||
|
||||
assert q_len % reshaped_block_size == 0
|
||||
assert k_len % segment_size == 0
|
||||
assert segment_size % reshaped_block_size == 0
|
||||
assert attn_weights_slice.stride(-1) == 1
|
||||
|
||||
output = torch.empty(
|
||||
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
||||
dtype=attn_weights_slice.dtype,
|
||||
device=attn_weights_slice.device
|
||||
)
|
||||
|
||||
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||
|
||||
softmax_normalize_block_sum_kernel[grid](
|
||||
attn_weights_slice,
|
||||
output,
|
||||
m_global,
|
||||
l_global,
|
||||
scale,
|
||||
attn_weights_slice.stride(0),
|
||||
attn_weights_slice.stride(1),
|
||||
attn_weights_slice.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
m_global.stride(0),
|
||||
m_global.stride(1),
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
kv_offset,
|
||||
segment_size,
|
||||
reshaped_block_size,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def flat_group_gemm_fuse_reshape(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
|
||||
265
tests/test_xattn_estimate_alignment.py
Normal file
265
tests/test_xattn_estimate_alignment.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Test: 验证 xattn_estimate 与 KV chunking kernels 的一致性
|
||||
|
||||
使用真实 KV cache 数据,对比:
|
||||
1. xattn_estimate (高层 API)
|
||||
2. 三阶段 KV chunking (softmax_compute_partial_stats + merge + normalize)
|
||||
|
||||
三阶段 KV chunking 流程:
|
||||
1. softmax_compute_partial_stats: 计算每个 KV chunk 的 (m, l)
|
||||
2. merge_softmax_stats: Host 端合并所有 chunks 的 stats
|
||||
3. softmax_normalize_and_block_sum: 使用全局 stats 归一化
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_estimate_alignment.py
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
|
||||
import torch
|
||||
import math
|
||||
from nanovllm.ops.xattn import (
|
||||
xattn_estimate,
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_compute_partial_stats,
|
||||
softmax_normalize_and_block_sum,
|
||||
merge_softmax_stats,
|
||||
find_blocks_chunked,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
DATA_FILE = "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt"
|
||||
BSA_BLOCK_SIZE = 128
|
||||
CHUNK_SIZE = 16384 # xattn_estimate 默认值
|
||||
USE_SAVED_PARAMS = True # 设为 False 则使用默认值
|
||||
|
||||
device = "cuda"
|
||||
|
||||
# ============================================================
|
||||
# Step 1: 加载真实数据
|
||||
# ============================================================
|
||||
print("=" * 60)
|
||||
print("Step 1: 加载真实 KV cache 数据")
|
||||
print("=" * 60)
|
||||
|
||||
data = torch.load(DATA_FILE, map_location="cpu")
|
||||
Q = data["query"].to(device) # [1, 32, seq_len, 128]
|
||||
K = data["key"].to(device) # [1, 32, seq_len, 128]
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = Q.shape
|
||||
|
||||
# 从保存的数据中读取参数
|
||||
if USE_SAVED_PARAMS:
|
||||
STRIDE = data["stride"]
|
||||
THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"]
|
||||
else:
|
||||
STRIDE = 8
|
||||
THRESHOLD = 0.9
|
||||
|
||||
print(f"Q shape: {Q.shape}")
|
||||
print(f"K shape: {K.shape}")
|
||||
print(f"Data layer_id: {data['layer_id']}, saved density: {data['density']:.4f}")
|
||||
print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}, CHUNK_SIZE={CHUNK_SIZE}")
|
||||
print()
|
||||
|
||||
# ============================================================
|
||||
# Step 2: 使用 xattn_estimate 高层 API
|
||||
# ============================================================
|
||||
print("=" * 60)
|
||||
print("Step 2: 调用 xattn_estimate (高层 API)")
|
||||
print("=" * 60)
|
||||
|
||||
attn_sums_api, mask_api = xattn_estimate(
|
||||
Q, K,
|
||||
block_size=BSA_BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# 裁剪到有效区域
|
||||
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
|
||||
|
||||
# 计算 density (causal)
|
||||
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
|
||||
total_api = causal_mask.sum().item() * batch_size * num_heads
|
||||
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_api = selected_api / total_api
|
||||
|
||||
print(f"mask_api shape (padded): {mask_api.shape}")
|
||||
print(f"mask_api_valid shape: {mask_api_valid.shape}")
|
||||
print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, total={total_api})")
|
||||
print()
|
||||
|
||||
# ============================================================
|
||||
# Step 3: 三阶段 KV Chunking
|
||||
# ============================================================
|
||||
print("=" * 60)
|
||||
print("Step 3: 三阶段 KV Chunking")
|
||||
print("=" * 60)
|
||||
print(" 1) 每个 KV chunk 计算 partial stats")
|
||||
print(" 2) Host 端合并 stats")
|
||||
print(" 3) 使用全局 stats 归一化并计算 block sums")
|
||||
print()
|
||||
|
||||
# 计算 padding 参数
|
||||
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
|
||||
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||
|
||||
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
|
||||
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
|
||||
|
||||
reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
||||
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
|
||||
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
|
||||
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
|
||||
print(f"seq_len: {seq_len}, q_chunk_num: {q_chunk_num}, kv_chunk_num: {kv_chunk_num}")
|
||||
print()
|
||||
|
||||
# Padding
|
||||
if k_num_to_pad > 0:
|
||||
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
|
||||
else:
|
||||
K_padded = K
|
||||
|
||||
if q_num_to_pad > 0:
|
||||
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
|
||||
else:
|
||||
Q_padded = Q
|
||||
|
||||
# Softmax scale
|
||||
norm = 1.0
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
||||
|
||||
simple_mask_list = []
|
||||
|
||||
for q_chunk_idx in range(q_chunk_num):
|
||||
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
|
||||
q_end = q_start + reshaped_chunk_size * STRIDE
|
||||
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||
|
||||
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
|
||||
chunk_end = chunk_start + reshaped_chunk_size
|
||||
|
||||
# 阶段 1: 每个 KV chunk 计算 partial stats 和 raw scores
|
||||
m_chunks = []
|
||||
l_chunks = []
|
||||
attn_weights_chunks = []
|
||||
|
||||
for kv_chunk_idx in range(kv_chunk_num):
|
||||
kv_start = kv_chunk_idx * CHUNK_SIZE
|
||||
kv_end = kv_start + CHUNK_SIZE
|
||||
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||
|
||||
# KV offset in reshaped space
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
|
||||
# 计算 raw attention scores
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q_chunk, K_chunk, STRIDE,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False, # K 不完整,不能在这里用 causal
|
||||
)
|
||||
attn_weights_chunks.append(attn_weights_kv)
|
||||
|
||||
# 计算 partial stats (带 causal mask)
|
||||
m_partial, l_partial = softmax_compute_partial_stats(
|
||||
attn_weights_kv,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m_partial)
|
||||
l_chunks.append(l_partial)
|
||||
|
||||
# 阶段 2: Host 端合并 stats
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
|
||||
# 阶段 3: 使用全局 stats 归一化并计算 block sums
|
||||
attn_sum_per_kv = []
|
||||
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||
attn_weights_kv,
|
||||
m_global,
|
||||
l_global,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
attn_sum_per_kv.append(attn_sum_kv)
|
||||
|
||||
# 拼接各 KV chunk 的 block sums
|
||||
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||
|
||||
# 选择 blocks
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum_concat,
|
||||
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
|
||||
threshold=THRESHOLD,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True,
|
||||
)
|
||||
simple_mask_list.append(simple_mask)
|
||||
|
||||
print(f" Q chunk {q_chunk_idx}: merged {kv_chunk_num} KV chunks, attn_sum shape={attn_sum_concat.shape}")
|
||||
|
||||
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
|
||||
|
||||
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
|
||||
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_kv = selected_kv / total_api
|
||||
|
||||
print()
|
||||
print(f"[KV chunking] density: {density_kv:.6f} (selected={selected_kv}, total={total_api})")
|
||||
print()
|
||||
|
||||
# ============================================================
|
||||
# Step 4: 对比结果
|
||||
# ============================================================
|
||||
print("=" * 60)
|
||||
print("Step 4: 对比结果")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
mask_total = mask_api_valid.numel()
|
||||
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
|
||||
|
||||
print("| 方法 | density | 与 API 差异 | Mask 差异 |")
|
||||
print("|------|---------|-------------|-----------|")
|
||||
print(f"| xattn_estimate API | {density_api:.6f} | - | - |")
|
||||
print(f"| KV chunking | {density_kv:.6f} | {abs(density_api - density_kv):.6f} | {100*mask_diff/mask_total:.4f}% |")
|
||||
print()
|
||||
|
||||
if abs(density_api - density_kv) < 1e-6 and mask_diff / mask_total < 0.001:
|
||||
print("test_xattn_estimate_alignment: PASSED")
|
||||
else:
|
||||
print("test_xattn_estimate_alignment: FAILED")
|
||||
246
tests/test_xattn_kv_chunking_batch.py
Normal file
246
tests/test_xattn_kv_chunking_batch.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Test: 批量验证 xattn_estimate 与 KV chunking kernels 的一致性
|
||||
|
||||
测试 results/kvcache 下所有保存的 QKV 数据
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_kv_chunking_batch.py
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import math
|
||||
from nanovllm.ops.xattn import (
|
||||
xattn_estimate,
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_compute_partial_stats,
|
||||
softmax_normalize_and_block_sum,
|
||||
merge_softmax_stats,
|
||||
find_blocks_chunked,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
DATA_DIR = "/home/zijie/Code/nano-vllm/results/kvcache"
|
||||
BSA_BLOCK_SIZE = 128
|
||||
CHUNK_SIZE = 16384
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def test_single_file(data_file: str) -> dict:
|
||||
"""测试单个 kvcache 文件"""
|
||||
data = torch.load(data_file, map_location="cpu")
|
||||
Q = data["query"].to(device)
|
||||
K = data["key"].to(device)
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = Q.shape
|
||||
STRIDE = data["stride"]
|
||||
THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"]
|
||||
|
||||
# ========== xattn_estimate API ==========
|
||||
attn_sums_api, mask_api = xattn_estimate(
|
||||
Q, K,
|
||||
block_size=BSA_BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
|
||||
|
||||
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
|
||||
total_api = causal_mask.sum().item() * batch_size * num_heads
|
||||
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_api = selected_api / total_api
|
||||
|
||||
# ========== 三阶段 KV Chunking ==========
|
||||
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
|
||||
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||
|
||||
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
|
||||
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
|
||||
|
||||
reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
||||
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
|
||||
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
|
||||
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
|
||||
if k_num_to_pad > 0:
|
||||
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
|
||||
else:
|
||||
K_padded = K
|
||||
|
||||
if q_num_to_pad > 0:
|
||||
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
|
||||
else:
|
||||
Q_padded = Q
|
||||
|
||||
norm = 1.0
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
||||
|
||||
simple_mask_list = []
|
||||
|
||||
for q_chunk_idx in range(q_chunk_num):
|
||||
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
|
||||
q_end = q_start + reshaped_chunk_size * STRIDE
|
||||
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||
|
||||
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
|
||||
chunk_end = chunk_start + reshaped_chunk_size
|
||||
|
||||
m_chunks = []
|
||||
l_chunks = []
|
||||
attn_weights_chunks = []
|
||||
|
||||
for kv_chunk_idx in range(kv_chunk_num):
|
||||
kv_start = kv_chunk_idx * CHUNK_SIZE
|
||||
kv_end = kv_start + CHUNK_SIZE
|
||||
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q_chunk, K_chunk, STRIDE,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights_chunks.append(attn_weights_kv)
|
||||
|
||||
m_partial, l_partial = softmax_compute_partial_stats(
|
||||
attn_weights_kv,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m_partial)
|
||||
l_chunks.append(l_partial)
|
||||
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
|
||||
attn_sum_per_kv = []
|
||||
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||
attn_weights_kv,
|
||||
m_global,
|
||||
l_global,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
attn_sum_per_kv.append(attn_sum_kv)
|
||||
|
||||
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum_concat,
|
||||
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
|
||||
threshold=THRESHOLD,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True,
|
||||
)
|
||||
simple_mask_list.append(simple_mask)
|
||||
|
||||
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
|
||||
|
||||
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
|
||||
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_kv = selected_kv / total_api
|
||||
|
||||
mask_total = mask_api_valid.numel()
|
||||
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
|
||||
mask_diff_pct = 100 * mask_diff / mask_total
|
||||
|
||||
return {
|
||||
"seq_len": seq_len,
|
||||
"stride": STRIDE,
|
||||
"threshold": THRESHOLD,
|
||||
"kv_chunks": kv_chunk_num,
|
||||
"density_api": density_api,
|
||||
"density_kv": density_kv,
|
||||
"density_diff": abs(density_api - density_kv),
|
||||
"mask_diff_pct": mask_diff_pct,
|
||||
"passed": abs(density_api - density_kv) < 1e-6 and mask_diff_pct < 0.01,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
files = sorted(glob.glob(os.path.join(DATA_DIR, "qkv_*.pt")))
|
||||
|
||||
print("=" * 80)
|
||||
print("XAttention KV Chunking Alignment Test")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
results = []
|
||||
for f in files:
|
||||
fname = os.path.basename(f)
|
||||
print(f"Testing {fname}...", end=" ", flush=True)
|
||||
try:
|
||||
r = test_single_file(f)
|
||||
results.append(r)
|
||||
status = "✓ PASS" if r["passed"] else "✗ FAIL"
|
||||
print(f"{status} (seq_len={r['seq_len']}, kv_chunks={r['kv_chunks']})")
|
||||
except Exception as e:
|
||||
print(f"✗ ERROR: {e}")
|
||||
results.append({"file": fname, "error": str(e)})
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("Results Summary")
|
||||
print("=" * 80)
|
||||
print()
|
||||
print("| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |")
|
||||
print("|---------|--------|-----------|-----------|-------------|------------|------|-----------|--------|")
|
||||
|
||||
all_passed = True
|
||||
for r in results:
|
||||
if "error" in r:
|
||||
print(f"| ERROR | - | - | - | - | - | - | - | {r['error'][:20]} |")
|
||||
all_passed = False
|
||||
else:
|
||||
status = "PASS" if r["passed"] else "FAIL"
|
||||
if not r["passed"]:
|
||||
all_passed = False
|
||||
print(f"| {r['seq_len']:>7} | {r['stride']:>6} | {r['threshold']:.2f} | {r['kv_chunks']:>9} | "
|
||||
f"{r['density_api']:.6f} | {r['density_kv']:.6f} | {r['density_diff']:.6f} | "
|
||||
f"{r['mask_diff_pct']:.4f}% | {status} |")
|
||||
|
||||
print()
|
||||
if all_passed:
|
||||
print("test_xattn_kv_chunking_batch: ALL PASSED")
|
||||
else:
|
||||
print("test_xattn_kv_chunking_batch: SOME FAILED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user