Commit Graph

6 Commits

Author SHA1 Message Date
Zijie Tian
5acd5558d6 feat: add KV chunking support for XAttention softmax kernels
Implement three-phase KV chunking for sparse attention estimation:
1. softmax_compute_partial_stats: compute (m, l) per KV chunk
2. merge_softmax_stats: merge partial stats on host
3. softmax_normalize_and_block_sum: normalize with global stats

This allows computing sparse attention masks without storing full
raw attention scores in GPU memory, reducing peak memory usage
from O(q_len * k_full_len) to O(q_len * k_chunk_len).

Key changes:
- Add softmax_partial_stats_kernel with causal mask support
- Add softmax_normalize_block_sum_kernel with kv_offset parameter
- Add Python wrappers for new kernels
- Update test script to validate KV chunking alignment
- Add documentation for the new kernels

Test results show perfect alignment with xattn_estimate API:
- Density difference: 0.000000
- Mask difference: 0.0044%

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 18:53:26 +08:00
Zijie Tian
8d19e61446 ️ perf: replace Triton merge with FlashInfer merge_state
Use FlashInfer's optimized merge_state kernel for attention output merging
in chunked prefill. End-to-end improvement: +0.8% (32K) to +2.4% (64K).

Key changes:
- Add merge_attention_outputs_flashinfer() with LSE format conversion
- FlashInfer uses log2, flash_attn uses ln: convert via LOG2_E/LN_2
- Keep original Triton kernel for fallback

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 10:04:38 +08:00
Zijie Tian
999858e82f feat: add xattn kernels test and update testing rules
- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape
  and softmax_fuse_block_sum Triton kernels with structured data
- Update testing.md with new test code style guidelines
- Update xattn.py and xattn_bsa.py with improvements

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 03:01:25 +08:00
Zijie Tian
bc92c1fdb8 feat: add xattn_estimate_chunked for chunked prefill support
- Add xattn_estimate_chunked function ported from COMPASS
- Support chunked prefill with q_start_pos parameter
- Ensure 100% consistency with standard xattn_estimate when
  using matching chunk_size parameter
- Add test and documentation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 01:13:17 +08:00
Zijie Tian
3aef6fc3a2 feat: add XAttention Triton operators for sparse attention estimation
Port XAttention operators from COMPASS project:
- flat_group_gemm_fuse_reshape: stride reshape GEMM kernel
- softmax_fuse_block_sum: fused softmax with block-level summation
- xattn_estimate: main estimation function for block sparse attention
- find_blocks_chunked: cumulative threshold-based block selection

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 04:27:07 +08:00
Zijie Tian
690456dbf9 ♻️ refactor: create ops module and move chunked_attention
- Create nanovllm/ops/ module for low-level attention operators
- Move chunked_attention.py from kvcache/ to ops/
- Update imports in full_policy.py (3 locations)
- Fix: remove dead code in OffloadEngine.reset() referencing
  non-existent layer_k/v_buffer_a/b attributes

Verified with needle test (32K offload): PASSED

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 02:50:14 +08:00