|
a |
|
b/src/move/tasks/tune_model.py |
|
|
1 |
__all__ = ["tune_model"] |
|
|
2 |
|
|
|
3 |
from pathlib import Path |
|
|
4 |
from typing import Any, Literal, cast |
|
|
5 |
|
|
|
6 |
import hydra |
|
|
7 |
import numpy as np |
|
|
8 |
import pandas as pd |
|
|
9 |
import torch |
|
|
10 |
from hydra.core.hydra_config import HydraConfig |
|
|
11 |
from hydra.types import RunMode |
|
|
12 |
from matplotlib.cbook import boxplot_stats |
|
|
13 |
from numpy.typing import ArrayLike |
|
|
14 |
from omegaconf import OmegaConf |
|
|
15 |
from sklearn.metrics.pairwise import cosine_similarity |
|
|
16 |
|
|
|
17 |
from move.analysis.metrics import ( |
|
|
18 |
calculate_accuracy, |
|
|
19 |
calculate_cosine_similarity, |
|
|
20 |
) |
|
|
21 |
from move.conf.schema import ( |
|
|
22 |
MOVEConfig, |
|
|
23 |
TuneModelConfig, |
|
|
24 |
TuneModelReconstructionConfig, |
|
|
25 |
TuneModelStabilityConfig, |
|
|
26 |
) |
|
|
27 |
from move.core.logging import get_logger |
|
|
28 |
from move.core.typing import BoolArray |
|
|
29 |
from move.data import io |
|
|
30 |
from move.data.dataloaders import MOVEDataset, make_dataloader, split_samples |
|
|
31 |
from move.models.vae import VAE |
|
|
32 |
|
|
|
33 |
TaskType = Literal["reconstruction", "stability"] |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def _get_task_type( |
|
|
37 |
task_config: TuneModelConfig, |
|
|
38 |
) -> TaskType: |
|
|
39 |
task_type = OmegaConf.get_type(task_config) |
|
|
40 |
if task_type is TuneModelReconstructionConfig: |
|
|
41 |
return "reconstruction" |
|
|
42 |
if task_type is TuneModelStabilityConfig: |
|
|
43 |
return "stability" |
|
|
44 |
raise ValueError("Unsupported type of task!") |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def _get_record(values: ArrayLike, **kwargs) -> dict[str, Any]: |
|
|
48 |
record = kwargs |
|
|
49 |
bxp_stats, *_ = boxplot_stats(values) |
|
|
50 |
bxp_stats.pop("fliers") |
|
|
51 |
record.update(bxp_stats) |
|
|
52 |
return record |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
def tune_model(config: MOVEConfig) -> float: |
|
|
56 |
"""Train multiple models to tune the model hyperparameters.""" |
|
|
57 |
hydra_config = HydraConfig.get() |
|
|
58 |
|
|
|
59 |
if hydra_config.mode != RunMode.MULTIRUN: |
|
|
60 |
raise ValueError("This task must run in multirun mode.") |
|
|
61 |
|
|
|
62 |
# Delete sweep run config |
|
|
63 |
sweep_config_path = Path(hydra_config.sweep.dir).joinpath("multirun.yaml") |
|
|
64 |
if sweep_config_path.exists(): |
|
|
65 |
sweep_config_path.unlink() |
|
|
66 |
|
|
|
67 |
job_num = hydra_config.job.num + 1 |
|
|
68 |
|
|
|
69 |
logger = get_logger(__name__) |
|
|
70 |
task_config = cast(TuneModelConfig, config.task) |
|
|
71 |
task_type = _get_task_type(task_config) |
|
|
72 |
|
|
|
73 |
logger.info(f"Beginning task: tune model {task_type} {job_num}") |
|
|
74 |
logger.info(f"Job name: {hydra_config.job.override_dirname}") |
|
|
75 |
|
|
|
76 |
interim_path = Path(config.data.interim_data_path) |
|
|
77 |
output_path = Path(config.data.results_path) / "tune_model" |
|
|
78 |
output_path.mkdir(exist_ok=True, parents=True) |
|
|
79 |
|
|
|
80 |
logger.debug("Reading data") |
|
|
81 |
|
|
|
82 |
cat_list, _, con_list, _ = io.load_preprocessed_data( |
|
|
83 |
interim_path, |
|
|
84 |
config.data.categorical_names, |
|
|
85 |
config.data.continuous_names, |
|
|
86 |
) |
|
|
87 |
|
|
|
88 |
assert task_config.model is not None |
|
|
89 |
device = torch.device("cuda" if task_config.model.cuda is True else "cpu") |
|
|
90 |
|
|
|
91 |
def _tune_stability( |
|
|
92 |
task_config: TuneModelStabilityConfig, |
|
|
93 |
): |
|
|
94 |
label = [hp.split("=") for hp in hydra_config.job.override_dirname.split(",")] |
|
|
95 |
|
|
|
96 |
train_dataloader = make_dataloader( |
|
|
97 |
cat_list, |
|
|
98 |
con_list, |
|
|
99 |
shuffle=True, |
|
|
100 |
batch_size=task_config.batch_size, |
|
|
101 |
drop_last=True, |
|
|
102 |
) |
|
|
103 |
|
|
|
104 |
test_dataloader = make_dataloader( |
|
|
105 |
cat_list, |
|
|
106 |
con_list, |
|
|
107 |
shuffle=False, |
|
|
108 |
batch_size=task_config.batch_size, |
|
|
109 |
drop_last=False, |
|
|
110 |
) |
|
|
111 |
|
|
|
112 |
train_dataset = cast(MOVEDataset, train_dataloader.dataset) |
|
|
113 |
|
|
|
114 |
logger.info(f"Training {task_config.num_refits} refits") |
|
|
115 |
|
|
|
116 |
cosine_sim0 = None |
|
|
117 |
cosine_sim_diffs = [] |
|
|
118 |
for j in range(task_config.num_refits): |
|
|
119 |
logger.debug(f"Refit: {j + 1}/{task_config.num_refits}") |
|
|
120 |
model: VAE = hydra.utils.instantiate( |
|
|
121 |
task_config.model, |
|
|
122 |
continuous_shapes=train_dataset.con_shapes, |
|
|
123 |
categorical_shapes=train_dataset.cat_shapes, |
|
|
124 |
) |
|
|
125 |
model.to(device) |
|
|
126 |
|
|
|
127 |
hydra.utils.call( |
|
|
128 |
task_config.training_loop, |
|
|
129 |
model=model, |
|
|
130 |
train_dataloader=train_dataloader, |
|
|
131 |
) |
|
|
132 |
|
|
|
133 |
model.eval() |
|
|
134 |
latent, *_ = model.latent(test_dataloader, kld_weight=1) |
|
|
135 |
|
|
|
136 |
if cosine_sim0 is None: |
|
|
137 |
cosine_sim0 = cosine_similarity(latent) |
|
|
138 |
else: |
|
|
139 |
cosine_sim = cosine_similarity(latent) |
|
|
140 |
D = np.absolute(cosine_sim - cosine_sim0) |
|
|
141 |
# removing the diagonal element (cos_sim with itself) |
|
|
142 |
diff = D[~np.eye(D.shape[0], dtype=bool)].reshape(D.shape[0], -1) |
|
|
143 |
mean_diff = np.mean(diff) |
|
|
144 |
cosine_sim_diffs.append(mean_diff) |
|
|
145 |
|
|
|
146 |
record = _get_record( |
|
|
147 |
cosine_sim_diffs, |
|
|
148 |
job_num=job_num, |
|
|
149 |
**dict(label), |
|
|
150 |
metric="mean_diff_cosine_similarity", |
|
|
151 |
num_refits=task_config.num_refits, |
|
|
152 |
) |
|
|
153 |
logger.info("Writing results") |
|
|
154 |
df_path = output_path / "stability_stats.tsv" |
|
|
155 |
header = not df_path.exists() |
|
|
156 |
df = pd.DataFrame.from_records([record]) |
|
|
157 |
df.to_csv(df_path, sep="\t", mode="a", header=header, index=False) |
|
|
158 |
|
|
|
159 |
def _tune_reconstruction( |
|
|
160 |
task_config: TuneModelReconstructionConfig, |
|
|
161 |
): |
|
|
162 |
split_path = interim_path / "split_mask.npy" |
|
|
163 |
if split_path.exists(): |
|
|
164 |
split_mask: BoolArray = np.load(split_path) |
|
|
165 |
else: |
|
|
166 |
num_samples = cat_list[0].shape[0] if cat_list else con_list[0].shape[0] |
|
|
167 |
split_mask = split_samples(num_samples, 0.9) |
|
|
168 |
np.save(split_path, split_mask) |
|
|
169 |
|
|
|
170 |
train_dataloader = make_dataloader( |
|
|
171 |
cat_list, |
|
|
172 |
con_list, |
|
|
173 |
split_mask, |
|
|
174 |
shuffle=True, |
|
|
175 |
batch_size=task_config.batch_size, |
|
|
176 |
drop_last=True, |
|
|
177 |
) |
|
|
178 |
|
|
|
179 |
train_dataset = cast(MOVEDataset, train_dataloader.dataset) |
|
|
180 |
|
|
|
181 |
model: VAE = hydra.utils.instantiate( |
|
|
182 |
task_config.model, |
|
|
183 |
continuous_shapes=train_dataset.con_shapes, |
|
|
184 |
categorical_shapes=train_dataset.cat_shapes, |
|
|
185 |
) |
|
|
186 |
model.to(device) |
|
|
187 |
logger.debug(f"Model: {model}") |
|
|
188 |
|
|
|
189 |
logger.debug("Training model") |
|
|
190 |
hydra.utils.call( |
|
|
191 |
task_config.training_loop, |
|
|
192 |
model=model, |
|
|
193 |
train_dataloader=train_dataloader, |
|
|
194 |
) |
|
|
195 |
model.eval() |
|
|
196 |
logger.info("Reconstructing") |
|
|
197 |
logger.info("Computing reconstruction metrics") |
|
|
198 |
label = [hp.split("=") for hp in hydra_config.job.override_dirname.split(";")] |
|
|
199 |
records = [] |
|
|
200 |
splits = zip(["train", "test"], [split_mask, ~split_mask]) |
|
|
201 |
for split_name, mask in splits: |
|
|
202 |
dataloader = make_dataloader( |
|
|
203 |
cat_list, |
|
|
204 |
con_list, |
|
|
205 |
mask, |
|
|
206 |
shuffle=False, |
|
|
207 |
batch_size=task_config.batch_size, |
|
|
208 |
) |
|
|
209 |
cat_recons, con_recons = model.reconstruct(dataloader) |
|
|
210 |
con_recons = np.split( |
|
|
211 |
con_recons, np.cumsum(model.continuous_shapes[:-1]), axis=1 |
|
|
212 |
) |
|
|
213 |
for cat, cat_recon, dataset_name in zip( |
|
|
214 |
cat_list, cat_recons, config.data.categorical_names |
|
|
215 |
): |
|
|
216 |
logger.debug(f"Computing accuracy: '{dataset_name}'") |
|
|
217 |
accuracy = calculate_accuracy(cat[mask], cat_recon) |
|
|
218 |
record = _get_record( |
|
|
219 |
accuracy, |
|
|
220 |
job_num=job_num, |
|
|
221 |
**dict(label), |
|
|
222 |
metric="accuracy", |
|
|
223 |
dataset=dataset_name, |
|
|
224 |
split=split_name, |
|
|
225 |
) |
|
|
226 |
records.append(record) |
|
|
227 |
for con, con_recon, dataset_name in zip( |
|
|
228 |
con_list, con_recons, config.data.continuous_names |
|
|
229 |
): |
|
|
230 |
logger.debug(f"Computing cosine similarity: '{dataset_name}'") |
|
|
231 |
cosine_sim = calculate_cosine_similarity(con[mask], con_recon) |
|
|
232 |
record = _get_record( |
|
|
233 |
cosine_sim, |
|
|
234 |
job_num=job_num, |
|
|
235 |
**dict(label), |
|
|
236 |
metric="cosine_similarity", |
|
|
237 |
dataset=dataset_name, |
|
|
238 |
split=split_name, |
|
|
239 |
) |
|
|
240 |
records.append(record) |
|
|
241 |
|
|
|
242 |
logger.info("Writing results") |
|
|
243 |
df_path = output_path / "reconstruction_stats.tsv" |
|
|
244 |
header = not df_path.exists() |
|
|
245 |
df = pd.DataFrame.from_records(records) |
|
|
246 |
df.to_csv(df_path, sep="\t", mode="a", header=header, index=False) |
|
|
247 |
|
|
|
248 |
if task_type == "reconstruction": |
|
|
249 |
task_config = cast(TuneModelReconstructionConfig, task_config) |
|
|
250 |
_tune_reconstruction(task_config) |
|
|
251 |
elif task_type == "stability": |
|
|
252 |
task_config = cast(TuneModelStabilityConfig, task_config) |
|
|
253 |
_tune_stability(task_config) |
|
|
254 |
|
|
|
255 |
return 0.0 |