✨ feat: add XAttention Triton operators for sparse attention estimation
Port XAttention operators from COMPASS project: - flat_group_gemm_fuse_reshape: stride reshape GEMM kernel - softmax_fuse_block_sum: fused softmax with block-level summation - xattn_estimate: main estimation function for block sparse attention - find_blocks_chunked: cumulative threshold-based block selection Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -11,9 +11,26 @@ from nanovllm.ops.chunked_attention import (
|
||||
ChunkedPrefillState,
|
||||
)
|
||||
|
||||
from nanovllm.ops.xattn import (
|
||||
xattn_estimate,
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_fuse_block_sum,
|
||||
find_blocks_chunked,
|
||||
create_causal_mask,
|
||||
compute_sparsity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# chunked_attention
|
||||
"flash_attn_with_lse",
|
||||
"merge_attention_outputs",
|
||||
"chunked_attention_varlen",
|
||||
"ChunkedPrefillState",
|
||||
# xattn
|
||||
"xattn_estimate",
|
||||
"flat_group_gemm_fuse_reshape",
|
||||
"softmax_fuse_block_sum",
|
||||
"find_blocks_chunked",
|
||||
"create_causal_mask",
|
||||
"compute_sparsity",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user