a b/src/move/tasks/identify_associations.py
1
__all__ = ["identify_associations"]
2
3
from functools import reduce
4
from os.path import exists
5
from pathlib import Path
6
from typing import Literal, Sized, Union, cast
7
8
import hydra
9
import numpy as np
10
import pandas as pd
11
import torch
12
from omegaconf import OmegaConf
13
from scipy.stats import ks_2samp, pearsonr  # type: ignore
14
from torch.utils.data import DataLoader
15
16
from move.analysis.metrics import get_2nd_order_polynomial
17
from move.conf.schema import (
18
    IdentifyAssociationsBayesConfig,
19
    IdentifyAssociationsConfig,
20
    IdentifyAssociationsKSConfig,
21
    IdentifyAssociationsTTestConfig,
22
    MOVEConfig,
23
)
24
from move.core.logging import get_logger
25
from move.core.typing import BoolArray, FloatArray, IntArray
26
from move.data import io
27
from move.data.dataloaders import MOVEDataset, make_dataloader
28
from move.data.perturbations import (
29
    ContinuousPerturbationType,
30
    perturb_categorical_data,
31
    perturb_continuous_data_extended,
32
)
33
from move.data.preprocessing import one_hot_encode_single
34
from move.models.vae import VAE
35
from move.visualization.dataset_distributions import (
36
    plot_correlations,
37
    plot_cumulative_distributions,
38
    plot_feature_association_graph,
39
    plot_reconstruction_movement,
40
)
41
42
TaskType = Literal["bayes", "ttest", "ks"]
43
CONTINUOUS_TARGET_VALUE = ["minimum", "maximum", "plus_std", "minus_std"]
44
45
46
def _get_task_type(
47
    task_config: IdentifyAssociationsConfig,
48
) -> TaskType:
49
    task_type = OmegaConf.get_type(task_config)
50
    if task_type is IdentifyAssociationsBayesConfig:
51
        return "bayes"
52
    if task_type is IdentifyAssociationsTTestConfig:
53
        return "ttest"
54
    if task_type is IdentifyAssociationsKSConfig:
55
        return "ks"
56
    raise ValueError("Unsupported type of task!")
57
58
59
def _validate_task_config(
60
    task_config: IdentifyAssociationsConfig, task_type: TaskType
61
) -> None:
62
    if not (0.0 <= task_config.sig_threshold <= 1.0):
63
        raise ValueError("Significance threshold must be within [0, 1].")
64
    if task_type == "ttest":
65
        task_config = cast(IdentifyAssociationsTTestConfig, task_config)
66
        if len(task_config.num_latent) != 4:
67
            raise ValueError("4 latent space dimensions required.")
68
69
70
def prepare_for_categorical_perturbation(
71
    config: MOVEConfig,
72
    interim_path: Path,
73
    baseline_dataloader: DataLoader,
74
    cat_list: list[FloatArray],
75
) -> tuple[
76
    list[DataLoader],
77
    BoolArray,
78
    BoolArray,
79
]:
80
    """
81
    This function creates the required dataloaders and masks
82
    for further categorical association analysis.
83
84
    Args:
85
        config: main configuration file
86
        interim_path: path where the intermediate outputs are saved
87
        baseline_dataloader: reference dataloader that will be perturbed
88
        cat_list: list of arrays with categorical data
89
90
    Returns:
91
        dataloaders: all dataloaders, including baseline appended last.
92
        nan_mask: mask for Nans
93
        feature_mask: masks the column for the perturbed feature.
94
    """
95
96
    # Read original data and create perturbed datasets
97
    task_config = cast(IdentifyAssociationsConfig, config.task)
98
    logger = get_logger(__name__)
99
100
    # Loading mappings:
101
    mappings = io.load_mappings(interim_path / "mappings.json")
102
    target_mapping = mappings[task_config.target_dataset]
103
    target_value = one_hot_encode_single(target_mapping, task_config.target_value)
104
    logger.debug(
105
        f"Target value: {task_config.target_value} => {target_value.astype(int)[0]}"
106
    )
107
108
    dataloaders = perturb_categorical_data(
109
        baseline_dataloader,
110
        config.data.categorical_names,
111
        task_config.target_dataset,
112
        target_value,
113
    )
114
    dataloaders.append(baseline_dataloader)
115
116
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
117
118
    assert baseline_dataset.con_all is not None
119
    orig_con = baseline_dataset.con_all
