a b/src/move/tasks/analyze_latent.py
1
__all__ = ["analyze_latent"]
2
3
import re
4
from pathlib import Path
5
from typing import Sized, cast
6
7
import hydra
8
import numpy as np
9
import pandas as pd
10
import torch
11
from sklearn.base import TransformerMixin
12
13
import move.visualization as viz
14
from move.analysis.metrics import (
15
    calculate_accuracy,
16
    calculate_cosine_similarity,
17
)
18
from move.conf.schema import AnalyzeLatentConfig, MOVEConfig
19
from move.core.logging import get_logger
20
from move.core.typing import FloatArray
21
from move.data import io
22
from move.data.dataloaders import MOVEDataset, make_dataloader
23
from move.data.perturbations import (
24
    perturb_categorical_data,
25
    perturb_continuous_data,
26
)
27
from move.data.preprocessing import one_hot_encode_single
28
from move.models.vae import VAE
29
from move.training.training_loop import TrainingLoopOutput
30
31
32
def find_feature_values(
33
    feature_name: str,
34
    feature_names_lists: list[list[str]],
35
    feature_values: list[FloatArray],
36
) -> tuple[int, FloatArray]:
37
    """Look for the feature in the list of datasets and returns its values.
38
39
    Args:
40
        feature_name: Look-up key
41
        feature_names_lists: List of lists with feature names for each dataset
42
        feature_values: List of data arrays, each representing a dataset
43
44
    Raises:
45
        KeyError: If feature does not exist in any dataset
46
47
    Returns:
48
        Tuple containing (1) index of dataset containing feature and (2)
49
        values corresponding to the feature
50
    """
51
    _dataset_index, feature_index = [None] * 2
52
    for _dataset_index, feature_names in enumerate(feature_names_lists):
53
        try:
54
            feature_index = feature_names.index(feature_name)
55
        except ValueError:
56
            continue
57
        break
58
    if _dataset_index is not None and feature_index is not None:
59
        return (
60
            _dataset_index,
61
            np.take(feature_values[_dataset_index], feature_index, axis=1),
62
        )
63
    raise KeyError(f"Feature '{feature_name}' not in any dataset.")
64
65
66
def _validate_task_config(task_config: AnalyzeLatentConfig) -> None:
67
    if "_target_" not in task_config.reducer:
68
        raise ValueError("Reducer class not specified properly.")
69
70
71
def analyze_latent(config: MOVEConfig) -> None:
72
    """Train one model to inspect its latent space projections."""
73
74
    logger = get_logger(__name__)
75
    logger.info("Beginning task: analyze latent space")
76
    task_config = cast(AnalyzeLatentConfig, config.task)
77
    _validate_task_config(task_config)
78
79
    raw_data_path = Path(config.data.raw_data_path)
80
    interim_path = Path(config.data.interim_data_path)
81
    output_path = Path(config.data.results_path) / "latent_space"
82
    output_path.mkdir(exist_ok=True, parents=True)
83
84
    logger.debug("Reading data")
85
    sample_names = io.read_names(raw_data_path / f"{config.data.sample_names}.txt")
86
    cat_list, cat_names, con_list, con_names = io.load_preprocessed_data(
87
        interim_path,
88
        config.data.categorical_names,
89
        config.data.continuous_names,
90
    )
91
    test_dataloader = make_dataloader(
92
        cat_list,
93
        con_list,
94
        shuffle=False,
95
        batch_size=task_config.batch_size,
96
    )
97
    test_dataset = cast(MOVEDataset, test_dataloader.dataset)
98
    df_index = pd.Index(sample_names, name="sample")
99
100
    assert task_config.model is not None
101
    device = torch.device("cuda" if task_config.model.cuda else "cpu")
102
    model: VAE = hydra.utils.instantiate(
103
        task_config.model,
104
        continuous_shapes=test_dataset.con_shapes,
105
        categorical_shapes=test_dataset.cat_shapes,
106
    )
107
108
    logger.debug(f"Model: {model}")
109
110
    model_path = output_path / "model.pt"
111
    if model_path.exists():
112
        logger.debug("Re-loading model")
113
        model.load_state_dict(torch.load(model_path))
114
        model.to(device)
115
    else:
116
        logger.debug("Training model")
117
118
        model.to(device)
119
        train_dataloader = make_dataloader(
120
            cat_list,
121
            con_list,
122
            shuffle=True,
123
            batch_size=task_config.batch_size,
124
            drop_last=True,
125
        )
126
        output: TrainingLoopOutput = hydra.utils.call(
127
            task_config.training_loop,
128
            model=model,
129
            train_dataloader=train_dataloader,
130
        )
131
        losses = output[:-1]
132
        torch.save(model.state_dict(), model_path)
133
        logger.info("Generating visualizations")
134
        logger.debug("Generating plot: loss curves")
135
        fig = viz.plot_loss_curves(losses)
136
        fig_path = str(output_path / "loss_curve.png")
137
        fig.savefig(fig_path, bbox_inches="tight")
138
        fig_df = pd.DataFrame(dict(zip(viz.LOSS_LABELS, losses)))
139
        fig_df.index.name = "epoch"
140
        fig_df.to_csv(output_path / "loss_curve.tsv", sep="\t")
141
142
    model.eval()
143
144
    logger.info("Projecting into latent space")
145
    latent_space = model.project(test_dataloader)
146
    reducer: TransformerMixin = hydra.utils.instantiate(task_config.reducer)
147
    embedding = reducer.fit_transform(latent_space)
148
149
    mappings_path = interim_path / "mappings.json"
150
    if mappings_path.exists():
151
        mappings = io.load_mappings(mappings_path)
152
    else:
153
        mappings = {}
