[aca2dc]: / neuralcvd / postprocessing / 2_calibration.ipynb

Download this file

477 lines (476 with data), 15.7 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benchmarks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "from tqdm.auto import tqdm\n",
    "import pathlib\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import lifelines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env MKL_NUM_THREADS=1\n",
    "%env NUMEXPR_NUM_THREADS=1\n",
    "%env OMP_NUM_THREADS=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask.distributed import Client, LocalCluster\n",
    "cluster = LocalCluster(n_workers=50, threads_per_worker=1)\n",
    "client = Client(cluster)\n",
    "cluster.scheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project_name = \"210616_centres_dask\"\n",
    "data_path = \"/data/analysis/ag-reils/steinfej\"\n",
    "data_pre = f\"{data_path}/data/2_datasets_pre/{project_name}\"\n",
    "data_post = f\"{data_path}/data/3_datasets_post/{project_name}\"\n",
    "\n",
    "project_label = \"21_PGS_Revision\"\n",
    "project_path = f\"/data/analysis/ag-reils/ag-reils-shared/cardioRS/results/projects/{project_label}\"\n",
    "figures_path = f\"{project_path}/figures\"\n",
    "data_results_path = f\"{project_path}/data\"\n",
    "pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)\n",
    "pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoints = ['MACE']\n",
    "endpoint_labels = sorted([f\"{e}_event\" for e in endpoints]+[f\"{e}_event_time\" for e in endpoints])\n",
    "endpoint_data =  pd.read_feather(f\"{data_post}/data_merged.feather\", columns=[\"eid\"] + endpoint_labels)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "data_temp = pd.read_feather(f\"{data_post}/data_merged.feather\")\n",
    "eids_dict = {}\n",
    "for endpoint in tqdm(endpoints):\n",
    "    if endpoint == \"MACE\": eids_incl = data_temp.copy().query(f\"myocardial_infarction==False&stroke==False&statins==False\").eid.to_list()\n",
    "    print(endpoint, len(eids_incl))\n",
    "    eids_dict[endpoint] = eids_incl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_temp = pd.read_feather(f\"{data_post}/data_merged.feather\")\n",
    "eids_endpoint = {}\n",
    "for endpoint in tqdm(endpoints):\n",
    "    if endpoint == \"MACE\": eids_incl = data_temp.copy().query(f\"myocardial_infarction==False&stroke==False&statins==False\").eid.to_list()\n",
    "    else: eids_incl=data_temp.copy().eid.to_list()\n",
    "    print(endpoint, len(eids_incl))\n",
    "    eids_endpoint[endpoint] = eids_incl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "partitions = [i for i in range(22)]\n",
    "splits = [\"train\", \"valid\", \"test\"]\n",
    "eids_partition_split = {partition: {split: pd.read_feather(f\"{data_post}/partition_{partition}/{split}/data.feather\").eid.to_list() for split in splits} for partition in tqdm(partitions)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_float32(df):\n",
    "    for col in tqdm(df.columns.to_list()):\n",
    "        if df[col].dtype == \"float64\": \n",
    "            print(col, \"convert\")\n",
    "            df[col]= df[col].astype(\"float32\")\n",
    "    return df\n",
    "\n",
    "def convert_to_category(df):\n",
    "    for col in tqdm(df.columns.to_list()):\n",
    "        if df[col].dtype == \"object\": df[col]= df[col].astype(\"category\")\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_models = pd.read_feather(f\"{data_results_path}/predictions_model_210703_FINAL.feather\").query(\"endpoint=='MACE'\")#\")#&module=='DSM'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_cox = pd.read_feather(f\"{data_results_path}/predictions_cox_210631_REVISION.feather\").query(\"endpoint=='MACE'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_scores = pd.read_csv(f\"{data_results_path}/predictions_scores_210616.csv\")#.assign(endpoint=\"MACE\")#.query(\"endpoint=='MACE'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_scores_list = []\n",
    "for endpoint in endpoints:\n",
    "    for partition in partitions:\n",
    "        for split in splits:\n",
    "            eids_incl = eids_endpoint[endpoint]\n",
    "            eids_split = eids_partition_split[partition][split]\n",
    "            data_temp = preds_scores.query(\"eid==@eids_incl\").query(\"eid==@eids_split\").assign(endpoint=endpoint, partition=partition, split=split)\n",
    "            print(endpoint, partition, split, len(data_temp))\n",
    "            preds_scores_list.append(data_temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_scores = pd.concat(preds_scores_list, axis=0).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eids_incl = eids_endpoint[\"MACE\"]\n",
    "preds = pd.concat([preds_models, preds_cox, preds_scores], axis=0)#.query(\"eid==@eids_incl\")"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "preds = preds_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del preds[\"Ft_0\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del preds_models\n",
    "del preds_cox\n",
    "del preds_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calibration DF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_cols = {t: f\"Ft_{t}\" for t in range(1,26)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_train = preds.query(\"split=='train'\")\n",
    "data_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_fit = data_train[[\"eid\"]+[\"endpoint\", \"module\", \"features\", \"partition\"] + list(time_cols.values())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modules = data_train.module.unique().tolist()\n",
    "features = data_train.features.unique().tolist()\n",
    "partitions = data_train.partition.unique().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lifelines import CRCSplineFitter\n",
    "def get_observed_probability(F_t, events, durations, t):\n",
    "    def ccl(p): return np.log(-np.log(1 - p))\n",
    "    T = \"time\"\n",
    "    E = \"event\"\n",
    "    predictions_at_t = np.clip(F_t, 1e-5, 1 - 1e-5)\n",
    "    prediction_df = pd.DataFrame({f\"ccl_at_{t}\": ccl(predictions_at_t), T: durations, E: events})\n",
    "\n",
    "    index_old = prediction_df.index\n",
    "    prediction_df = prediction_df.dropna()\n",
    "    index_new = prediction_df.index\n",
    "    diff = index_old.difference(index_new)\n",
    "\n",
    "    knots=3\n",
    "    crc = CRCSplineFitter(knots, penalizer=0.00001)\n",
    "    \n",
    "    regressors = {\"beta_\": [f\"ccl_at_{t}\"], **{f\"gamma{i}_\":\"1\" for i in range(knots)}}\n",
    "    crc.fit_right_censoring(prediction_df, T, E, regressors=regressors)\n",
    "    \n",
    "    risk_obs = (1 - crc.predict_survival_function(prediction_df, times=[t])).T.squeeze()\n",
    "    return risk_obs, diff.to_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.isotonic import IsotonicRegression\n",
    "\n",
    "def fit_isotonic_regression(df, endpoint, time_cols):\n",
    "    isoreg_dict = {}\n",
    "    for t, col in tqdm(time_cols.items()):\n",
    "        if df[f\"Ft_{t}\"].nunique()>1:\n",
    "            risk_obs = pd.Series(np.nan)\n",
    "            i=1\n",
    "            while risk_obs.nunique()<=1: \n",
    "                df_temp = df.sample(frac=0.5, replace=True).dropna(subset=[col])\n",
    "                try:\n",
    "                    risk_obs, nan = get_observed_probability(df_temp[col].values, \n",
    "                                                                df_temp[f\"{endpoint}_event\"].values, \n",
    "                                                                df_temp[f\"{endpoint}_event_time\"].values,\n",
    "                                                                t)\n",
    "                except:\n",
    "                    if i==20: break\n",
    "                    i+=1\n",
    "            risk_pred = df_temp.drop(df_temp.index[nan])[col].reset_index(drop=True)\n",
    "            \n",
    "            isoreg_dict[t] = IsotonicRegression().fit(risk_pred.values,risk_obs.values)\n",
    "        else: isoreg_dict[t] = None\n",
    "    return isoreg_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask.diagnostics import ProgressBar\n",
    "\n",
    "with ProgressBar():\n",
    "    ir_dict = {}\n",
    "    for endpoint in endpoints:\n",
    "        temp = data_fit.query(\"endpoint==@endpoint\").merge(endpoint_data[[\"eid\", f\"{endpoint}_event\", f\"{endpoint}_event_time\"]], on=\"eid\", how=\"left\")\n",
    "        for module in tqdm(modules):\n",
    "            temp_module = temp.query(\"module==@module\")\n",
    "            for feature in tqdm(features, leave=False):\n",
    "                temp_feature = temp_module.query(\"features==@feature\").drop([\"endpoint\", \"module\", \"features\"], axis=1)\n",
    "                if len(temp_feature)>0:\n",
    "                    for partition in tqdm(partitions, leave=False):                  \n",
    "                        temp_partition = temp_feature.query(\"partition==@partition\").drop([\"partition\"], axis=1)\n",
    "                        if len(temp_partition)>0:\n",
    "                            data_future = client.scatter(temp_partition)\n",
    "                            ir_dict[f\"{endpoint}_{module}_{feature}_{partition}\"] = client.submit(fit_isotonic_regression, data_future, endpoint, time_cols=time_cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask.distributed import progress\n",
    "progress(ir_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ir_dict = client.gather(ir_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_calib_list = []\n",
    "\n",
    "for endpoint in endpoints:\n",
    "    temp = preds.query(\"endpoint==@endpoint\")\n",
    "    for module in tqdm(modules):\n",
    "        temp_module = temp.query(\"module==@module\")\n",
    "        for feature in tqdm(features):\n",
    "            temp_feature = temp_module.query(\"features==@feature\")\n",
    "            if len(temp_feature)>0:\n",
    "                for partition in tqdm(partitions): \n",
    "                    if module!=\"SCORE\":\n",
    "                        temp_partition = temp_feature.query(\"partition==@partition\").dropna(subset=list(time_cols.values())) \n",
    "                    else: \n",
    "                        temp_partition = temp_feature.query(\"partition==@partition\")\n",
    "                    isoreg_str = f\"{endpoint}_{module}_{feature}_{partition}\"\n",
    "                    if len(temp_partition)>0:\n",
    "                        if isoreg_str in ir_dict:\n",
    "                            isoreg_dict = ir_dict[isoreg_str]\n",
    "                            for t, col in time_cols.items(): \n",
    "                                if isinstance(isoreg_dict, dict):\n",
    "                                    isoreg = isoreg_dict[t]\n",
    "                                    if isinstance(isoreg, IsotonicRegression):\n",
    "                                        calibrated = isoreg_dict[t].predict(temp_partition[col])\n",
    "                                        #if np.isnan(np.sum(calibrated)): print(np.sum(np.isnan(calibrated)), \"NaNs in calibrated probabilites\")\n",
    "                                        temp_partition[col] = calibrated\n",
    "                                    else:\n",
    "                                        print(\"Not working t\", endpoint, module, feature, isoreg_str, t)\n",
    "                                        temp_partition[col] = np.nan\n",
    "                                else:\n",
    "                                    print(\"Not working\", endpoint, module, feature, isoreg_str, t)\n",
    "                                    temp_partition[col] = np.nan\n",
    "                            preds_calib_list.append(temp_partition)\n",
    "                        else:\n",
    "                            print(isoreg_str, \"not available\")\n",
    "                            pass\n",
    "                    else:\n",
    "                        print(\"No data available for\", isoreg_str)\n",
    "                        pass\n",
    "            else:\n",
    "                print(f\"No data available for {endpoint}_{module}_{feature}\")\n",
    "                pass\n",
    "from IPython.display import clear_output\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_calib = pd.concat(preds_calib_list, axis=0).assign(calibrated=True).query(\"eid==@eids_incl\").reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_calib.to_feather(f\"{data_results_path}/predictions_210703_centres_FINAL.feather\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pl1.x_revision]",
   "language": "python",
   "name": "conda-env-pl1.x_revision-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "toc-autonumbering": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}