[WIP] Before refactor policies.

This commit is contained in:
Zijie Tian
2026-01-06 20:47:55 +08:00
parent 7cc8a394a5
commit 690492e074
6 changed files with 112 additions and 237 deletions

View File

@@ -7,10 +7,17 @@ from CPU for each query chunk during chunked attention computation.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional, Any
import torch
class SparsePolicyType(Enum):
"""Built-in sparse attention policy types."""
FULL = auto() # prefill + decode
QUEST = auto() # decode only
@dataclass
class PolicyContext:
"""
@@ -54,8 +61,15 @@ class SparsePolicy(ABC):
sparse attention patterns. The policy receives context about
the current query chunk and returns which KV blocks to load.
Attributes:
supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase.
Example:
class MySparsePolicy(SparsePolicy):
supports_prefill = False # decode-only policy
supports_decode = True
def select_blocks(self, available_blocks, ctx):
# Load first block and last 2 blocks
if len(available_blocks) <= 3:
@@ -63,6 +77,34 @@ class SparsePolicy(ABC):
return [available_blocks[0]] + available_blocks[-2:]
"""
# Compatibility flags - override in subclasses
supports_prefill: bool = True
supports_decode: bool = True
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
) -> None:
"""
Initialize policy resources.
Called by the framework after KV cache is allocated. Override this
to create metadata structures (e.g., BlockMetadataManager for Quest).
Default implementation does nothing.
Args:
num_layers: Number of transformer layers
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
num_cpu_blocks: Number of CPU blocks allocated
dtype: Data type for tensors
"""
pass
@abstractmethod
def select_blocks(
self,