120
    nan_mask = (orig_con == 0).numpy()  # NaN values encoded as 0s
121
    logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}")
122
123
    target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset)
124
    target_dataset = cat_list[target_dataset_idx]
125
    feature_mask = np.all(target_dataset == target_value, axis=2)  # 2D: N x P
126
    feature_mask |= np.sum(target_dataset, axis=2) == 0
127
128
    return (
129
        dataloaders,
130
        nan_mask,
131
        feature_mask,
132
    )
133
134
135
def prepare_for_continuous_perturbation(
136
    config: MOVEConfig,
137
    output_subpath: Path,
138
    baseline_dataloader: DataLoader,
139
) -> tuple[
140
    list[DataLoader],
141
    BoolArray,
142
    BoolArray,
143
]:
144
    """
145
    This function creates the required dataloaders and masks
146
    for further continuous association analysis.
147
148
    Args:
149
        config:
150
            main configuration file.
151
        output_subpath:
152
            path where the output plots for continuous analysis are saved.
153
        baseline_dataloader:
154
            reference dataloader that will be perturbed.
155
156
    Returns:
157
        dataloaders:
158
            list with all dataloaders, including baseline appended last.
159
        nan_mask:
160
            mask for NaNs
161
        feature_mask:
162
            same as `nan_mask`, in this case.
163
    """
164
165
    # Read original data and create perturbed datasets
166
    logger = get_logger(__name__)
167
    task_config = cast(IdentifyAssociationsConfig, config.task)
168
169
    dataloaders = perturb_continuous_data_extended(
170
        baseline_dataloader,
171
        config.data.continuous_names,
172
        task_config.target_dataset,
173
        cast(ContinuousPerturbationType, task_config.target_value),
174
        output_subpath,
175
    )
176
    dataloaders.append(baseline_dataloader)
177
178
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
179
180
    assert baseline_dataset.con_all is not None
181
    orig_con = baseline_dataset.con_all
182
    nan_mask = (orig_con == 0).numpy()  # NaN values encoded as 0s
183
    logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}")
184
    feature_mask = nan_mask
185
186
    return (dataloaders, nan_mask, feature_mask)
187
188
189
def _bayes_approach(
190
    config: MOVEConfig,
191
    task_config: IdentifyAssociationsBayesConfig,
192
    train_dataloader: DataLoader,
193
    baseline_dataloader: DataLoader,
194
    dataloaders: list[DataLoader],
195
    models_path: Path,
196
    num_perturbed: int,
197
    num_samples: int,
198
    num_continuous: int,
199
    nan_mask: BoolArray,
200
    feature_mask: BoolArray,
201
) -> tuple[Union[IntArray, FloatArray], ...]:
202
203
    assert task_config.model is not None
204
    device = torch.device("cuda" if task_config.model.cuda else "cpu")
205
206
    # Train models
207
    logger = get_logger(__name__)
208
    logger.info("Training models")
209
    mean_diff = np.zeros((num_perturbed, num_samples, num_continuous))
210
    normalizer = 1 / task_config.num_refits
211
212
    # Last appended dataloader is the baseline
213
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
214
215
    for j in range(task_config.num_refits):
216
        # Initialize model
217
        model: VAE = hydra.utils.instantiate(
218
            task_config.model,
219
            continuous_shapes=baseline_dataset.con_shapes,
220
            categorical_shapes=baseline_dataset.cat_shapes,
221
        )
222
        if j == 0:
223
            logger.debug(f"Model: {model}")
224
225
        # Train/reload model
226
        model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt"
227
        if model_path.exists():
228
            logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}")
229
            model.load_state_dict(torch.load(model_path))
230
            model.to(device)
231
        else:
232
            logger.debug(f"Training refit {j + 1}/{task_config.num_refits}")
233
            model.to(device)
234
            hydra.utils.call(
235
                task_config.training_loop,
236
                model=model,
237
                train_dataloader=train_dataloader,
238
            )
239
            if task_config.save_refits:
240
                torch.save(model.state_dict(), model_path)
241
        model.eval()
242
243
        # Calculate baseline reconstruction
244
        _, baseline_recon = model.reconstruct(baseline_dataloader)
245
        min_feat, max_feat = np.zeros((num_perturbed, num_continuous)), np.zeros(
246
            (num_perturbed, num_continuous)
247
        )
248
        min_baseline, max_baseline = np.min(baseline_recon, axis=0), np.max(
249
            baseline_recon, axis=0
250
        )