154
155
    fig_df = pd.DataFrame(
156
        np.take(embedding, [0, 1], axis=1),
157
        columns=["dim0", "dim1"],
158
        index=df_index,
159
    )
160
161
    for feature_name in task_config.feature_names:
162
        logger.debug(f"Generating plot: latent space + '{feature_name}'")
163
        is_categorical = False
164
        try:
165
            dataset_index, feature_values = find_feature_values(
166
                feature_name, cat_names, cat_list
167
            )
168
            is_categorical = True
169
        except KeyError:
170
            try:
171
                dataset_index, feature_values = find_feature_values(
172
                    feature_name, con_names, con_list
173
                )
174
            except KeyError:
175
                logger.warning(f"Feature '{feature_name}' not found in any dataset.")
176
                continue
177
178
        if is_categorical:
179
            # Convert one-hot encoding to category codes
180
            is_nan = feature_values.sum(axis=1) == 0
181
            feature_values = np.argmax(feature_values, axis=1)
182
183
            dataset_name = config.data.categorical_names[dataset_index]
184
            feature_mapping = {
185
                str(code): category for category, code in mappings[dataset_name].items()
186
            }
187
            fig = viz.plot_latent_space_with_cat(
188
                embedding,
189
                feature_name,
190
                feature_values,
191
                feature_mapping,
192
                is_nan,
193
            )
194
            fig_df[feature_name] = np.where(is_nan, np.nan, feature_values)
195
        else:
196
            feature_values = feature_values
197
            fig = viz.plot_latent_space_with_con(
198
                embedding, feature_name, feature_values
199
            )
200
            fig_df[feature_name] = np.where(feature_values == 0, np.nan, feature_values)
201
202
        # Remove non-alpha characters
203
        safe_feature_name = re.sub(r"[^\w\s]", "", feature_name)
204
        fig_path = str(output_path / f"latent_space_{safe_feature_name}.png")
205
        fig.savefig(fig_path, bbox_inches="tight")
206
207
    fig_df.to_csv(output_path / "latent_space.tsv", sep="\t")
208
209
    logger.info("Reconstructing")
210
    cat_recons, con_recons = model.reconstruct(test_dataloader)
211
    con_recons = np.split(con_recons, np.cumsum(model.continuous_shapes[:-1]), axis=1)
212
    logger.info("Computing reconstruction metrics")
213
    scores = []
214
    labels = config.data.categorical_names + config.data.continuous_names
215
    for cat, cat_recon in zip(cat_list, cat_recons):
216
        accuracy = calculate_accuracy(cat, cat_recon)
217
        scores.append(accuracy)
218
    for con, con_recon in zip(con_list, con_recons):
219
        cosine_sim = calculate_cosine_similarity(con, con_recon)
220
        scores.append(cosine_sim)
221
222
    logger.debug("Generating plot: reconstruction metrics")
223
224
    plot_scores = [np.ma.compressed(np.ma.masked_equal(each, 0)) for each in scores]
225
    fig = viz.plot_metrics_boxplot(plot_scores, labels)
226
    fig_path = str(output_path / "reconstruction_metrics.png")
227
    fig.savefig(fig_path, bbox_inches="tight")
228
    fig_df = pd.DataFrame(dict(zip(labels, scores)), index=df_index)
229
    fig_df.to_csv(output_path / "reconstruction_metrics.tsv", sep="\t")
230
231
    logger.info("Computing feature importance")
232
    num_samples = len(cast(Sized, test_dataloader.sampler))
233
    for i, dataset_name in enumerate(config.data.categorical_names):
234
        logger.debug(f"Generating plot: feature importance '{dataset_name}'")
235
        na_value = one_hot_encode_single(mappings[dataset_name], None)
236
        dataloaders = perturb_categorical_data(
237
            test_dataloader, config.data.categorical_names, dataset_name, na_value
238
        )
239
        num_features = len(dataloaders)
240
        z = model.project(test_dataloader)
241
        diffs = np.empty((num_samples, num_features))
242
        for j, dataloader in enumerate(dataloaders):
243
            z_perturb = model.project(dataloader)
244
            diffs[:, j] = np.sum(z_perturb - z, axis=1)
245
        feature_mapping = {
246
            str(code): category for category, code in mappings[dataset_name].items()
247
        }
248
        fig = viz.plot_categorical_feature_importance(
249
            diffs, cat_list[i], cat_names[i], feature_mapping
250
        )
251
        fig_path = str(output_path / f"feat_importance_{dataset_name}.png")
252
        fig.savefig(fig_path, bbox_inches="tight")
253
        fig_df = pd.DataFrame(diffs, columns=cat_names[i], index=df_index)
254
        fig_df.to_csv(output_path / f"feat_importance_{dataset_name}.tsv", sep="\t")
255
256
    for i, dataset_name in enumerate(config.data.continuous_names):
257
        logger.debug(f"Generating plot: feature importance '{dataset_name}'")
258
        dataloaders = perturb_continuous_data(
259
            test_dataloader, config.data.continuous_names, dataset_name, 0.0
260
        )
261
        num_features = len(dataloaders)
262
        z = model.project(test_dataloader)
263
        diffs = np.empty((num_samples, num_features))
264
        for j, dataloader in enumerate(dataloaders):
265
            z_perturb = model.project(dataloader)
266
            diffs[:, j] = np.sum(z_perturb - z, axis=1)
267
        fig = viz.plot_continuous_feature_importance(diffs, con_list[i], con_names[i])
268
        fig_path = str(output_path / f"feat_importance_{dataset_name}.png")
269
        fig.savefig(fig_path, bbox_inches="tight")
270
        fig_df = pd.DataFrame(diffs, columns=con_names[i], index=df_index)
271
        fig_df.to_csv(output_path / f"feat_importance_{dataset_name}.tsv", sep="\t")