[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -12,6 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
from utils import generate_needle_prompt, check_needle_answer
@@ -29,6 +30,9 @@ def run_needle_test(
needle_value: str = "7492",
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
enable_quest: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
verbose: bool = True,
) -> bool:
"""
@@ -44,11 +48,16 @@ def run_needle_test(
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test")
@@ -60,6 +69,8 @@ def run_needle_test(
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload:
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
print(f"{'='*60}\n")
# 1. Initialize LLM
@@ -72,6 +83,9 @@ def run_needle_test(
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
llm = LLM(model_path, **llm_kwargs)
@@ -167,6 +181,23 @@ if __name__ == "__main__":
action="store_true",
help="Enable CPU offload (has known bug for long sequences)"
)
parser.add_argument(
"--enable-quest",
action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)"
)
parser.add_argument(
"--sparse-topk",
type=int,
default=8,
help="Top-K blocks for Quest sparse attention"
)
parser.add_argument(
"--sparse-threshold",
type=int,
default=4,
help="Apply sparse only when blocks > threshold"
)
args = parser.parse_args()
passed = run_needle_test(
@@ -179,6 +210,9 @@ if __name__ == "__main__":
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
verbose=True,
)