본문 바로가기

TIL

[Python] GPU 사용량 확인

사용가능한 여러개의 GPU중에서 사용량 기준으로 어떤 GPU를 모델 학습/추론에 사용할지 자동으로 결정하기 위한 코드를 작성하고 싶어 방법을 찾아보던중, 새로 알게된 정보가 있어 적어놓으려한다.

torch.cuda.memory_allocated()

서두에서 언급한 내용을 검색해보면, 보통 `torch.cuda.memory_allocated(device_idx)`를 사용하면 `device_idx`에 해당하는 gpu에서 어느 정도의 메모리를 사용하고 있는지가 나온다고 한다.

 

물론 맞는 말이지만, 조사하면서 알게된 내용을 덧붙이자면, 해당 코드는 pytorch 프레임워크 내에서만 사용중인 gpu 메모리의 양을 보여준다.

 

예를 들어서 한 컴퓨터에서 그래픽 작업을 위한 툴을 사용하면서 GPU를 사용중이라고 하면, `nvidia-smi` 커맨드를 입력했을 때에는 0이 넘는 gpu-util값이 나오겠지만, `torch.cuda.memory_allocated(device_idx)`의 값은 0으로 아무것도 없다고 나오게 되는 것이다.

 

chatgpt에게 pytorch내에서 `nvidia-smi`와 동일한 역할을 하는 코드가 있는지 물어본 결과, pytorch에서 제공하는 것은 없고 python의 외부 라이브러리인 `pynvml`을 통해서 전체 gpu의 사용량을 확인할 수 있다는 것을 알 수 있었다.

pynvml

사용방법은 다음과 같다.

import pynvml

def get_gpu_info():
    # NVML 초기화
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    
    gpu_info = []
    
    for i in range(device_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        name = pynvml.nvmlDeviceGetName(handle)
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        util_info = pynvml.nvmlDeviceGetUtilizationRates(handle)
        
        gpu_info.append({
            "index": i,
            "name": name.decode("utf-8"),
            "memory_total": mem_info.total / (1024 ** 2),  # MB
            "memory_used": mem_info.used / (1024 ** 2),   # MB
            "memory_free": mem_info.free / (1024 ** 2),   # MB
            "gpu_util": util_info.gpu,                   # GPU utilization (%)
            "memory_util": util_info.memory,             # Memory utilization (%)
        })
    
    # NVML 종료
    pynvml.nvmlShutdown()
    
    return gpu_info

# GPU 정보 출력
for gpu in get_gpu_info():
    print(f"GPU {gpu['index']} ({gpu['name']}):")
    print(f"  - Memory Total: {gpu['memory_total']} MB")
    print(f"  - Memory Used: {gpu['memory_used']} MB")
    print(f"  - Memory Free: {gpu['memory_free']} MB")
    print(f"  - GPU Utilization: {gpu['gpu_util']}%")
    print(f"  - Memory Utilization: {gpu['memory_util']}%")
    print()