251
252
        # Calculate perturb reconstruction => keep track of mean difference
253
        for i in range(num_perturbed):
254
            _, perturb_recon = model.reconstruct(dataloaders[i])
255
            diff = perturb_recon - baseline_recon  # 2D: N x C
256
            mean_diff[i, :, :] += diff * normalizer
257
258
            min_perturb, max_perturb = np.min(perturb_recon, axis=0), np.max(
259
                perturb_recon, axis=0
260
            )
261
            min_feat[i, :], max_feat[i, :] = np.min(
262
                [min_baseline, min_perturb], axis=0
263
            ), np.max([max_baseline, max_perturb], axis=0)
264
265
    # Calculate Bayes factors
266
    logger.info("Identifying significant features")
267
    bayes_k = np.empty((num_perturbed, num_continuous))
268
    bayes_mask = np.zeros(np.shape(bayes_k))
269
    for i in range(num_perturbed):
270
        mask = feature_mask[:, [i]] | nan_mask  # 2D: N x C
271
        diff = np.ma.masked_array(mean_diff[i, :, :], mask=mask)  # 2D: N x C
272
        prob = np.ma.compressed(np.mean(diff > 1e-8, axis=0))  # 1D: C
273
        bayes_k[i, :] = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8)
274
        if task_config.target_value in CONTINUOUS_TARGET_VALUE:
275
            bayes_mask[i, :] = (
276
                baseline_dataloader.dataset.con_all[0, :]
277
                - dataloaders[i].dataset.con_all[0, :]
278
            )
279
280
    bayes_mask[bayes_mask != 0] = 1
281
    bayes_mask = np.array(bayes_mask, dtype=bool)
282
283
    # Calculate Bayes probabilities
284
    bayes_abs = np.abs(bayes_k)
285
    bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs))  # 2D: N x C
286
    bayes_abs[bayes_mask] = np.min(
287
        bayes_abs
288
    )  # Bring feature_i feature_i associations to minimum
289
    sort_ids = np.argsort(bayes_abs, axis=None)[::-1]  # 1D: N x C
290
    prob = np.take(bayes_p, sort_ids)  # 1D: N x C
291
    logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]")
292
293
    # Sort Bayes
294
    bayes_k = np.take(bayes_k, sort_ids)  # 1D: N x C
295
296
    # Calculate FDR
297
    fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1)  # 1D
298
    idx = np.argmin(np.abs(fdr - task_config.sig_threshold))
299
    logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]")
300
301
    return sort_ids[:idx], prob[:idx], fdr[:idx], bayes_k[:idx]
302
303
304
def _ttest_approach(
305
    task_config: IdentifyAssociationsTTestConfig,
306
    train_dataloader: DataLoader,
307
    baseline_dataloader: DataLoader,
308
    dataloaders: list[DataLoader],
309
    models_path: Path,
310
    interim_path: Path,
311
    num_perturbed: int,
312
    num_samples: int,
313
    num_continuous: int,
314
    nan_mask: BoolArray,
315
    feature_mask: BoolArray,
316
) -> tuple[Union[IntArray, FloatArray], ...]:
317
318
    from scipy.stats import ttest_rel
319
320
    assert task_config.model is not None
321
    device = torch.device("cuda" if task_config.model.cuda else "cpu")
322
323
    # Train models
324
    logger = get_logger(__name__)
325
    logger.info("Training models")
326
    pvalues = np.empty(
327
        (
328
            len(task_config.num_latent),
329
            task_config.num_refits,
330
            num_perturbed,
331
            num_continuous,
332
        )
333
    )
334
335
    # Last appended dataloader is the baseline
336
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
337
338
    for k, num_latent in enumerate(task_config.num_latent):
339
        for j in range(task_config.num_refits):
340
341
            # Initialize model
342
            model: VAE = hydra.utils.instantiate(
343
                task_config.model,
344
                continuous_shapes=baseline_dataset.con_shapes,
345
                categorical_shapes=baseline_dataset.cat_shapes,
346
                num_latent=num_latent,
347
            )
348
            if j == 0:
349
                logger.debug(f"Model: {model}")
350
351
            # Train model
352
            model_path = models_path / f"model_{num_latent}_{j}.pt"
353
            if model_path.exists():
354
                logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}")
355
                model.load_state_dict(torch.load(model_path))
356
                model.to(device)
357
            else:
