[WIP] need refactor.
This commit is contained in:
@@ -14,9 +14,9 @@ class Context:
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
# Sparse prefill attention support (GPU-only path)
|
||||
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
|
||||
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
|
||||
# Attention policy support (GPU-only path)
|
||||
# When set, uses policy.compute_prefill() instead of FlashAttention
|
||||
attention_policy: Any = None # AttentionPolicy instance
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
@@ -35,7 +35,7 @@ def set_context(
|
||||
slot_mapping=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
sparse_prefill_policy=None,
|
||||
attention_policy=None,
|
||||
):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(
|
||||
@@ -47,7 +47,7 @@ def set_context(
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
sparse_prefill_policy=sparse_prefill_policy,
|
||||
attention_policy=attention_policy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user