Files
nano-vllm/nanovllm/utils/context.py

57 lines
1.3 KiB
Python

from dataclasses import dataclass
from typing import Any
import torch
@dataclass
class Context:
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
# Sparse prefill attention support (GPU-only path)
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
_CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(
is_prefill,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=0,
max_seqlen_k=0,
slot_mapping=None,
context_lens=None,
block_tables=None,
sparse_prefill_policy=None,
):
global _CONTEXT
_CONTEXT = Context(
is_prefill=is_prefill,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
sparse_prefill_policy=sparse_prefill_policy,
)
def reset_context():
global _CONTEXT
_CONTEXT = Context()