[2cc208]: / survival4D / config.py

Download this file

121 lines (97 with data), 4.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import json
from pathlib import Path
from pyhocon import ConfigTree, ConfigFactory
def get_conf(conf: ConfigTree, group: str = "", key: str = "", default=None):
if group:
key = ".".join([group, key])
return conf.get(key, default)
class BaseConfig:
GROUP = ""
@classmethod
def from_conf(cls, conf_path: Path):
raise NotImplementedError("Must be implemented by subclasses")
def save(self, output_dir: Path):
assert output_dir.is_dir(), "output_dir has to be a directory."
with open(str(output_dir.joinpath(self.__class__.__name__)) + ".json", "w") as file:
json.dump(self.__dict__, file, indent=4)
def to_dict(self):
return self.__dict__
class ExperimentConfig(BaseConfig):
GROUP = "experiment"
def __init__(self, data_path: Path, output_dir: Path, n_evals: int, n_bootstraps: int, n_folds: int,
search_method: str):
self.data_path = data_path
self.output_dir = output_dir
self.n_evals = n_evals
self.n_folds = n_folds
self.n_bootstraps = n_bootstraps
self.search_method = search_method
class NNExperimentConfig(ExperimentConfig):
def __init__(
self, data_path: Path, output_dir: Path, n_evals: int, n_bootstraps: int, n_folds: int, search_method: str,
batch_size: int, n_epochs: int, backend: str
):
super().__init__(
data_path=data_path, output_dir=output_dir, n_evals=n_evals, n_bootstraps=n_bootstraps, n_folds=n_folds,
search_method=search_method
)
self.batch_size = batch_size
self.n_epochs = n_epochs
self.backend = backend
@classmethod
def from_conf(cls, conf_path):
conf = ConfigFactory.parse_file(str(conf_path))
data_path = Path(os.path.abspath(get_conf(conf, group=cls.GROUP, key="data_path")))
if get_conf(conf, group=cls.GROUP, key="output_dir") is None:
output_dir = data_path.parent.joinpath("output")
else:
output_dir = Path(get_conf(conf, group=cls.GROUP, key="output_dir"))
return cls(
data_path=data_path,
output_dir=output_dir,
batch_size=get_conf(conf, group=cls.GROUP, key="batch_size", default=16),
n_epochs=get_conf(conf, group=cls.GROUP, key="n_epochs", default=100),
n_evals=get_conf(conf, group=cls.GROUP, key="n_evals", default=50),
n_bootstraps=get_conf(conf, group=cls.GROUP, key="n_bootstraps", default=100),
n_folds=get_conf(conf, group=cls.GROUP, key="n_folds", default=6),
search_method=get_conf(conf, group=cls.GROUP, key="search_method", default="particle swarm"),
backend=get_conf(conf, group=cls.GROUP, key="backend", default="torch")
)
class CoxExperimentConfig(ExperimentConfig):
@classmethod
def from_conf(cls, conf_path):
conf = ConfigFactory.parse_file(str(conf_path))
data_path = Path(os.path.abspath(get_conf(conf, group=cls.GROUP, key="data_path")))
if get_conf(conf, group=cls.GROUP, key="output_dir") is None:
output_dir = data_path.parent.joinpath("output")
else:
output_dir = Path(get_conf(conf, group=cls.GROUP, key="output_dir"))
return cls(
data_path=data_path,
output_dir=output_dir,
n_evals=get_conf(conf, group=cls.GROUP, key="n_evals", default=50),
n_bootstraps=get_conf(conf, group=cls.GROUP, key="n_bootstraps", default=100),
n_folds=get_conf(conf, group=cls.GROUP, key="n_folds", default=6),
search_method=get_conf(conf, group=cls.GROUP, key="search_method", default="particle swarm")
)
class HypersearchConfig(BaseConfig):
GROUP = "hypersearch"
def __init__(self, **kwargs):
for key in kwargs:
setattr(self, key, kwargs[key])
@classmethod
def from_conf(cls, conf_path: Path):
conf = ConfigFactory.parse_file(str(conf_path))
conf = getattr(conf, cls.GROUP)
return cls(**conf)
class ModelConfig(BaseConfig):
GROUP = "model"
def __init__(self, **kwargs):
for key in kwargs:
setattr(self, key, kwargs[key])
@classmethod
def from_conf(cls, conf_path: Path):
conf = ConfigFactory.parse_file(str(conf_path))
conf = getattr(conf, cls.GROUP)
return cls(**conf)