Implement three-phase KV chunking for sparse attention estimation:
1. softmax_compute_partial_stats: compute (m, l) per KV chunk
2. merge_softmax_stats: merge partial stats on host
3. softmax_normalize_and_block_sum: normalize with global stats
This allows computing sparse attention masks without storing full
raw attention scores in GPU memory, reducing peak memory usage
from O(q_len * k_full_len) to O(q_len * k_chunk_len).
Key changes:
- Add softmax_partial_stats_kernel with causal mask support
- Add softmax_normalize_block_sum_kernel with kv_offset parameter
- Add Python wrappers for new kernels
- Update test script to validate KV chunking alignment
- Add documentation for the new kernels
Test results show perfect alignment with xattn_estimate API:
- Density difference: 0.000000
- Mask difference: 0.0044%
Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape
and softmax_fuse_block_sum Triton kernels with structured data
- Update testing.md with new test code style guidelines
- Update xattn.py and xattn_bsa.py with improvements
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add xattn_estimate_chunked function ported from COMPASS
- Support chunked prefill with q_start_pos parameter
- Ensure 100% consistency with standard xattn_estimate when
using matching chunk_size parameter
- Add test and documentation
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Port XAttention operators from COMPASS project:
- flat_group_gemm_fuse_reshape: stride reshape GEMM kernel
- softmax_fuse_block_sum: fused softmax with block-level summation
- xattn_estimate: main estimation function for block sparse attention
- find_blocks_chunked: cumulative threshold-based block selection
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>