Diff of /util/util.py [000000] .. [03464c]

Switch to side-by-side view

--- a
+++ b/util/util.py
@@ -0,0 +1,55 @@
+"""
+Contain some simple helper functions
+"""
+import os
+import shutil
+import torch
+import random
+import numpy as np
+
+
+def mkdir(path):
+    """
+    Create a empty directory in the disk if it didn't exist
+
+    Parameters:
+        path(str) -- a directory path we would like to create
+    """
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def clear_dir(path):
+    """
+    delete all files in a path
+
+    Parameters:
+        path(str) -- a directory path that we would like to delete all files in it
+    """
+    if os.path.exists(path):
+        shutil.rmtree(path, ignore_errors=True)
+        os.makedirs(path, exist_ok=True)
+
+
+def setup_seed(seed):
+    """
+    setup seed to make the experiments deterministic
+
+    Parameters:
+        seed(int) -- the random seed
+    """
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+    torch.backends.cudnn.deterministic = True
+
+
+def get_time_points(T_max, time_num, extra_time_percent=0.1):
+    """
+    Get time points for the MTLR model
+    """
+    # Get time points in the time axis
+    time_points = np.linspace(0, T_max * (1 + extra_time_percent), time_num + 1)
+
+    return time_points