diff --git a/tests/test_ruler.py b/tests/test_ruler.py index 1f20bf6..1c977c0 100644 --- a/tests/test_ruler.py +++ b/tests/test_ruler.py @@ -335,6 +335,7 @@ def run_ruler_benchmark( sparse_samples: int = 128, sparse_block_size: int = 128, sparse_stride: int = 8, + dtype: Optional[str] = None, ) -> Dict: """ Run RULER benchmark on multiple tasks. @@ -389,6 +390,8 @@ def run_ruler_benchmark( "kvcache_block_size": block_size, "enable_cpu_offload": enable_cpu_offload, } + if dtype: + llm_kwargs["dtype"] = dtype if enable_cpu_offload: llm_kwargs["num_gpu_blocks"] = num_gpu_blocks llm_kwargs["num_kv_buffers"] = num_kv_buffers @@ -550,6 +553,8 @@ if __name__ == "__main__": help="XAttention BSA: block size for estimation") parser.add_argument("--sparse-stride", type=int, default=8, help="XAttention BSA: stride for Q/K downsampling") + parser.add_argument("--dtype", type=str, default=None, + help="Model dtype (bfloat16, float16). Required for models with float32 default.") args = parser.parse_args() @@ -587,6 +592,7 @@ if __name__ == "__main__": sparse_samples=args.sparse_samples, sparse_block_size=args.sparse_block_size, sparse_stride=args.sparse_stride, + dtype=args.dtype, ) # Exit code (skip for json output mode)