✨ feat: add configurable stride and chunk_size for XAttention BSA
- Add sparse_chunk_size config option (default: 16384) - Pass stride, chunk_size, use_triton through factory function - Add --sparse-stride CLI option to test_ruler.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -51,6 +51,7 @@ class Config:
|
|||||||
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
|
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
|
||||||
sparse_use_triton: bool = True # Use Triton kernels for estimation
|
sparse_use_triton: bool = True # Use Triton kernels for estimation
|
||||||
sparse_stride: int = 8 # Stride for Q/K downsampling
|
sparse_stride: int = 8 # Stride for Q/K downsampling
|
||||||
|
sparse_chunk_size: int = 16384 # Triton kernel chunk size for estimation
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert os.path.isdir(self.model)
|
assert os.path.isdir(self.model)
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
'threshold': getattr(config, 'sparse_threshold', 0.9),
|
'threshold': getattr(config, 'sparse_threshold', 0.9),
|
||||||
'use_triton': getattr(config, 'sparse_use_triton', True),
|
'use_triton': getattr(config, 'sparse_use_triton', True),
|
||||||
'stride': getattr(config, 'sparse_stride', 8),
|
'stride': getattr(config, 'sparse_stride', 8),
|
||||||
|
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
|
||||||
}
|
}
|
||||||
|
|
||||||
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
|
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
|
||||||
|
|||||||
@@ -61,6 +61,9 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
|
|||||||
block_size=kwargs.get("block_size", 128),
|
block_size=kwargs.get("block_size", 128),
|
||||||
samples_per_chunk=kwargs.get("samples_per_chunk", 128),
|
samples_per_chunk=kwargs.get("samples_per_chunk", 128),
|
||||||
threshold=kwargs.get("threshold", 0.9),
|
threshold=kwargs.get("threshold", 0.9),
|
||||||
|
stride=kwargs.get("stride", 8),
|
||||||
|
chunk_size=kwargs.get("chunk_size", 16384),
|
||||||
|
use_triton=kwargs.get("use_triton", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -274,6 +274,7 @@ def run_ruler_benchmark(
|
|||||||
sparse_threshold: float = 0.9,
|
sparse_threshold: float = 0.9,
|
||||||
sparse_samples: int = 128,
|
sparse_samples: int = 128,
|
||||||
sparse_block_size: int = 128,
|
sparse_block_size: int = 128,
|
||||||
|
sparse_stride: int = 8,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Run RULER benchmark on multiple tasks.
|
Run RULER benchmark on multiple tasks.
|
||||||
@@ -339,6 +340,7 @@ def run_ruler_benchmark(
|
|||||||
if sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
if sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
||||||
llm_kwargs["sparse_threshold"] = sparse_threshold
|
llm_kwargs["sparse_threshold"] = sparse_threshold
|
||||||
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
||||||
|
llm_kwargs["sparse_stride"] = sparse_stride
|
||||||
|
|
||||||
# Factory function for fresh_llm mode
|
# Factory function for fresh_llm mode
|
||||||
def create_llm():
|
def create_llm():
|
||||||
@@ -485,6 +487,8 @@ if __name__ == "__main__":
|
|||||||
help="XAttention BSA: samples per chunk for estimation")
|
help="XAttention BSA: samples per chunk for estimation")
|
||||||
parser.add_argument("--sparse-block-size", type=int, default=128,
|
parser.add_argument("--sparse-block-size", type=int, default=128,
|
||||||
help="XAttention BSA: block size for estimation")
|
help="XAttention BSA: block size for estimation")
|
||||||
|
parser.add_argument("--sparse-stride", type=int, default=8,
|
||||||
|
help="XAttention BSA: stride for Q/K downsampling")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -521,6 +525,7 @@ if __name__ == "__main__":
|
|||||||
sparse_threshold=args.sparse_threshold,
|
sparse_threshold=args.sparse_threshold,
|
||||||
sparse_samples=args.sparse_samples,
|
sparse_samples=args.sparse_samples,
|
||||||
sparse_block_size=args.sparse_block_size,
|
sparse_block_size=args.sparse_block_size,
|
||||||
|
sparse_stride=args.sparse_stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exit code (skip for json output mode)
|
# Exit code (skip for json output mode)
|
||||||
|
|||||||
Reference in New Issue
Block a user