Diff of /survival4D/config.py [000000] .. [2cc208]

Switch to unified view

a b/survival4D/config.py
1
import os
2
import json
3
from pathlib import Path
4
from pyhocon import ConfigTree, ConfigFactory
5
6
7
def get_conf(conf: ConfigTree, group: str = "", key: str = "", default=None):
8
    if group:
9
        key = ".".join([group, key])
10
    return conf.get(key, default)
11
12
13
class BaseConfig:
14
    GROUP = ""
15
16
    @classmethod
17
    def from_conf(cls, conf_path: Path):
18
        raise NotImplementedError("Must be implemented by subclasses")
19
20
    def save(self, output_dir: Path):
21
        assert output_dir.is_dir(), "output_dir has to be a directory."
22
        with open(str(output_dir.joinpath(self.__class__.__name__)) + ".json", "w") as file:
23
            json.dump(self.__dict__, file, indent=4)
24
25
    def to_dict(self):
26
        return self.__dict__
27
28
29
class ExperimentConfig(BaseConfig):
30
    GROUP = "experiment"
31
32
    def __init__(self, data_path: Path, output_dir: Path, n_evals: int, n_bootstraps: int, n_folds: int,
33
                 search_method: str):
34
        self.data_path = data_path
35
        self.output_dir = output_dir
36
        self.n_evals = n_evals
37
        self.n_folds = n_folds
38
        self.n_bootstraps = n_bootstraps
39
        self.search_method = search_method
40
41
42
class NNExperimentConfig(ExperimentConfig):
43
    def __init__(
44
            self, data_path: Path, output_dir: Path, n_evals: int, n_bootstraps: int, n_folds: int, search_method: str,
45
            batch_size: int, n_epochs: int, backend: str
46
    ):
47
        super().__init__(
48
            data_path=data_path, output_dir=output_dir, n_evals=n_evals, n_bootstraps=n_bootstraps, n_folds=n_folds,
49
            search_method=search_method
50
        )
51
        self.batch_size = batch_size
52
        self.n_epochs = n_epochs
53
        self.backend = backend
54
55
    @classmethod
56
    def from_conf(cls, conf_path):
57
        conf = ConfigFactory.parse_file(str(conf_path))
58
        data_path = Path(os.path.abspath(get_conf(conf, group=cls.GROUP, key="data_path")))
59
        if get_conf(conf, group=cls.GROUP, key="output_dir") is None:
60
            output_dir = data_path.parent.joinpath("output")
61
        else:
62
            output_dir = Path(get_conf(conf, group=cls.GROUP, key="output_dir"))
63
        return cls(
64
            data_path=data_path,
65
            output_dir=output_dir,
66
            batch_size=get_conf(conf, group=cls.GROUP, key="batch_size", default=16),
67
            n_epochs=get_conf(conf, group=cls.GROUP, key="n_epochs", default=100),
68
            n_evals=get_conf(conf, group=cls.GROUP, key="n_evals", default=50),
69
            n_bootstraps=get_conf(conf, group=cls.GROUP, key="n_bootstraps", default=100),
70
            n_folds=get_conf(conf, group=cls.GROUP, key="n_folds", default=6),
71
            search_method=get_conf(conf, group=cls.GROUP, key="search_method", default="particle swarm"),
72
            backend=get_conf(conf, group=cls.GROUP, key="backend", default="torch")
73
        )
74
75
76
class CoxExperimentConfig(ExperimentConfig):
77
    @classmethod
78
    def from_conf(cls, conf_path):
79
        conf = ConfigFactory.parse_file(str(conf_path))
80
        data_path = Path(os.path.abspath(get_conf(conf, group=cls.GROUP, key="data_path")))
81
        if get_conf(conf, group=cls.GROUP, key="output_dir") is None:
82
            output_dir = data_path.parent.joinpath("output")
83
        else:
84
            output_dir = Path(get_conf(conf, group=cls.GROUP, key="output_dir"))
85
        return cls(
86
            data_path=data_path,
87
            output_dir=output_dir,
88
            n_evals=get_conf(conf, group=cls.GROUP, key="n_evals", default=50),
89
            n_bootstraps=get_conf(conf, group=cls.GROUP, key="n_bootstraps", default=100),
90
            n_folds=get_conf(conf, group=cls.GROUP, key="n_folds", default=6),
91
            search_method=get_conf(conf, group=cls.GROUP, key="search_method", default="particle swarm")
92
        )
93
94
95
class HypersearchConfig(BaseConfig):
96
    GROUP = "hypersearch"
97
98
    def __init__(self, **kwargs):
99
        for key in kwargs:
100
            setattr(self, key, kwargs[key])
101
102
    @classmethod
103
    def from_conf(cls, conf_path: Path):
104
        conf = ConfigFactory.parse_file(str(conf_path))
105
        conf = getattr(conf, cls.GROUP)
106
        return cls(**conf)
107
108
109
class ModelConfig(BaseConfig):
110
    GROUP = "model"
111
112
    def __init__(self, **kwargs):
113
        for key in kwargs:
114
            setattr(self, key, kwargs[key])
115
116
    @classmethod
117
    def from_conf(cls, conf_path: Path):
118
        conf = ConfigFactory.parse_file(str(conf_path))
119
        conf = getattr(conf, cls.GROUP)
120
        return cls(**conf)