Add ops module ported from tzj/minference branch containing: - xattn.py: XAttention block importance estimation with Triton kernels - xattn_estimate(): standard estimation for sparse attention mask - xattn_estimate_chunked(): chunked prefill compatible version - flat_group_gemm_fuse_reshape(): fused stride reshape + GEMM kernel - softmax_fuse_block_sum(): online softmax + block-wise sum kernel - chunked_attention.py: Flash attention with LSE output for chunk merging - test_xattn_estimate_chunked.py: verification test (all seq_lens pass) This prepares the foundation for AttentionPolicy refactoring where XAttentionPolicy.estimate() will call these ops. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
39 lines
842 B
Python
39 lines
842 B
Python
"""
|
|
Operators module for nano-vLLM.
|
|
|
|
This module contains low-level attention operators and kernels.
|
|
"""
|
|
|
|
from nanovllm.ops.chunked_attention import (
|
|
flash_attn_with_lse,
|
|
merge_attention_outputs,
|
|
chunked_attention_varlen,
|
|
ChunkedPrefillState,
|
|
)
|
|
|
|
from nanovllm.ops.xattn import (
|
|
xattn_estimate,
|
|
xattn_estimate_chunked,
|
|
flat_group_gemm_fuse_reshape,
|
|
softmax_fuse_block_sum,
|
|
find_blocks_chunked,
|
|
create_causal_mask,
|
|
compute_sparsity,
|
|
)
|
|
|
|
__all__ = [
|
|
# chunked_attention
|
|
"flash_attn_with_lse",
|
|
"merge_attention_outputs",
|
|
"chunked_attention_varlen",
|
|
"ChunkedPrefillState",
|
|
# xattn
|
|
"xattn_estimate",
|
|
"xattn_estimate_chunked",
|
|
"flat_group_gemm_fuse_reshape",
|
|
"softmax_fuse_block_sum",
|
|
"find_blocks_chunked",
|
|
"create_causal_mask",
|
|
"compute_sparsity",
|
|
]
|