This commit is contained in:
GeeeekExplorer
2025-06-17 23:15:02 +08:00
parent 7e42fa6f63
commit bc0ad5a116
6 changed files with 27 additions and 23 deletions

View File

@@ -53,7 +53,7 @@ class ModelRunner:
dist.barrier()
if self.rank == 0:
self.shm.unlink()
# dist.destroy_process_group()
dist.destroy_process_group()
def loop(self):
while True:
@@ -92,7 +92,7 @@ class ModelRunner:
hf_config = config.hf_config
total, used, _ = get_gpu_memory()
free = total * gpu_memory_utilization - used
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
num_kv_heads = hf_config.num_key_value_heads // self.world_size
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(free) // block_bytes
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
@@ -120,7 +120,6 @@ class ModelRunner:
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
context_lens = None
block_tables = None
for seq in seqs:
seqlen = len(seq)
@@ -142,14 +141,13 @@ class ModelRunner:
assert len(input_ids) == len(slot_mapping)
assert len(input_ids) == cu_seqlens_q[-1]
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
@@ -205,7 +203,7 @@ class ModelRunner:
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()