Diff of /PhenPred/vae/Hypers.py [000000] .. [305123]

Switch to unified view

a b/PhenPred/vae/Hypers.py
1
import json
2
from PhenPred.vae import data_folder, plot_folder
3
from PhenPred.vae.Losses import CLinesLosses
4
5
6
class Hypers:
7
    @classmethod
8
    def read_json(cls, json_file):
9
        with open(json_file, "r") as f:
10
            hypers = json.load(f)
11
        return hypers
12
13
    @classmethod
14
    def read_hyperparameters(
15
        cls, hypers_json=None, parse_torch_functions=True, timestamp=None
16
    ):
17
        if timestamp is not None:
18
            hypers_json = f"{plot_folder}/files/{timestamp}_hyperparameters.json"
19
        elif hypers_json is None:
20
            hypers_json = f"{plot_folder}/files/hyperparameters.json"
21
22
        hypers = cls.read_json(hypers_json)
23
24
        if timestamp is not None:
25
            hypers["load_run"] = timestamp
26
27
        if "model" not in hypers:
28
            hypers["model"] = "MOSA"
29
30
        if "standardize" not in hypers:
31
            hypers["standardize"] = False
32
33
        if "w_rec" not in hypers:
34
            hypers["w_rec"] = 1
35
36
        if "w_gauss" not in hypers:
37
            hypers["w_gauss"] = 0.01
38
39
        if "w_cat" not in hypers:
40
            hypers["w_cat"] = 0.01
41
42
        if hypers["view_loss_weights"] is None:
43
            hypers["view_loss_weights"] = [1.0] * len(hypers["views"])
44
45
        if hypers["use_conditionals"] is None:
46
            hypers["use_conditionals"] = True
47
48
        if timestamp is None:  # full path is already stored in previous json config
49
            hypers["datasets"] = {
50
                k: f"{data_folder}/{v}" for k, v in hypers["datasets"].items()
51
            }
52
        print(f"# ---- Hyperparameters")
53
        print(json.dumps(hypers, indent=4, sort_keys=True))
54
55
        if parse_torch_functions:
56
            hypers = cls.parse_torch_functions(hypers)
57
58
        return hypers
59
60
    @classmethod
61
    def parse_torch_functions(cls, hypers):
62
        if type(hypers["activation_function"]) == str:
63
            hypers["activation_function"] = CLinesLosses.activation_function(
64
                hypers["activation_function"]
65
            )
66
67
        if type(hypers["reconstruction_loss"]) == str:
68
            hypers["reconstruction_loss"] = CLinesLosses.reconstruction_loss_method(
69
                hypers["reconstruction_loss"]
70
            )
71
72
        if type(hypers["hidden_dims"]) == str:
73
            hypers["hidden_dims"] = [float(l) for l in hypers["hidden_dims"].split(",")]
74
75
        return hypers