[feat] Added Quest Sparsity Policy.
This commit is contained in:
@@ -12,6 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.config import SparsePolicyType
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
@@ -29,6 +30,9 @@ def run_needle_test(
|
||||
needle_value: str = "7492",
|
||||
max_new_tokens: int = 32,
|
||||
enable_cpu_offload: bool = False,
|
||||
enable_quest: bool = False,
|
||||
sparse_topk: int = 8,
|
||||
sparse_threshold: int = 4,
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -44,11 +48,16 @@ def run_needle_test(
|
||||
needle_value: The secret value to find
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
enable_cpu_offload: Enable CPU offload mode
|
||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||
sparse_topk: Top-K blocks for Quest
|
||||
sparse_threshold: Apply sparse only when blocks > threshold
|
||||
verbose: Print detailed output
|
||||
|
||||
Returns:
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Needle-in-Haystack Test")
|
||||
@@ -60,6 +69,8 @@ def run_needle_test(
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
if enable_cpu_offload:
|
||||
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 1. Initialize LLM
|
||||
@@ -72,6 +83,9 @@ def run_needle_test(
|
||||
}
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
llm_kwargs["sparse_policy"] = sparse_policy
|
||||
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
|
||||
@@ -167,6 +181,23 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Enable CPU offload (has known bug for long sequences)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-quest",
|
||||
action="store_true",
|
||||
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-topk",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Top-K blocks for Quest sparse attention"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-threshold",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Apply sparse only when blocks > threshold"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
passed = run_needle_test(
|
||||
@@ -179,6 +210,9 @@ if __name__ == "__main__":
|
||||
needle_value=args.needle_value,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
enable_cpu_offload=args.enable_offload,
|
||||
enable_quest=args.enable_quest,
|
||||
sparse_topk=args.sparse_topk,
|
||||
sparse_threshold=args.sparse_threshold,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user