[refactor] Refactor offload code to multi-chunk.

This commit is contained in:
Zijie Tian
2025-12-15 01:13:58 +08:00
parent 5949537faf
commit 1081ab51ea
7 changed files with 36 additions and 233 deletions

View File

@@ -21,7 +21,7 @@ class Context:
# Current chunk's position offset (for causal mask)
chunk_offset: int = 0
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
offload_engine: Any = None
kvcache_manager: Any = None
# Current layer's previous K/V chunks (loaded from CPU)
# Set by model_runner before each layer's forward
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
@@ -33,14 +33,6 @@ class Context:
# Used when batching decode offloads - we need to attend to all accumulated tokens
decode_start_pos_in_block: int = 0
# ========== Per-layer chunked attention state ==========
# Whether chunked decode/prefill is currently active (for hooks to check)
chunked_decode_active: bool = False
# CPU block IDs for the current chunk being processed
chunked_decode_chunk_ids: List[int] = field(default_factory=list)
# Current chunk index being processed
chunked_decode_current_chunk: int = 0
_CONTEXT = Context()
@@ -61,7 +53,7 @@ def set_context(
is_chunked_prefill=False,
prev_kv_ranges=None,
chunk_offset=0,
offload_engine=None,
kvcache_manager=None,
chunked_seq=None,
decode_pos_in_block=0,
decode_start_pos_in_block=0,
@@ -79,7 +71,7 @@ def set_context(
is_chunked_prefill=is_chunked_prefill,
prev_kv_ranges=prev_kv_ranges or [],
chunk_offset=chunk_offset,
offload_engine=offload_engine,
kvcache_manager=kvcache_manager,
chunked_seq=chunked_seq,
decode_pos_in_block=decode_pos_in_block,
decode_start_pos_in_block=decode_start_pos_in_block,