Files
nano-vllm/nanovllm/utils/context.py
2026-01-22 22:20:34 +08:00

57 lines
1.3 KiB
Python

from dataclasses import dataclass
from typing import Any
import torch
@dataclass
class Context:
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
# Attention policy support (GPU-only path)
# When set, uses policy.compute_prefill() instead of FlashAttention
attention_policy: Any = None # AttentionPolicy instance
_CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(
is_prefill,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=0,
max_seqlen_k=0,
slot_mapping=None,
context_lens=None,
block_tables=None,
attention_policy=None,
):
global _CONTEXT
_CONTEXT = Context(
is_prefill=is_prefill,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
attention_policy=attention_policy,
)
def reset_context():
global _CONTEXT
_CONTEXT = Context()