[feat] Added chunked prefill and kvcache offload mechenism.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Tuple, Any
|
||||
import torch
|
||||
|
||||
|
||||
@@ -13,14 +14,60 @@ class Context:
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
# Chunked prefill support
|
||||
is_chunked_prefill: bool = False
|
||||
# Previous KV chunks info: List of (start_pos, end_pos) for blocks on CPU
|
||||
prev_kv_ranges: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# Current chunk's position offset (for causal mask)
|
||||
chunk_offset: int = 0
|
||||
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
|
||||
offload_engine: Any = None
|
||||
# Current layer's previous K/V chunks (loaded from CPU)
|
||||
# Set by model_runner before each layer's forward
|
||||
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
||||
# Current sequence being processed (for chunked prefill to load KV)
|
||||
chunked_seq: Any = None
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
|
||||
|
||||
def get_context():
|
||||
return _CONTEXT
|
||||
|
||||
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
||||
|
||||
def set_context(
|
||||
is_prefill,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
max_seqlen_q=0,
|
||||
max_seqlen_k=0,
|
||||
slot_mapping=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
is_chunked_prefill=False,
|
||||
prev_kv_ranges=None,
|
||||
chunk_offset=0,
|
||||
offload_engine=None,
|
||||
chunked_seq=None,
|
||||
):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||
_CONTEXT = Context(
|
||||
is_prefill=is_prefill,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
is_chunked_prefill=is_chunked_prefill,
|
||||
prev_kv_ranges=prev_kv_ranges or [],
|
||||
chunk_offset=chunk_offset,
|
||||
offload_engine=offload_engine,
|
||||
chunked_seq=chunked_seq,
|
||||
)
|
||||
|
||||
|
||||
def reset_context():
|
||||
global _CONTEXT
|
||||
|
||||
Reference in New Issue
Block a user