Diff of /util/util.py [000000] .. [03464c]

Switch to unified view

a b/util/util.py
1
"""
2
Contain some simple helper functions
3
"""
4
import os
5
import shutil
6
import torch
7
import random
8
import numpy as np
9
10
11
def mkdir(path):
12
    """
13
    Create a empty directory in the disk if it didn't exist
14
15
    Parameters:
16
        path(str) -- a directory path we would like to create
17
    """
18
    if not os.path.exists(path):
19
        os.makedirs(path)
20
21
22
def clear_dir(path):
23
    """
24
    delete all files in a path
25
26
    Parameters:
27
        path(str) -- a directory path that we would like to delete all files in it
28
    """
29
    if os.path.exists(path):
30
        shutil.rmtree(path, ignore_errors=True)
31
        os.makedirs(path, exist_ok=True)
32
33
34
def setup_seed(seed):
35
    """
36
    setup seed to make the experiments deterministic
37
38
    Parameters:
39
        seed(int) -- the random seed
40
    """
41
    torch.manual_seed(seed)
42
    torch.cuda.manual_seed_all(seed)
43
    np.random.seed(seed)
44
    random.seed(seed)
45
    torch.backends.cudnn.deterministic = True
46
47
48
def get_time_points(T_max, time_num, extra_time_percent=0.1):
49
    """
50
    Get time points for the MTLR model
51
    """
52
    # Get time points in the time axis
53
    time_points = np.linspace(0, T_max * (1 + extra_time_percent), time_num + 1)
54
55
    return time_points