|
a |
|
b/PhenPred/vae/Hypers.py |
|
|
1 |
import json |
|
|
2 |
from PhenPred.vae import data_folder, plot_folder |
|
|
3 |
from PhenPred.vae.Losses import CLinesLosses |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class Hypers: |
|
|
7 |
@classmethod |
|
|
8 |
def read_json(cls, json_file): |
|
|
9 |
with open(json_file, "r") as f: |
|
|
10 |
hypers = json.load(f) |
|
|
11 |
return hypers |
|
|
12 |
|
|
|
13 |
@classmethod |
|
|
14 |
def read_hyperparameters( |
|
|
15 |
cls, hypers_json=None, parse_torch_functions=True, timestamp=None |
|
|
16 |
): |
|
|
17 |
if timestamp is not None: |
|
|
18 |
hypers_json = f"{plot_folder}/files/{timestamp}_hyperparameters.json" |
|
|
19 |
elif hypers_json is None: |
|
|
20 |
hypers_json = f"{plot_folder}/files/hyperparameters.json" |
|
|
21 |
|
|
|
22 |
hypers = cls.read_json(hypers_json) |
|
|
23 |
|
|
|
24 |
if timestamp is not None: |
|
|
25 |
hypers["load_run"] = timestamp |
|
|
26 |
|
|
|
27 |
if "model" not in hypers: |
|
|
28 |
hypers["model"] = "MOSA" |
|
|
29 |
|
|
|
30 |
if "standardize" not in hypers: |
|
|
31 |
hypers["standardize"] = False |
|
|
32 |
|
|
|
33 |
if "w_rec" not in hypers: |
|
|
34 |
hypers["w_rec"] = 1 |
|
|
35 |
|
|
|
36 |
if "w_gauss" not in hypers: |
|
|
37 |
hypers["w_gauss"] = 0.01 |
|
|
38 |
|
|
|
39 |
if "w_cat" not in hypers: |
|
|
40 |
hypers["w_cat"] = 0.01 |
|
|
41 |
|
|
|
42 |
if hypers["view_loss_weights"] is None: |
|
|
43 |
hypers["view_loss_weights"] = [1.0] * len(hypers["views"]) |
|
|
44 |
|
|
|
45 |
if hypers["use_conditionals"] is None: |
|
|
46 |
hypers["use_conditionals"] = True |
|
|
47 |
|
|
|
48 |
if timestamp is None: # full path is already stored in previous json config |
|
|
49 |
hypers["datasets"] = { |
|
|
50 |
k: f"{data_folder}/{v}" for k, v in hypers["datasets"].items() |
|
|
51 |
} |
|
|
52 |
print(f"# ---- Hyperparameters") |
|
|
53 |
print(json.dumps(hypers, indent=4, sort_keys=True)) |
|
|
54 |
|
|
|
55 |
if parse_torch_functions: |
|
|
56 |
hypers = cls.parse_torch_functions(hypers) |
|
|
57 |
|
|
|
58 |
return hypers |
|
|
59 |
|
|
|
60 |
@classmethod |
|
|
61 |
def parse_torch_functions(cls, hypers): |
|
|
62 |
if type(hypers["activation_function"]) == str: |
|
|
63 |
hypers["activation_function"] = CLinesLosses.activation_function( |
|
|
64 |
hypers["activation_function"] |
|
|
65 |
) |
|
|
66 |
|
|
|
67 |
if type(hypers["reconstruction_loss"]) == str: |
|
|
68 |
hypers["reconstruction_loss"] = CLinesLosses.reconstruction_loss_method( |
|
|
69 |
hypers["reconstruction_loss"] |
|
|
70 |
) |
|
|
71 |
|
|
|
72 |
if type(hypers["hidden_dims"]) == str: |
|
|
73 |
hypers["hidden_dims"] = [float(l) for l in hypers["hidden_dims"].split(",")] |
|
|
74 |
|
|
|
75 |
return hypers |