358
                logger.debug(f"Training refit {j + 1}/{task_config.num_refits}")
359
                model.to(device)
360
                hydra.utils.call(
361
                    task_config.training_loop,
362
                    model=model,
363
                    train_dataloader=train_dataloader,
364
                )
365
                if task_config.save_refits:
366
                    torch.save(model.state_dict(), model_path)
367
            model.eval()
368
369
            # Get baseline reconstruction and baseline difference
370
            _, baseline_recon = model.reconstruct(baseline_dataloader)
371
            baseline_diff = np.empty((10, num_samples, num_continuous))
372
            for i in range(10):
373
                _, recon = model.reconstruct(baseline_dataloader)
374
                baseline_diff[i, :, :] = recon - baseline_recon
375
            baseline_diff = np.mean(baseline_diff, axis=0)  # 2D: N x C
376
            baseline_diff = np.where(nan_mask, np.nan, baseline_diff)
377
378
            # T-test between baseline and perturb difference
379
            for i in range(num_perturbed):
380
                _, perturb_recon = model.reconstruct(dataloaders[i])
381
                perturb_diff = perturb_recon - baseline_recon
382
                mask = feature_mask[:, [i]] | nan_mask  # 2D: N x C
383
                _, pvalues[k, j, i, :] = ttest_rel(
384
                    a=np.where(mask, np.nan, perturb_diff),
385
                    b=np.where(mask, np.nan, baseline_diff),
386
                    axis=0,
387
                    nan_policy="omit",
388
                )
389
390
    # Correct p-values (Bonferroni)
391
    pvalues = np.minimum(pvalues * num_continuous, 1.0)
392
    np.save(interim_path / "pvals.npy", pvalues)
393
394
    # Find significant hits
395
    overlap_thres = task_config.num_refits // 2
396
    reject = pvalues <= task_config.sig_threshold  # 4D: L x R x P x C
397
    overlap = reject.sum(axis=1) >= overlap_thres  # 3D: L x P x C
398
    sig_ids = overlap.sum(axis=0) >= 3  # 2D: P x C
399
    sig_ids = np.flatnonzero(sig_ids)  # 1D
400
401
    # Report median p-value
402
    masked_pvalues = np.ma.masked_array(pvalues, mask=~reject)  # 4D
403
    masked_pvalues = np.ma.median(masked_pvalues, axis=1)  # 3D
404
    masked_pvalues = np.ma.median(masked_pvalues, axis=0)  # 2D
405
    sig_pvalues = np.ma.compressed(np.take(masked_pvalues, sig_ids))  # 1D
406
407
    return sig_ids, sig_pvalues
408
409
410
def _ks_approach(
411
    config: MOVEConfig,
412
    task_config: IdentifyAssociationsKSConfig,
413
    train_dataloader: DataLoader,
414
    baseline_dataloader: DataLoader,
415
    dataloaders: list[DataLoader],
416
    models_path: Path,
417
    num_perturbed: int,
418
    num_samples: int,
419
    num_continuous: int,
420
    con_names: list[list[str]],
421
    output_path: Path,
422
) -> tuple[Union[IntArray, FloatArray], ...]:
423
    """
424
    Find associations between continuous features using Kolmogorov-Smirnov distances.
425
    When perturbing feature A, this function measures the shift of the reconstructed
426
    distribution for feature B (over samples) from 1) the baseline reconstruction to 2)
427
    the reconstruction when perturbing A.
428
429
    If A and B are related the perturbation of A in the input will lead to a change in
430
    feature B's reconstruction, that will be measured by KS distance.
431
432
    Associations are then ranked according to KS distance (absolute value).
433
434
435
    Args:
436
        config: MOVE main configuration.
437
        task_config: IdentifyAssociationsKSConfig configuration.
438
        train_dataloader: training DataLoader.
439
        baseline_dataloader: unperturbed DataLoader.
440
        dataloaders: list of DataLoaders where DataLoader[i] is obtained by perturbing
441
                     feature i in the target dataset.
442
        models_path: path to the models.
443
        num_perturbed: number of perturbed features.
444
        num_samples: total number of samples
445
        num_continuous: number of continuous features
446
                        (all continuous datasets concatenated).
447
        con_names: list of lists where eah inner list
448
                   contains the feature names of a specific continuous dataset
449
        output_path: path where QC summary metrics will be saved.
450
451
    Returns:
452
        sort_ids: list with flattened IDs of the associations
453
                  above the significance threshold.
454
        ks_distance: Ordered list with signed KS scores. KS scores quantify the
455
                    direction and magnitude of the shift in feature B's reconstruction
456
                    when perturbing feature A.
457
458
459
    !!! Note !!!:
460
461
    The sign of the KS score can be misleading: negative sign means positive shift.
462
    since the cumulative distribution starts growing later and is found below
463
    the reference (baseline). Hence:
464
    a) with plus_std, negative sign means a positive correlation.
465
    b) with minus_std, negative sign means a negative correlation.
466
    """
