📝 docs: add storage overhead analysis and batch tests for KV chunking

- Update xattn_kv_chunking_kernels.md with:
  - Detailed storage overhead analysis (O(S) vs O(S²))
  - Peak memory optimization (8x reduction)
  - Support for independent Q/KV chunk sizes
  - Batch verification results (3K-64K seqlen)
  - ASCII pipeline diagram

- Add test_xattn_kv_chunking_batch.py for batch validation
- Fix causal mask post-processing in alignment test
- Update CLAUDE.md documentation index

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-02-01 19:22:36 +08:00
parent 5acd5558d6
commit 6e34efd58a
4 changed files with 429 additions and 10 deletions

View File

@@ -226,6 +226,14 @@ for q_chunk_idx in range(q_chunk_num):
print(f" Q chunk {q_chunk_idx}: merged {kv_chunk_num} KV chunks, attn_sum shape={attn_sum_concat.shape}")
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
False,
)
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_kv = selected_kv / total_api