Files
nano-vllm/nanovllm/kvcache/sparse/__init__.py
2025-12-22 08:51:02 +08:00

91 lines
2.9 KiB
Python

"""
Sparse Attention Policy module.
Provides pluggable policies for selecting which KV blocks to load
during chunked attention with CPU offload.
Usage:
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
# Use built-in policy
policy = VerticalSlashPolicy(VerticalSlashConfig())
# Or create custom policy
class MyPolicy(SparsePolicy):
def select_blocks(self, available_blocks, ctx):
return available_blocks[:5] # Just first 5 blocks
"""
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
# Built-in policy registry
BUILTIN_SPARSE_POLICIES = {
"full": FullAttentionPolicy,
"vertical_slash": VerticalSlashPolicy,
"streaming_llm": StreamingLLMPolicy,
}
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
"""
Get a sparse attention policy instance by name.
Args:
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
**kwargs: Policy-specific configuration
Returns:
SparsePolicy instance
"""
policy_name = policy_name.lower()
if policy_name == "full":
return FullAttentionPolicy()
elif policy_name == "vertical_slash":
config = VerticalSlashConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
local_window_blocks=kwargs.get("local_window_blocks", 2),
threshold_blocks=kwargs.get("threshold_blocks", 4),
)
return VerticalSlashPolicy(config)
elif policy_name == "streaming_llm":
config = StreamingLLMConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
)
return StreamingLLMPolicy(config)
elif policy_name == "quest":
# Quest requires metadata_manager to be passed separately
raise ValueError(
"Quest policy requires BlockMetadataManager. "
"Use QuestPolicy(config, metadata_manager) directly."
)
else:
raise ValueError(
f"Unknown sparse policy '{policy_name}'. "
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
)
__all__ = [
"SparsePolicy",
"PolicyContext",
"FullAttentionPolicy",
"VerticalSlashPolicy",
"VerticalSlashConfig",
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"StreamingLLMPolicy",
"StreamingLLMConfig",
"HybridPolicy",
"get_sparse_policy",
"BUILTIN_SPARSE_POLICIES",
]