[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -11,7 +11,7 @@ import logging
import torch
from dataclasses import dataclass
from typing import List, Tuple, Optional
from .policy import SparsePolicy, PolicyContext
from .policy import AttentionPolicy, PolicyContext
logger = logging.getLogger(__name__)
@@ -137,7 +137,7 @@ class QuestConfig:
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
class QuestPolicy(SparsePolicy):
class QuestPolicy(AttentionPolicy):
"""
Quest-style Top-K block selection using min/max key bounds.
@@ -317,6 +317,25 @@ class QuestPolicy(SparsePolicy):
if self.metadata is not None:
self.metadata.reset()
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Quest does not support prefill - raises error.
Quest is a decode-only policy for selective block loading.
For prefill, use FullAttentionPolicy or XAttentionPolicy.
"""
raise NotImplementedError(
"QuestPolicy does not support prefill. "
"Use FullAttentionPolicy or XAttentionPolicy for prefill."
)
def __repr__(self) -> str:
return (
f"QuestPolicy(topk={self.config.topk_blocks}, "