Diff of /src/utils.py [000000] .. [780764]

Switch to unified view

a b/src/utils.py
1
import csv
2
import json
3
import logging
4
import os
5
import pickle
6
import random
7
8
import numpy as np
9
import torch
10
11
""" TODO: update if necessary """
12
# path to the root project directory on the local machine
13
project_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir))
14
# project name, i.e., llemr
15
project_name = os.path.basename(project_path)
16
# path to the root project directory on the remote machine (default to local path)
17
remote_project_path = project_path
18
# path to the raw data directory on the remote machine
19
raw_data_path = os.path.join(remote_project_path, "raw_data")
20
# path to the processed data directory on the remote machine
21
processed_data_path = os.path.join(remote_project_path, "processed_data")
22
23
24
def set_seed(seed):
25
    random.seed(seed)
26
    np.random.seed(seed)
27
    torch.manual_seed(seed)
28
    if torch.cuda.is_available():
29
        torch.cuda.manual_seed(seed)
30
        torch.cuda.manual_seed_all(seed)
31
        torch.backends.cudnn.deterministic = True
32
        torch.backends.cudnn.benchmark = False
33
    os.environ["PYTHONHASHSEED"] = str(seed)
34
35
36
def read_csv(filename):
37
    logging.info(f"Reading from {filename}")
38
    data = []
39
    with open(filename, "r") as file:
40
        csv_reader = csv.DictReader(file, delimiter=",")
41
        for row in csv_reader:
42
            data.append(row)
43
    header = list(data[0].keys())
44
    return header, data
45
46
47
def read_txt(filename):
48
    logging.info(f"Reading from {filename}")
49
    data = []
50
    with open(filename, "r") as file:
51
        lines = file.read().splitlines()
52
        for line in lines:
53
            data.append(line)
54
    return data
55
56
57
def write_txt(filename, data):
58
    logging.info(f"Writing to {filename}")
59
    with open(filename, "w") as file:
60
        for line in data:
61
            file.write(line + "\n")
62
    return
63
64
65
def read_json(filename):
66
    logging.info(f"Reading from {filename}")
67
    with open(filename, "r") as file:
68
        data = json.load(file)
69
    return data
70
71
72
def write_json(filename, data):
73
    logging.info(f"Writing to {filename}")
74
    with open(filename, "w") as file:
75
        json.dump(data, file)
76
    return
77
78
79
def create_directory(directory):
80
    if not os.path.exists(directory):
81
        logging.info(f"Creating directory {directory}")
82
        os.makedirs(directory)
83
84
85
def load_pickle(filename):
86
    logging.info(f"Data loaded from {filename}")
87
    with open(filename, "rb") as f:
88
        return pickle.load(f)
89
90
91
def dump_pickle(data, filename):
92
    logging.info(f"Data saved to {filename}")
93
    with open(filename, "wb") as f:
94
        pickle.dump(data, f)
95
96
97
def count_parameters(model):
98
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
99
100
101
if __name__ == "__main__":
102
    set_seed(0)
103
    print(project_path)
104
    print(project_name)
105
    print(remote_root)
106
    print(remote_project_path)
107
    print(raw_data_path)
108
    print(processed_data_path)