[feat] Added Quest Sparsity Policy.
This commit is contained in:
@@ -156,22 +156,19 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Initialize sparse policies if manager has them (CPU offload mode)
|
||||
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'):
|
||||
# Initialize both policies with model config
|
||||
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
|
||||
if policy is not None:
|
||||
policy.initialize(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
# Initialize sparse policy if manager has one (CPU offload mode)
|
||||
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||
self.kvcache_manager.sparse_policy.initialize(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} "
|
||||
f"Sparse policy initialized: {config.sparse_policy.name} "
|
||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user