This commit is contained in:
GeeeekExplorer
2025-06-21 17:04:53 +08:00
parent ad4e95fbdc
commit cde3fc22c2
9 changed files with 42 additions and 100 deletions

View File

@@ -1,4 +1,3 @@
from contextlib import contextmanager
from dataclasses import dataclass
import torch

View File

@@ -1,18 +0,0 @@
import os
import torch
from pynvml import *
def get_gpu_memory():
torch.cuda.synchronize()
nvmlInit()
visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
cuda_device_idx = torch.cuda.current_device()
cuda_device_idx = visible_device[cuda_device_idx]
handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
mem_info = nvmlDeviceGetMemoryInfo(handle)
total_memory = mem_info.total
used_memory = mem_info.used
free_memory = mem_info.free
nvmlShutdown()
return total_memory, used_memory, free_memory