[WIP] Before refactor policies.

This commit is contained in:
Zijie Tian
2026-01-06 20:47:55 +08:00
parent 7cc8a394a5
commit 690492e074
6 changed files with 112 additions and 237 deletions

View File

@@ -5,86 +5,68 @@ Provides pluggable policies for selecting which KV blocks to load
during chunked attention with CPU offload.
Usage:
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
# Use built-in policy
policy = VerticalSlashPolicy(VerticalSlashConfig())
# Create policy using factory function
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
# Or create custom policy
class MyPolicy(SparsePolicy):
supports_prefill = True
supports_decode = True
def select_blocks(self, available_blocks, ctx):
return available_blocks[:5] # Just first 5 blocks
"""
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
# Built-in policy registry
BUILTIN_SPARSE_POLICIES = {
"full": FullAttentionPolicy,
"vertical_slash": VerticalSlashPolicy,
"streaming_llm": StreamingLLMPolicy,
}
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
"""
Get a sparse attention policy instance by name.
Create a sparse policy instance from an enum type.
The returned policy is not yet initialized. Call policy.initialize()
or let the framework call it during KV cache allocation.
Args:
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
**kwargs: Policy-specific configuration
policy_type: SparsePolicyType enum value
**kwargs: Policy-specific configuration options
Returns:
SparsePolicy instance
"""
policy_name = policy_name.lower()
SparsePolicy instance (not initialized)
if policy_name == "full":
Example:
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
policy.initialize(num_layers=28, num_kv_heads=8, ...)
"""
if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy()
elif policy_name == "vertical_slash":
config = VerticalSlashConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
local_window_blocks=kwargs.get("local_window_blocks", 2),
elif policy_type == SparsePolicyType.QUEST:
config = QuestConfig(
topk_blocks=kwargs.get("topk_blocks", 8),
threshold_blocks=kwargs.get("threshold_blocks", 4),
include_sink_blocks=kwargs.get("include_sink_blocks", 0),
include_recent_blocks=kwargs.get("include_recent_blocks", 0),
)
return VerticalSlashPolicy(config)
elif policy_name == "streaming_llm":
config = StreamingLLMConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
)
return StreamingLLMPolicy(config)
elif policy_name == "quest":
# Quest requires metadata_manager to be passed separately
raise ValueError(
"Quest policy requires BlockMetadataManager. "
"Use QuestPolicy(config, metadata_manager) directly."
)
return QuestPolicy(config)
else:
raise ValueError(
f"Unknown sparse policy '{policy_name}'. "
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
)
raise ValueError(f"Unknown policy type: {policy_type}")
__all__ = [
"SparsePolicy",
"PolicyContext",
"SparsePolicyType",
"FullAttentionPolicy",
"VerticalSlashPolicy",
"VerticalSlashConfig",
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"StreamingLLMPolicy",
"StreamingLLMConfig",
"HybridPolicy",
"get_sparse_policy",
"BUILTIN_SPARSE_POLICIES",
"create_sparse_policy",
]