[WIP] Before add Quest policy.
This commit is contained in:
@@ -3,11 +3,6 @@ import time
|
||||
from random import randint, seed
|
||||
from nanovllm import LLM, SamplingParams
|
||||
|
||||
# Import sparse policy classes
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
|
||||
|
||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||
"""Benchmark decode performance (original test)"""
|
||||
@@ -38,58 +33,6 @@ def bench_prefill(llm, num_seqs, input_len):
|
||||
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||
|
||||
|
||||
def setup_quest_policy(llm, topk_blocks=8, threshold_blocks=4):
|
||||
"""
|
||||
Setup Quest sparse policy for decode phase.
|
||||
|
||||
Uses HybridPolicy: Full attention for prefill, Quest Top-K for decode.
|
||||
"""
|
||||
import torch
|
||||
|
||||
kvcache_manager = llm.model_runner.kvcache_manager
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Get model parameters from offload engine
|
||||
num_layers = offload_engine.num_layers
|
||||
num_kv_heads = offload_engine.num_kv_heads
|
||||
head_dim = offload_engine.head_dim
|
||||
num_cpu_blocks = kvcache_manager.num_cpu_blocks
|
||||
dtype = offload_engine.k_cache_cpu.dtype
|
||||
|
||||
print(f"Setting up Quest policy:")
|
||||
print(f" num_layers={num_layers}, num_kv_heads={num_kv_heads}, head_dim={head_dim}")
|
||||
print(f" num_cpu_blocks={num_cpu_blocks}, dtype={dtype}")
|
||||
print(f" topk_blocks={topk_blocks}, threshold_blocks={threshold_blocks}")
|
||||
|
||||
# Create BlockMetadataManager for storing min/max keys
|
||||
metadata = BlockMetadataManager(
|
||||
num_blocks=num_cpu_blocks,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Create Quest policy for decode
|
||||
quest_config = QuestConfig(
|
||||
topk_blocks=topk_blocks,
|
||||
threshold_blocks=threshold_blocks,
|
||||
)
|
||||
quest_policy = QuestPolicy(quest_config, metadata)
|
||||
|
||||
# Create Hybrid policy: Full for prefill, Quest for decode
|
||||
hybrid_policy = HybridPolicy(
|
||||
prefill_policy=FullAttentionPolicy(),
|
||||
decode_policy=quest_policy,
|
||||
)
|
||||
|
||||
# Set the policy
|
||||
kvcache_manager.set_sparse_policy(hybrid_policy)
|
||||
print(f" Policy set: HybridPolicy(prefill=Full, decode=Quest)")
|
||||
|
||||
return hybrid_policy
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -101,7 +44,18 @@ def main():
|
||||
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
|
||||
max_len = 131072 # 128K tokens
|
||||
max_len = 32 * 1024 # 128K tokens
|
||||
|
||||
# Setup policy configuration
|
||||
if not args.no_sparse:
|
||||
prefill_policy = "full" # Full attention for prefill
|
||||
decode_policy = "quest" # Quest Top-K for decode
|
||||
print(f"\n[Quest Sparse Attention] prefill={prefill_policy}, decode={decode_policy}, topk={args.topk}")
|
||||
else:
|
||||
prefill_policy = "full" # Full attention for both phases
|
||||
decode_policy = "full"
|
||||
print("\n[Full Attention] No sparse policy (baseline)")
|
||||
|
||||
llm = LLM(
|
||||
path,
|
||||
enforce_eager=False,
|
||||
@@ -109,15 +63,12 @@ def main():
|
||||
max_num_batched_tokens=max_len,
|
||||
enable_cpu_offload=True,
|
||||
num_gpu_blocks=6, # Small GPU buffer for offload testing
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
sparse_topk_blocks=args.topk,
|
||||
sparse_threshold_blocks=4,
|
||||
)
|
||||
|
||||
if not args.no_sparse:
|
||||
# Setup Quest policy for decode (Top-K blocks, apply when > 4 blocks)
|
||||
setup_quest_policy(llm, topk_blocks=args.topk, threshold_blocks=4)
|
||||
print(f"\n[Quest Sparse Attention] topk={args.topk}")
|
||||
else:
|
||||
print("\n[Full Attention] No sparse policy (baseline)")
|
||||
|
||||
# Warmup
|
||||
llm.generate(["Benchmark: "], SamplingParams())
|
||||
|
||||
|
||||
@@ -29,8 +29,9 @@ class Config:
|
||||
num_gpu_kvcache_blocks: int = -1
|
||||
num_cpu_kvcache_blocks: int = -1
|
||||
|
||||
# Sparse attention configuration
|
||||
sparse_policy: str | None = None # "vertical_slash", "quest", "streaming_llm", or None
|
||||
# Sparse attention configuration (dual policy architecture)
|
||||
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
||||
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm"
|
||||
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
|
||||
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
|
||||
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||
|
||||
@@ -156,6 +156,25 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Initialize sparse policies if manager has them (CPU offload mode)
|
||||
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'):
|
||||
# Initialize both policies with model config
|
||||
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
|
||||
if policy is not None:
|
||||
policy.initialize(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} "
|
||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||
)
|
||||
|
||||
# Log KV cache allocation info with detailed per-token breakdown
|
||||
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
|
||||
@@ -56,14 +56,36 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
# Need CPU offload: use hybrid manager
|
||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||
from nanovllm.kvcache.policies import get_policy
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||
|
||||
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
|
||||
# Create sparse policies from config
|
||||
prefill_policy_type = getattr(config, 'prefill_policy', 'full')
|
||||
decode_policy_type = getattr(config, 'decode_policy', 'full')
|
||||
|
||||
def create_policy(policy_type_str):
|
||||
"""Create a sparse policy from config string."""
|
||||
if policy_type_str.lower() == 'full':
|
||||
return create_sparse_policy(SparsePolicyType.FULL)
|
||||
policy_type = SparsePolicyType[policy_type_str.upper()]
|
||||
return create_sparse_policy(
|
||||
policy_type,
|
||||
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
||||
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
|
||||
)
|
||||
|
||||
prefill_policy = create_policy(prefill_policy_type)
|
||||
decode_policy = create_policy(decode_policy_type)
|
||||
|
||||
return HybridKVCacheManager(
|
||||
num_gpu_slots=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
policy=policy,
|
||||
policy=eviction_policy,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||
@@ -102,6 +104,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||
prefill_policy: Sparse attention policy for prefill phase
|
||||
decode_policy: Sparse attention policy for decode phase
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
@@ -113,6 +117,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
|
||||
# Sparse attention policies (set at construction time, immutable)
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
|
||||
# Logical blocks (what sequences reference) - one per CPU block
|
||||
self.logical_blocks: List[LogicalBlock] = [
|
||||
LogicalBlock(i) for i in range(self.total_blocks)
|
||||
@@ -153,9 +161,6 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
||||
self._prefill_len: Dict[int, int] = {}
|
||||
|
||||
# Sparse attention policy (optional)
|
||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
@@ -180,6 +185,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
prefill_policy=self.prefill_policy,
|
||||
decode_policy=self.decode_policy,
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
@@ -187,23 +194,17 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
assert self.offload_engine is not None
|
||||
return self.offload_engine.get_layer_cache(layer_id)
|
||||
|
||||
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
|
||||
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
|
||||
"""
|
||||
Set sparse attention policy for block selection.
|
||||
|
||||
The sparse policy determines which KV blocks to load from CPU
|
||||
for each query chunk during chunked attention computation.
|
||||
Get sparse policy for the specified phase.
|
||||
|
||||
Args:
|
||||
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
|
||||
is_prefill: True for prefill phase, False for decode phase
|
||||
|
||||
Example:
|
||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
|
||||
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
|
||||
manager.set_sparse_policy(policy)
|
||||
Returns:
|
||||
SparsePolicy for the phase, or None if not set
|
||||
"""
|
||||
self.sparse_policy = policy
|
||||
logger.info(f"Sparse attention policy set: {policy}")
|
||||
return self.prefill_policy if is_prefill else self.decode_policy
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""Check if we can allocate blocks for a new sequence."""
|
||||
|
||||
@@ -17,6 +17,11 @@ from nanovllm.kvcache.kernels import gathered_copy_kv
|
||||
from nanovllm.comm import memcpy_2d_async
|
||||
from nanovllm.utils.logger import get_logger
|
||||
|
||||
# Import for type hints only (avoid circular import)
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.sparse import SparsePolicy
|
||||
|
||||
logger = get_logger("offload_engine")
|
||||
|
||||
|
||||
@@ -55,6 +60,8 @@ class OffloadEngine:
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_streams: int = 4,
|
||||
prefill_policy: "SparsePolicy" = None,
|
||||
decode_policy: "SparsePolicy" = None,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@@ -210,6 +217,10 @@ class OffloadEngine:
|
||||
self._debug_mode = False
|
||||
self._debug_hooks: List = [] # External hooks for debug events
|
||||
|
||||
# ========== Sparse attention policies (set at construction time) ==========
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
stream = self.transfer_streams[self._stream_idx]
|
||||
@@ -730,7 +741,14 @@ class OffloadEngine:
|
||||
"""Wait for slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||
|
||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
def offload_slot_layer_to_cpu(
|
||||
self,
|
||||
slot_idx: int,
|
||||
layer_id: int,
|
||||
cpu_block_id: int,
|
||||
num_valid_tokens: int = -1,
|
||||
is_prefill: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU for one layer.
|
||||
|
||||
@@ -741,9 +759,27 @@ class OffloadEngine:
|
||||
slot_idx: Source GPU slot index
|
||||
layer_id: Target layer in CPU cache
|
||||
cpu_block_id: Target CPU block ID
|
||||
num_valid_tokens: Number of valid tokens in this block (-1 = use block_size)
|
||||
is_prefill: True if in prefill phase, False if in decode phase
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||
|
||||
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
||||
# Both policies' callbacks are called - each decides whether to respond
|
||||
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||
k_cache = self.k_cache_gpu[slot_idx]
|
||||
|
||||
if is_prefill:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
else:
|
||||
if self.prefill_policy is not None:
|
||||
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
if self.decode_policy is not None:
|
||||
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for both compute_stream and default stream
|
||||
|
||||
@@ -22,7 +22,6 @@ Usage:
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
||||
|
||||
|
||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||
@@ -67,6 +66,5 @@ __all__ = [
|
||||
"QuestPolicy",
|
||||
"QuestConfig",
|
||||
"BlockMetadataManager",
|
||||
"HybridPolicy",
|
||||
"create_sparse_policy",
|
||||
]
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
"""
|
||||
Hybrid sparse attention policy.
|
||||
|
||||
Allows using different policies for prefill vs decode phases.
|
||||
This is useful because optimal sparsity patterns often differ:
|
||||
- Prefill: fixed patterns work well (e.g., VerticalSlash)
|
||||
- Decode: query-aware selection helps (e.g., Quest)
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
|
||||
|
||||
class HybridPolicy(SparsePolicy):
|
||||
"""
|
||||
Hybrid policy that uses different policies for prefill and decode.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from nanovllm.kvcache.sparse import (
|
||||
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
|
||||
VerticalSlashConfig, QuestConfig, BlockMetadataManager
|
||||
)
|
||||
|
||||
# Prefill: use fast fixed pattern
|
||||
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
|
||||
num_sink_blocks=1,
|
||||
local_window_blocks=3,
|
||||
))
|
||||
|
||||
# Decode: use query-aware selection
|
||||
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
|
||||
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
|
||||
|
||||
# Combine
|
||||
policy = HybridPolicy(prefill_policy, decode_policy)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefill_policy: SparsePolicy,
|
||||
decode_policy: SparsePolicy,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid policy.
|
||||
|
||||
Args:
|
||||
prefill_policy: Policy to use during prefill phase
|
||||
decode_policy: Policy to use during decode phase
|
||||
"""
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Delegate to appropriate policy based on phase."""
|
||||
if ctx.is_prefill:
|
||||
return self.prefill_policy.select_blocks(available_blocks, ctx)
|
||||
else:
|
||||
return self.decode_policy.select_blocks(available_blocks, ctx)
|
||||
|
||||
def on_block_offloaded(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Forward to both policies (both may need metadata updates)."""
|
||||
self.prefill_policy.on_block_offloaded(
|
||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
||||
)
|
||||
self.decode_policy.on_block_offloaded(
|
||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset both policies."""
|
||||
self.prefill_policy.reset()
|
||||
self.decode_policy.reset()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"HybridPolicy(\n"
|
||||
f" prefill={self.prefill_policy},\n"
|
||||
f" decode={self.decode_policy}\n"
|
||||
f")"
|
||||
)
|
||||
@@ -134,7 +134,7 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_block_offloaded(
|
||||
def on_prefill_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
@@ -142,15 +142,38 @@ class SparsePolicy(ABC):
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Hook called when a block is offloaded from GPU to CPU.
|
||||
Hook called when a block is offloaded during prefill phase.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to collect metadata about blocks (e.g., min/max keys
|
||||
for Quest-style selection). Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that was written
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_decode_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Hook called when a block is offloaded during decode phase.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to update metadata about blocks. Default implementation
|
||||
does nothing.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -289,14 +289,25 @@ class QuestPolicy(SparsePolicy):
|
||||
|
||||
return result
|
||||
|
||||
def on_block_offloaded(
|
||||
def on_prefill_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Update min/max key metadata when block is offloaded."""
|
||||
"""Update min/max key metadata during prefill offload."""
|
||||
if self.metadata is not None:
|
||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
|
||||
def on_decode_offload(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Update min/max key metadata during decode offload (for new blocks)."""
|
||||
if self.metadata is not None:
|
||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
|
||||
|
||||
@@ -189,7 +189,8 @@ class Attention(nn.Module):
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if cpu_block_table and kvcache_manager.sparse_policy is not None:
|
||||
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
|
||||
if cpu_block_table and prefill_policy is not None:
|
||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
@@ -200,7 +201,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
||||
cpu_block_table = prefill_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
@@ -279,7 +280,11 @@ class Attention(nn.Module):
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
||||
# k.shape[0] = number of tokens in current chunk
|
||||
num_valid_tokens = k.shape[0]
|
||||
offload_engine.offload_slot_layer_to_cpu(
|
||||
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
|
||||
)
|
||||
|
||||
# CRITICAL: compute_stream must wait for offload to complete
|
||||
# before the next layer's store_kvcache can overwrite the GPU slot.
|
||||
@@ -508,7 +513,8 @@ class Attention(nn.Module):
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if kvcache_manager.sparse_policy is not None:
|
||||
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
|
||||
if decode_policy is not None:
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
@@ -518,7 +524,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
||||
cpu_block_table = decode_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user