Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery

Switch to unified view

a b/src/move/tasks/tune_model.py
1
__all__ = ["tune_model"]
2
3
from pathlib import Path
4
from typing import Any, Literal, cast
5
6
import hydra
7
import numpy as np
8
import pandas as pd
9
import torch
10
from hydra.core.hydra_config import HydraConfig
11
from hydra.types import RunMode
12
from matplotlib.cbook import boxplot_stats
13
from numpy.typing import ArrayLike
14
from omegaconf import OmegaConf
15
from sklearn.metrics.pairwise import cosine_similarity
16
17
from move.analysis.metrics import (
18
    calculate_accuracy,
19
    calculate_cosine_similarity,
20
)
21
from move.conf.schema import (
22
    MOVEConfig,
23
    TuneModelConfig,
24
    TuneModelReconstructionConfig,
25
    TuneModelStabilityConfig,
26
)
27
from move.core.logging import get_logger
28
from move.core.typing import BoolArray
29
from move.data import io
30
from move.data.dataloaders import MOVEDataset, make_dataloader, split_samples
31
from move.models.vae import VAE
32
33
TaskType = Literal["reconstruction", "stability"]
34
35
36
def _get_task_type(
37
    task_config: TuneModelConfig,
38
) -> TaskType:
39
    task_type = OmegaConf.get_type(task_config)
40
    if task_type is TuneModelReconstructionConfig:
41
        return "reconstruction"
42
    if task_type is TuneModelStabilityConfig:
43
        return "stability"
44
    raise ValueError("Unsupported type of task!")
45
46
47
def _get_record(values: ArrayLike, **kwargs) -> dict[str, Any]:
48
    record = kwargs
49
    bxp_stats, *_ = boxplot_stats(values)
50
    bxp_stats.pop("fliers")
51
    record.update(bxp_stats)
52
    return record
53
54
55
def tune_model(config: MOVEConfig) -> float:
56
    """Train multiple models to tune the model hyperparameters."""
57
    hydra_config = HydraConfig.get()
58
59
    if hydra_config.mode != RunMode.MULTIRUN:
60
        raise ValueError("This task must run in multirun mode.")
61
62
    # Delete sweep run config
63
    sweep_config_path = Path(hydra_config.sweep.dir).joinpath("multirun.yaml")
64
    if sweep_config_path.exists():
65
        sweep_config_path.unlink()
66
67
    job_num = hydra_config.job.num + 1
68
69
    logger = get_logger(__name__)
70
    task_config = cast(TuneModelConfig, config.task)
71
    task_type = _get_task_type(task_config)
72
73
    logger.info(f"Beginning task: tune model {task_type} {job_num}")
74
    logger.info(f"Job name: {hydra_config.job.override_dirname}")
75
76
    interim_path = Path(config.data.interim_data_path)
77
    output_path = Path(config.data.results_path) / "tune_model"
78
    output_path.mkdir(exist_ok=True, parents=True)
79
80
    logger.debug("Reading data")
81
82
    cat_list, _, con_list, _ = io.load_preprocessed_data(
83
        interim_path,
84
        config.data.categorical_names,
85
        config.data.continuous_names,
86
    )
87
88
    assert task_config.model is not None
89
    device = torch.device("cuda" if task_config.model.cuda is True else "cpu")
90
91
    def _tune_stability(
92
        task_config: TuneModelStabilityConfig,
93
    ):
94
        label = [hp.split("=") for hp in hydra_config.job.override_dirname.split(",")]
95
96
        train_dataloader = make_dataloader(
97
            cat_list,
98
            con_list,
99
            shuffle=True,
100
            batch_size=task_config.batch_size,
101
            drop_last=True,
102
        )
103
104
        test_dataloader = make_dataloader(
105
            cat_list,
106
            con_list,
107
            shuffle=False,
108
            batch_size=task_config.batch_size,
109
            drop_last=False,
110
        )
111
112
        train_dataset = cast(MOVEDataset, train_dataloader.dataset)
113
114
        logger.info(f"Training {task_config.num_refits} refits")
115
116
        cosine_sim0 = None
117
        cosine_sim_diffs = []
118
        for j in range(task_config.num_refits):
119
            logger.debug(f"Refit: {j + 1}/{task_config.num_refits}")
120
            model: VAE = hydra.utils.instantiate(
121
                task_config.model,
122
                continuous_shapes=train_dataset.con_shapes,
123
                categorical_shapes=train_dataset.cat_shapes,
124
            )
125
            model.to(device)
126
127
            hydra.utils.call(
128
                task_config.training_loop,
129
                model=model,
130
                train_dataloader=train_dataloader,
131
            )
132
133
            model.eval()
134
            latent, *_ = model.latent(test_dataloader, kld_weight=1)