467
468
    assert task_config.model is not None
469
    device = torch.device("cuda" if task_config.model.cuda else "cpu")
470
    figure_path = output_path / "figures"
471
    figure_path.mkdir(exist_ok=True, parents=True)
472
473
    # Data containers
474
    stats = np.empty((task_config.num_refits, num_perturbed, num_continuous))
475
    stat_signs = np.empty_like(stats)
476
    rec_corr, slope = np.empty((task_config.num_refits, num_continuous)), np.empty(
477
        (task_config.num_refits, num_continuous)
478
    )
479
    ks_mask = np.zeros((num_perturbed, num_continuous))
480
    latent_matrix = np.empty(
481
        (num_samples, task_config.model.num_latent, len(dataloaders))
482
    )
483
484
    # Last appended dataloader is the baseline
485
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
486
487
    # Train models
488
    logger = get_logger(__name__)
489
    logger.info("Training models")
490
491
    target_dataset_idx = config.data.continuous_names.index(task_config.target_dataset)
492
    perturbed_names = con_names[target_dataset_idx]
493
494
    for j in range(task_config.num_refits):  # Train num_refits models
495
496
        # Initialize model
497
        model: VAE = hydra.utils.instantiate(
498
            task_config.model,
499
            continuous_shapes=baseline_dataset.con_shapes,
500
            categorical_shapes=baseline_dataset.cat_shapes,
501
        )
502
        if j == 0:
503
            logger.debug(f"Model: {model}")
504
505
        # Train/reload model
506
        model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt"
507
        if model_path.exists():
508
            logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}")
509
            model.load_state_dict(torch.load(model_path))
510
            model.to(device)
511
        else:
512
            logger.debug(f"Training refit {j + 1}/{task_config.num_refits}")
513
            model.to(device)
514
            hydra.utils.call(
515
                task_config.training_loop,
516
                model=model,
517
                train_dataloader=train_dataloader,
518
            )
519
            if task_config.save_refits:
520
                torch.save(model.state_dict(), model_path)
521
        model.eval()
522
523
        # Calculate baseline reconstruction
524
        _, baseline_recon = model.reconstruct(baseline_dataloader)
525
        min_feat = np.zeros((num_perturbed, num_continuous))
526
        max_feat = np.zeros((num_perturbed, num_continuous))
527
        min_baseline = np.min(baseline_recon, axis=0)
528
        max_baseline = np.max(baseline_recon, axis=0)
529
530
        # QC of feature's reconstruction ##############################
531
        logger.debug("Calculating quality control of the feature reconstructions")
532
        # Correlation and slope for each feature's reconstruction
533
        feature_names = reduce(list.__add__, con_names)
534
535
        for k in range(num_continuous):
536
            x = baseline_dataloader.dataset.con_all.numpy()[:, k]  # baseline_recon[:,i]
537
            y = baseline_recon[:, k]
538
            x_pol, y_pol, (a2, a1, a) = get_2nd_order_polynomial(x, y)
539
            slope[j, k] = a1
540
            rec_corr[j, k] = pearsonr(x, y).statistic
541
542
            if (
543
                feature_names[k] in task_config.perturbed_feature_names
544
                or feature_names[k] in task_config.target_feature_names
545
            ):
546
547
                # Plot correlations
548
                fig = plot_correlations(x, y, x_pol, y_pol, a2, a1, a, k)
549
                fig.savefig(
550
                    figure_path
551
                    / f"Input_vs_reconstruction_correlation_feature_{k}_refit_{j}.png",
552
                    dpi=50,
553
                )
554
555
        # Calculate perturbed reconstruction and shifts #############################
556
        logger.debug("Computing KS scores")
557
558
        # Save original latent space for first refit:
559
        if j == 0:
560
            latent = model.project(baseline_dataloader)
561
            latent_matrix[:, :, -1] = latent
562
563
        for i, pert_feat in enumerate(perturbed_names):
