[WIP] Before refactor policies.
This commit is contained in:
@@ -5,86 +5,68 @@ Provides pluggable policies for selecting which KV blocks to load
|
||||
during chunked attention with CPU offload.
|
||||
|
||||
Usage:
|
||||
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||
|
||||
# Use built-in policy
|
||||
policy = VerticalSlashPolicy(VerticalSlashConfig())
|
||||
# Create policy using factory function
|
||||
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
|
||||
|
||||
# Or create custom policy
|
||||
class MyPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def select_blocks(self, available_blocks, ctx):
|
||||
return available_blocks[:5] # Just first 5 blocks
|
||||
"""
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
|
||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
||||
|
||||
# Built-in policy registry
|
||||
BUILTIN_SPARSE_POLICIES = {
|
||||
"full": FullAttentionPolicy,
|
||||
"vertical_slash": VerticalSlashPolicy,
|
||||
"streaming_llm": StreamingLLMPolicy,
|
||||
}
|
||||
|
||||
|
||||
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
|
||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||
"""
|
||||
Get a sparse attention policy instance by name.
|
||||
Create a sparse policy instance from an enum type.
|
||||
|
||||
The returned policy is not yet initialized. Call policy.initialize()
|
||||
or let the framework call it during KV cache allocation.
|
||||
|
||||
Args:
|
||||
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
|
||||
**kwargs: Policy-specific configuration
|
||||
policy_type: SparsePolicyType enum value
|
||||
**kwargs: Policy-specific configuration options
|
||||
|
||||
Returns:
|
||||
SparsePolicy instance
|
||||
"""
|
||||
policy_name = policy_name.lower()
|
||||
SparsePolicy instance (not initialized)
|
||||
|
||||
if policy_name == "full":
|
||||
Example:
|
||||
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
|
||||
policy.initialize(num_layers=28, num_kv_heads=8, ...)
|
||||
"""
|
||||
if policy_type == SparsePolicyType.FULL:
|
||||
return FullAttentionPolicy()
|
||||
elif policy_name == "vertical_slash":
|
||||
config = VerticalSlashConfig(
|
||||
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
||||
local_window_blocks=kwargs.get("local_window_blocks", 2),
|
||||
|
||||
elif policy_type == SparsePolicyType.QUEST:
|
||||
config = QuestConfig(
|
||||
topk_blocks=kwargs.get("topk_blocks", 8),
|
||||
threshold_blocks=kwargs.get("threshold_blocks", 4),
|
||||
include_sink_blocks=kwargs.get("include_sink_blocks", 0),
|
||||
include_recent_blocks=kwargs.get("include_recent_blocks", 0),
|
||||
)
|
||||
return VerticalSlashPolicy(config)
|
||||
elif policy_name == "streaming_llm":
|
||||
config = StreamingLLMConfig(
|
||||
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
||||
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
|
||||
)
|
||||
return StreamingLLMPolicy(config)
|
||||
elif policy_name == "quest":
|
||||
# Quest requires metadata_manager to be passed separately
|
||||
raise ValueError(
|
||||
"Quest policy requires BlockMetadataManager. "
|
||||
"Use QuestPolicy(config, metadata_manager) directly."
|
||||
)
|
||||
return QuestPolicy(config)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown sparse policy '{policy_name}'. "
|
||||
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
|
||||
)
|
||||
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SparsePolicy",
|
||||
"PolicyContext",
|
||||
"SparsePolicyType",
|
||||
"FullAttentionPolicy",
|
||||
"VerticalSlashPolicy",
|
||||
"VerticalSlashConfig",
|
||||
"QuestPolicy",
|
||||
"QuestConfig",
|
||||
"BlockMetadataManager",
|
||||
"StreamingLLMPolicy",
|
||||
"StreamingLLMConfig",
|
||||
"HybridPolicy",
|
||||
"get_sparse_policy",
|
||||
"BUILTIN_SPARSE_POLICIES",
|
||||
"create_sparse_policy",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user