[aca2dc]: / neuralcvd / postprocessing / 1_collect_results_models.ipynb

Download this file

342 lines (341 with data), 11.7 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benchmarks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "from tqdm.auto import tqdm\n",
    "import pathlib\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import lifelines\n",
    "import pandas as pd\n",
    "\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm.notebook import tqdm\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import shutil\n",
    "\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "from plotly.graph_objects import Box\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from lifelines import CRCSplineFitter\n",
    "import warnings\n",
    "from lifelines.utils import CensoringType\n",
    "\n",
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "import math\n",
    "\n",
    "from IPython.display import clear_output\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "from lifelines.utils import concordance_index\n",
    "\n",
    "from pycox.evaluation import EvalSurv\n",
    "api_token=\"eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiZTI4M2QxNjYtNWZkNS00ZDQwLWFkYWUtMGVmOTM4ZGZlOWJjIn0=\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask.distributed import Client, LocalCluster\n",
    "cluster = LocalCluster(n_workers=1, threads_per_worker=10)\n",
    "client = Client(cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#from utils import get_best_experiment_predictions\n",
    "import time\n",
    "from IPython.display import display, HTML\n",
    "def get_best_experiment(df, module=\"CoxPH\", datamodule=\"UKBBDataModule\", endpoint=\"M_MACE\", feature_set=\"AgeSex\", prediction_available=None, partition=\"0\"):\n",
    "    if df.empty:\n",
    "        print(\"No experiments with given tag!\") \n",
    "        return None, None       \n",
    "    \n",
    "    df = df[df['parameters/module'].str.contains(module)]\n",
    "    df = df[df['parameters/datamodule'].str.contains(datamodule)]\n",
    "    df = df[df['parameters/train_targets'].str.contains(endpoint, na=False)]\n",
    "    df = df[df['parameters/feature_set']==feature_set]\n",
    "    #print(partition)\n",
    "    #if partition is not None:\n",
    "    #print(df.parameter_cv_partition\n",
    "    df = df[df['parameters/cv_partition'].astype(float).astype(int) == int(partition)]\n",
    "    df = df[df['logs/checkpoint_value'] == df['logs/checkpoint_value']]\n",
    "    if df.empty: print(\"No relevant experiment completed!\")\n",
    "    #display(df)\n",
    "    if prediction_available==True: \n",
    "        try:\n",
    "            df = df[df['logs/prediction_available'] == \"TRUE\"]\n",
    "        except:\n",
    "            print(\"No predictions!\")\n",
    "            display(df)\n",
    "    if df.empty: \n",
    "        print(\"No predictions!\")   \n",
    "        return None, None\n",
    "    else:\n",
    "        for f in ['logs/checkpoint_value']: df.loc[:, f] = df.loc[:,f].astype(float)\n",
    "        experiment_name = df.iloc[[df['logs/checkpoint_value'].argmax()]][\"sys/id\"].values[0]\n",
    "        return experiment_name, df\n",
    "\n",
    "def get_best_experiment_paths(df, module=\"CoxPH\", datamodule=\"UKBBDataModule\", endpoint=\"M_MACE\", feature_set=\"AgeSex\", partition=\"0\", api_token=None):\n",
    "    prediction_available=True\n",
    "    experiment_name, df_out = get_best_experiment(df, module, datamodule, endpoint, feature_set, prediction_available, partition)\n",
    "    if experiment_name is not None:\n",
    "        cpt_path = df_out[df_out[\"sys/id\"]==experiment_name]['logs/checkpoint_path'].values[0]\n",
    "        #print(df[df.id==experiment_name].channel_prediction_path)\n",
    "        pred_path = df_out[df_out[\"sys/id\"]==experiment_name]['logs/prediction_path'].values[0]\n",
    "        if pred_path is not None: return cpt_path, pred_path\n",
    "        else: return cpt_path, None\n",
    "    else: return None, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project_name = \"210616_centres_dask\"\n",
    "data_path = \"/data/analysis/ag-reils/steinfej\"\n",
    "data_pre = f\"{data_path}/data/2_datasets_pre/{project_name}\"\n",
    "data_post = f\"{data_path}/data/3_datasets_post/{project_name}\"\n",
    "\n",
    "project_label = \"21_PGS_Revision\"\n",
    "project_path = f\"/data/analysis/ag-reils/ag-reils-shared/cardioRS/results/projects/{project_label}\"\n",
    "figures_path = f\"{project_path}/figures\"\n",
    "data_results_path = f\"{project_path}/data\"\n",
    "pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)\n",
    "pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoints = ['MACE']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "partitions = [str(p) for p in range(0, 22)]\n",
    "splits = [\"train\", \"valid\", \"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project=\"CardioRS/benchmarks\"\n",
    "#modules = [\"CoxPH\", \"DeepHit\", \"DeepSurvivalMachine\"]\n",
    "modules = [\"OLDDeepSurvivalMachine\", \"DeepSurvivalMachine\"]\n",
    "datamodules = [\"UKBBSurvivalDatamodule\"]\n",
    "feature_sets = [\"CVDCoreVariablesREORDERED\", \"CVDCoreVariablesWithPGS\"]\n",
    "tag=\"210701_REVISION_NEWDSM_CENTERS_fr_6\"\n",
    "import neptune.new as neptune\n",
    "project = neptune.get_project(project, api_token)\n",
    "df_new = project.fetch_runs_table(tag=tag).to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([df_new, df_old], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_predictions(df, module, datamodule, endpoint, feature_set, partition): # maybe perspectively we could hand over a dict for filtering?\n",
    "    #tqdm.write(f\"{module} - {datamodule} - {endpoint} - {feature_set} - {partition}\")\n",
    "    _, preds_path = get_best_experiment_paths(df=df, module=module, datamodule=datamodule, endpoint=endpoint, feature_set=feature_set, partition=partition)\n",
    "    #print(preds_path)\n",
    "    return preds_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_paths = Parallel(n_jobs=1)(delayed(get_model_predictions)(df, module, datamodule, endpoint, feature_set, partition) \n",
    "                                for endpoint in tqdm(endpoints) \n",
    "                                for feature_set in feature_sets\n",
    "                                for partition in partitions\n",
    "                                for module in modules \n",
    "                                for datamodule in datamodules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "def get_df(path): return pd.read_feather(path)#return pd.read_csv(f\"{path[:-8]}.csv\", index_col=0)\n",
    "print(\"Status: \", sum(x is not None for x in pred_paths)/len(pred_paths))\n",
    "with joblib.parallel_backend('dask'):\n",
    "    dfs = Parallel(n_jobs=80)(delayed(get_df)(path) for path in tqdm(pred_paths) if path is not None if not pd.isna(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_all = pd.concat(dfs, axis=0).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clean_df(df, predictions):\n",
    "    feature_df = df[[\"parameters/feature_set\", \"parameters/train_features\"]].drop_duplicates().dropna()\n",
    "    feature_map ={row[\"parameters/train_features\"]: row[\"parameters/feature_set\"] for i, row in feature_df.iterrows()}\n",
    "    endpoint_map = {e: e[2:-8] for e in predictions.event_names.unique().tolist()}\n",
    "    predictions[\"features\"] = predictions[\"feature_names\"].map(feature_map)\n",
    "    predictions[\"endpoint\"] = predictions[\"event_names\"].map(endpoint_map)\n",
    "    feature_read_map ={\"CVDCoreVariables\": \"clinical\", \n",
    "                       #\"CVDCoreVariablesWithPGSno18REORDERED\": \"clinical_pgs_paper\",\n",
    "                      \"CVDCoreVariablesWithPGS\": \"clinical_pgs_all\" \n",
    "                       #\"CVDCoreVariablesWithPGS000018\":\"clinical_pgs_18\"\n",
    "                      }\n",
    "    predictions[\"features\"] = predictions[\"features\"].map(feature_read_map)\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_all_cleaned = clean_df(df, preds_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_cols = [f\"0_{t}_Ft_native\" for t in range(1, 27)]\n",
    "preds = preds_all_cleaned[[\"eid\", 'endpoint', 'features', 'split', 'partition', 'module', 'datamodule', 'net'] + time_cols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for col in tqdm(preds.columns.to_list()):\n",
    "    if preds[col].dtype == \"object\": preds[col]= preds[col].astype(\"category\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fix_column_names(df):\n",
    "    # rename and fix time bugs!!! -> 0_11_Ft -> Ft at t=10 -> fix earlier\n",
    "    time_fix_map = dict(zip([col for col in df.columns if \"Ft\" in col], [f\"Ft_{col}\" for col in range(len([col for col in df.columns if \"Ft\" in col]))]))\n",
    "    df = df.rename(time_fix_map, axis=\"columns\")\n",
    "    return df\n",
    "preds = fix_column_names(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds.to_feather(f\"{data_results_path}/predictions_model_210703_FINAL.feather\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:miniconda3-pl1.x]",
   "language": "python",
   "name": "conda-env-miniconda3-pl1.x-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "toc-autonumbering": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}