Port XAttention operators from COMPASS project: - flat_group_gemm_fuse_reshape: stride reshape GEMM kernel - softmax_fuse_block_sum: fused softmax with block-level summation - xattn_estimate: main estimation function for block sparse attention - find_blocks_chunked: cumulative threshold-based block selection Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
37 lines
784 B
Python
37 lines
784 B
Python
"""
|
|
Operators module for nano-vLLM.
|
|
|
|
This module contains low-level attention operators and kernels.
|
|
"""
|
|
|
|
from nanovllm.ops.chunked_attention import (
|
|
flash_attn_with_lse,
|
|
merge_attention_outputs,
|
|
chunked_attention_varlen,
|
|
ChunkedPrefillState,
|
|
)
|
|
|
|
from nanovllm.ops.xattn import (
|
|
xattn_estimate,
|
|
flat_group_gemm_fuse_reshape,
|
|
softmax_fuse_block_sum,
|
|
find_blocks_chunked,
|
|
create_causal_mask,
|
|
compute_sparsity,
|
|
)
|
|
|
|
__all__ = [
|
|
# chunked_attention
|
|
"flash_attn_with_lse",
|
|
"merge_attention_outputs",
|
|
"chunked_attention_varlen",
|
|
"ChunkedPrefillState",
|
|
# xattn
|
|
"xattn_estimate",
|
|
"flat_group_gemm_fuse_reshape",
|
|
"softmax_fuse_block_sum",
|
|
"find_blocks_chunked",
|
|
"create_causal_mask",
|
|
"compute_sparsity",
|
|
]
|