[WIP] Before integrate the xattn operator.

This commit is contained in:
Zijie Tian
2026-01-19 21:19:21 +08:00
parent 9e6fdc0650
commit b5da802dff
11 changed files with 949 additions and 32 deletions

View File

@@ -31,8 +31,10 @@ def run_needle_test(
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
enable_quest: bool = False,
enable_xattn_bsa: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
sparse_samples: int = 128,
verbose: bool = True,
) -> bool:
"""
@@ -49,14 +51,22 @@ def run_needle_test(
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_xattn_bsa: Enable XAttention BSA sparse attention (prefill-only)
sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold
sparse_threshold: Threshold for sparse selection (Quest/XAttention BSA)
sparse_samples: Samples per chunk for XAttention BSA estimation
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
# Determine sparse policy
if enable_xattn_bsa:
sparse_policy = SparsePolicyType.XATTN_BSA
elif enable_quest:
sparse_policy = SparsePolicyType.QUEST
else:
sparse_policy = SparsePolicyType.FULL
if verbose:
print(f"\n{'='*60}")
@@ -70,7 +80,11 @@ def run_needle_test(
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload:
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
print(f"Sparse policy: {sparse_policy.name}")
if sparse_policy == SparsePolicyType.QUEST:
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
elif sparse_policy == SparsePolicyType.XATTN_BSA:
print(f" XAttention BSA: threshold={sparse_threshold}, samples={sparse_samples}")
print(f"{'='*60}\n")
# 1. Initialize LLM
@@ -84,8 +98,12 @@ def run_needle_test(
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
if sparse_policy == SparsePolicyType.QUEST:
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
elif sparse_policy == SparsePolicyType.XATTN_BSA:
llm_kwargs["sparse_threshold"] = float(sparse_threshold) / 10.0 # Convert to 0.0-1.0 range
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
llm = LLM(model_path, **llm_kwargs)
@@ -186,6 +204,11 @@ if __name__ == "__main__":
action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)"
)
parser.add_argument(
"--enable-xattn-bsa",
action="store_true",
help="Enable XAttention BSA sparse attention (prefill-only)"
)
parser.add_argument(
"--sparse-topk",
type=int,
@@ -196,7 +219,13 @@ if __name__ == "__main__":
"--sparse-threshold",
type=int,
default=4,
help="Apply sparse only when blocks > threshold"
help="Apply sparse only when blocks > threshold (Quest) or attention threshold 0-9 (XAttention BSA)"
)
parser.add_argument(
"--sparse-samples",
type=int,
default=128,
help="Samples per chunk for XAttention BSA estimation"
)
args = parser.parse_args()
@@ -211,8 +240,10 @@ if __name__ == "__main__":
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
enable_xattn_bsa=args.enable_xattn_bsa,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
sparse_samples=args.sparse_samples,
verbose=True,
)