Files
nano-vllm/docs/xattn_kv_chunking_kernels.md
Zijie Tian 6e34efd58a 📝 docs: add storage overhead analysis and batch tests for KV chunking
- Update xattn_kv_chunking_kernels.md with:
  - Detailed storage overhead analysis (O(S) vs O(S²))
  - Peak memory optimization (8x reduction)
  - Support for independent Q/KV chunk sizes
  - Batch verification results (3K-64K seqlen)
  - ASCII pipeline diagram

- Add test_xattn_kv_chunking_batch.py for batch validation
- Fix causal mask post-processing in alignment test
- Update CLAUDE.md documentation index

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 19:22:36 +08:00

15 KiB
Raw Permalink Blame History

XAttention KV Chunking Kernels

概述

本文档描述了支持 KV 维度分 chunk 的 softmax kernels 实现。这些 kernels 允许在 CPU offload 场景下,沿 KV 维度分块计算 sparse attention estimation而不需要在 GPU 上保存完整的 raw attention scores。

背景

原始的 softmax_fuse_block_sum kernel 需要完整的 K 序列来计算正确的 softmax 归一化分母:

softmax(x_i) = exp(x_i) / Σ_j exp(x_j)

如果只有部分 K (KV chunk),分母 Σ_j exp(x_j) 不完整,导致归一化错误。

解决方案:三阶段计算

通过将 softmax 计算拆分为三个阶段,实现正确的 KV chunking

┌─────────────────────────────────────────────────────────────────┐
│                        三阶段 Pipeline                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────────┐   ┌─────────────┐   ┌─────────────┐           │
│  │ KV Chunk 0  │   │ KV Chunk 1  │   │ KV Chunk N  │           │
│  │ attn_scores │   │ attn_scores │   │ attn_scores │           │
│  └──────┬──────┘   └──────┬──────┘   └──────┬──────┘           │
│         │                 │                 │                   │
│         ▼                 ▼                 ▼                   │
│  ┌─────────────────────────────────────────────────┐           │
│  │     阶段 1: softmax_compute_partial_stats       │           │
│  │     计算每个 chunk 的 (m_partial, l_partial)    │           │
│  └─────────────────────────────────────────────────┘           │
│         │                 │                 │                   │
│         ▼                 ▼                 ▼                   │
│      (m_0, l_0)       (m_1, l_1)       (m_N, l_N)              │
│         │                 │                 │                   │
│         └────────────────┬┴─────────────────┘                   │
│                          ▼                                      │
│  ┌─────────────────────────────────────────────────┐           │
│  │         阶段 2: merge_softmax_stats             │           │
│  │         Host 端合并 → (m_global, l_global)      │           │
│  └─────────────────────────────────────────────────┘           │
│                          │                                      │
│         ┌────────────────┼────────────────┐                     │
│         ▼                ▼                ▼                     │
│  ┌─────────────────────────────────────────────────┐           │
│  │    阶段 3: softmax_normalize_and_block_sum      │           │
│  │    使用全局 stats 归一化并计算 block sums        │           │
│  └─────────────────────────────────────────────────┘           │
│         │                │                │                     │
│         ▼                ▼                ▼                     │
│    block_sums_0    block_sums_1    block_sums_N                │
│         │                │                │                     │
│         └────────────────┴────────────────┘                     │
│                          │                                      │
│                          ▼                                      │
│                 torch.cat → final mask                          │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

阶段 1: softmax_compute_partial_stats

计算每个 KV chunk 的 partial statistics

  • m_partial: 该 chunk 内的最大值 (per query row)
  • l_partial: 该 chunk 内的 partial sum = Σ exp(x - m_partial)
m_partial, l_partial = softmax_compute_partial_stats(
    attn_weights_kv,      # [batch, heads, q_len, k_chunk_len]
    reshaped_block_size,
    segment_size,
    scale,
    chunk_start=chunk_start,
    kv_offset=kv_offset,  # KV chunk 在完整序列中的偏移
    is_causal=True,
)
# 输出: m_partial, l_partial 形状为 [batch, heads, q_len]

阶段 2: merge_softmax_stats

Host 端合并所有 KV chunks 的 statistics

m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)

合并公式 (Online Softmax):

m_new = max(m_global, m_chunk)
l_new = l_global * exp(m_global - m_new) + l_chunk * exp(m_chunk - m_new)

阶段 3: softmax_normalize_and_block_sum

使用全局 statistics 归一化并计算 block sums

attn_sum_kv = softmax_normalize_and_block_sum(
    attn_weights_kv,      # [batch, heads, q_len, k_chunk_len]
    m_global,             # [batch, heads, q_len]
    l_global,             # [batch, heads, q_len]
    reshaped_block_size,
    segment_size,
    chunk_start=chunk_start,
    real_q_len=real_q_len,
    scale=scale,
    kv_offset=kv_offset,
    is_causal=True,
)
# 输出: [batch, heads, q_blocks, k_chunk_blocks]

