Files
nano-vllm/nanovllm/utils/context.py

84 lines
2.6 KiB
Python

from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Any
import torch
@dataclass
class Context:
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
# Chunked prefill support
is_chunked_prefill: bool = False
# Previous KV chunks info: List of (start_pos, end_pos) for blocks on CPU
prev_kv_ranges: List[Tuple[int, int]] = field(default_factory=list)
# Current chunk's position offset (for causal mask)
chunk_offset: int = 0
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
offload_engine: Any = None
# Current layer's previous K/V chunks (loaded from CPU)
# Set by model_runner before each layer's forward
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
# Current sequence being processed (for chunked prefill to load KV)
chunked_seq: Any = None
# Position within block for decode (used for reading from Decode region)
decode_pos_in_block: int = 0
# Starting position within block where decode tokens began (for accumulated token tracking)
# Used when batching decode offloads - we need to attend to all accumulated tokens
decode_start_pos_in_block: int = 0
_CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(
is_prefill,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=0,
max_seqlen_k=0,
slot_mapping=None,
context_lens=None,
block_tables=None,
is_chunked_prefill=False,
prev_kv_ranges=None,
chunk_offset=0,
offload_engine=None,
chunked_seq=None,
decode_pos_in_block=0,
decode_start_pos_in_block=0,
):
global _CONTEXT
_CONTEXT = Context(
is_prefill=is_prefill,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
is_chunked_prefill=is_chunked_prefill,
prev_kv_ranges=prev_kv_ranges or [],
chunk_offset=chunk_offset,
offload_engine=offload_engine,
chunked_seq=chunked_seq,
decode_pos_in_block=decode_pos_in_block,
decode_start_pos_in_block=decode_start_pos_in_block,
)
def reset_context():
global _CONTEXT
_CONTEXT = Context()