[f123c6]: / notebooks / two_vs_seven_omics_analysis.ipynb

Download this file

2224 lines (2223 with data), 78.0 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a temporary notebook for shap value analysis and plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "proj_dir = \"/home/scai/PhenPred\"\n",
    "if not os.path.exists(proj_dir):\n",
    "    proj_dir = \"/Users/emanuel/Projects/PhenPred\"\n",
    "sys.path.extend([proj_dir])\n",
    "\n",
    "import json\n",
    "import PhenPred\n",
    "import argparse\n",
    "import pandas as pd\n",
    "from PhenPred.vae import plot_folder\n",
    "from PhenPred.vae.Hypers import Hypers\n",
    "from PhenPred.vae.Train import CLinesTrain\n",
    "from PhenPred.vae.DatasetDepMap23Q2 import CLinesDatasetDepMap23Q2\n",
    "from PhenPred.vae.DatasetMOFA import CLinesDatasetMOFA\n",
    "from PhenPred.vae.DatasetMOVE import CLinesDatasetMOVE\n",
    "from PhenPred.vae.DatasetJAMIE import CLinesDatasetJAMIE\n",
    "from PhenPred.vae.DatasetIClusterPlus import CLinesDatasetIClusterPlus\n",
    "from PhenPred.vae.DatasetMoCluster import CLinesDatasetMoCluster\n",
    "from PhenPred.vae.DatasetMixOmics import CLinesDatasetMixOmics\n",
    "from PhenPred.vae.DatasetSCVAEIT import CLinesDatasetSCVAEIT\n",
    "from PhenPred.Utils import two_vars_correlation\n",
    "from sklearn.discriminant_analysis import StandardScaler\n",
    "from sklearn.metrics import calinski_harabasz_score, davies_bouldin_score\n",
    "\n",
    "\n",
    "pd.set_option(\"display.max_rows\", 100)\n",
    "pd.set_option(\"display.max_columns\", 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import ttest_ind, ttest_rel, wilcoxon\n",
    "from scipy.stats import shapiro"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"Arial\"\n",
    "plt.rcParams[\"font.size\"] = 4\n",
    "plt.rcParams[\"axes.linewidth\"] = 0.25\n",
    "plt.rcParams[\"figure.figsize\"] = (2.5, 2.5)\n",
    "plt.rcParams[\"pdf.fonttype\"] = 42\n",
    "plt.rcParams[\"ps.fonttype\"] = 42\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"axes.linewidth\"] = 0.25\n",
    "plt.rcParams[\"legend.fontsize\"] = 4\n",
    "\n",
    "sns.set(style=\"ticks\", context=\"paper\", font_scale=1, font=\"Arial\")\n",
    "sns.set_context(\n",
    "    \"paper\",\n",
    "    rc={\n",
    "        \"axes.linewidth\": 0.25,\n",
    "        \"xtick.major.size\": 2,\n",
    "        \"xtick.major.width\": 0.25,\n",
    "        \"ytick.major.size\": 2,\n",
    "        \"ytick.major.width\": 0.25,\n",
    "        \"xtick.labelsize\": 6,\n",
    "        \"ytick.labelsize\": 6,\n",
    "        \"axes.labelsize\": 7,\n",
    "        \"legend.fontsize\": 6,\n",
    "        \"legend.title_fontsize\": 6,\n",
    "    },\n",
    ")\n",
    "\n",
    "import matplotlib.patches as mpatches\n",
    "import umap\n",
    "\n",
    "pd.set_option(\"display.max_rows\", 100)\n",
    "pd.set_option(\"display.max_columns\", 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shap\n",
    "import pickle\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "OMIC_PALLETS = {\n",
    "    \"conditionals\": \"#4c72b0\",\n",
    "    \"copynumber\": \"#dd8452\",\n",
    "    \"drugresponse\": \"#55a868\",\n",
    "    \"metabolomics\": \"#c44e52\",\n",
    "    \"proteomics\": \"#8172b3\",\n",
    "    \"crisprcas9\": \"#937860\",\n",
    "    \"transcriptomics\": \"#da8bc3\",\n",
    "    \"methylation\": \"#8c8c8c\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Latent comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "TIMESTAMP = \"20240830_110319\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "clustering_score_df = pd.read_csv(\n",
    "    f\"./reports/vae/latent/{TIMESTAMP}_clustering_score.csv\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, ax = plt.subplots(1, 1, figsize=(3, 3), dpi=600)\n",
    "sns.barplot(data=clustering_score_df, x=\"metric\", y=\"score\", ax=ax, hue=\"model\")\n",
    "PhenPred.save_figure(f\"{plot_folder}/latent/{TIMESTAMP}_clustering_score_barplot_2vs7omcis\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "TIMESTAMP_2omics = \"20240830_110319\"\n",
    "TIMESTAMP_7omics = \"20231023_092657\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"<not serializable>\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"copynumber\": \"data/clines//cnv_summary_20230303_matrix.csv\",\n",
      "        \"crisprcas9\": \"data/clines//depmap23Q2/CRISPRGeneEffect.csv\",\n",
      "        \"drugresponse\": \"data/clines//drugresponse.csv\",\n",
      "        \"metabolomics\": \"data/clines//metabolomics.csv\",\n",
      "        \"methylation\": \"data/clines//methylation.csv\",\n",
      "        \"proteomics\": \"data/clines//proteomics.csv\",\n",
      "        \"transcriptomics\": \"data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": \"20231023_092657\",\n",
      "    \"model\": \"MOVE\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"<not serializable>\",\n",
      "    \"save_model\": true,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_cv\": false,\n",
      "    \"standardize\": true,\n",
      "    \"use_conditionals\": true,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"macro\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "DepMap23Q2 | Samples = 1,656 | Proteomics = 4,922 (0 masked) | Metabolomics = 225 (0 masked) | Drug response = 810 (0 masked) | CRISPR-Cas9 = 17,931 (12,714 masked) | Methylation = 14,608 (7,014 masked) | Transcriptomics = 15,278 (7,193 masked) | Copy number = 777 (0 masked) | Labels = 237\n"
     ]
    }
   ],
   "source": [
    "hyperparameters_7omics = Hypers.read_hyperparameters(timestamp=\"20231023_092657\")\n",
    "clines_db_7omics = CLinesDatasetDepMap23Q2(\n",
    "    datasets=hyperparameters_7omics[\"datasets\"],\n",
    "    labels_names=hyperparameters_7omics[\"labels\"],\n",
    "    standardize=hyperparameters_7omics[\"standardize\"],\n",
    "    filter_features=hyperparameters_7omics[\"filter_features\"],\n",
    "    filtered_encoder_only=hyperparameters_7omics[\"filtered_encoder_only\"],\n",
    "    feature_miss_rate_thres=hyperparameters_7omics[\"feature_miss_rate_thres\"],\n",
    ")\n",
    "\n",
    "train_7omics = CLinesTrain(\n",
    "    clines_db_7omics,\n",
    "    hyperparameters_7omics,\n",
    "    verbose=hyperparameters_7omics[\"verbose\"],\n",
    "    stratify_cv_by=clines_db_7omics.samples_by_tissue(\n",
    "        \"Haematopoietic and Lymphoid\"\n",
    "    ),\n",
    ")\n",
    "train_7omics.run(run_timestamp=hyperparameters_7omics[\"load_run\"])\n",
    "mosa_7omics_imputed, mosa_7omics_latent = train_7omics.load_vae_reconstructions()\n",
    "mosa_7omics_predicted, _ = train_7omics.load_vae_reconstructions(mode=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"<not serializable>\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": \"20240830_110319\",\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"<not serializable>\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": true,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "DepMap23Q2 | Samples = 1,590 | Transcriptomics = 15,278 (7,193 masked) | Drug response = 810 (0 masked) | Labels = 237\n"
     ]
    }
   ],
   "source": [
    "hyperparameters = Hypers.read_hyperparameters(timestamp=TIMESTAMP_2omics)\n",
    "clines_db = CLinesDatasetDepMap23Q2(\n",
    "    labels_names=hyperparameters[\"labels\"],\n",
    "    datasets=hyperparameters[\"datasets\"],\n",
    "    feature_miss_rate_thres=hyperparameters[\"feature_miss_rate_thres\"],\n",
    "    standardize=hyperparameters[\"standardize\"],\n",
    "    filter_features=hyperparameters[\"filter_features\"],\n",
    "    filtered_encoder_only=hyperparameters[\"filtered_encoder_only\"],\n",
    ")\n",
    "train = CLinesTrain(\n",
    "    clines_db,\n",
    "    hyperparameters,\n",
    "    verbose=hyperparameters[\"verbose\"],\n",
    "    stratify_cv_by=clines_db.samples_by_tissue(\"Haematopoietic and Lymphoid\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "train.run(run_timestamp=hyperparameters[\"load_run\"])\n",
    "vae_imputed, vae_latent = train.load_vae_reconstructions()\n",
    "vae_predicted, _ = train.load_vae_reconstructions(mode=\"all\")\n",
    "\n",
    "mofa_imputed, mofa_latent = CLinesDatasetMOFA.load_reconstructions(clines_db)\n",
    "move_diabetes_imputed, move_diabetes_latent = CLinesDatasetMOVE.load_reconstructions(\n",
    "    clines_db\n",
    ")\n",
    "jamie_imputed, jamie_latent = CLinesDatasetJAMIE.load_reconstructions(clines_db)\n",
    "\n",
    "_, mixOmics_latent = CLinesDatasetMixOmics.load_reconstructions(clines_db)\n",
    "_, iClusterPlus_latent = CLinesDatasetIClusterPlus.load_reconstructions(clines_db)\n",
    "_, moCluster_latent = CLinesDatasetMoCluster.load_reconstructions(clines_db)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "common_cell_lines = list(\n",
    "    set(vae_latent.index)\n",
    "    & set(mosa_7omics_latent.index)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "mosa_7omics_latent = mosa_7omics_latent.loc[common_cell_lines]\n",
    "vae_latent = vae_latent.loc[common_cell_lines]\n",
    "mofa_latent[\"factors\"] = mofa_latent[\"factors\"].loc[common_cell_lines]\n",
    "move_diabetes_latent[\"factors\"] = move_diabetes_latent[\"factors\"].loc[common_cell_lines]\n",
    "jamie_latent[\"factors\"] = jamie_latent[\"factors\"].loc[common_cell_lines]\n",
    "mixOmics_latent[\"factors\"] = mixOmics_latent[\"factors\"].loc[common_cell_lines]\n",
    "iClusterPlus_latent[\"factors\"] = iClusterPlus_latent[\"factors\"].loc[common_cell_lines]\n",
    "moCluster_latent[\"factors\"] = moCluster_latent[\"factors\"].loc[common_cell_lines]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "samplesheet = clines_db.samplesheet[\"tissue\"].fillna(\"Other tissue\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "clustering_score_df = {\"model\": [], \"metric\": [], \"score\": []}\n",
    "for n, z_joint in [\n",
    "    (\"MOSA (7 omics)\", mosa_7omics_latent),\n",
    "    (\"MOSA (2 omics)\", vae_latent),\n",
    "    (\"MOFA\", mofa_latent[\"factors\"]),\n",
    "    (\"MOVE\", move_diabetes_latent[\"factors\"]),\n",
    "    (\"JAMIE\", jamie_latent[\"factors\"]),\n",
    "    (\"mixOmics\", mixOmics_latent[\"factors\"]),\n",
    "    (\"iClusterPlus\", iClusterPlus_latent[\"factors\"]),\n",
    "    (\"moCluster\", moCluster_latent[\"factors\"]),\n",
    "]:\n",
    "    cluster_labels = samplesheet[z_joint.index]\n",
    "    clustering_score_df[\"model\"].append(n)\n",
    "    clustering_score_df[\"metric\"].append(\"calinski_harabasz\")\n",
    "    clustering_score_df[\"score\"].append(\n",
    "        calinski_harabasz_score(\n",
    "            StandardScaler().fit_transform(z_joint), cluster_labels\n",
    "        )\n",
    "    )\n",
    "    clustering_score_df[\"model\"].append(n)\n",
    "    clustering_score_df[\"metric\"].append(\"davies_bouldin\")\n",
    "    clustering_score_df[\"score\"].append(\n",
    "        davies_bouldin_score(\n",
    "            StandardScaler().fit_transform(z_joint), cluster_labels\n",
    "        )\n",
    "    )\n",
    "clustering_score_df = pd.DataFrame(clustering_score_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, ax = plt.subplots(1, 1, figsize=(3, 3), dpi=600)\n",
    "sns.barplot(data=clustering_score_df, x=\"metric\", y=\"score\", ax=ax, hue=\"model\")\n",
    "PhenPred.save_figure(\n",
    "    f\"{plot_folder}/latent/{TIMESTAMP_2omics}_clustering_score_barplot_2vs7omics\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "TIMESTAMP_2omics = \"20240830_110319\"\n",
    "TIMESTAMP_7omics = \"20231023_092657\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"<not serializable>\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"copynumber\": \"data/clines//cnv_summary_20230303_matrix.csv\",\n",
      "        \"crisprcas9\": \"data/clines//depmap23Q2/CRISPRGeneEffect.csv\",\n",
      "        \"drugresponse\": \"data/clines//drugresponse.csv\",\n",
      "        \"metabolomics\": \"data/clines//metabolomics.csv\",\n",
      "        \"methylation\": \"data/clines//methylation.csv\",\n",
      "        \"proteomics\": \"data/clines//proteomics.csv\",\n",
      "        \"transcriptomics\": \"data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": \"20231023_092657\",\n",
      "    \"model\": \"MOVE\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"<not serializable>\",\n",
      "    \"save_model\": true,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_cv\": false,\n",
      "    \"standardize\": true,\n",
      "    \"use_conditionals\": true,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"mean\",\n",
      "        \"macro\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "DepMap23Q2 | Samples = 1,656 | Proteomics = 4,922 (0 masked) | Metabolomics = 225 (0 masked) | Drug response = 810 (0 masked) | CRISPR-Cas9 = 17,931 (12,714 masked) | Methylation = 14,608 (7,014 masked) | Transcriptomics = 15,278 (7,193 masked) | Copy number = 777 (0 masked) | Labels = 237\n"
     ]
    }
   ],
   "source": [
    "hyperparameters = Hypers.read_hyperparameters(timestamp=TIMESTAMP_7omics)\n",
    "clines_db = CLinesDatasetDepMap23Q2(\n",
    "    labels_names=hyperparameters[\"labels\"],\n",
    "    datasets=hyperparameters[\"datasets\"],\n",
    "    feature_miss_rate_thres=hyperparameters[\"feature_miss_rate_thres\"],\n",
    "    standardize=hyperparameters[\"standardize\"],\n",
    "    filter_features=hyperparameters[\"filter_features\"],\n",
    "    filtered_encoder_only=hyperparameters[\"filtered_encoder_only\"],\n",
    ")\n",
    "train = CLinesTrain(\n",
    "    clines_db,\n",
    "    hyperparameters,\n",
    "    verbose=hyperparameters[\"verbose\"],\n",
    "    stratify_cv_by=clines_db.samples_by_tissue(\"Haematopoietic and Lymphoid\"),\n",
    ")\n",
    "\n",
    "train.run(run_timestamp=hyperparameters[\"load_run\"])\n",
    "mosa_7omics_imputed, mosa_7omics_latent = train.load_vae_reconstructions()\n",
    "mosa_7omics_predicted, _ = train.load_vae_reconstructions(mode=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"<not serializable>\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": \"20240830_110319\",\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"<not serializable>\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": true,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DepMap23Q2 | Samples = 1,590 | Transcriptomics = 15,278 (7,193 masked) | Drug response = 810 (0 masked) | Labels = 237\n"
     ]
    }
   ],
   "source": [
    "hyperparameters = Hypers.read_hyperparameters(timestamp=TIMESTAMP_2omics)\n",
    "clines_db = CLinesDatasetDepMap23Q2(\n",
    "    labels_names=hyperparameters[\"labels\"],\n",
    "    datasets=hyperparameters[\"datasets\"],\n",
    "    feature_miss_rate_thres=hyperparameters[\"feature_miss_rate_thres\"],\n",
    "    standardize=hyperparameters[\"standardize\"],\n",
    "    filter_features=hyperparameters[\"filter_features\"],\n",
    "    filtered_encoder_only=hyperparameters[\"filtered_encoder_only\"],\n",
    ")\n",
    "train = CLinesTrain(\n",
    "    clines_db,\n",
    "    hyperparameters,\n",
    "    verbose=hyperparameters[\"verbose\"],\n",
    "    stratify_cv_by=clines_db.samples_by_tissue(\"Haematopoietic and Lymphoid\"),\n",
    ")\n",
    "\n",
    "train.run(run_timestamp=hyperparameters[\"load_run\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n",
      "# ---- Hyperparameters\n",
      "{\n",
      "    \"activation_function\": \"prelu\",\n",
      "    \"batch_norm\": false,\n",
      "    \"batch_size\": 256,\n",
      "    \"contrastive_neg_margin\": 0.15,\n",
      "    \"contrastive_pos_margin\": 0.85,\n",
      "    \"dataname\": \"depmap23Q2\",\n",
      "    \"datasets\": {\n",
      "        \"drugresponse\": \"/home/scai/PhenPred/data/clines//drugresponse.csv\",\n",
      "        \"transcriptomics\": \"/home/scai/PhenPred/data/clines//depmap23Q2/OmicsExpressionGenesExpectedCountProfileVoom.csv\"\n",
      "    },\n",
      "    \"feature_dropout\": 0,\n",
      "    \"feature_miss_rate_thres\": 0.85,\n",
      "    \"filter_features\": [\n",
      "        \"transcriptomics\",\n",
      "        \"crisprcas9\",\n",
      "        \"methylation\"\n",
      "    ],\n",
      "    \"filtered_encoder_only\": true,\n",
      "    \"gmvae_decay_temp\": true,\n",
      "    \"gmvae_decay_temp_rate\": 0.013862944,\n",
      "    \"gmvae_hard_gumbel\": 0.7936881144482251,\n",
      "    \"gmvae_hidden_size\": 935,\n",
      "    \"gmvae_init_temp\": 1.0,\n",
      "    \"gmvae_k\": 51,\n",
      "    \"gmvae_min_temp\": 0.5,\n",
      "    \"gmvae_views_logits\": 726,\n",
      "    \"hidden_dims\": [\n",
      "        0.7\n",
      "    ],\n",
      "    \"labels\": [\n",
      "        \"tissue\",\n",
      "        \"mutations\",\n",
      "        \"fussions\",\n",
      "        \"msi\",\n",
      "        \"growth\"\n",
      "    ],\n",
      "    \"latent_dim\": 200,\n",
      "    \"learning_rate\": 0.0003,\n",
      "    \"load_run\": null,\n",
      "    \"model\": \"MOSA\",\n",
      "    \"n_folds\": 3,\n",
      "    \"num_epochs\": 500,\n",
      "    \"optimizer_type\": \"adam\",\n",
      "    \"probability\": 0.4,\n",
      "    \"reconstruction_loss\": \"mse\",\n",
      "    \"save_model\": false,\n",
      "    \"scheduler\": \"plateau\",\n",
      "    \"scheduler_factor\": 0.6,\n",
      "    \"scheduler_min_lr\": 1e-07,\n",
      "    \"scheduler_patience\": 7,\n",
      "    \"scheduler_threshold\": 0.0001,\n",
      "    \"skip_benchmarks\": false,\n",
      "    \"skip_cv\": true,\n",
      "    \"standardize\": true,\n",
      "    \"two_omics_benchmark\": true,\n",
      "    \"use_conditionals\": false,\n",
      "    \"verbose\": 0,\n",
      "    \"view_dropout\": 0.3,\n",
      "    \"view_latent_dim\": 0.25,\n",
      "    \"view_loss_recon_type\": [\n",
      "        \"mean\",\n",
      "        \"mean\"\n",
      "    ],\n",
      "    \"view_loss_weights\": [\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0,\n",
      "        1.0\n",
      "    ],\n",
      "    \"w_cat\": 0.01,\n",
      "    \"w_contrastive\": 0.005,\n",
      "    \"w_decay\": 0.0005,\n",
      "    \"w_gauss\": 0.0001,\n",
      "    \"w_kl\": 0.0001,\n",
      "    \"w_rec\": 1\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "vae_imputed, vae_latent = train.load_vae_reconstructions()\n",
    "vae_predicted, _ = train.load_vae_reconstructions(mode=\"all\")\n",
    "\n",
    "mofa_imputed, mofa_latent = CLinesDatasetMOFA.load_reconstructions(clines_db)\n",
    "move_diabetes_imputed, move_diabetes_latent = (\n",
    "    CLinesDatasetMOVE.load_reconstructions(clines_db)\n",
    ")\n",
    "jamie_imputed, jamie_latent = CLinesDatasetJAMIE.load_reconstructions(clines_db)\n",
    "scvaeit_imputed, scvaeit_latent = CLinesDatasetSCVAEIT.load_reconstructions(clines_db)\n",
    "\n",
    "_, mixOmics_latent = CLinesDatasetMixOmics.load_reconstructions(clines_db)\n",
    "_, iClusterPlus_latent = CLinesDatasetIClusterPlus.load_reconstructions(clines_db)\n",
    "_, moCluster_latent = CLinesDatasetMoCluster.load_reconstructions(clines_db)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples_mgexp = ~clines_db.dfs[\"transcriptomics\"].isnull().all(axis=1)\n",
    "\n",
    "gexp_gdsc = pd.read_csv(f\"./data/clines/transcriptomics.csv\", index_col=0).T\n",
    "gexp_mosa = vae_imputed[\"transcriptomics\"]\n",
    "gexp_mosa_7omics = mosa_7omics_imputed[\"transcriptomics\"]\n",
    "gexp_move = move_diabetes_imputed[\"transcriptomics\"]\n",
    "gexp_jamie = jamie_imputed[\"transcriptomics\"]\n",
    "gexp_mofa = mofa_imputed[\"transcriptomics\"]\n",
    "gexp_scvaeit = scvaeit_imputed[\"transcriptomics\"]\n",
    "\n",
    "gexp_dfs = dict(\n",
    "    [\n",
    "        (\"MOSA_7omics\", gexp_mosa_7omics),\n",
    "        (\"MOSA_2omics\", gexp_mosa),\n",
    "        (\"MOFA\", gexp_mofa),\n",
    "        (\"MOVE\", gexp_move),\n",
    "        (\"JAMIE\", gexp_jamie),\n",
    "        (\"scVAEIT\", gexp_scvaeit),\n",
    "    ]\n",
    ")\n",
    "samples = set(gexp_gdsc.index).intersection(gexp_mosa.index)\n",
    "genes = list(set(gexp_gdsc.columns).intersection(gexp_mosa.columns))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "gexp_corr_dfs = []\n",
    "for name in gexp_dfs:\n",
    "    gexp_corr = pd.DataFrame(\n",
    "        [\n",
    "            two_vars_correlation(\n",
    "                gexp_gdsc.loc[s, genes],\n",
    "                gexp_dfs[name].loc[s, genes],\n",
    "                method=\"pearson\",\n",
    "                extra_fields=dict(sample=s, with_gexp=samples_mgexp.loc[s]),\n",
    "            )\n",
    "            for s in samples\n",
    "        ]\n",
    "    )\n",
    "    gexp_corr['model'] = name\n",
    "    gexp_corr_dfs.append(gexp_corr)\n",
    "gexp_corr_dfs = pd.concat(gexp_corr_dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, ax = plt.subplots(1, 1, figsize=(1.5, 2), dpi=600)\n",
    "\n",
    "sns.boxplot(\n",
    "    data=gexp_corr_dfs[~gexp_corr_dfs[\"with_gexp\"]],\n",
    "    x=\"model\",\n",
    "    y=\"corr\",\n",
    "    hue=\"model\",\n",
    "    palette=\"tab10\",\n",
    "    linewidth=0.3,\n",
    "    fliersize=1,\n",
    "    notch=True,\n",
    "    saturation=1.0,\n",
    "    showcaps=False,\n",
    "    boxprops=dict(linewidth=0.5, edgecolor=\"black\"),\n",
    "    whiskerprops=dict(linewidth=0.5, color=\"black\"),\n",
    "    flierprops=dict(\n",
    "        marker=\"o\",\n",
    "        markerfacecolor=\"black\",\n",
    "        markersize=1.0,\n",
    "        linestyle=\"none\",\n",
    "        markeredgecolor=\"none\",\n",
    "        alpha=0.6,\n",
    "    ),\n",
    "    medianprops=dict(linestyle=\"-\", linewidth=0.5),\n",
    "    ax=ax,\n",
    ")\n",
    "ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha=\"right\")\n",
    "ax.set(\n",
    "    title=f\"\",\n",
    "    ylabel=\"Correlation between reconstructed\\nand GDSC transcriptomics (Pearson's r)\",\n",
    "    xlabel=f\"Sample without transcriptomics\\nduring training\",\n",
    ")\n",
    "\n",
    "PhenPred.save_figure(\n",
    "    f\"{plot_folder}/{hyperparameters['load_run']}_reconstructed_gexp_correlation_boxplot_2vs7omics\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "gexp_corr_out_sample_dfs = gexp_corr_dfs[~gexp_corr_dfs[\"with_gexp\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ShapiroResult(statistic=0.9753528740783312, pvalue=0.00012429713994164323)"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shapiro(\n",
    "    gexp_corr_out_sample_dfs.query(\"model == 'MOSA_7omics'\")[\"corr\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TtestResult(statistic=21.43230547868578, pvalue=3.3627101251530766e-74, df=540.0)\n",
      "WilcoxonResult(statistic=65.0, pvalue=6.947921853872344e-46)\n"
     ]
    }
   ],
   "source": [
    "print(ttest_ind(\n",
    "    gexp_corr_out_sample_dfs.query(\"model == 'MOSA_7omics'\")[\"corr\"],\n",
    "    gexp_corr_out_sample_dfs.query(\"model == 'MOSA_2omics'\")[\"corr\"],\n",
    "))\n",
    "print(wilcoxon(\n",
    "    gexp_corr_out_sample_dfs.query(\"model == 'MOSA_7omics'\")[\"corr\"],\n",
    "    gexp_corr_out_sample_dfs.query(\"model == 'MOSA_2omics'\")[\"corr\"],\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TtestResult(statistic=18.24570154229317, pvalue=2.6836218299727367e-58, df=540.0)\n",
      "WilcoxonResult(statistic=74.0, pvalue=7.675247614114657e-46)\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    ttest_ind(\n",
    "        gexp_corr_out_sample_dfs.query(\"model == 'MOSA_7omics'\")[\"corr\"],\n",
    "        gexp_corr_out_sample_dfs.query(\"model == 'MOFA'\")[\"corr\"],\n",
    "    )\n",
    ")\n",
    "print(\n",
    "    wilcoxon(\n",
    "        gexp_corr_out_sample_dfs.query(\"model == 'MOSA_7omics'\")[\"corr\"],\n",
    "        gexp_corr_out_sample_dfs.query(\"model == 'MOFA'\")[\"corr\"],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TtestResult(statistic=-2.282000726313907, pvalue=0.022877465018056837, df=540.0)\n",
      "WilcoxonResult(statistic=7739.0, pvalue=1.2630246307937823e-16)\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    ttest_ind(\n",
    "        gexp_corr_out_sample_dfs.query(\"model == 'MOSA_2omics'\")[\"corr\"],\n",
    "        gexp_corr_out_sample_dfs.query(\"model == 'MOFA'\")[\"corr\"],\n",
    "    )\n",
    ")\n",
    "print(\n",
    "    wilcoxon(\n",
    "        gexp_corr_out_sample_dfs.query(\"model == 'MOSA_2omics'\")[\"corr\"],\n",
    "        gexp_corr_out_sample_dfs.query(\"model == 'MOFA'\")[\"corr\"],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TtestResult(statistic=29.191437824112302, pvalue=1.6706086209243048e-85, df=270)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ttest_rel(\n",
    "    gexp_corr_out_sample_dfs.query(\"model == 'MOSA_7omics'\")[\"corr\"],\n",
    "    gexp_corr_out_sample_dfs.query(\"model == 'JAMIE'\")[\"corr\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Drug"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df = pd.read_csv(\n",
    "    \"./reports/vae/drugresponse/20240830_110319_predicted_ctd2_corr.csv\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_out_df = plot_df[plot_df[\"MOSA_outofsample\"] == \"Out-of-sample\"]\n",
    "plot_in_df = plot_df[plot_df[\"MOSA_outofsample\"] == \"In-sample\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TtestResult(statistic=3.44106965740751, pvalue=0.0006309014889326985, df=472.0)\n",
      "WilcoxonResult(statistic=3480.0, pvalue=8.939902145451377e-24)\n"
     ]
    }
   ],
   "source": [
    "print(ttest_ind(\n",
    "    plot_out_df[plot_out_df[\"method\"] == \"MOSA_corr\"][\"corr\"],\n",
    "    plot_out_df[plot_out_df[\"method\"] == \"MOFA_corr\"][\"corr\"],\n",
    "))\n",
    "print(wilcoxon(\n",
    "    plot_out_df[plot_out_df[\"method\"] == \"MOSA_corr\"][\"corr\"],\n",
    "    plot_out_df[plot_out_df[\"method\"] == \"MOFA_corr\"][\"corr\"],\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TtestResult(statistic=-0.10472369634495242, pvalue=0.916639544284535, df=472.0)\n",
      "WilcoxonResult(statistic=13553.0, pvalue=0.6036728652783327)\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    ttest_ind(\n",
    "        plot_out_df[plot_out_df[\"method\"] == \"MOSA_corr\"][\"corr\"],\n",
    "        plot_out_df[plot_out_df[\"method\"] == \"MOSA_7omics_corr\"][\"corr\"],\n",
    "    )\n",
    ")\n",
    "print(\n",
    "    wilcoxon(\n",
    "        plot_out_df[plot_out_df[\"method\"] == \"MOSA_corr\"][\"corr\"],\n",
    "        plot_out_df[plot_out_df[\"method\"] == \"MOSA_7omics_corr\"][\"corr\"],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mosa",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}