--- a +++ b/src/move/tasks/identify_associations.py @@ -0,0 +1,898 @@ +__all__ = ["identify_associations"] + +from functools import reduce +from os.path import exists +from pathlib import Path +from typing import Literal, Sized, Union, cast + +import hydra +import numpy as np +import pandas as pd +import torch +from omegaconf import OmegaConf +from scipy.stats import ks_2samp, pearsonr # type: ignore +from torch.utils.data import DataLoader + +from move.analysis.metrics import get_2nd_order_polynomial +from move.conf.schema import ( + IdentifyAssociationsBayesConfig, + IdentifyAssociationsConfig, + IdentifyAssociationsKSConfig, + IdentifyAssociationsTTestConfig, + MOVEConfig, +) +from move.core.logging import get_logger +from move.core.typing import BoolArray, FloatArray, IntArray +from move.data import io +from move.data.dataloaders import MOVEDataset, make_dataloader +from move.data.perturbations import ( + ContinuousPerturbationType, + perturb_categorical_data, + perturb_continuous_data_extended, +) +from move.data.preprocessing import one_hot_encode_single +from move.models.vae import VAE +from move.visualization.dataset_distributions import ( + plot_correlations, + plot_cumulative_distributions, + plot_feature_association_graph, + plot_reconstruction_movement, +) + +TaskType = Literal["bayes", "ttest", "ks"] +CONTINUOUS_TARGET_VALUE = ["minimum", "maximum", "plus_std", "minus_std"] + + +def _get_task_type( + task_config: IdentifyAssociationsConfig, +) -> TaskType: + task_type = OmegaConf.get_type(task_config) + if task_type is IdentifyAssociationsBayesConfig: + return "bayes" + if task_type is IdentifyAssociationsTTestConfig: + return "ttest" + if task_type is IdentifyAssociationsKSConfig: + return "ks" + raise ValueError("Unsupported type of task!") + + +def _validate_task_config( + task_config: IdentifyAssociationsConfig, task_type: TaskType +) -> None: + if not (0.0 <= task_config.sig_threshold <= 1.0): + raise ValueError("Significance threshold must be within [0, 1].") + if task_type == "ttest": + task_config = cast(IdentifyAssociationsTTestConfig, task_config) + if len(task_config.num_latent) != 4: + raise ValueError("4 latent space dimensions required.") + + +def prepare_for_categorical_perturbation( + config: MOVEConfig, + interim_path: Path, + baseline_dataloader: DataLoader, + cat_list: list[FloatArray], +) -> tuple[ + list[DataLoader], + BoolArray, + BoolArray, +]: + """ + This function creates the required dataloaders and masks + for further categorical association analysis. + + Args: + config: main configuration file + interim_path: path where the intermediate outputs are saved + baseline_dataloader: reference dataloader that will be perturbed + cat_list: list of arrays with categorical data + + Returns: + dataloaders: all dataloaders, including baseline appended last. + nan_mask: mask for Nans + feature_mask: masks the column for the perturbed feature. + """ + + # Read original data and create perturbed datasets + task_config = cast(IdentifyAssociationsConfig, config.task) + logger = get_logger(__name__) + + # Loading mappings: + mappings = io.load_mappings(interim_path / "mappings.json") + target_mapping = mappings[task_config.target_dataset] + target_value = one_hot_encode_single(target_mapping, task_config.target_value) + logger.debug( + f"Target value: {task_config.target_value} => {target_value.astype(int)[0]}" + ) + + dataloaders = perturb_categorical_data( + baseline_dataloader, + config.data.categorical_names, + task_config.target_dataset, + target_value, + ) + dataloaders.append(baseline_dataloader) + + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + + assert baseline_dataset.con_all is not None + orig_con = baseline_dataset.con_all + nan_mask = (orig_con == 0).numpy() # NaN values encoded as 0s + logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}") + + target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset) + target_dataset = cat_list[target_dataset_idx] + feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P + feature_mask |= np.sum(target_dataset, axis=2) == 0 + + return ( + dataloaders, + nan_mask, + feature_mask, + ) + + +def prepare_for_continuous_perturbation( + config: MOVEConfig, + output_subpath: Path, + baseline_dataloader: DataLoader, +) -> tuple[ + list[DataLoader], + BoolArray, + BoolArray, +]: + """ + This function creates the required dataloaders and masks + for further continuous association analysis. + + Args: + config: + main configuration file. + output_subpath: + path where the output plots for continuous analysis are saved. + baseline_dataloader: + reference dataloader that will be perturbed. + + Returns: + dataloaders: + list with all dataloaders, including baseline appended last. + nan_mask: + mask for NaNs + feature_mask: + same as `nan_mask`, in this case. + """ + + # Read original data and create perturbed datasets + logger = get_logger(__name__) + task_config = cast(IdentifyAssociationsConfig, config.task) + + dataloaders = perturb_continuous_data_extended( + baseline_dataloader, + config.data.continuous_names, + task_config.target_dataset, + cast(ContinuousPerturbationType, task_config.target_value), + output_subpath, + ) + dataloaders.append(baseline_dataloader) + + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + + assert baseline_dataset.con_all is not None + orig_con = baseline_dataset.con_all + nan_mask = (orig_con == 0).numpy() # NaN values encoded as 0s + logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}") + feature_mask = nan_mask + + return (dataloaders, nan_mask, feature_mask) + + +def _bayes_approach( + config: MOVEConfig, + task_config: IdentifyAssociationsBayesConfig, + train_dataloader: DataLoader, + baseline_dataloader: DataLoader, + dataloaders: list[DataLoader], + models_path: Path, + num_perturbed: int, + num_samples: int, + num_continuous: int, + nan_mask: BoolArray, + feature_mask: BoolArray, +) -> tuple[Union[IntArray, FloatArray], ...]: + + assert task_config.model is not None + device = torch.device("cuda" if task_config.model.cuda else "cpu") + + # Train models + logger = get_logger(__name__) + logger.info("Training models") + mean_diff = np.zeros((num_perturbed, num_samples, num_continuous)) + normalizer = 1 / task_config.num_refits + + # Last appended dataloader is the baseline + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + + for j in range(task_config.num_refits): + # Initialize model + model: VAE = hydra.utils.instantiate( + task_config.model, + continuous_shapes=baseline_dataset.con_shapes, + categorical_shapes=baseline_dataset.cat_shapes, + ) + if j == 0: + logger.debug(f"Model: {model}") + + # Train/reload model + model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt" + if model_path.exists(): + logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}") + model.load_state_dict(torch.load(model_path)) + model.to(device) + else: + logger.debug(f"Training refit {j + 1}/{task_config.num_refits}") + model.to(device) + hydra.utils.call( + task_config.training_loop, + model=model, + train_dataloader=train_dataloader, + ) + if task_config.save_refits: + torch.save(model.state_dict(), model_path) + model.eval() + + # Calculate baseline reconstruction + _, baseline_recon = model.reconstruct(baseline_dataloader) + min_feat, max_feat = np.zeros((num_perturbed, num_continuous)), np.zeros( + (num_perturbed, num_continuous) + ) + min_baseline, max_baseline = np.min(baseline_recon, axis=0), np.max( + baseline_recon, axis=0 + ) + + # Calculate perturb reconstruction => keep track of mean difference + for i in range(num_perturbed): + _, perturb_recon = model.reconstruct(dataloaders[i]) + diff = perturb_recon - baseline_recon # 2D: N x C + mean_diff[i, :, :] += diff * normalizer + + min_perturb, max_perturb = np.min(perturb_recon, axis=0), np.max( + perturb_recon, axis=0 + ) + min_feat[i, :], max_feat[i, :] = np.min( + [min_baseline, min_perturb], axis=0 + ), np.max([max_baseline, max_perturb], axis=0) + + # Calculate Bayes factors + logger.info("Identifying significant features") + bayes_k = np.empty((num_perturbed, num_continuous)) + bayes_mask = np.zeros(np.shape(bayes_k)) + for i in range(num_perturbed): + mask = feature_mask[:, [i]] | nan_mask # 2D: N x C + diff = np.ma.masked_array(mean_diff[i, :, :], mask=mask) # 2D: N x C + prob = np.ma.compressed(np.mean(diff > 1e-8, axis=0)) # 1D: C + bayes_k[i, :] = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8) + if task_config.target_value in CONTINUOUS_TARGET_VALUE: + bayes_mask[i, :] = ( + baseline_dataloader.dataset.con_all[0, :] + - dataloaders[i].dataset.con_all[0, :] + ) + + bayes_mask[bayes_mask != 0] = 1 + bayes_mask = np.array(bayes_mask, dtype=bool) + + # Calculate Bayes probabilities + bayes_abs = np.abs(bayes_k) + bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs)) # 2D: N x C + bayes_abs[bayes_mask] = np.min( + bayes_abs + ) # Bring feature_i feature_i associations to minimum + sort_ids = np.argsort(bayes_abs, axis=None)[::-1] # 1D: N x C + prob = np.take(bayes_p, sort_ids) # 1D: N x C + logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]") + + # Sort Bayes + bayes_k = np.take(bayes_k, sort_ids) # 1D: N x C + + # Calculate FDR + fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1) # 1D + idx = np.argmin(np.abs(fdr - task_config.sig_threshold)) + logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]") + + return sort_ids[:idx], prob[:idx], fdr[:idx], bayes_k[:idx] + + +def _ttest_approach( + task_config: IdentifyAssociationsTTestConfig, + train_dataloader: DataLoader, + baseline_dataloader: DataLoader, + dataloaders: list[DataLoader], + models_path: Path, + interim_path: Path, + num_perturbed: int, + num_samples: int, + num_continuous: int, + nan_mask: BoolArray, + feature_mask: BoolArray, +) -> tuple[Union[IntArray, FloatArray], ...]: + + from scipy.stats import ttest_rel + + assert task_config.model is not None + device = torch.device("cuda" if task_config.model.cuda else "cpu") + + # Train models + logger = get_logger(__name__) + logger.info("Training models") + pvalues = np.empty( + ( + len(task_config.num_latent), + task_config.num_refits, + num_perturbed, + num_continuous, + ) + ) + + # Last appended dataloader is the baseline + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + + for k, num_latent in enumerate(task_config.num_latent): + for j in range(task_config.num_refits): + + # Initialize model + model: VAE = hydra.utils.instantiate( + task_config.model, + continuous_shapes=baseline_dataset.con_shapes, + categorical_shapes=baseline_dataset.cat_shapes, + num_latent=num_latent, + ) + if j == 0: + logger.debug(f"Model: {model}") + + # Train model + model_path = models_path / f"model_{num_latent}_{j}.pt" + if model_path.exists(): + logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}") + model.load_state_dict(torch.load(model_path)) + model.to(device) + else: + logger.debug(f"Training refit {j + 1}/{task_config.num_refits}") + model.to(device) + hydra.utils.call( + task_config.training_loop, + model=model, + train_dataloader=train_dataloader, + ) + if task_config.save_refits: + torch.save(model.state_dict(), model_path) + model.eval() + + # Get baseline reconstruction and baseline difference + _, baseline_recon = model.reconstruct(baseline_dataloader) + baseline_diff = np.empty((10, num_samples, num_continuous)) + for i in range(10): + _, recon = model.reconstruct(baseline_dataloader) + baseline_diff[i, :, :] = recon - baseline_recon + baseline_diff = np.mean(baseline_diff, axis=0) # 2D: N x C + baseline_diff = np.where(nan_mask, np.nan, baseline_diff) + + # T-test between baseline and perturb difference + for i in range(num_perturbed): + _, perturb_recon = model.reconstruct(dataloaders[i]) + perturb_diff = perturb_recon - baseline_recon + mask = feature_mask[:, [i]] | nan_mask # 2D: N x C + _, pvalues[k, j, i, :] = ttest_rel( + a=np.where(mask, np.nan, perturb_diff), + b=np.where(mask, np.nan, baseline_diff), + axis=0, + nan_policy="omit", + ) + + # Correct p-values (Bonferroni) + pvalues = np.minimum(pvalues * num_continuous, 1.0) + np.save(interim_path / "pvals.npy", pvalues) + + # Find significant hits + overlap_thres = task_config.num_refits // 2 + reject = pvalues <= task_config.sig_threshold # 4D: L x R x P x C + overlap = reject.sum(axis=1) >= overlap_thres # 3D: L x P x C + sig_ids = overlap.sum(axis=0) >= 3 # 2D: P x C + sig_ids = np.flatnonzero(sig_ids) # 1D + + # Report median p-value + masked_pvalues = np.ma.masked_array(pvalues, mask=~reject) # 4D + masked_pvalues = np.ma.median(masked_pvalues, axis=1) # 3D + masked_pvalues = np.ma.median(masked_pvalues, axis=0) # 2D + sig_pvalues = np.ma.compressed(np.take(masked_pvalues, sig_ids)) # 1D + + return sig_ids, sig_pvalues + + +def _ks_approach( + config: MOVEConfig, + task_config: IdentifyAssociationsKSConfig, + train_dataloader: DataLoader, + baseline_dataloader: DataLoader, + dataloaders: list[DataLoader], + models_path: Path, + num_perturbed: int, + num_samples: int, + num_continuous: int, + con_names: list[list[str]], + output_path: Path, +) -> tuple[Union[IntArray, FloatArray], ...]: + """ + Find associations between continuous features using Kolmogorov-Smirnov distances. + When perturbing feature A, this function measures the shift of the reconstructed + distribution for feature B (over samples) from 1) the baseline reconstruction to 2) + the reconstruction when perturbing A. + + If A and B are related the perturbation of A in the input will lead to a change in + feature B's reconstruction, that will be measured by KS distance. + + Associations are then ranked according to KS distance (absolute value). + + + Args: + config: MOVE main configuration. + task_config: IdentifyAssociationsKSConfig configuration. + train_dataloader: training DataLoader. + baseline_dataloader: unperturbed DataLoader. + dataloaders: list of DataLoaders where DataLoader[i] is obtained by perturbing + feature i in the target dataset. + models_path: path to the models. + num_perturbed: number of perturbed features. + num_samples: total number of samples + num_continuous: number of continuous features + (all continuous datasets concatenated). + con_names: list of lists where eah inner list + contains the feature names of a specific continuous dataset + output_path: path where QC summary metrics will be saved. + + Returns: + sort_ids: list with flattened IDs of the associations + above the significance threshold. + ks_distance: Ordered list with signed KS scores. KS scores quantify the + direction and magnitude of the shift in feature B's reconstruction + when perturbing feature A. + + + !!! Note !!!: + + The sign of the KS score can be misleading: negative sign means positive shift. + since the cumulative distribution starts growing later and is found below + the reference (baseline). Hence: + a) with plus_std, negative sign means a positive correlation. + b) with minus_std, negative sign means a negative correlation. + """ + + assert task_config.model is not None + device = torch.device("cuda" if task_config.model.cuda else "cpu") + figure_path = output_path / "figures" + figure_path.mkdir(exist_ok=True, parents=True) + + # Data containers + stats = np.empty((task_config.num_refits, num_perturbed, num_continuous)) + stat_signs = np.empty_like(stats) + rec_corr, slope = np.empty((task_config.num_refits, num_continuous)), np.empty( + (task_config.num_refits, num_continuous) + ) + ks_mask = np.zeros((num_perturbed, num_continuous)) + latent_matrix = np.empty( + (num_samples, task_config.model.num_latent, len(dataloaders)) + ) + + # Last appended dataloader is the baseline + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + + # Train models + logger = get_logger(__name__) + logger.info("Training models") + + target_dataset_idx = config.data.continuous_names.index(task_config.target_dataset) + perturbed_names = con_names[target_dataset_idx] + + for j in range(task_config.num_refits): # Train num_refits models + + # Initialize model + model: VAE = hydra.utils.instantiate( + task_config.model, + continuous_shapes=baseline_dataset.con_shapes, + categorical_shapes=baseline_dataset.cat_shapes, + ) + if j == 0: + logger.debug(f"Model: {model}") + + # Train/reload model + model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt" + if model_path.exists(): + logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}") + model.load_state_dict(torch.load(model_path)) + model.to(device) + else: + logger.debug(f"Training refit {j + 1}/{task_config.num_refits}") + model.to(device) + hydra.utils.call( + task_config.training_loop, + model=model, + train_dataloader=train_dataloader, + ) + if task_config.save_refits: + torch.save(model.state_dict(), model_path) + model.eval() + + # Calculate baseline reconstruction + _, baseline_recon = model.reconstruct(baseline_dataloader) + min_feat = np.zeros((num_perturbed, num_continuous)) + max_feat = np.zeros((num_perturbed, num_continuous)) + min_baseline = np.min(baseline_recon, axis=0) + max_baseline = np.max(baseline_recon, axis=0) + + # QC of feature's reconstruction ############################## + logger.debug("Calculating quality control of the feature reconstructions") + # Correlation and slope for each feature's reconstruction + feature_names = reduce(list.__add__, con_names) + + for k in range(num_continuous): + x = baseline_dataloader.dataset.con_all.numpy()[:, k] # baseline_recon[:,i] + y = baseline_recon[:, k] + x_pol, y_pol, (a2, a1, a) = get_2nd_order_polynomial(x, y) + slope[j, k] = a1 + rec_corr[j, k] = pearsonr(x, y).statistic + + if ( + feature_names[k] in task_config.perturbed_feature_names + or feature_names[k] in task_config.target_feature_names + ): + + # Plot correlations + fig = plot_correlations(x, y, x_pol, y_pol, a2, a1, a, k) + fig.savefig( + figure_path + / f"Input_vs_reconstruction_correlation_feature_{k}_refit_{j}.png", + dpi=50, + ) + + # Calculate perturbed reconstruction and shifts ############################# + logger.debug("Computing KS scores") + + # Save original latent space for first refit: + if j == 0: + latent = model.project(baseline_dataloader) + latent_matrix[:, :, -1] = latent + + for i, pert_feat in enumerate(perturbed_names): + _, perturb_recon = model.reconstruct(dataloaders[i]) + min_perturb = np.min(perturb_recon, axis=0) + max_perturb = np.max(perturb_recon, axis=0) + min_feat[i, :] = np.min([min_baseline, min_perturb], axis=0) + max_feat[i, :] = np.max([max_baseline, max_perturb], axis=0) + + # Save latent representation for perturbed samples + if j == 0: + latent_pert = model.project(dataloaders[i]) + latent_matrix[:, :, i] = latent_pert + + for k, targ_feat in enumerate(feature_names): + # Calculate ks factors: measure distance between baseline and perturbed + # reconstruction distributions per feature (k) + res = ks_2samp(perturb_recon[:, k], baseline_recon[:, k]) + stats[j, i, k] = res.statistic + stat_signs[j, i, k] = res.statistic_sign + + if ( + pert_feat in task_config.perturbed_feature_names + and targ_feat in task_config.target_feature_names + ): + + # Plotting preliminary results: + n_bins = 50 + hist_base, edges = np.histogram( + baseline_recon[:, k], + bins=np.linspace(min_feat[i, k], max_feat[i, k], n_bins), + density=True, + ) + hist_pert, edges = np.histogram( + perturb_recon[:, k], + bins=np.linspace(min_feat[i, k], max_feat[i, k], n_bins), + density=True, + ) + + # Cumulative distribution: + fig = plot_cumulative_distributions( + edges, + hist_base, + hist_pert, + title=f"Cumulative_perturbed_{i}_measuring_" + f"{k}_stats_{stats[j, i, k]}", + ) + fig.savefig( + figure_path + / ( + f"Cumulative_refit_{j}_perturbed_{i}_" + f"measuring_{k}_stats_{stats[j, i, k]}.png" + ) + ) + + # Feature changes: + fig = plot_reconstruction_movement(baseline_recon, perturb_recon, k) + fig.savefig( + figure_path / f"Changes_pert_{i}_on_feat_{k}_refit_{j}.png" + ) + + # Save latent space matrix: + np.save(output_path / "latent_location.npy", latent_matrix) + np.save(output_path / "perturbed_features_list.npy", np.array(perturbed_names)) + + # Creating a mask for self associations + logger.debug("Creating self-association mask") + for i in range(num_perturbed): + if task_config.target_value in CONTINUOUS_TARGET_VALUE: + ks_mask[i, :] = ( + baseline_dataloader.dataset.con_all[0, :] + - dataloaders[i].dataset.con_all[0, :] + ) + ks_mask[ks_mask != 0] = 1 + ks_mask = np.array(ks_mask, dtype=bool) + + # Take the median of KS values (with sign) over refits. + final_stats = np.nanmedian(stats * stat_signs, axis=0) + final_stats[ks_mask] = ( + 0.0 # Zero all masked values, placing them at end of the ranking + ) + + # KS-threshold: + ks_thr = np.sqrt(-np.log(task_config.sig_threshold / 2) * 1 / (num_samples)) + logger.info(f"Suggested absolute KS threshold is: {ks_thr}") + + # Sort associations by absolute KS value + sort_ids = np.argsort(abs(final_stats), axis=None)[::-1] # 1D: N x C + ks_distance = np.take(final_stats, sort_ids) # 1D: N x C + + # Writing Quality control csv file. + # Mean slope and correlation over refits as qc metrics. + logger.info("Writing QC file") + qc_df = pd.DataFrame({"Feature names": feature_names}) + qc_df["slope"] = np.nanmean(slope, axis=0) + qc_df["reconstruction_correlation"] = np.nanmean(rec_corr, axis=0) + qc_df.to_csv(output_path / "QC_summary_KS.tsv", sep="\t", index=False) + + # Return first idx associations: redefined for reasonable threshold + + return sort_ids[abs(ks_distance) >= ks_thr], ks_distance[abs(ks_distance) >= ks_thr] + + +def save_results( + config: MOVEConfig, + con_shapes: list[int], + cat_names: list[list[str]], + con_names: list[list[str]], + output_path: Path, + sig_ids, + extra_cols, + extra_colnames, +) -> None: + """ + This function saves the obtained associations in a TSV file containing + the following columns: + feature_a_id + feature_b_id + feature_a_name + feature_b_name + feature_b_dataset + proba/p_value: number quantifying the significance of the association + + Args: + config: main config + con_shapes: tuple with the number of features per continuous dataset + cat_names: list of lists of names for the categorical features. + Each inner list corresponds to a separate dataset. + con_names: list of lists of names for the continuous features. + Each inner list corresponds to a separate dataset. + output_path: path where the results will be saved + sig_ids: ids for the significat features + extra_cols: extra data when calling the approach function + extra_colnames: names for the extra data columns + """ + logger = get_logger(__name__) + logger.info(f"Significant hits found: {sig_ids.size}") + task_config = cast(IdentifyAssociationsConfig, config.task) + task_type = _get_task_type(task_config) + + num_continuous = sum(con_shapes) # C + + if sig_ids.size > 0: + sig_ids = np.vstack((sig_ids // num_continuous, sig_ids % num_continuous)).T + logger.info("Writing results") + results = pd.DataFrame(sig_ids, columns=["feature_a_id", "feature_b_id"]) + + # Check if the task is for continuous or categorical data + if task_config.target_value in CONTINUOUS_TARGET_VALUE: + target_dataset_idx = config.data.continuous_names.index( + task_config.target_dataset + ) + a_df = pd.DataFrame(dict(feature_a_name=con_names[target_dataset_idx])) + else: + target_dataset_idx = config.data.categorical_names.index( + task_config.target_dataset + ) + a_df = pd.DataFrame(dict(feature_a_name=cat_names[target_dataset_idx])) + a_df.index.name = "feature_a_id" + a_df.reset_index(inplace=True) + feature_names = reduce(list.__add__, con_names) + b_df = pd.DataFrame(dict(feature_b_name=feature_names)) + b_df.index.name = "feature_b_id" + b_df.reset_index(inplace=True) + results = results.merge(a_df, on="feature_a_id", how="left").merge( + b_df, on="feature_b_id", how="left" + ) + results["feature_b_dataset"] = pd.cut( + results["feature_b_id"], + bins=cast(list[int], np.cumsum([0] + con_shapes)), + right=False, + labels=config.data.continuous_names, + ) + for col, colname in zip(extra_cols, extra_colnames): + results[colname] = col + results.to_csv( + output_path / f"results_sig_assoc_{task_type}.tsv", sep="\t", index=False + ) + + +def identify_associations(config: MOVEConfig) -> None: + """ + Leads to the execution of the appropriate association + identification tasks. The function is organized in three + blocks: + 1) Prepare the data and create the dataloaders with their masks. + 2) Evaluate associations using bayes or ttest approach. + 3) Save results. + """ + # DATA PREPARATION ###################### + # Read original data and create perturbed datasets#### + + logger = get_logger(__name__) + task_config = cast(IdentifyAssociationsConfig, config.task) + task_type = _get_task_type(task_config) + _validate_task_config(task_config, task_type) + + interim_path = Path(config.data.interim_data_path) + + models_path = interim_path / "models" + if task_config.save_refits: + models_path.mkdir(exist_ok=True) + + output_path = Path(config.data.results_path) / "identify_associations" + output_path.mkdir(exist_ok=True, parents=True) + + # Load datasets: + cat_list, cat_names, con_list, con_names = io.load_preprocessed_data( + interim_path, + config.data.categorical_names, + config.data.continuous_names, + ) + + train_dataloader = make_dataloader( + cat_list, + con_list, + shuffle=True, + batch_size=task_config.batch_size, + drop_last=True, + ) + + con_shapes = [con.shape[1] for con in con_list] + + num_samples = len(cast(Sized, train_dataloader.sampler)) # N + num_continuous = sum(con_shapes) # C + logger.debug(f"# continuous features: {num_continuous}") + + # Creating the baseline dataloader: + baseline_dataloader = make_dataloader( + cat_list, con_list, shuffle=False, batch_size=task_config.batch_size + ) + + # Indentify associations between continuous features: + logger.info(f"Perturbing dataset: '{task_config.target_dataset}'") + if task_config.target_value in CONTINUOUS_TARGET_VALUE: + logger.info(f"Beginning task: identify associations continuous ({task_type})") + logger.info(f"Perturbation type: {task_config.target_value}") + output_subpath = Path(output_path) / "perturbation_visualization" + output_subpath.mkdir(exist_ok=True, parents=True) + ( + dataloaders, + nan_mask, + feature_mask, + ) = prepare_for_continuous_perturbation( + config, output_subpath, baseline_dataloader + ) + + # Identify associations between categorical and continuous features: + else: + logger.info("Beginning task: identify associations categorical") + ( + dataloaders, + nan_mask, + feature_mask, + ) = prepare_for_categorical_perturbation( + config, interim_path, baseline_dataloader, cat_list + ) + + num_perturbed = len(dataloaders) - 1 # P + logger.debug(f"# perturbed features: {num_perturbed}") + + # APPROACH EVALUATION ########################## + + if task_type == "bayes": + task_config = cast(IdentifyAssociationsBayesConfig, task_config) + sig_ids, *extra_cols = _bayes_approach( + config, + task_config, + train_dataloader, + baseline_dataloader, + dataloaders, + models_path, + num_perturbed, + num_samples, + num_continuous, + nan_mask, + feature_mask, + ) + + extra_colnames = ["proba", "fdr", "bayes_k"] + + elif task_type == "ttest": + task_config = cast(IdentifyAssociationsTTestConfig, task_config) + sig_ids, *extra_cols = _ttest_approach( + task_config, + train_dataloader, + baseline_dataloader, + dataloaders, + models_path, + interim_path, + num_perturbed, + num_samples, + num_continuous, + nan_mask, + feature_mask, + ) + + extra_colnames = ["p_value"] + + elif task_type == "ks": + task_config = cast(IdentifyAssociationsKSConfig, task_config) + sig_ids, *extra_cols = _ks_approach( + config, + task_config, + train_dataloader, + baseline_dataloader, + dataloaders, + models_path, + num_perturbed, + num_samples, + num_continuous, + con_names, + output_path, + ) + + extra_colnames = ["ks_distance"] + + else: + raise ValueError() + + # RESULTS ################################ + save_results( + config, + con_shapes, + cat_names, + con_names, + output_path, + sig_ids, + extra_cols, + extra_colnames, + ) + + if exists(output_path / f"results_sig_assoc_{task_type}.tsv"): + association_df = pd.read_csv( + output_path / f"results_sig_assoc_{task_type}.tsv", sep="\t" + ) + _ = plot_feature_association_graph(association_df, output_path) + _ = plot_feature_association_graph(association_df, output_path, layout="spring")