564
            _, perturb_recon = model.reconstruct(dataloaders[i])
565
            min_perturb = np.min(perturb_recon, axis=0)
566
            max_perturb = np.max(perturb_recon, axis=0)
567
            min_feat[i, :] = np.min([min_baseline, min_perturb], axis=0)
568
            max_feat[i, :] = np.max([max_baseline, max_perturb], axis=0)
569
570
            # Save latent representation for perturbed samples
571
            if j == 0:
572
                latent_pert = model.project(dataloaders[i])
573
                latent_matrix[:, :, i] = latent_pert
574
575
            for k, targ_feat in enumerate(feature_names):
576
                # Calculate ks factors: measure distance between baseline and perturbed
577
                # reconstruction distributions per feature (k)
578
                res = ks_2samp(perturb_recon[:, k], baseline_recon[:, k])
579
                stats[j, i, k] = res.statistic
580
                stat_signs[j, i, k] = res.statistic_sign
581
582
                if (
583
                    pert_feat in task_config.perturbed_feature_names
584
                    and targ_feat in task_config.target_feature_names
585
                ):
586
587
                    # Plotting preliminary results:
588
                    n_bins = 50
589
                    hist_base, edges = np.histogram(
590
                        baseline_recon[:, k],
591
                        bins=np.linspace(min_feat[i, k], max_feat[i, k], n_bins),
592
                        density=True,
593
                    )
594
                    hist_pert, edges = np.histogram(
595
                        perturb_recon[:, k],
596
                        bins=np.linspace(min_feat[i, k], max_feat[i, k], n_bins),
597
                        density=True,
598
                    )
599
600
                    # Cumulative distribution:
601
                    fig = plot_cumulative_distributions(
602
                        edges,
603
                        hist_base,
604
                        hist_pert,
605
                        title=f"Cumulative_perturbed_{i}_measuring_"
606
                        f"{k}_stats_{stats[j, i, k]}",
607
                    )
608
                    fig.savefig(
609
                        figure_path
610
                        / (
611
                            f"Cumulative_refit_{j}_perturbed_{i}_"
612
                            f"measuring_{k}_stats_{stats[j, i, k]}.png"
613
                        )
614
                    )
615
616
                    # Feature changes:
617
                    fig = plot_reconstruction_movement(baseline_recon, perturb_recon, k)
618
                    fig.savefig(
619
                        figure_path / f"Changes_pert_{i}_on_feat_{k}_refit_{j}.png"
620
                    )
621
622
    # Save latent space matrix:
623
    np.save(output_path / "latent_location.npy", latent_matrix)
624
    np.save(output_path / "perturbed_features_list.npy", np.array(perturbed_names))
625
626
    # Creating a mask for self associations
627
    logger.debug("Creating self-association mask")
628
    for i in range(num_perturbed):
629
        if task_config.target_value in CONTINUOUS_TARGET_VALUE:
630
            ks_mask[i, :] = (
631
                baseline_dataloader.dataset.con_all[0, :]
632
                - dataloaders[i].dataset.con_all[0, :]
633
            )
634
    ks_mask[ks_mask != 0] = 1
635
    ks_mask = np.array(ks_mask, dtype=bool)
636
637
    # Take the median of KS values (with sign) over refits.
638
    final_stats = np.nanmedian(stats * stat_signs, axis=0)
639
    final_stats[ks_mask] = (
640
        0.0  # Zero all masked values, placing them at end of the ranking
641
    )
642
643
    # KS-threshold:
644
    ks_thr = np.sqrt(-np.log(task_config.sig_threshold / 2) * 1 / (num_samples))
645
    logger.info(f"Suggested absolute KS threshold is: {ks_thr}")
646
647
    # Sort associations by absolute KS value
648
    sort_ids = np.argsort(abs(final_stats), axis=None)[::-1]  # 1D: N x C
649
    ks_distance = np.take(final_stats, sort_ids)  # 1D: N x C
650
651
    # Writing Quality control csv file.
652
    # Mean slope and correlation over refits as qc metrics.
653
    logger.info("Writing QC file")
654
    qc_df = pd.DataFrame({"Feature names": feature_names})
655
    qc_df["slope"] = np.nanmean(slope, axis=0)
656
    qc_df["reconstruction_correlation"] = np.nanmean(rec_corr, axis=0)
657
    qc_df.to_csv(output_path / "QC_summary_KS.tsv", sep="\t", index=False)
