Diff of /survival4D/config.py [000000] .. [2cc208]

Switch to side-by-side view

--- a
+++ b/survival4D/config.py
@@ -0,0 +1,120 @@
+import os
+import json
+from pathlib import Path
+from pyhocon import ConfigTree, ConfigFactory
+
+
+def get_conf(conf: ConfigTree, group: str = "", key: str = "", default=None):
+    if group:
+        key = ".".join([group, key])
+    return conf.get(key, default)
+
+
+class BaseConfig:
+    GROUP = ""
+
+    @classmethod
+    def from_conf(cls, conf_path: Path):
+        raise NotImplementedError("Must be implemented by subclasses")
+
+    def save(self, output_dir: Path):
+        assert output_dir.is_dir(), "output_dir has to be a directory."
+        with open(str(output_dir.joinpath(self.__class__.__name__)) + ".json", "w") as file:
+            json.dump(self.__dict__, file, indent=4)
+
+    def to_dict(self):
+        return self.__dict__
+
+
+class ExperimentConfig(BaseConfig):
+    GROUP = "experiment"
+
+    def __init__(self, data_path: Path, output_dir: Path, n_evals: int, n_bootstraps: int, n_folds: int,
+                 search_method: str):
+        self.data_path = data_path
+        self.output_dir = output_dir
+        self.n_evals = n_evals
+        self.n_folds = n_folds
+        self.n_bootstraps = n_bootstraps
+        self.search_method = search_method
+
+
+class NNExperimentConfig(ExperimentConfig):
+    def __init__(
+            self, data_path: Path, output_dir: Path, n_evals: int, n_bootstraps: int, n_folds: int, search_method: str,
+            batch_size: int, n_epochs: int, backend: str
+    ):
+        super().__init__(
+            data_path=data_path, output_dir=output_dir, n_evals=n_evals, n_bootstraps=n_bootstraps, n_folds=n_folds,
+            search_method=search_method
+        )
+        self.batch_size = batch_size
+        self.n_epochs = n_epochs
+        self.backend = backend
+
+    @classmethod
+    def from_conf(cls, conf_path):
+        conf = ConfigFactory.parse_file(str(conf_path))
+        data_path = Path(os.path.abspath(get_conf(conf, group=cls.GROUP, key="data_path")))
+        if get_conf(conf, group=cls.GROUP, key="output_dir") is None:
+            output_dir = data_path.parent.joinpath("output")
+        else:
+            output_dir = Path(get_conf(conf, group=cls.GROUP, key="output_dir"))
+        return cls(
+            data_path=data_path,
+            output_dir=output_dir,
+            batch_size=get_conf(conf, group=cls.GROUP, key="batch_size", default=16),
+            n_epochs=get_conf(conf, group=cls.GROUP, key="n_epochs", default=100),
+            n_evals=get_conf(conf, group=cls.GROUP, key="n_evals", default=50),
+            n_bootstraps=get_conf(conf, group=cls.GROUP, key="n_bootstraps", default=100),
+            n_folds=get_conf(conf, group=cls.GROUP, key="n_folds", default=6),
+            search_method=get_conf(conf, group=cls.GROUP, key="search_method", default="particle swarm"),
+            backend=get_conf(conf, group=cls.GROUP, key="backend", default="torch")
+        )
+
+
+class CoxExperimentConfig(ExperimentConfig):
+    @classmethod
+    def from_conf(cls, conf_path):
+        conf = ConfigFactory.parse_file(str(conf_path))
+        data_path = Path(os.path.abspath(get_conf(conf, group=cls.GROUP, key="data_path")))
+        if get_conf(conf, group=cls.GROUP, key="output_dir") is None:
+            output_dir = data_path.parent.joinpath("output")
+        else:
+            output_dir = Path(get_conf(conf, group=cls.GROUP, key="output_dir"))
+        return cls(
+            data_path=data_path,
+            output_dir=output_dir,
+            n_evals=get_conf(conf, group=cls.GROUP, key="n_evals", default=50),
+            n_bootstraps=get_conf(conf, group=cls.GROUP, key="n_bootstraps", default=100),
+            n_folds=get_conf(conf, group=cls.GROUP, key="n_folds", default=6),
+            search_method=get_conf(conf, group=cls.GROUP, key="search_method", default="particle swarm")
+        )
+
+
+class HypersearchConfig(BaseConfig):
+    GROUP = "hypersearch"
+
+    def __init__(self, **kwargs):
+        for key in kwargs:
+            setattr(self, key, kwargs[key])
+
+    @classmethod
+    def from_conf(cls, conf_path: Path):
+        conf = ConfigFactory.parse_file(str(conf_path))
+        conf = getattr(conf, cls.GROUP)
+        return cls(**conf)
+
+
+class ModelConfig(BaseConfig):
+    GROUP = "model"
+
+    def __init__(self, **kwargs):
+        for key in kwargs:
+            setattr(self, key, kwargs[key])
+
+    @classmethod
+    def from_conf(cls, conf_path: Path):
+        conf = ConfigFactory.parse_file(str(conf_path))
+        conf = getattr(conf, cls.GROUP)
+        return cls(**conf)