feat: add configurable stride and chunk_size for XAttention BSA

- Add sparse_chunk_size config option (default: 16384)
- Pass stride, chunk_size, use_triton through factory function
- Add --sparse-stride CLI option to test_ruler.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-23 10:37:04 +08:00
parent f28b500120
commit 7c41032a2e
4 changed files with 10 additions and 0 deletions

View File

@@ -51,6 +51,7 @@ class Config:
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
sparse_use_triton: bool = True # Use Triton kernels for estimation
sparse_stride: int = 8 # Stride for Q/K downsampling
sparse_chunk_size: int = 16384 # Triton kernel chunk size for estimation
def __post_init__(self):
assert os.path.isdir(self.model)

View File

@@ -79,6 +79,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
'threshold': getattr(config, 'sparse_threshold', 0.9),
'use_triton': getattr(config, 'sparse_use_triton', True),
'stride': getattr(config, 'sparse_stride', 8),
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
}
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)

View File

@@ -61,6 +61,9 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
block_size=kwargs.get("block_size", 128),
samples_per_chunk=kwargs.get("samples_per_chunk", 128),
threshold=kwargs.get("threshold", 0.9),
stride=kwargs.get("stride", 8),
chunk_size=kwargs.get("chunk_size", 16384),
use_triton=kwargs.get("use_triton", True),
)
else: