|
a |
|
b/src/utils.py |
|
|
1 |
import csv |
|
|
2 |
import json |
|
|
3 |
import logging |
|
|
4 |
import os |
|
|
5 |
import pickle |
|
|
6 |
import random |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
import torch |
|
|
10 |
|
|
|
11 |
""" TODO: update if necessary """ |
|
|
12 |
# path to the root project directory on the local machine |
|
|
13 |
project_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)) |
|
|
14 |
# project name, i.e., llemr |
|
|
15 |
project_name = os.path.basename(project_path) |
|
|
16 |
# path to the root project directory on the remote machine (default to local path) |
|
|
17 |
remote_project_path = project_path |
|
|
18 |
# path to the raw data directory on the remote machine |
|
|
19 |
raw_data_path = os.path.join(remote_project_path, "raw_data") |
|
|
20 |
# path to the processed data directory on the remote machine |
|
|
21 |
processed_data_path = os.path.join(remote_project_path, "processed_data") |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
def set_seed(seed): |
|
|
25 |
random.seed(seed) |
|
|
26 |
np.random.seed(seed) |
|
|
27 |
torch.manual_seed(seed) |
|
|
28 |
if torch.cuda.is_available(): |
|
|
29 |
torch.cuda.manual_seed(seed) |
|
|
30 |
torch.cuda.manual_seed_all(seed) |
|
|
31 |
torch.backends.cudnn.deterministic = True |
|
|
32 |
torch.backends.cudnn.benchmark = False |
|
|
33 |
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def read_csv(filename): |
|
|
37 |
logging.info(f"Reading from {filename}") |
|
|
38 |
data = [] |
|
|
39 |
with open(filename, "r") as file: |
|
|
40 |
csv_reader = csv.DictReader(file, delimiter=",") |
|
|
41 |
for row in csv_reader: |
|
|
42 |
data.append(row) |
|
|
43 |
header = list(data[0].keys()) |
|
|
44 |
return header, data |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def read_txt(filename): |
|
|
48 |
logging.info(f"Reading from {filename}") |
|
|
49 |
data = [] |
|
|
50 |
with open(filename, "r") as file: |
|
|
51 |
lines = file.read().splitlines() |
|
|
52 |
for line in lines: |
|
|
53 |
data.append(line) |
|
|
54 |
return data |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
def write_txt(filename, data): |
|
|
58 |
logging.info(f"Writing to {filename}") |
|
|
59 |
with open(filename, "w") as file: |
|
|
60 |
for line in data: |
|
|
61 |
file.write(line + "\n") |
|
|
62 |
return |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
def read_json(filename): |
|
|
66 |
logging.info(f"Reading from {filename}") |
|
|
67 |
with open(filename, "r") as file: |
|
|
68 |
data = json.load(file) |
|
|
69 |
return data |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def write_json(filename, data): |
|
|
73 |
logging.info(f"Writing to {filename}") |
|
|
74 |
with open(filename, "w") as file: |
|
|
75 |
json.dump(data, file) |
|
|
76 |
return |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
def create_directory(directory): |
|
|
80 |
if not os.path.exists(directory): |
|
|
81 |
logging.info(f"Creating directory {directory}") |
|
|
82 |
os.makedirs(directory) |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
def load_pickle(filename): |
|
|
86 |
logging.info(f"Data loaded from {filename}") |
|
|
87 |
with open(filename, "rb") as f: |
|
|
88 |
return pickle.load(f) |
|
|
89 |
|
|
|
90 |
|
|
|
91 |
def dump_pickle(data, filename): |
|
|
92 |
logging.info(f"Data saved to {filename}") |
|
|
93 |
with open(filename, "wb") as f: |
|
|
94 |
pickle.dump(data, f) |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
def count_parameters(model): |
|
|
98 |
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
if __name__ == "__main__": |
|
|
102 |
set_seed(0) |
|
|
103 |
print(project_path) |
|
|
104 |
print(project_name) |
|
|
105 |
print(remote_root) |
|
|
106 |
print(remote_project_path) |
|
|
107 |
print(raw_data_path) |
|
|
108 |
print(processed_data_path) |