57 lines
1.3 KiB
Python
57 lines
1.3 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class Context:
|
|
is_prefill: bool = False
|
|
cu_seqlens_q: torch.Tensor | None = None
|
|
cu_seqlens_k: torch.Tensor | None = None
|
|
max_seqlen_q: int = 0
|
|
max_seqlen_k: int = 0
|
|
slot_mapping: torch.Tensor | None = None
|
|
context_lens: torch.Tensor | None = None
|
|
block_tables: torch.Tensor | None = None
|
|
|
|
# Sparse prefill attention support (GPU-only path)
|
|
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
|
|
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
|
|
|
|
|
|
_CONTEXT = Context()
|
|
|
|
|
|
def get_context():
|
|
return _CONTEXT
|
|
|
|
|
|
def set_context(
|
|
is_prefill,
|
|
cu_seqlens_q=None,
|
|
cu_seqlens_k=None,
|
|
max_seqlen_q=0,
|
|
max_seqlen_k=0,
|
|
slot_mapping=None,
|
|
context_lens=None,
|
|
block_tables=None,
|
|
sparse_prefill_policy=None,
|
|
):
|
|
global _CONTEXT
|
|
_CONTEXT = Context(
|
|
is_prefill=is_prefill,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
max_seqlen_q=max_seqlen_q,
|
|
max_seqlen_k=max_seqlen_k,
|
|
slot_mapping=slot_mapping,
|
|
context_lens=context_lens,
|
|
block_tables=block_tables,
|
|
sparse_prefill_policy=sparse_prefill_policy,
|
|
)
|
|
|
|
|
|
def reset_context():
|
|
global _CONTEXT
|
|
_CONTEXT = Context()
|