--- a +++ b/survival4D/config.py @@ -0,0 +1,120 @@ +import os +import json +from pathlib import Path +from pyhocon import ConfigTree, ConfigFactory + + +def get_conf(conf: ConfigTree, group: str = "", key: str = "", default=None): + if group: + key = ".".join([group, key]) + return conf.get(key, default) + + +class BaseConfig: + GROUP = "" + + @classmethod + def from_conf(cls, conf_path: Path): + raise NotImplementedError("Must be implemented by subclasses") + + def save(self, output_dir: Path): + assert output_dir.is_dir(), "output_dir has to be a directory." + with open(str(output_dir.joinpath(self.__class__.__name__)) + ".json", "w") as file: + json.dump(self.__dict__, file, indent=4) + + def to_dict(self): + return self.__dict__ + + +class ExperimentConfig(BaseConfig): + GROUP = "experiment" + + def __init__(self, data_path: Path, output_dir: Path, n_evals: int, n_bootstraps: int, n_folds: int, + search_method: str): + self.data_path = data_path + self.output_dir = output_dir + self.n_evals = n_evals + self.n_folds = n_folds + self.n_bootstraps = n_bootstraps + self.search_method = search_method + + +class NNExperimentConfig(ExperimentConfig): + def __init__( + self, data_path: Path, output_dir: Path, n_evals: int, n_bootstraps: int, n_folds: int, search_method: str, + batch_size: int, n_epochs: int, backend: str + ): + super().__init__( + data_path=data_path, output_dir=output_dir, n_evals=n_evals, n_bootstraps=n_bootstraps, n_folds=n_folds, + search_method=search_method + ) + self.batch_size = batch_size + self.n_epochs = n_epochs + self.backend = backend + + @classmethod + def from_conf(cls, conf_path): + conf = ConfigFactory.parse_file(str(conf_path)) + data_path = Path(os.path.abspath(get_conf(conf, group=cls.GROUP, key="data_path"))) + if get_conf(conf, group=cls.GROUP, key="output_dir") is None: + output_dir = data_path.parent.joinpath("output") + else: + output_dir = Path(get_conf(conf, group=cls.GROUP, key="output_dir")) + return cls( + data_path=data_path, + output_dir=output_dir, + batch_size=get_conf(conf, group=cls.GROUP, key="batch_size", default=16), + n_epochs=get_conf(conf, group=cls.GROUP, key="n_epochs", default=100), + n_evals=get_conf(conf, group=cls.GROUP, key="n_evals", default=50), + n_bootstraps=get_conf(conf, group=cls.GROUP, key="n_bootstraps", default=100), + n_folds=get_conf(conf, group=cls.GROUP, key="n_folds", default=6), + search_method=get_conf(conf, group=cls.GROUP, key="search_method", default="particle swarm"), + backend=get_conf(conf, group=cls.GROUP, key="backend", default="torch") + ) + + +class CoxExperimentConfig(ExperimentConfig): + @classmethod + def from_conf(cls, conf_path): + conf = ConfigFactory.parse_file(str(conf_path)) + data_path = Path(os.path.abspath(get_conf(conf, group=cls.GROUP, key="data_path"))) + if get_conf(conf, group=cls.GROUP, key="output_dir") is None: + output_dir = data_path.parent.joinpath("output") + else: + output_dir = Path(get_conf(conf, group=cls.GROUP, key="output_dir")) + return cls( + data_path=data_path, + output_dir=output_dir, + n_evals=get_conf(conf, group=cls.GROUP, key="n_evals", default=50), + n_bootstraps=get_conf(conf, group=cls.GROUP, key="n_bootstraps", default=100), + n_folds=get_conf(conf, group=cls.GROUP, key="n_folds", default=6), + search_method=get_conf(conf, group=cls.GROUP, key="search_method", default="particle swarm") + ) + + +class HypersearchConfig(BaseConfig): + GROUP = "hypersearch" + + def __init__(self, **kwargs): + for key in kwargs: + setattr(self, key, kwargs[key]) + + @classmethod + def from_conf(cls, conf_path: Path): + conf = ConfigFactory.parse_file(str(conf_path)) + conf = getattr(conf, cls.GROUP) + return cls(**conf) + + +class ModelConfig(BaseConfig): + GROUP = "model" + + def __init__(self, **kwargs): + for key in kwargs: + setattr(self, key, kwargs[key]) + + @classmethod + def from_conf(cls, conf_path: Path): + conf = ConfigFactory.parse_file(str(conf_path)) + conf = getattr(conf, cls.GROUP) + return cls(**conf)