658
659
    # Return first idx associations: redefined for reasonable threshold
660
661
    return sort_ids[abs(ks_distance) >= ks_thr], ks_distance[abs(ks_distance) >= ks_thr]
662
663
664
def save_results(
665
    config: MOVEConfig,
666
    con_shapes: list[int],
667
    cat_names: list[list[str]],
668
    con_names: list[list[str]],
669
    output_path: Path,
670
    sig_ids,
671
    extra_cols,
672
    extra_colnames,
673
) -> None:
674
    """
675
    This function saves the obtained associations in a TSV file containing
676
    the following columns:
677
        feature_a_id
678
        feature_b_id
679
        feature_a_name
680
        feature_b_name
681
        feature_b_dataset
682
        proba/p_value: number quantifying the significance of the association
683
684
    Args:
685
        config: main config
686
        con_shapes: tuple with the number of features per continuous dataset
687
        cat_names: list of lists of names for the categorical features.
688
                   Each inner list corresponds to a separate dataset.
689
        con_names: list of lists of names for the continuous features.
690
                   Each inner list corresponds to a separate dataset.
691
        output_path: path where the results will be saved
692
        sig_ids: ids for the significat features
693
        extra_cols: extra data when calling the approach function
694
        extra_colnames: names for the extra data columns
695
    """
696
    logger = get_logger(__name__)
697
    logger.info(f"Significant hits found: {sig_ids.size}")
698
    task_config = cast(IdentifyAssociationsConfig, config.task)
699
    task_type = _get_task_type(task_config)
700
701
    num_continuous = sum(con_shapes)  # C
702
703
    if sig_ids.size > 0:
