✨ feat: add --dtype parameter to test_ruler.py
Support models with float32 default dtype (e.g., Nemotron). FlashAttention requires fp16/bf16, so dtype must be specified. Usage: --dtype bfloat16 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -335,6 +335,7 @@ def run_ruler_benchmark(
|
||||
sparse_samples: int = 128,
|
||||
sparse_block_size: int = 128,
|
||||
sparse_stride: int = 8,
|
||||
dtype: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Run RULER benchmark on multiple tasks.
|
||||
@@ -389,6 +390,8 @@ def run_ruler_benchmark(
|
||||
"kvcache_block_size": block_size,
|
||||
"enable_cpu_offload": enable_cpu_offload,
|
||||
}
|
||||
if dtype:
|
||||
llm_kwargs["dtype"] = dtype
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
llm_kwargs["num_kv_buffers"] = num_kv_buffers
|
||||
@@ -550,6 +553,8 @@ if __name__ == "__main__":
|
||||
help="XAttention BSA: block size for estimation")
|
||||
parser.add_argument("--sparse-stride", type=int, default=8,
|
||||
help="XAttention BSA: stride for Q/K downsampling")
|
||||
parser.add_argument("--dtype", type=str, default=None,
|
||||
help="Model dtype (bfloat16, float16). Required for models with float32 default.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -587,6 +592,7 @@ if __name__ == "__main__":
|
||||
sparse_samples=args.sparse_samples,
|
||||
sparse_block_size=args.sparse_block_size,
|
||||
sparse_stride=args.sparse_stride,
|
||||
dtype=args.dtype,
|
||||
)
|
||||
|
||||
# Exit code (skip for json output mode)
|
||||
|
||||
Reference in New Issue
Block a user