[refactor] Refactor offload code to multi-chunk.
This commit is contained in:
@@ -21,7 +21,7 @@ class Context:
|
||||
# Current chunk's position offset (for causal mask)
|
||||
chunk_offset: int = 0
|
||||
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
|
||||
offload_engine: Any = None
|
||||
kvcache_manager: Any = None
|
||||
# Current layer's previous K/V chunks (loaded from CPU)
|
||||
# Set by model_runner before each layer's forward
|
||||
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
||||
@@ -33,14 +33,6 @@ class Context:
|
||||
# Used when batching decode offloads - we need to attend to all accumulated tokens
|
||||
decode_start_pos_in_block: int = 0
|
||||
|
||||
# ========== Per-layer chunked attention state ==========
|
||||
# Whether chunked decode/prefill is currently active (for hooks to check)
|
||||
chunked_decode_active: bool = False
|
||||
# CPU block IDs for the current chunk being processed
|
||||
chunked_decode_chunk_ids: List[int] = field(default_factory=list)
|
||||
# Current chunk index being processed
|
||||
chunked_decode_current_chunk: int = 0
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
|
||||
@@ -61,7 +53,7 @@ def set_context(
|
||||
is_chunked_prefill=False,
|
||||
prev_kv_ranges=None,
|
||||
chunk_offset=0,
|
||||
offload_engine=None,
|
||||
kvcache_manager=None,
|
||||
chunked_seq=None,
|
||||
decode_pos_in_block=0,
|
||||
decode_start_pos_in_block=0,
|
||||
@@ -79,7 +71,7 @@ def set_context(
|
||||
is_chunked_prefill=is_chunked_prefill,
|
||||
prev_kv_ranges=prev_kv_ranges or [],
|
||||
chunk_offset=chunk_offset,
|
||||
offload_engine=offload_engine,
|
||||
kvcache_manager=kvcache_manager,
|
||||
chunked_seq=chunked_seq,
|
||||
decode_pos_in_block=decode_pos_in_block,
|
||||
decode_start_pos_in_block=decode_start_pos_in_block,
|
||||
|
||||
Reference in New Issue
Block a user