support CUDA_VISIBLE_DEVICES
This commit is contained in:
@@ -35,7 +35,7 @@ class ModelRunner:
|
|||||||
total, used, _ = get_gpu_memory()
|
total, used, _ = get_gpu_memory()
|
||||||
free = total * gpu_memory_utilization - used
|
free = total * gpu_memory_utilization - used
|
||||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||||
config.num_kvcache_blocks = int(free * 1e6) // block_bytes
|
config.num_kvcache_blocks = int(free) // block_bytes
|
||||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim)
|
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim)
|
||||||
layer_id = 0
|
layer_id = 0
|
||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import torch
|
import torch
|
||||||
|
from pynvml import *
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory(device_id: int = 0):
|
def get_gpu_memory():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
result = subprocess.check_output(
|
nvmlInit()
|
||||||
['nvidia-smi', '-i', str(device_id), '--query-gpu=memory.total,memory.used,memory.free', '--format=csv,nounits,noheader'],
|
visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
|
||||||
encoding='utf-8'
|
cuda_device_idx = torch.cuda.current_device()
|
||||||
)
|
cuda_device_idx = visible_device[cuda_device_idx]
|
||||||
total_memory, used_memory, free_memory = [int(x) for x in result.strip().split(', ')]
|
handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
|
||||||
|
mem_info = nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
total_memory = mem_info.total
|
||||||
|
used_memory = mem_info.used
|
||||||
|
free_memory = mem_info.free
|
||||||
|
nvmlShutdown()
|
||||||
return total_memory, used_memory, free_memory
|
return total_memory, used_memory, free_memory
|
||||||
|
|
||||||
Reference in New Issue
Block a user