[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -156,22 +156,19 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Initialize sparse policies if manager has them (CPU offload mode)
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'):
# Initialize both policies with model config
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
if policy is not None:
policy.initialize(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
# Initialize sparse policy if manager has one (CPU offload mode)
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
self.kvcache_manager.sparse_policy.initialize(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
logger.info(
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} "
f"Sparse policy initialized: {config.sparse_policy.name} "
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
)