[96354c]: / src / config.py

Download this file

110 lines (81 with data), 4.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
import configparser
import os
import time
from src.logging_conf import logger
def create_directory(directory):
logger.debug(f"Create directory: {directory}")
if not os.path.exists(directory):
os.makedirs(directory)
def check_path_exists(path):
if not os.path.exists(path):
raise FileNotFoundError(path)
return path
def get_correct_path(local_path, server_path):
if os.path.exists(local_path):
root_path = local_path
elif os.path.exists(server_path):
root_path = server_path
else:
raise ValueError('No path is working')
return root_path
class BratsConfiguration:
def __init__(self, path):
self.config = configparser.ConfigParser()
self.path = check_path_exists(path)
self.config.read(self.path)
self.prepare_parameters()
def prepare_parameters(self):
self.config["model"]["model_path"] = get_correct_path(self.config.get("model", "model_path_local"),
self.config.get("model", "model_path_server"))
train = self.config.getboolean("basics", "train_flag")
resume = self.config.getboolean("basics", "resume")
if train:
if resume:
model_name, _ = os.path.split(self.config.get("model", "checkpoint"))
else:
model_name = f"model_{round(time.time())}"
logger.info("Create model directory and save configuration")
self.config["basics"]["tensorboard_logs"] = os.path.join(self.config.get("basics", "tensorboard_logs"),
model_name)
self.config["model"]["checkpoint"] = os.path.join(self.config.get("model", "model_path"),
self.config.get("model", "checkpoint"))
create_directory(self.config.get("basics", "tensorboard_logs"))
self.config["model"]["model_path"] = os.path.join(self.config.get("model", "model_path"), model_name)
create_directory(self.config["model"]["model_path"])
# save current configuration there
with open(os.path.join(self.config["model"]["model_path"], "config.ini"), 'w') as configfile:
self.config.write(configfile)
if train:
sampling_method = self.config["dataset"]["source_sampling"].split(".")[-1]
else:
sampling_method = self.config["dataset"]["sampling_method"].split(".")[-1]
# SETTING PATH!
self.config["dataset"]["root_path"] = get_correct_path(self.config.get("dataset", "dataset_root_path_local"),
self.config.get("dataset", "dataset_root_path_server"))
self.config["dataset"]["path_train"] = os.path.join(self.config["dataset"]["root_path"],
self.config["dataset"]["dataset_train_folder"],
sampling_method)
self.config["dataset"]["path_val"] = os.path.join(self.config["dataset"]["root_path"],
self.config["dataset"]["dataset_val_folder"],
sampling_method)
self.config["dataset"]["path_test"] = os.path.join(self.config["dataset"]["root_path"],
self.config["dataset"]["dataset_test_folder"],
sampling_method)
self.config["dataset"]["train_csv"] = os.path.join(self.config["dataset"]["path_train"],
self.config.get("dataset", "train_csv"))
self.config["dataset"]["val_csv"] = os.path.join(self.config["dataset"]["path_val"],
self.config.get("dataset", "val_csv"))
self.config["dataset"]["test_csv"] = os.path.join(self.config["dataset"]["path_test"],
self.config.get("dataset", "test_csv"))
if "batch_size" not in self.config["dataset"]:
self.config["dataset"]["batch_size"] = str(
self.config.getint("dataset", "n_patients_per_batch") * self.config.getint("dataset", "n_patches"))
self.patch_size = tuple([int(item) for item in self.config.get("dataset", "patch_size").split("\n")])
def get_model_config(self):
return self.config["model"]
def get_uncertainty_config(self):
return self.config["uncertainty"]
def get_dataset_config(self):
return self.config["dataset"]
def get_basic_config(self):
return self.config["basics"]