from dataclasses import dataclass from typing import Any import torch @dataclass class Context: is_prefill: bool = False cu_seqlens_q: torch.Tensor | None = None cu_seqlens_k: torch.Tensor | None = None max_seqlen_q: int = 0 max_seqlen_k: int = 0 slot_mapping: torch.Tensor | None = None context_lens: torch.Tensor | None = None block_tables: torch.Tensor | None = None # Sparse prefill attention support (GPU-only path) # When set, uses policy.sparse_prefill_attention() instead of FlashAttention sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True _CONTEXT = Context() def get_context(): return _CONTEXT def set_context( is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, sparse_prefill_policy=None, ): global _CONTEXT _CONTEXT = Context( is_prefill=is_prefill, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables, sparse_prefill_policy=sparse_prefill_policy, ) def reset_context(): global _CONTEXT _CONTEXT = Context()