[780764]: / src / utils.py

Download this file

109 lines (84 with data), 2.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
import csv
import json
import logging
import os
import pickle
import random
import numpy as np
import torch
""" TODO: update if necessary """
# path to the root project directory on the local machine
project_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir))
# project name, i.e., llemr
project_name = os.path.basename(project_path)
# path to the root project directory on the remote machine (default to local path)
remote_project_path = project_path
# path to the raw data directory on the remote machine
raw_data_path = os.path.join(remote_project_path, "raw_data")
# path to the processed data directory on the remote machine
processed_data_path = os.path.join(remote_project_path, "processed_data")
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)
def read_csv(filename):
logging.info(f"Reading from {filename}")
data = []
with open(filename, "r") as file:
csv_reader = csv.DictReader(file, delimiter=",")
for row in csv_reader:
data.append(row)
header = list(data[0].keys())
return header, data
def read_txt(filename):
logging.info(f"Reading from {filename}")
data = []
with open(filename, "r") as file:
lines = file.read().splitlines()
for line in lines:
data.append(line)
return data
def write_txt(filename, data):
logging.info(f"Writing to {filename}")
with open(filename, "w") as file:
for line in data:
file.write(line + "\n")
return
def read_json(filename):
logging.info(f"Reading from {filename}")
with open(filename, "r") as file:
data = json.load(file)
return data
def write_json(filename, data):
logging.info(f"Writing to {filename}")
with open(filename, "w") as file:
json.dump(data, file)
return
def create_directory(directory):
if not os.path.exists(directory):
logging.info(f"Creating directory {directory}")
os.makedirs(directory)
def load_pickle(filename):
logging.info(f"Data loaded from {filename}")
with open(filename, "rb") as f:
return pickle.load(f)
def dump_pickle(data, filename):
logging.info(f"Data saved to {filename}")
with open(filename, "wb") as f:
pickle.dump(data, f)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
set_seed(0)
print(project_path)
print(project_name)
print(remote_root)
print(remote_project_path)
print(raw_data_path)
print(processed_data_path)