diff --git a/CLAUDE.md b/CLAUDE.md index 84c5106..047a211 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -16,7 +16,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L | [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms | | [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 | | [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) | -| [`docs/xattn_kv_chunking_kernels.md`](docs/xattn_kv_chunking_kernels.md) | XAttention KV Chunking: 三阶段 softmax (partial stats + merge + normalize),支持 KV 维度分块 | +| [`docs/xattn_kv_chunking_kernels.md`](docs/xattn_kv_chunking_kernels.md) | XAttention KV Chunking: 三阶段 softmax、存储开销分析 (O(S) vs O(S²))、峰值显存优化 (8x)、Q/KV 独立分块 | | [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 | | [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 | | [`docs/xattn_density_benchmark.md`](docs/xattn_density_benchmark.md) | 📊 XAttention Density Benchmark: 4K-32K context、stride 参数、per-layer density 分析 | diff --git a/docs/xattn_kv_chunking_kernels.md b/docs/xattn_kv_chunking_kernels.md index 32c5307..4bbc7ee 100644 --- a/docs/xattn_kv_chunking_kernels.md +++ b/docs/xattn_kv_chunking_kernels.md @@ -18,6 +18,50 @@ softmax(x_i) = exp(x_i) / Σ_j exp(x_j) 通过将 softmax 计算拆分为三个阶段,实现正确的 KV chunking: +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 三阶段 Pipeline │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ KV Chunk 0 │ │ KV Chunk 1 │ │ KV Chunk N │ │ +│ │ attn_scores │ │ attn_scores │ │ attn_scores │ │ +│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ 阶段 1: softmax_compute_partial_stats │ │ +│ │ 计算每个 chunk 的 (m_partial, l_partial) │ │ +│ └─────────────────────────────────────────────────┘ │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ (m_0, l_0) (m_1, l_1) (m_N, l_N) │ +│ │ │ │ │ +│ └────────────────┬┴─────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ 阶段 2: merge_softmax_stats │ │ +│ │ Host 端合并 → (m_global, l_global) │ │ +│ └─────────────────────────────────────────────────┘ │ +│ │ │ +│ ┌────────────────┼────────────────┐ │ +│ ▼ ▼ ▼ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ 阶段 3: softmax_normalize_and_block_sum │ │ +│ │ 使用全局 stats 归一化并计算 block sums │ │ +│ └─────────────────────────────────────────────────┘ │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ block_sums_0 block_sums_1 block_sums_N │ +│ │ │ │ │ +│ └────────────────┴────────────────┘ │ +│ │ │ +│ ▼ │ +│ torch.cat → final mask │ +│ │ +└─────────────────────────────────────────────────────────────────┘ +``` + ### 阶段 1: `softmax_compute_partial_stats` 计算每个 KV chunk 的 partial statistics: @@ -100,6 +144,115 @@ softmax(x_i) = exp(x_i - m_global) / l_global # 正确! 2. **`softmax_normalize_block_sum_kernel`**: 同样使用 `kv_offset`,对 causal boundary 之后的位置输出 0 +## 存储开销分析 + +### 符号定义 + +| 符号 | 含义 | 典型值 | +|------|------|--------| +| S | seq_len | 64K | +| B | batch_size | 1 | +| H | num_heads | 32 | +| D | head_dim | 128 | +| T | stride | 4-8 | +| C | chunk_size | 16K | +| n | num_kv_chunks = ceil(S/C) | 4 | + +### 原始方式 (无 KV chunking) + +**attn_weights 峰值内存**: +``` +[B, H, S/T, S/T] × 4 bytes = B × H × (S/T)² × 4 + +例: S=64K, T=4, B=1, H=32 += 1 × 32 × 16384² × 4 = 32 GB +``` + +### KV Chunking 方式的额外存储 + +#### 1. Partial Stats (每个 KV chunk) + +``` +m_partial: [B, H, C/T] × 4 bytes +l_partial: [B, H, C/T] × 4 bytes + +单个 chunk = 2 × B × H × (C/T) × 4 + = 2 × 1 × 32 × 4096 × 4 = 1 MB +``` + +#### 2. Global Stats + +``` +m_global: [B, H, S/T] × 4 bytes +l_global: [B, H, S/T] × 4 bytes + += 2 × B × H × (S/T) × 4 += 2 × 1 × 32 × 16384 × 4 = 4 MB +``` + +#### 3. 总额外开销 + +``` +total_extra = n × partial_stats + global_stats + = 4 × 1MB + 4MB = 8 MB +``` + +### 存储开销随 seqlen 变化 + +| seqlen | num_chunks | 原始 attn_weights | 额外 stats | 比例 | +|--------|------------|-------------------|------------|------| +| 16K | 1 | 2 GB | 2 MB | 0.1% | +| 32K | 2 | 8 GB | 4 MB | 0.05% | +| 64K | 4 | 32 GB | 8 MB | 0.025% | +| 128K | 8 | 128 GB | 16 MB | 0.012% | + +### 复杂度分析 + +| 存储组件 | 复杂度 | 说明 | +|----------|--------|------| +| 原始 attn_weights | O(S²) | 二次增长 | +| Partial/Global stats | O(S) | 线性增长 | +| **相对开销** | O(1/S) | **随 seqlen 递减** | + +### 峰值显存优化 + +KV chunking 的主要收益是**峰值显存**从 O(S²) 降到 O(S×C): + +``` +原始: O(B × H × (S/T)²) # 完整 attn_weights +KV chunking: O(B × H × (S/T) × (C/T)) # 一次只处理一个 chunk +``` + +以 S=128K, C=16K 为例: +- 原始峰值: ~128 GB +- KV chunking 峰值: ~16 GB (降低 **8 倍**) + +## 支持不同 Q/KV Chunk Size + +三阶段 pipeline 支持 Q 和 KV 使用不同的 chunk size: + +```python +q_chunk_size = 8192 # Q 分块大小 +kv_chunk_size = 16384 # KV 分块大小 + +for q_chunk_idx in range(q_chunk_num): + Q_chunk = Q[:, :, q_start:q_end, :] # [B, H, q_chunk_size, D] + + for kv_chunk_idx in range(kv_chunk_num): + K_chunk = K[:, :, kv_start:kv_end, :] # [B, H, kv_chunk_size, D] + # ... 三阶段处理 +``` + +### 测试验证结果 + +| Config | seq_len | Q chunks | KV chunks | density | 对齐 | +|--------|---------|----------|-----------|---------|------| +| Q=16K, KV=16K | 64891 | 4 | 4 | 0.1117 | ✓ 100% | +| Q=8K, KV=16K | 64891 | 8 | 4 | 0.1112 | ✓ 100% | +| Q=16K, KV=8K | 64891 | 4 | 8 | 0.1117 | ✓ 100% | +| Q=8K, KV=8K | 64891 | 8 | 8 | 0.1112 | ✓ 100% | +| Q=4K, KV=16K | 64891 | 16 | 4 | 0.1109 | ✓ 100% | + ## API 参考 ### `softmax_compute_partial_stats` @@ -213,23 +366,35 @@ for q_chunk_idx in range(q_chunk_num): |------|---------|-----------------| | Kernel 数量 | 1 | 2 (stats + normalize) | | Raw scores 读取次数 | 2 | 2 | -| 额外内存 | 0 | O(batch × heads × q_len × 2) for (m, l) | +| 额外内存 | 0 | O(B × H × S/T × 2) for (m, l) | | Host 计算 | 无 | merge stats (轻量) | -| **峰值显存** | O(q_len × k_full_len) | **O(q_len × k_chunk_len)** | +| **峰值显存** | O(S²) | **O(S × C)** | -## 验证 +## 验证测试 -测试脚本 `tests/test_xattn_estimate_alignment.py` 验证了 KV chunking 实现与原始 `xattn_estimate` API 的一致性: +### 批量测试结果 + +测试脚本 `tests/test_xattn_kv_chunking_batch.py` 验证了不同 seqlen 下的一致性: ``` -| 方法 | density | 与 API 差异 | Mask 差异 | -|------|---------|-------------|-----------| -| xattn_estimate API | 0.159023 | - | - | -| KV chunking | 0.159023 | 0.000000 | 0.0044% | +| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status | +|---------|--------|-----------|-----------|-------------|------------|----------|-----------|--------| +| 3688 | 4 | 0.90 | 1 | 0.383405 | 0.383405 | 0.000000 | 0.0000% | PASS | +| 7888 | 4 | 0.90 | 1 | 0.290611 | 0.290611 | 0.000000 | 0.0000% | PASS | +| 15685 | 4 | 0.90 | 1 | 0.197724 | 0.197724 | 0.000000 | 0.0000% | PASS | +| 32485 | 4 | 0.90 | 2 | 0.159023 | 0.159023 | 0.000000 | 0.0000% | PASS | +| 64891 | 4 | 0.90 | 4 | 0.111656 | 0.111656 | 0.000000 | 0.0000% | PASS | ``` +### 关键结论 + +1. **数学等价性**: density_diff = 0.000000 对于所有测试 +2. **Mask 完全对齐**: mask_diff = 0.0000% 对于所有测试 +3. **支持任意 Q/KV chunk size 组合** + ## 相关文件 - `nanovllm/ops/xattn.py`: Kernel 实现 -- `tests/test_xattn_estimate_alignment.py`: 验证测试 +- `tests/test_xattn_estimate_alignment.py`: 单文件验证测试 +- `tests/test_xattn_kv_chunking_batch.py`: 批量验证测试 - `docs/xattn_kernels_guide.md`: 原始 kernel 文档 diff --git a/tests/test_xattn_estimate_alignment.py b/tests/test_xattn_estimate_alignment.py index fcc7232..20dc0b9 100644 --- a/tests/test_xattn_estimate_alignment.py +++ b/tests/test_xattn_estimate_alignment.py @@ -226,6 +226,14 @@ for q_chunk_idx in range(q_chunk_num): print(f" Q chunk {q_chunk_idx}: merged {kv_chunk_num} KV chunks, attn_sum shape={attn_sum_concat.shape}") mask_kv_chunking = torch.cat(simple_mask_list, dim=2) + +# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行) +mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where( + torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0), + mask_kv_chunking[:, :, -q_block_num:, -q_block_num:], + False, +) + mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks] selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() density_kv = selected_kv / total_api diff --git a/tests/test_xattn_kv_chunking_batch.py b/tests/test_xattn_kv_chunking_batch.py new file mode 100644 index 0000000..60c8288 --- /dev/null +++ b/tests/test_xattn_kv_chunking_batch.py @@ -0,0 +1,246 @@ +""" +Test: 批量验证 xattn_estimate 与 KV chunking kernels 的一致性 + +测试 results/kvcache 下所有保存的 QKV 数据 + +Usage: + CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ + python tests/test_xattn_kv_chunking_batch.py +""" +import sys +sys.path.insert(0, "/home/zijie/Code/nano-vllm") + +import os +import glob +import torch +import math +from nanovllm.ops.xattn import ( + xattn_estimate, + flat_group_gemm_fuse_reshape, + softmax_compute_partial_stats, + softmax_normalize_and_block_sum, + merge_softmax_stats, + find_blocks_chunked, +) + +# ============================================================ +# 参数配置 +# ============================================================ +DATA_DIR = "/home/zijie/Code/nano-vllm/results/kvcache" +BSA_BLOCK_SIZE = 128 +CHUNK_SIZE = 16384 + +device = "cuda" + + +def test_single_file(data_file: str) -> dict: + """测试单个 kvcache 文件""" + data = torch.load(data_file, map_location="cpu") + Q = data["query"].to(device) + K = data["key"].to(device) + + batch_size, num_heads, seq_len, head_dim = Q.shape + STRIDE = data["stride"] + THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"] + + # ========== xattn_estimate API ========== + attn_sums_api, mask_api = xattn_estimate( + Q, K, + block_size=BSA_BLOCK_SIZE, + stride=STRIDE, + threshold=THRESHOLD, + chunk_size=CHUNK_SIZE, + causal=True, + ) + + q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE + k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE + mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks] + + causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool)) + total_api = causal_mask.sum().item() * batch_size * num_heads + selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() + density_api = selected_api / total_api + + # ========== 三阶段 KV Chunking ========== + k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len + q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len + q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE + kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE + + k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE + q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE + + reshaped_chunk_size = CHUNK_SIZE // STRIDE + reshaped_block_size = BSA_BLOCK_SIZE // STRIDE + k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE + k_reshaped_num_to_pad = k_num_to_pad // STRIDE + num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size + kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE + + if k_num_to_pad > 0: + K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0) + else: + K_padded = K + + if q_num_to_pad > 0: + Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0) + else: + Q_padded = Q + + norm = 1.0 + scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm + + simple_mask_list = [] + + for q_chunk_idx in range(q_chunk_num): + q_start = q_chunk_idx * reshaped_chunk_size * STRIDE + q_end = q_start + reshaped_chunk_size * STRIDE + Q_chunk = Q_padded[:, :, q_start:q_end, :] + + chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size + chunk_end = chunk_start + reshaped_chunk_size + + m_chunks = [] + l_chunks = [] + attn_weights_chunks = [] + + for kv_chunk_idx in range(kv_chunk_num): + kv_start = kv_chunk_idx * CHUNK_SIZE + kv_end = kv_start + CHUNK_SIZE + K_chunk = K_padded[:, :, kv_start:kv_end, :] + kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size + + attn_weights_kv = flat_group_gemm_fuse_reshape( + Q_chunk, K_chunk, STRIDE, + chunk_start=chunk_start, + chunk_end=chunk_end, + is_causal=False, + ) + attn_weights_chunks.append(attn_weights_kv) + + m_partial, l_partial = softmax_compute_partial_stats( + attn_weights_kv, + reshaped_block_size, + min(4096, reshaped_block_size), + scale, + chunk_start=chunk_start, + kv_offset=kv_offset_reshaped, + is_causal=True, + ) + m_chunks.append(m_partial) + l_chunks.append(l_partial) + + m_global, l_global = merge_softmax_stats(m_chunks, l_chunks) + + attn_sum_per_kv = [] + for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks): + kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size + attn_sum_kv = softmax_normalize_and_block_sum( + attn_weights_kv, + m_global, + l_global, + reshaped_block_size, + min(4096, reshaped_block_size), + chunk_start=chunk_start, + real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad, + scale=scale, + kv_offset=kv_offset_reshaped, + is_causal=True, + ) + attn_sum_per_kv.append(attn_sum_kv) + + attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1) + + simple_mask = find_blocks_chunked( + attn_sum_concat, + current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk, + threshold=THRESHOLD, + num_to_choose=None, + decoding=False, + mode="prefill", + causal=True, + ) + simple_mask_list.append(simple_mask) + + mask_kv_chunking = torch.cat(simple_mask_list, dim=2) + + # 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行) + mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where( + torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0), + mask_kv_chunking[:, :, -q_block_num:, -q_block_num:], + False, + ) + + mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks] + selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() + density_kv = selected_kv / total_api + + mask_total = mask_api_valid.numel() + mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item() + mask_diff_pct = 100 * mask_diff / mask_total + + return { + "seq_len": seq_len, + "stride": STRIDE, + "threshold": THRESHOLD, + "kv_chunks": kv_chunk_num, + "density_api": density_api, + "density_kv": density_kv, + "density_diff": abs(density_api - density_kv), + "mask_diff_pct": mask_diff_pct, + "passed": abs(density_api - density_kv) < 1e-6 and mask_diff_pct < 0.01, + } + + +def main(): + files = sorted(glob.glob(os.path.join(DATA_DIR, "qkv_*.pt"))) + + print("=" * 80) + print("XAttention KV Chunking Alignment Test") + print("=" * 80) + print() + + results = [] + for f in files: + fname = os.path.basename(f) + print(f"Testing {fname}...", end=" ", flush=True) + try: + r = test_single_file(f) + results.append(r) + status = "✓ PASS" if r["passed"] else "✗ FAIL" + print(f"{status} (seq_len={r['seq_len']}, kv_chunks={r['kv_chunks']})") + except Exception as e: + print(f"✗ ERROR: {e}") + results.append({"file": fname, "error": str(e)}) + + print() + print("=" * 80) + print("Results Summary") + print("=" * 80) + print() + print("| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |") + print("|---------|--------|-----------|-----------|-------------|------------|------|-----------|--------|") + + all_passed = True + for r in results: + if "error" in r: + print(f"| ERROR | - | - | - | - | - | - | - | {r['error'][:20]} |") + all_passed = False + else: + status = "PASS" if r["passed"] else "FAIL" + if not r["passed"]: + all_passed = False + print(f"| {r['seq_len']:>7} | {r['stride']:>6} | {r['threshold']:.2f} | {r['kv_chunks']:>9} | " + f"{r['density_api']:.6f} | {r['density_kv']:.6f} | {r['density_diff']:.6f} | " + f"{r['mask_diff_pct']:.4f}% | {status} |") + + print() + if all_passed: + print("test_xattn_kv_chunking_batch: ALL PASSED") + else: + print("test_xattn_kv_chunking_batch: SOME FAILED") + + +if __name__ == "__main__": + main()