数学等价性证明

原始 softmax 计算 (完整 K):

softmax(x_i) = exp(x_i - m) / Σ_j exp(x_j - m)

分 KV chunk 计算:

Chunk 0: m_0 = max(x[0:N/2]), l_0 = Σ exp(x[0:N/2] - m_0)
Chunk 1: m_1 = max(x[N/2:N]), l_1 = Σ exp(x[N/2:N] - m_1)

合并:
m_global = max(m_0, m_1)
l_global = l_0 * exp(m_0 - m_global) + l_1 * exp(m_1 - m_global)
         = Σ exp(x[0:N] - m_global)  # 等于全局 sum

归一化:
softmax(x_i) = exp(x_i - m_global) / l_global  # 正确!

Causal Mask 处理

两个 kernel 都正确处理了 causal attention

  1. softmax_partial_stats_kernel: 通过 kv_offset 参数确定当前 KV chunk 在完整序列中的位置,正确计算 causal boundary

  2. softmax_normalize_block_sum_kernel: 同样使用 kv_offset,对 causal boundary 之后的位置输出 0

存储开销分析

符号定义

符号 含义 典型值
S seq_len 64K
B batch_size 1
H num_heads 32
D head_dim 128
T stride 4-8
C chunk_size 16K
n num_kv_chunks = ceil(S/C) 4

原始方式 (无 KV chunking)

attn_weights 峰值内存:

[B, H, S/T, S/T] × 4 bytes = B × H × (S/T)² × 4

例: S=64K, T=4, B=1, H=32
= 1 × 32 × 16384² × 4 = 32 GB

KV Chunking 方式的额外存储

1. Partial Stats (每个 KV chunk)

m_partial: [B, H, C/T] × 4 bytes
l_partial: [B, H, C/T] × 4 bytes

单个 chunk = 2 × B × H × (C/T) × 4
           = 2 × 1 × 32 × 4096 × 4 = 1 MB

2. Global Stats

m_global: [B, H, S/T] × 4 bytes
l_global: [B, H, S/T] × 4 bytes

= 2 × B × H × (S/T) × 4
= 2 × 1 × 32 × 16384 × 4 = 4 MB

3. 总额外开销

total_extra = n × partial_stats + global_stats
            = 4 × 1MB + 4MB = 8 MB

存储开销随 seqlen 变化

seqlen num_chunks 原始 attn_weights 额外 stats 比例
16K 1 2 GB 2 MB 0.1%
32K 2 8 GB 4 MB 0.05%
64K 4 32 GB 8 MB 0.025%
128K 8 128 GB 16 MB 0.012%

复杂度分析

存储组件 复杂度 说明
原始 attn_weights O(S²) 二次增长
Partial/Global stats O(S) 线性增长
相对开销 O(1/S) 随 seqlen 递减

峰值显存优化

KV chunking 的主要收益是峰值显存从 O(S²) 降到 O(S×C)

原始: O(B × H × (S/T)²)           # 完整 attn_weights
KV chunking: O(B × H × (S/T) × (C/T))  # 一次只处理一个 chunk

以 S=128K, C=16K 为例:

  • 原始峰值: ~128 GB
  • KV chunking 峰值: ~16 GB (降低 8 倍)

支持不同 Q/KV Chunk Size

三阶段 pipeline 支持 Q 和 KV 使用不同的 chunk size

q_chunk_size = 8192   # Q 分块大小
kv_chunk_size = 16384 # KV 分块大小

for q_chunk_idx in range(q_chunk_num):
    Q_chunk = Q[:, :, q_start:q_end, :]  # [B, H, q_chunk_size, D]

    for kv_chunk_idx in range(kv_chunk_num):
        K_chunk = K[:, :, kv_start:kv_end, :]  # [B, H, kv_chunk_size, D]
        # ... 三阶段处理

测试验证结果

Config seq_len Q chunks KV chunks density 对齐
Q=16K, KV=16K 64891 4 4 0.1117 ✓ 100%
Q=8K, KV=16K 64891 8 4 0.1112 ✓ 100%
Q=16K, KV=8K 64891 4 8 0.1117 ✓ 100%
Q=8K, KV=8K 64891 8 8 0.1112 ✓ 100%
Q=4K, KV=16K 64891 16 4 0.1109 ✓ 100%

API 参考

softmax_compute_partial_stats

