better
This commit is contained in:
@@ -53,7 +53,7 @@ class ModelRunner:
|
||||
dist.barrier()
|
||||
if self.rank == 0:
|
||||
self.shm.unlink()
|
||||
# dist.destroy_process_group()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def loop(self):
|
||||
while True:
|
||||
@@ -92,7 +92,7 @@ class ModelRunner:
|
||||
hf_config = config.hf_config
|
||||
total, used, _ = get_gpu_memory()
|
||||
free = total * gpu_memory_utilization - used
|
||||
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(free) // block_bytes
|
||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||
@@ -120,7 +120,6 @@ class ModelRunner:
|
||||
max_seqlen_q = 0
|
||||
max_seqlen_k = 0
|
||||
slot_mapping = []
|
||||
context_lens = None
|
||||
block_tables = None
|
||||
for seq in seqs:
|
||||
seqlen = len(seq)
|
||||
@@ -142,14 +141,13 @@ class ModelRunner:
|
||||
assert len(input_ids) == len(slot_mapping)
|
||||
assert len(input_ids) == cu_seqlens_q[-1]
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
@@ -205,7 +203,7 @@ class ModelRunner:
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs)
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
logits = self.run_model(input_ids, positions, is_prefill)
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
reset_context()
|
||||
|
||||
Reference in New Issue
Block a user