[claudesquad] update from 'lw-offload-2' on 08 Jan 26 20:53 CST
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
|
||||
@@ -14,27 +14,6 @@ class Context:
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
# Chunked prefill support
|
||||
is_chunked_prefill: bool = False
|
||||
# Previous KV chunks info: List of (start_pos, end_pos) for blocks on CPU
|
||||
prev_kv_ranges: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# Current chunk's position offset (for causal mask)
|
||||
chunk_offset: int = 0
|
||||
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
|
||||
kvcache_manager: Any = None
|
||||
# Current layer's previous K/V chunks (loaded from CPU)
|
||||
# Set by model_runner before each layer's forward
|
||||
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
||||
# Current sequence being processed (for chunked prefill to load KV)
|
||||
chunked_seq: Any = None
|
||||
# Position within block for decode (used for reading from Decode region)
|
||||
decode_pos_in_block: int = 0
|
||||
# Starting position within block where decode tokens began (for accumulated token tracking)
|
||||
# Used when batching decode offloads - we need to attend to all accumulated tokens
|
||||
decode_start_pos_in_block: int = 0
|
||||
# Current chunk index for ring buffer pipeline (prefill only)
|
||||
current_chunk_idx: int = 0
|
||||
|
||||
# Sparse prefill attention support (GPU-only path)
|
||||
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
|
||||
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
|
||||
@@ -56,14 +35,6 @@ def set_context(
|
||||
slot_mapping=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
is_chunked_prefill=False,
|
||||
prev_kv_ranges=None,
|
||||
chunk_offset=0,
|
||||
kvcache_manager=None,
|
||||
chunked_seq=None,
|
||||
decode_pos_in_block=0,
|
||||
decode_start_pos_in_block=0,
|
||||
current_chunk_idx=0,
|
||||
sparse_prefill_policy=None,
|
||||
):
|
||||
global _CONTEXT
|
||||
@@ -76,14 +47,6 @@ def set_context(
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
is_chunked_prefill=is_chunked_prefill,
|
||||
prev_kv_ranges=prev_kv_ranges or [],
|
||||
chunk_offset=chunk_offset,
|
||||
kvcache_manager=kvcache_manager,
|
||||
chunked_seq=chunked_seq,
|
||||
decode_pos_in_block=decode_pos_in_block,
|
||||
decode_start_pos_in_block=decode_start_pos_in_block,
|
||||
current_chunk_idx=current_chunk_idx,
|
||||
sparse_prefill_policy=sparse_prefill_policy,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user