feat: add XAttention sparse policy integration
Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload execution path. Uses FlashAttention with native GQA support for offload mode. New files: - nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility - nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention - nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation Modified: - nanovllm/config.py: Add XATTN configuration parameters - nanovllm/engine/model_runner.py: Support XATTN policy - nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy - tests/test_ruler.py: Add --sparse-policy parameter Test results (32k ruler): - NIAH tasks: 12/12 (100%) - QA/Recall tasks: 11/15 (73%) - Overall: 23/27 (85%) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -226,6 +226,7 @@ def run_ruler_benchmark(
|
||||
gpu_utilization: float = 0.9,
|
||||
enforce_eager: bool = True,
|
||||
verbose: bool = True,
|
||||
sparse_policy: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Run RULER benchmark on multiple tasks.
|
||||
@@ -236,6 +237,7 @@ def run_ruler_benchmark(
|
||||
datasets: List of task names to test (None = all)
|
||||
num_samples: Number of samples per task (None = all)
|
||||
...other LLM config params...
|
||||
sparse_policy: Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)
|
||||
|
||||
Returns:
|
||||
Dict with overall results and per-task results
|
||||
@@ -272,6 +274,10 @@ def run_ruler_benchmark(
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
llm_kwargs["num_kv_buffers"] = num_kv_buffers
|
||||
if sparse_policy:
|
||||
from nanovllm.config import SparsePolicyType
|
||||
sparse_policy_type = SparsePolicyType[sparse_policy]
|
||||
llm_kwargs["sparse_policy"] = sparse_policy_type
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
|
||||
@@ -366,6 +372,8 @@ if __name__ == "__main__":
|
||||
help="Enable CUDA graph")
|
||||
parser.add_argument("--quiet", "-q", action="store_true",
|
||||
help="Quiet mode")
|
||||
parser.add_argument("--sparse-policy", type=str, default="",
|
||||
help="Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -373,6 +381,9 @@ if __name__ == "__main__":
|
||||
datasets = args.datasets.split(",") if args.datasets else None
|
||||
num_samples = args.num_samples if args.num_samples > 0 else None
|
||||
|
||||
# Parse sparse policy
|
||||
sparse_policy_str = args.sparse_policy.upper() if args.sparse_policy else None
|
||||
|
||||
results = run_ruler_benchmark(
|
||||
model_path=os.path.expanduser(args.model),
|
||||
data_dir=Path(args.data_dir),
|
||||
@@ -387,6 +398,7 @@ if __name__ == "__main__":
|
||||
gpu_utilization=args.gpu_utilization,
|
||||
enforce_eager=not args.use_cuda_graph,
|
||||
verbose=not args.quiet,
|
||||
sparse_policy=sparse_policy_str,
|
||||
)
|
||||
|
||||
# Exit code
|
||||
|
||||
Reference in New Issue
Block a user