135
136
            if cosine_sim0 is None:
137
                cosine_sim0 = cosine_similarity(latent)
138
            else:
139
                cosine_sim = cosine_similarity(latent)
140
                D = np.absolute(cosine_sim - cosine_sim0)
141
                # removing the diagonal element (cos_sim with itself)
142
                diff = D[~np.eye(D.shape[0], dtype=bool)].reshape(D.shape[0], -1)
143
                mean_diff = np.mean(diff)
144
                cosine_sim_diffs.append(mean_diff)
145
146
        record = _get_record(
147
            cosine_sim_diffs,
148
            job_num=job_num,
149
            **dict(label),
150
            metric="mean_diff_cosine_similarity",
151
            num_refits=task_config.num_refits,
152
        )
153
        logger.info("Writing results")
154
        df_path = output_path / "stability_stats.tsv"
155
        header = not df_path.exists()
156
        df = pd.DataFrame.from_records([record])
157
        df.to_csv(df_path, sep="\t", mode="a", header=header, index=False)
158
159
    def _tune_reconstruction(
160
        task_config: TuneModelReconstructionConfig,
161
    ):
162
        split_path = interim_path / "split_mask.npy"
163
        if split_path.exists():
164
            split_mask: BoolArray = np.load(split_path)
165
        else:
166
            num_samples = cat_list[0].shape[0] if cat_list else con_list[0].shape[0]
167
            split_mask = split_samples(num_samples, 0.9)
168
            np.save(split_path, split_mask)
169
170
        train_dataloader = make_dataloader(
171
            cat_list,
172
            con_list,
173
            split_mask,
174
            shuffle=True,
175
            batch_size=task_config.batch_size,
176
            drop_last=True,
177
        )
178
179
        train_dataset = cast(MOVEDataset, train_dataloader.dataset)
180
181
        model: VAE = hydra.utils.instantiate(
182
            task_config.model,
183
            continuous_shapes=train_dataset.con_shapes,
184
            categorical_shapes=train_dataset.cat_shapes,
185
        )
186
        model.to(device)
187
        logger.debug(f"Model: {model}")
188
189
        logger.debug("Training model")
190
        hydra.utils.call(
191
            task_config.training_loop,
192
            model=model,
193
            train_dataloader=train_dataloader,
194
        )
195
        model.eval()
196
        logger.info("Reconstructing")
197
        logger.info("Computing reconstruction metrics")
198
        label = [hp.split("=") for hp in hydra_config.job.override_dirname.split(";")]
199
        records = []
200
        splits = zip(["train", "test"], [split_mask, ~split_mask])
201
        for split_name, mask in splits:
202
            dataloader = make_dataloader(
203
                cat_list,
204
                con_list,
205
                mask,
206
                shuffle=False,
207
                batch_size=task_config.batch_size,
208
            )
209
            cat_recons, con_recons = model.reconstruct(dataloader)
210
            con_recons = np.split(
211
                con_recons, np.cumsum(model.continuous_shapes[:-1]), axis=1
212
            )
213
            for cat, cat_recon, dataset_name in zip(
214
                cat_list, cat_recons, config.data.categorical_names
215
            ):
216
                logger.debug(f"Computing accuracy: '{dataset_name}'")
217
                accuracy = calculate_accuracy(cat[mask], cat_recon)
218
                record = _get_record(
219
                    accuracy,
220
                    job_num=job_num,
221
                    **dict(label),
222
                    metric="accuracy",
223
                    dataset=dataset_name,
224
                    split=split_name,
225
                )
226
                records.append(record)
227
            for con, con_recon, dataset_name in zip(
228
                con_list, con_recons, config.data.continuous_names
229
            ):
230
                logger.debug(f"Computing cosine similarity: '{dataset_name}'")
231
                cosine_sim = calculate_cosine_similarity(con[mask], con_recon)
232
                record = _get_record(
233
                    cosine_sim,
234
                    job_num=job_num,
235
                    **dict(label),
236
                    metric="cosine_similarity",
237
                    dataset=dataset_name,
238
                    split=split_name,
239
                )
240
                records.append(record)
241
242
        logger.info("Writing results")
243
        df_path = output_path / "reconstruction_stats.tsv"
244
        header = not df_path.exists()
245
        df = pd.DataFrame.from_records(records)
246
        df.to_csv(df_path, sep="\t", mode="a", header=header, index=False)
247
248
    if task_type == "reconstruction":
249
        task_config = cast(TuneModelReconstructionConfig, task_config)
250
        _tune_reconstruction(task_config)
251
    elif task_type == "stability":
252
        task_config = cast(TuneModelStabilityConfig, task_config)
253
        _tune_stability(task_config)
254
255
    return 0.0