Use FlashInfer's optimized merge_state kernel for attention output merging
in chunked prefill. End-to-end improvement: +0.8% (32K) to +2.4% (64K).
Key changes:
- Add merge_attention_outputs_flashinfer() with LSE format conversion
- FlashInfer uses log2, flash_attn uses ln: convert via LOG2_E/LN_2
- Keep original Triton kernel for fallback
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add load_k_only_to_slot_layer() to OffloadEngine for estimate phase:
- Only load K (not K+V) during block selection in select_blocks()
- Reduces H2D transfer by 50% in estimate phase
- 64K context: XAttn/Full ratio drops from 1.48x to 0.99x
- 32K context: XAttn/Full ratio drops from 1.67x to 1.20x
The estimate phase uses flat_group_gemm_fuse_reshape(Q, K) which
only requires K for attention score computation. V is unused.
Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"
All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.
Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
Add statistics tracking to compare block selection between policies:
- XAttentionBSAPolicy: track available/selected blocks per chunk
- FullAttentionPolicy: track total blocks (always 100% density)
- Add reset_stats(), get_density_stats(), print_density_stats() methods
- Use logger.debug for per-chunk density logging
Results on 32K niah_single_1:
- Full: 100% density across all chunks
- XAttn BSA: 90% -> 73% density (saves ~25-30% blocks in later chunks)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Copy compute_chunked_prefill implementation from FullAttentionPolicy
- Set default threshold to 0.95 for accuracy testing
- Remove debug code (sys.exit, verbose prints)
- Use ring buffer pipeline for historical block loading
- Merge with current chunk attention using flash_attn_with_lse
RULER NIAH test passed with 5/5 samples (100% accuracy).
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implement XAttention-based block selection for sparse attention:
- Use flat_group_gemm_fuse_reshape to compute Q@K^T attention scores
- Apply softmax_fuse_block_sum to aggregate into block-level attention
- Use find_blocks_chunked for threshold-based block selection
- Handle GQA by aggregating within KV head groups first
- Use majority voting (>50%) across heads instead of any() for better sparsity
- Align block_size with CPU offload block size (1024 tokens / stride = 128)
Test results show ~45% density at chunk 40 (down from 100% with any() aggregation).
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape
and softmax_fuse_block_sum Triton kernels with structured data
- Update testing.md with new test code style guidelines
- Update xattn.py and xattn_bsa.py with improvements
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>