diff --git a/nanovllm/config.py b/nanovllm/config.py index 2766654..faaeb34 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -51,6 +51,7 @@ class Config: sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention) sparse_use_triton: bool = True # Use Triton kernels for estimation sparse_stride: int = 8 # Stride for Q/K downsampling + sparse_chunk_size: int = 16384 # Triton kernel chunk size for estimation def __post_init__(self): assert os.path.isdir(self.model) diff --git a/nanovllm/kvcache/__init__.py b/nanovllm/kvcache/__init__.py index 155697d..fe0456d 100644 --- a/nanovllm/kvcache/__init__.py +++ b/nanovllm/kvcache/__init__.py @@ -79,6 +79,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager: 'threshold': getattr(config, 'sparse_threshold', 0.9), 'use_triton': getattr(config, 'sparse_use_triton', True), 'stride': getattr(config, 'sparse_stride', 8), + 'chunk_size': getattr(config, 'sparse_chunk_size', 16384), } sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs) diff --git a/nanovllm/kvcache/sparse/__init__.py b/nanovllm/kvcache/sparse/__init__.py index 6c947fe..c4ec9be 100644 --- a/nanovllm/kvcache/sparse/__init__.py +++ b/nanovllm/kvcache/sparse/__init__.py @@ -61,6 +61,9 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic block_size=kwargs.get("block_size", 128), samples_per_chunk=kwargs.get("samples_per_chunk", 128), threshold=kwargs.get("threshold", 0.9), + stride=kwargs.get("stride", 8), + chunk_size=kwargs.get("chunk_size", 16384), + use_triton=kwargs.get("use_triton", True), ) else: diff --git a/tests/test_ruler.py b/tests/test_ruler.py index c75532d..5db2638 100644 --- a/tests/test_ruler.py +++ b/tests/test_ruler.py @@ -274,6 +274,7 @@ def run_ruler_benchmark( sparse_threshold: float = 0.9, sparse_samples: int = 128, sparse_block_size: int = 128, + sparse_stride: int = 8, ) -> Dict: """ Run RULER benchmark on multiple tasks. @@ -339,6 +340,7 @@ def run_ruler_benchmark( if sparse_policy_type == SparsePolicyType.XATTN_BSA: llm_kwargs["sparse_threshold"] = sparse_threshold llm_kwargs["sparse_samples_per_chunk"] = sparse_samples + llm_kwargs["sparse_stride"] = sparse_stride # Factory function for fresh_llm mode def create_llm(): @@ -485,6 +487,8 @@ if __name__ == "__main__": help="XAttention BSA: samples per chunk for estimation") parser.add_argument("--sparse-block-size", type=int, default=128, help="XAttention BSA: block size for estimation") + parser.add_argument("--sparse-stride", type=int, default=8, + help="XAttention BSA: stride for Q/K downsampling") args = parser.parse_args() @@ -521,6 +525,7 @@ if __name__ == "__main__": sparse_threshold=args.sparse_threshold, sparse_samples=args.sparse_samples, sparse_block_size=args.sparse_block_size, + sparse_stride=args.sparse_stride, ) # Exit code (skip for json output mode)