704
        sig_ids = np.vstack((sig_ids // num_continuous, sig_ids % num_continuous)).T
705
        logger.info("Writing results")
706
        results = pd.DataFrame(sig_ids, columns=["feature_a_id", "feature_b_id"])
707
708
        # Check if the task is for continuous or categorical data
709
        if task_config.target_value in CONTINUOUS_TARGET_VALUE:
710
            target_dataset_idx = config.data.continuous_names.index(
711
                task_config.target_dataset
712
            )
713
            a_df = pd.DataFrame(dict(feature_a_name=con_names[target_dataset_idx]))
714
        else:
715
            target_dataset_idx = config.data.categorical_names.index(
716
                task_config.target_dataset
717
            )
718
            a_df = pd.DataFrame(dict(feature_a_name=cat_names[target_dataset_idx]))
719
        a_df.index.name = "feature_a_id"
720
        a_df.reset_index(inplace=True)
721
        feature_names = reduce(list.__add__, con_names)
722
        b_df = pd.DataFrame(dict(feature_b_name=feature_names))
723
        b_df.index.name = "feature_b_id"
724
        b_df.reset_index(inplace=True)
725
        results = results.merge(a_df, on="feature_a_id", how="left").merge(
726
            b_df, on="feature_b_id", how="left"
727
        )
728
        results["feature_b_dataset"] = pd.cut(
729
            results["feature_b_id"],
730
            bins=cast(list[int], np.cumsum([0] + con_shapes)),
731
            right=False,
732
            labels=config.data.continuous_names,
733
        )
734
        for col, colname in zip(extra_cols, extra_colnames):
735
            results[colname] = col
736
        results.to_csv(
737
            output_path / f"results_sig_assoc_{task_type}.tsv", sep="\t", index=False
738
        )
739
740
741
def identify_associations(config: MOVEConfig) -> None:
742
    """
743
    Leads to the execution of the appropriate association
744
    identification tasks. The function is organized in three
745
    blocks:
746
        1) Prepare the data and create the dataloaders with their masks.
747
        2) Evaluate associations using bayes or ttest approach.
748
        3) Save results.
749
    """
750
    # DATA PREPARATION ######################
751
    # Read original data and create perturbed datasets####
752
753
    logger = get_logger(__name__)
754
    task_config = cast(IdentifyAssociationsConfig, config.task)
755
    task_type = _get_task_type(task_config)
756
    _validate_task_config(task_config, task_type)
757
758
    interim_path = Path(config.data.interim_data_path)
759
760
    models_path = interim_path / "models"
761
    if task_config.save_refits:
762
        models_path.mkdir(exist_ok=True)
763
764
    output_path = Path(config.data.results_path) / "identify_associations"
765
    output_path.mkdir(exist_ok=True, parents=True)
766
767
    # Load datasets:
768
    cat_list, cat_names, con_list, con_names = io.load_preprocessed_data(
769
        interim_path,
770
        config.data.categorical_names,
771
        config.data.continuous_names,
772
    )
773
774
    train_dataloader = make_dataloader(
775
        cat_list,
776
        con_list,
777
        shuffle=True,
778
        batch_size=task_config.batch_size,
779
        drop_last=True,
780
    )
781
782
    con_shapes = [con.shape[1] for con in con_list]
783
784
    num_samples = len(cast(Sized, train_dataloader.sampler))  # N
785
    num_continuous = sum(con_shapes)  # C
786
    logger.debug(f"# continuous features: {num_continuous}")
787
788
    # Creating the baseline dataloader:
789
    baseline_dataloader = make_dataloader(
790
        cat_list, con_list, shuffle=False, batch_size=task_config.batch_size
791
    )
792
793
    # Indentify associations between continuous features:
794
    logger.info(f"Perturbing dataset: '{task_config.target_dataset}'")
795
    if task_config.target_value in CONTINUOUS_TARGET_VALUE:
796
        logger.info(f"Beginning task: identify associations continuous ({task_type})")
797
        logger.info(f"Perturbation type: {task_config.target_value}")
798
        output_subpath = Path(output_path) / "perturbation_visualization"
799
        output_subpath.mkdir(exist_ok=True, parents=True)
800
        (
801
            dataloaders,
802
            nan_mask,
803
            feature_mask,
804
        ) = prepare_for_continuous_perturbation(
805
            config, output_subpath, baseline_dataloader
806
        )
807
808
    # Identify associations between categorical and continuous features:
809
    else:
810
        logger.info("Beginning task: identify associations categorical")
811
        (
812
            dataloaders,
813
            nan_mask,
814
            feature_mask,
815
        ) = prepare_for_categorical_perturbation(
816
            config, interim_path, baseline_dataloader, cat_list
817
        )
818
819
    num_perturbed = len(dataloaders) - 1  # P
820
    logger.debug(f"# perturbed features: {num_perturbed}")
821
822
    # APPROACH EVALUATION ##########################
823
824
    if task_type == "bayes":
825
        task_config = cast(IdentifyAssociationsBayesConfig, task_config)
826
        sig_ids, *extra_cols = _bayes_approach(
827
            config,
828
            task_config,
829
            train_dataloader,
830
            baseline_dataloader,
831
            dataloaders,
832
            models_path,
833
            num_perturbed,
834
            num_samples,
835
            num_continuous,
836
            nan_mask,
837
            feature_mask,
838
        )
839
840
        extra_colnames = ["proba", "fdr", "bayes_k"]
841
842
    elif task_type == "ttest":
843
        task_config = cast(IdentifyAssociationsTTestConfig, task_config)
844
        sig_ids, *extra_cols = _ttest_approach(
845
            task_config,
846
            train_dataloader,
847
            baseline_dataloader,
848
            dataloaders,
849
            models_path,
850
            interim_path,
851
            num_perturbed,
852
            num_samples,
853
            num_continuous,
854
            nan_mask,
855
            feature_mask,
856
        )
857
858
        extra_colnames = ["p_value"]
859
860
    elif task_type == "ks":
861
        task_config = cast(IdentifyAssociationsKSConfig, task_config)
862
        sig_ids, *extra_cols = _ks_approach(
863
            config,
864
            task_config,
865
            train_dataloader,
866
            baseline_dataloader,
867
            dataloaders,
868
            models_path,
869
            num_perturbed,
870
            num_samples,
871
            num_continuous,
872
            con_names,
873
            output_path,
874
        )
875
876
        extra_colnames = ["ks_distance"]
877
878
    else:
879
        raise ValueError()
880
881
    # RESULTS ################################
882
    save_results(
883
        config,
884
        con_shapes,
885
        cat_names,
886
        con_names,
887
        output_path,
888
        sig_ids,
889
        extra_cols,
890
        extra_colnames,
891
    )
892
893
    if exists(output_path / f"results_sig_assoc_{task_type}.tsv"):
894
        association_df = pd.read_csv(
895
            output_path / f"results_sig_assoc_{task_type}.tsv", sep="\t"
896
        )
897
        _ = plot_feature_association_graph(association_df, output_path)
898
        _ = plot_feature_association_graph(association_df, output_path, layout="spring")