def softmax_compute_partial_stats(
    attn_weights_slice: torch.Tensor,  # [batch, heads, q_len, k_chunk_len]
    reshaped_block_size: int,
    segment_size: int,
    scale: float,
    chunk_start: int = 0,              # Q chunk 起始位置 (reshaped space)
    kv_offset: int = 0,                # KV chunk 偏移 (reshaped space)
    is_causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """返回 (m, l) partial stats"""

merge_softmax_stats

def merge_softmax_stats(
    m_chunks: list,  # List of [batch, heads, q_len] tensors
    l_chunks: list,  # List of [batch, heads, q_len] tensors
) -> Tuple[torch.Tensor, torch.Tensor]:
    """返回 (m_global, l_global)"""

softmax_normalize_and_block_sum

def softmax_normalize_and_block_sum(
    attn_weights_slice: torch.Tensor,  # [batch, heads, q_len, k_chunk_len]
    m_global: torch.Tensor,            # [batch, heads, q_len]
    l_global: torch.Tensor,            # [batch, heads, q_len]
    reshaped_block_size: int,
    segment_size: int,
    chunk_start: int,
    real_q_len: int,
    scale: float,
    kv_offset: int = 0,
    is_causal: bool = False,
) -> torch.Tensor:
    """返回 block sums [batch, heads, q_blocks, k_chunk_blocks]"""

使用示例

from nanovllm.ops.xattn import (
    flat_group_gemm_fuse_reshape,
    softmax_compute_partial_stats,
    softmax_normalize_and_block_sum,
    merge_softmax_stats,
    find_blocks_chunked,
)

# 对每个 Q chunk
for q_chunk_idx in range(q_chunk_num):
    Q_chunk = Q_padded[:, :, q_start:q_end, :]

    # 阶段 1: 每个 KV chunk 计算 partial stats
    m_chunks, l_chunks = [], []
    attn_weights_chunks = []

    for kv_chunk_idx in range(kv_chunk_num):
        K_chunk = K_padded[:, :, kv_start:kv_end, :]
        kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE

        # 计算 raw scores
        attn_weights = flat_group_gemm_fuse_reshape(
            Q_chunk, K_chunk, STRIDE,
            chunk_start=chunk_start,
            chunk_end=chunk_end,
            is_causal=False,  # K 不完整
        )
        attn_weights_chunks.append(attn_weights)

        # 计算 partial stats
        m, l = softmax_compute_partial_stats(
            attn_weights, block_size, segment_size, scale,
            chunk_start=chunk_start,
            kv_offset=kv_offset,
            is_causal=True,
        )
        m_chunks.append(m)
        l_chunks.append(l)

    # 阶段 2: 合并 stats
    m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)

    # 阶段 3: 归一化并计算 block sums
    block_sums_list = []
    for kv_chunk_idx, attn_weights in enumerate(attn_weights_chunks):
        kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
        block_sums = softmax_normalize_and_block_sum(
            attn_weights, m_global, l_global,
            block_size, segment_size, chunk_start, real_q_len, scale,
            kv_offset=kv_offset,
            is_causal=True,
        )
        block_sums_list.append(block_sums)

    # 拼接并选择 blocks
    attn_sum = torch.cat(block_sums_list, dim=-1)
    mask = find_blocks_chunked(attn_sum, ...)

性能对比

方面 原始实现 KV Chunking 实现
Kernel 数量 1 2 (stats + normalize)
Raw scores 读取次数 2 2
额外内存 0 O(B × H × S/T × 2) for (m, l)
Host 计算 merge stats (轻量)
峰值显存 O(S²) O(S × C)

验证测试

批量测试结果

测试脚本 tests/test_xattn_kv_chunking_batch.py 验证了不同 seqlen 下的一致性:

| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff     | mask_diff | status |
|---------|--------|-----------|-----------|-------------|------------|----------|-----------|--------|
|    3688 |      4 |      0.90 |         1 |    0.383405 |   0.383405 | 0.000000 |   0.0000% |   PASS |
|    7888 |      4 |      0.90 |         1 |    0.290611 |   0.290611 | 0.000000 |   0.0000% |   PASS |
|   15685 |      4 |      0.90 |         1 |    0.197724 |   0.197724 | 0.000000 |   0.0000% |   PASS |
|   32485 |      4 |      0.90 |         2 |    0.159023 |   0.159023 | 0.000000 |   0.0000% |   PASS |
|   64891 |      4 |      0.90 |         4 |    0.111656 |   0.111656 | 0.000000 |   0.0000% |   PASS |

关键结论

  1. 数学等价性: density_diff = 0.000000 对于所有测试
  2. Mask 完全对齐: mask_diff = 0.0000% 对于所有测试
  3. 支持任意 Q/KV chunk size 组合

相关文件

  • nanovllm/ops/xattn.py: Kernel 实现
  • tests/test_xattn_estimate_alignment.py: 单文件验证测试
  • tests/test_xattn_kv_chunking_batch.py: 批量验证测试
  • docs/xattn_kernels_guide.md: 原始 kernel 文档