✨ feat: add configurable stride and chunk_size for XAttention BSA
- Add sparse_chunk_size config option (default: 16384) - Pass stride, chunk_size, use_triton through factory function - Add --sparse-stride CLI option to test_ruler.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -274,6 +274,7 @@ def run_ruler_benchmark(
|
||||
sparse_threshold: float = 0.9,
|
||||
sparse_samples: int = 128,
|
||||
sparse_block_size: int = 128,
|
||||
sparse_stride: int = 8,
|
||||
) -> Dict:
|
||||
"""
|
||||
Run RULER benchmark on multiple tasks.
|
||||
@@ -339,6 +340,7 @@ def run_ruler_benchmark(
|
||||
if sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
||||
llm_kwargs["sparse_threshold"] = sparse_threshold
|
||||
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
||||
llm_kwargs["sparse_stride"] = sparse_stride
|
||||
|
||||
# Factory function for fresh_llm mode
|
||||
def create_llm():
|
||||
@@ -485,6 +487,8 @@ if __name__ == "__main__":
|
||||
help="XAttention BSA: samples per chunk for estimation")
|
||||
parser.add_argument("--sparse-block-size", type=int, default=128,
|
||||
help="XAttention BSA: block size for estimation")
|
||||
parser.add_argument("--sparse-stride", type=int, default=8,
|
||||
help="XAttention BSA: stride for Q/K downsampling")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -521,6 +525,7 @@ if __name__ == "__main__":
|
||||
sparse_threshold=args.sparse_threshold,
|
||||
sparse_samples=args.sparse_samples,
|
||||
sparse_block_size=args.sparse_block_size,
|
||||
sparse_stride=args.sparse_stride,
|
||||
)
|
||||
|
||||
# Exit code (skip for json output mode)
|
||||
|
||||
Reference in New Issue
Block a user