📝 docs: add storage overhead analysis and batch tests for KV chunking
- Update xattn_kv_chunking_kernels.md with: - Detailed storage overhead analysis (O(S) vs O(S²)) - Peak memory optimization (8x reduction) - Support for independent Q/KV chunk sizes - Batch verification results (3K-64K seqlen) - ASCII pipeline diagram - Add test_xattn_kv_chunking_batch.py for batch validation - Fix causal mask post-processing in alignment test - Update CLAUDE.md documentation index Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -226,6 +226,14 @@ for q_chunk_idx in range(q_chunk_num):
|
||||
print(f" Q chunk {q_chunk_idx}: merged {kv_chunk_num} KV chunks, attn_sum shape={attn_sum_concat.shape}")
|
||||
|
||||
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
|
||||
|
||||
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
|
||||
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_kv = selected_kv / total_api
|
||||
|
||||
Reference in New Issue
Block a user