[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -32,11 +32,14 @@ def run_needle_test(
enable_cpu_offload: bool = False,
enable_quest: bool = False,
enable_minference: bool = False,
enable_xattn: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
minference_budget: float = 0.3,
minference_vertical: int = 1000,
minference_slash: int = 6096,
xattn_threshold: float = 0.9,
xattn_use_bsa: bool = True,
gpu_utilization: float = 0.9,
enforce_eager: bool = True,
verbose: bool = True,
@@ -56,11 +59,14 @@ def run_needle_test(
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_minference: Enable MInference sparse prefill (GPU-only)
enable_xattn: Enable XAttention sparse prefill with BSA
sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
minference_vertical: Fixed vertical_size (only used when budget=None)
minference_slash: Fixed slash_size (only used when budget=None)
xattn_threshold: XAttention block selection threshold (0-1)
xattn_use_bsa: Use Block Sparse Attention library
gpu_utilization: GPU memory utilization fraction
verbose: Print detailed output
@@ -68,7 +74,9 @@ def run_needle_test(
True if test passed, False otherwise
"""
# Determine sparse policy
if enable_minference:
if enable_xattn:
sparse_policy = SparsePolicyType.XATTN
elif enable_minference:
sparse_policy = SparsePolicyType.MINFERENCE
elif enable_quest:
sparse_policy = SparsePolicyType.QUEST
@@ -94,6 +102,8 @@ def run_needle_test(
print(f" MInference: adaptive (budget={minference_budget})")
else:
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
if enable_xattn:
print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}")
print(f"{'='*60}\n")
# 1. Initialize LLM
@@ -111,7 +121,7 @@ def run_needle_test(
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
# Set sparse policy (can be used with or without offload)
if enable_minference or enable_quest:
if enable_minference or enable_quest or enable_xattn:
llm_kwargs["sparse_policy"] = sparse_policy
# MInference params (works with both GPU-only and offload mode)
@@ -120,6 +130,11 @@ def run_needle_test(
llm_kwargs["minference_vertical_size"] = minference_vertical
llm_kwargs["minference_slash_size"] = minference_slash
# XAttention params
if enable_xattn:
llm_kwargs["xattn_threshold"] = xattn_threshold
llm_kwargs["xattn_use_bsa"] = xattn_use_bsa
llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt
@@ -224,6 +239,11 @@ if __name__ == "__main__":
action="store_true",
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
)
parser.add_argument(
"--enable-xattn",
action="store_true",
help="Enable XAttention sparse prefill with Block Sparse Attention"
)
parser.add_argument(
"--sparse-topk",
type=int,
@@ -254,6 +274,17 @@ if __name__ == "__main__":
default=6096,
help="Fixed slash_size (only used when budget=0)"
)
parser.add_argument(
"--xattn-threshold",
type=float,
default=0.9,
help="XAttention block selection threshold (0-1, higher=more blocks)"
)
parser.add_argument(
"--xattn-no-bsa",
action="store_true",
help="Disable Block Sparse Attention (use FlashAttention fallback)"
)
parser.add_argument(
"--gpu-utilization",
type=float,
@@ -291,11 +322,14 @@ if __name__ == "__main__":
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
enable_minference=args.enable_minference,
enable_xattn=args.enable_xattn,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
minference_budget=minference_budget,
minference_vertical=args.minference_vertical,
minference_slash=args.minference_slash,
xattn_threshold=args.xattn_threshold,
xattn_use_bsa=not args.xattn_no_bsa,
gpu_utilization=args.gpu_utilization,
enforce_eager=enforce_eager,
verbose=True,