From 45efcf0db171b9137f0bb1fe4239a8bd3ad3ae66 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 28 Jan 2026 13:56:15 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20--dtype=20parameter?= =?UTF-8?q?=20to=20test=5Fruler.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Support models with float32 default dtype (e.g., Nemotron). FlashAttention requires fp16/bf16, so dtype must be specified. Usage: --dtype bfloat16 Co-Authored-By: Claude Opus 4.5 --- tests/test_ruler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_ruler.py b/tests/test_ruler.py index 1f20bf6..1c977c0 100644 --- a/tests/test_ruler.py +++ b/tests/test_ruler.py @@ -335,6 +335,7 @@ def run_ruler_benchmark( sparse_samples: int = 128, sparse_block_size: int = 128, sparse_stride: int = 8, + dtype: Optional[str] = None, ) -> Dict: """ Run RULER benchmark on multiple tasks. @@ -389,6 +390,8 @@ def run_ruler_benchmark( "kvcache_block_size": block_size, "enable_cpu_offload": enable_cpu_offload, } + if dtype: + llm_kwargs["dtype"] = dtype if enable_cpu_offload: llm_kwargs["num_gpu_blocks"] = num_gpu_blocks llm_kwargs["num_kv_buffers"] = num_kv_buffers @@ -550,6 +553,8 @@ if __name__ == "__main__": help="XAttention BSA: block size for estimation") parser.add_argument("--sparse-stride", type=int, default=8, help="XAttention BSA: stride for Q/K downsampling") + parser.add_argument("--dtype", type=str, default=None, + help="Model dtype (bfloat16, float16). Required for models with float32 default.") args = parser.parse_args() @@ -587,6 +592,7 @@ if __name__ == "__main__": sparse_samples=args.sparse_samples, sparse_block_size=args.sparse_block_size, sparse_stride=args.sparse_stride, + dtype=args.dtype, ) # Exit code (skip for json output mode)