[refactor] Before add sprae policy.
This commit is contained in:
@@ -382,7 +382,7 @@ class ModelRunner:
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
# Check if Chunked Offload mode should be used (all blocks on CPU)
|
||||
#> Check if Chunked Offload mode should be used (all blocks on CPU)
|
||||
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
|
||||
use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill)
|
||||
if use_chunked_offload:
|
||||
@@ -391,6 +391,7 @@ class ModelRunner:
|
||||
else:
|
||||
return self.run_chunked_offload_decode(seqs)
|
||||
|
||||
#> Following Code will not use Chunked Offload mode
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
logits = self.run_model(input_ids, positions, is_prefill)
|
||||
|
||||
Reference in New Issue
Block a user