- Update xattn_kv_chunking_kernels.md with: - Detailed storage overhead analysis (O(S) vs O(S²)) - Peak memory optimization (8x reduction) - Support for independent Q/KV chunk sizes - Batch verification results (3K-64K seqlen) - ASCII pipeline diagram - Add test_xattn_kv_chunking_batch.py for batch validation - Fix causal mask post-processing in alignment test - Update CLAUDE.md documentation index Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
15 KiB
XAttention KV Chunking Kernels
概述
本文档描述了支持 KV 维度分 chunk 的 softmax kernels 实现。这些 kernels 允许在 CPU offload 场景下,沿 KV 维度分块计算 sparse attention estimation,而不需要在 GPU 上保存完整的 raw attention scores。
背景
原始的 softmax_fuse_block_sum kernel 需要完整的 K 序列来计算正确的 softmax 归一化分母:
softmax(x_i) = exp(x_i) / Σ_j exp(x_j)
如果只有部分 K (KV chunk),分母 Σ_j exp(x_j) 不完整,导致归一化错误。
解决方案:三阶段计算
通过将 softmax 计算拆分为三个阶段,实现正确的 KV chunking:
┌─────────────────────────────────────────────────────────────────┐
│ 三阶段 Pipeline │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ KV Chunk 0 │ │ KV Chunk 1 │ │ KV Chunk N │ │
│ │ attn_scores │ │ attn_scores │ │ attn_scores │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 阶段 1: softmax_compute_partial_stats │ │
│ │ 计算每个 chunk 的 (m_partial, l_partial) │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ (m_0, l_0) (m_1, l_1) (m_N, l_N) │
│ │ │ │ │
│ └────────────────┬┴─────────────────┘ │
│ ▼ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 阶段 2: merge_softmax_stats │ │
│ │ Host 端合并 → (m_global, l_global) │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────┼────────────────┐ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 阶段 3: softmax_normalize_and_block_sum │ │
│ │ 使用全局 stats 归一化并计算 block sums │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ block_sums_0 block_sums_1 block_sums_N │
│ │ │ │ │
│ └────────────────┴────────────────┘ │
│ │ │
│ ▼ │
│ torch.cat → final mask │
│ │
└─────────────────────────────────────────────────────────────────┘
阶段 1: softmax_compute_partial_stats
计算每个 KV chunk 的 partial statistics:
m_partial: 该 chunk 内的最大值 (per query row)l_partial: 该 chunk 内的 partial sum = Σ exp(x - m_partial)
m_partial, l_partial = softmax_compute_partial_stats(
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
reshaped_block_size,
segment_size,
scale,
chunk_start=chunk_start,
kv_offset=kv_offset, # KV chunk 在完整序列中的偏移
is_causal=True,
)
# 输出: m_partial, l_partial 形状为 [batch, heads, q_len]
阶段 2: merge_softmax_stats
Host 端合并所有 KV chunks 的 statistics:
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
合并公式 (Online Softmax):
m_new = max(m_global, m_chunk)
l_new = l_global * exp(m_global - m_new) + l_chunk * exp(m_chunk - m_new)
阶段 3: softmax_normalize_and_block_sum
使用全局 statistics 归一化并计算 block sums:
attn_sum_kv = softmax_normalize_and_block_sum(
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
m_global, # [batch, heads, q_len]
l_global, # [batch, heads, q_len]
reshaped_block_size,
segment_size,
chunk_start=chunk_start,
real_q_len=real_q_len,
scale=scale,
kv_offset=kv_offset,
is_causal=True,
)
# 输出: [batch, heads, q_blocks, k_chunk_blocks]
数学等价性证明
原始 softmax 计算 (完整 K):
softmax(x_i) = exp(x_i - m) / Σ_j exp(x_j - m)
分 KV chunk 计算:
Chunk 0: m_0 = max(x[0:N/2]), l_0 = Σ exp(x[0:N/2] - m_0)
Chunk 1: m_1 = max(x[N/2:N]), l_1 = Σ exp(x[N/2:N] - m_1)
合并:
m_global = max(m_0, m_1)
l_global = l_0 * exp(m_0 - m_global) + l_1 * exp(m_1 - m_global)
= Σ exp(x[0:N] - m_global) # 等于全局 sum
归一化:
softmax(x_i) = exp(x_i - m_global) / l_global # 正确!
Causal Mask 处理
两个 kernel 都正确处理了 causal attention:
-
softmax_partial_stats_kernel: 通过kv_offset参数确定当前 KV chunk 在完整序列中的位置,正确计算 causal boundary -
softmax_normalize_block_sum_kernel: 同样使用kv_offset,对 causal boundary 之后的位置输出 0
存储开销分析
符号定义
| 符号 | 含义 | 典型值 |
|---|---|---|
| S | seq_len | 64K |
| B | batch_size | 1 |
| H | num_heads | 32 |
| D | head_dim | 128 |
| T | stride | 4-8 |
| C | chunk_size | 16K |
| n | num_kv_chunks = ceil(S/C) | 4 |
原始方式 (无 KV chunking)
attn_weights 峰值内存:
[B, H, S/T, S/T] × 4 bytes = B × H × (S/T)² × 4
例: S=64K, T=4, B=1, H=32
= 1 × 32 × 16384² × 4 = 32 GB
KV Chunking 方式的额外存储
1. Partial Stats (每个 KV chunk)
m_partial: [B, H, C/T] × 4 bytes
l_partial: [B, H, C/T] × 4 bytes
单个 chunk = 2 × B × H × (C/T) × 4
= 2 × 1 × 32 × 4096 × 4 = 1 MB
2. Global Stats
m_global: [B, H, S/T] × 4 bytes
l_global: [B, H, S/T] × 4 bytes
= 2 × B × H × (S/T) × 4
= 2 × 1 × 32 × 16384 × 4 = 4 MB
3. 总额外开销
total_extra = n × partial_stats + global_stats
= 4 × 1MB + 4MB = 8 MB
存储开销随 seqlen 变化
| seqlen | num_chunks | 原始 attn_weights | 额外 stats | 比例 |
|---|---|---|---|---|
| 16K | 1 | 2 GB | 2 MB | 0.1% |
| 32K | 2 | 8 GB | 4 MB | 0.05% |
| 64K | 4 | 32 GB | 8 MB | 0.025% |
| 128K | 8 | 128 GB | 16 MB | 0.012% |
复杂度分析
| 存储组件 | 复杂度 | 说明 |
|---|---|---|
| 原始 attn_weights | O(S²) | 二次增长 |
| Partial/Global stats | O(S) | 线性增长 |
| 相对开销 | O(1/S) | 随 seqlen 递减 |
峰值显存优化
KV chunking 的主要收益是峰值显存从 O(S²) 降到 O(S×C):
原始: O(B × H × (S/T)²) # 完整 attn_weights
KV chunking: O(B × H × (S/T) × (C/T)) # 一次只处理一个 chunk
以 S=128K, C=16K 为例:
- 原始峰值: ~128 GB
- KV chunking 峰值: ~16 GB (降低 8 倍)
支持不同 Q/KV Chunk Size
三阶段 pipeline 支持 Q 和 KV 使用不同的 chunk size:
q_chunk_size = 8192 # Q 分块大小
kv_chunk_size = 16384 # KV 分块大小
for q_chunk_idx in range(q_chunk_num):
Q_chunk = Q[:, :, q_start:q_end, :] # [B, H, q_chunk_size, D]
for kv_chunk_idx in range(kv_chunk_num):
K_chunk = K[:, :, kv_start:kv_end, :] # [B, H, kv_chunk_size, D]
# ... 三阶段处理
测试验证结果
| Config | seq_len | Q chunks | KV chunks | density | 对齐 |
|---|---|---|---|---|---|
| Q=16K, KV=16K | 64891 | 4 | 4 | 0.1117 | ✓ 100% |
| Q=8K, KV=16K | 64891 | 8 | 4 | 0.1112 | ✓ 100% |
| Q=16K, KV=8K | 64891 | 4 | 8 | 0.1117 | ✓ 100% |
| Q=8K, KV=8K | 64891 | 8 | 8 | 0.1112 | ✓ 100% |
| Q=4K, KV=16K | 64891 | 16 | 4 | 0.1109 | ✓ 100% |
API 参考
softmax_compute_partial_stats
def softmax_compute_partial_stats(
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
reshaped_block_size: int,
segment_size: int,
scale: float,
chunk_start: int = 0, # Q chunk 起始位置 (reshaped space)
kv_offset: int = 0, # KV chunk 偏移 (reshaped space)
is_causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""返回 (m, l) partial stats"""
merge_softmax_stats
def merge_softmax_stats(
m_chunks: list, # List of [batch, heads, q_len] tensors
l_chunks: list, # List of [batch, heads, q_len] tensors
) -> Tuple[torch.Tensor, torch.Tensor]:
"""返回 (m_global, l_global)"""
softmax_normalize_and_block_sum
def softmax_normalize_and_block_sum(
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
m_global: torch.Tensor, # [batch, heads, q_len]
l_global: torch.Tensor, # [batch, heads, q_len]
reshaped_block_size: int,
segment_size: int,
chunk_start: int,
real_q_len: int,
scale: float,
kv_offset: int = 0,
is_causal: bool = False,
) -> torch.Tensor:
"""返回 block sums [batch, heads, q_blocks, k_chunk_blocks]"""
使用示例
from nanovllm.ops.xattn import (
flat_group_gemm_fuse_reshape,
softmax_compute_partial_stats,
softmax_normalize_and_block_sum,
merge_softmax_stats,
find_blocks_chunked,
)
# 对每个 Q chunk
for q_chunk_idx in range(q_chunk_num):
Q_chunk = Q_padded[:, :, q_start:q_end, :]
# 阶段 1: 每个 KV chunk 计算 partial stats
m_chunks, l_chunks = [], []
attn_weights_chunks = []
for kv_chunk_idx in range(kv_chunk_num):
K_chunk = K_padded[:, :, kv_start:kv_end, :]
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
# 计算 raw scores
attn_weights = flat_group_gemm_fuse_reshape(
Q_chunk, K_chunk, STRIDE,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False, # K 不完整
)
attn_weights_chunks.append(attn_weights)
# 计算 partial stats
m, l = softmax_compute_partial_stats(
attn_weights, block_size, segment_size, scale,
chunk_start=chunk_start,
kv_offset=kv_offset,
is_causal=True,
)
m_chunks.append(m)
l_chunks.append(l)
# 阶段 2: 合并 stats
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
# 阶段 3: 归一化并计算 block sums
block_sums_list = []
for kv_chunk_idx, attn_weights in enumerate(attn_weights_chunks):
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
block_sums = softmax_normalize_and_block_sum(
attn_weights, m_global, l_global,
block_size, segment_size, chunk_start, real_q_len, scale,
kv_offset=kv_offset,
is_causal=True,
)
block_sums_list.append(block_sums)
# 拼接并选择 blocks
attn_sum = torch.cat(block_sums_list, dim=-1)
mask = find_blocks_chunked(attn_sum, ...)
性能对比
| 方面 | 原始实现 | KV Chunking 实现 |
|---|---|---|
| Kernel 数量 | 1 | 2 (stats + normalize) |
| Raw scores 读取次数 | 2 | 2 |
| 额外内存 | 0 | O(B × H × S/T × 2) for (m, l) |
| Host 计算 | 无 | merge stats (轻量) |
| 峰值显存 | O(S²) | O(S × C) |
验证测试
批量测试结果
测试脚本 tests/test_xattn_kv_chunking_batch.py 验证了不同 seqlen 下的一致性:
| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |
|---------|--------|-----------|-----------|-------------|------------|----------|-----------|--------|
| 3688 | 4 | 0.90 | 1 | 0.383405 | 0.383405 | 0.000000 | 0.0000% | PASS |
| 7888 | 4 | 0.90 | 1 | 0.290611 | 0.290611 | 0.000000 | 0.0000% | PASS |
| 15685 | 4 | 0.90 | 1 | 0.197724 | 0.197724 | 0.000000 | 0.0000% | PASS |
| 32485 | 4 | 0.90 | 2 | 0.159023 | 0.159023 | 0.000000 | 0.0000% | PASS |
| 64891 | 4 | 0.90 | 4 | 0.111656 | 0.111656 | 0.000000 | 0.0000% | PASS |
关键结论
- 数学等价性: density_diff = 0.000000 对于所有测试
- Mask 完全对齐: mask_diff = 0.0000% 对于所有测试
- 支持任意 Q/KV chunk size 组合
相关文件
nanovllm/ops/xattn.py: Kernel 实现tests/test_xattn_estimate_alignment.py: 单文件验证测试tests/test_xattn_kv_chunking_batch.py: 批量验证测试docs/xattn_kernels_guide.md: 原始 kernel 文档