Switch to unified view

a b/sybil/utils/device_utils.py
1
import itertools
2
import os
3
from typing import Union
4
5
import torch
6
7
8
def get_default_device():
9
    if torch.cuda.is_available():
10
        return get_most_free_gpu()
11
    elif torch.backends.mps.is_available():
12
        # Not all operations implemented in MPS yet
13
        use_mps = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") == "1"
14
        if use_mps:
15
            return torch.device('mps')
16
        else:
17
            return torch.device('cpu')
18
    else:
19
        return torch.device('cpu')
20
21
22
def get_available_devices(num_devices=None, max_devices=None):
23
    device = get_default_device()
24
    if device.type == "cuda":
25
        num_gpus = torch.cuda.device_count()
26
        if max_devices is not None:
27
            num_gpus = min(num_gpus, max_devices)
28
        gpu_list = [get_device(i) for i in range(num_gpus)]
29
        if num_devices is not None:
30
            cycle_gpu_list = itertools.cycle(gpu_list)
31
            gpu_list = [next(cycle_gpu_list) for _ in range(num_devices)]
32
        return gpu_list
33
    else:
34
        num_devices = num_devices if num_devices else torch.multiprocessing.cpu_count()
35
        num_devices = min(num_devices, max_devices) if max_devices is not None else num_devices
36
        return [device]*num_devices
37
38
39
def get_device(gpu_id: int):
40
    if gpu_id is not None and torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
41
        return torch.device(f'cuda:{gpu_id}')
42
    else:
43
        return None
44
45
46
def get_device_mem_info(device: Union[int, torch.device]):
47
    if not torch.cuda.is_available():
48
        return None
49
50
    free_mem, total_mem = torch.cuda.mem_get_info(device=device)
51
    return free_mem, total_mem
52
53
54
def get_most_free_gpu():
55
    """
56
    Get the GPU with the most free memory
57
    If system has no GPUs (or CUDA not available), return None
58
    """
59
    if not torch.cuda.is_available():
60
        return None
61
62
    num_gpus = torch.cuda.device_count()
63
    if num_gpus == 0:
64
        return None
65
66
    most_free_idx, most_free_val = -1, -1
67
    for i in range(num_gpus):
68
        free_mem, total_mem = get_device_mem_info(i)
69
        if free_mem > most_free_val:
70
            most_free_idx, most_free_val = i, free_mem
71
72
    return torch.device(f'cuda:{most_free_idx}')