[WIP] Before refactor policies.
This commit is contained in:
@@ -7,10 +7,17 @@ from CPU for each query chunk during chunked attention computation.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List, Optional, Any
|
||||
import torch
|
||||
|
||||
|
||||
class SparsePolicyType(Enum):
|
||||
"""Built-in sparse attention policy types."""
|
||||
FULL = auto() # prefill + decode
|
||||
QUEST = auto() # decode only
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""
|
||||
@@ -54,8 +61,15 @@ class SparsePolicy(ABC):
|
||||
sparse attention patterns. The policy receives context about
|
||||
the current query chunk and returns which KV blocks to load.
|
||||
|
||||
Attributes:
|
||||
supports_prefill: Whether this policy can be used for prefill phase.
|
||||
supports_decode: Whether this policy can be used for decode phase.
|
||||
|
||||
Example:
|
||||
class MySparsePolicy(SparsePolicy):
|
||||
supports_prefill = False # decode-only policy
|
||||
supports_decode = True
|
||||
|
||||
def select_blocks(self, available_blocks, ctx):
|
||||
# Load first block and last 2 blocks
|
||||
if len(available_blocks) <= 3:
|
||||
@@ -63,6 +77,34 @@ class SparsePolicy(ABC):
|
||||
return [available_blocks[0]] + available_blocks[-2:]
|
||||
"""
|
||||
|
||||
# Compatibility flags - override in subclasses
|
||||
supports_prefill: bool = True
|
||||
supports_decode: bool = True
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize policy resources.
|
||||
|
||||
Called by the framework after KV cache is allocated. Override this
|
||||
to create metadata structures (e.g., BlockMetadataManager for Quest).
|
||||
Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
num_layers: Number of transformer layers
|
||||
num_kv_heads: Number of KV attention heads
|
||||
head_dim: Dimension per head
|
||||
num_cpu_blocks: Number of CPU blocks allocated
|
||||
dtype: Data type for tensors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user