[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -14,9 +14,9 @@ class Context:
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
# Sparse prefill attention support (GPU-only path)
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
# Attention policy support (GPU-only path)
# When set, uses policy.compute_prefill() instead of FlashAttention
attention_policy: Any = None # AttentionPolicy instance
_CONTEXT = Context()
@@ -35,7 +35,7 @@ def set_context(
slot_mapping=None,
context_lens=None,
block_tables=None,
sparse_prefill_policy=None,
attention_policy=None,
):
global _CONTEXT
_CONTEXT = Context(
@@ -47,7 +47,7 @@ def set_context(
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
sparse_prefill_policy=sparse_prefill_policy,
attention_policy=attention_policy,
)