[aca2dc]: / neuralcvd / postprocessing / 3_benchmarks_centres.ipynb

Download this file

284 lines (283 with data), 10.6 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "from tqdm.auto import tqdm\n",
    "import pathlib\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import lifelines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask.distributed import Client, LocalCluster\n",
    "cluster = LocalCluster(n_workers=10, threads_per_worker=5)\n",
    "client = Client(cluster)\n",
    "cluster.scheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project_name = \"210616_centres_dask\"\n",
    "data_path = \"/data/analysis/ag-reils/steinfej\"\n",
    "data_pre = f\"{data_path}/data/2_datasets_pre/{project_name}\"\n",
    "data_post = f\"{data_path}/data/3_datasets_post/{project_name}\"\n",
    "\n",
    "project_label = \"21_PGS_Revision\"\n",
    "project_path = f\"/data/analysis/ag-reils/ag-reils-shared/cardioRS/results/projects/{project_label}\"\n",
    "figures_path = f\"{project_path}/figures\"\n",
    "data_results_path = f\"{project_path}/data\"\n",
    "pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)\n",
    "pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data =  pd.read_feather(f\"{data_post}/data_merged.feather\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoints = ['MACE']\n",
    "endpoint_labels = sorted([f\"{e}_event\" for e in endpoints]+[f\"{e}_event_time\" for e in endpoints])\n",
    "endpoint_data =  pd.read_feather(f\"{data_post}/data_merged.feather\", columns=[\"eid\"] + endpoint_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = pd.read_feather(f\"{data_results_path}/predictions_210703_centres_FINAL.feather\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bootstrapping or even recruitment centers?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_test = preds[['eid','endpoint', 'module','features','split','partition','Ft_10']].query(\"split=='test'\")\n",
    "data_test\n",
    "\n",
    "modules = data_test.module.unique().tolist()\n",
    "features = data_test.features.unique().tolist()\n",
    "partitions = data_test.partition.unique().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations=[i for i in range(1000)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoint=\"MACE\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc, integrated_brier_score\n",
    "from lifelines.utils import concordance_index\n",
    "from dask.diagnostics import ProgressBar\n",
    "\n",
    "def calculate_per_endpoint(df, partition, iteration, endpoint, module, feature, time):  \n",
    "    event = [0 if (endpoint_event == 0) | (endpoint_event_time > time) else 1 for endpoint_event, endpoint_event_time in zip(df[endpoint+\"_event\"], df[endpoint+\"_event_time\"])]\n",
    "    event_time = [time if (endpoint_event == 0) | (endpoint_event_time > time) else endpoint_event_time for endpoint_event, endpoint_event_time in zip(df[endpoint+\"_event\"], df[endpoint+\"_event_time\"])]\n",
    "    df = df.assign(event = event, event_time = event_time)\n",
    "    df = df.dropna(subset=[\"event_time\", f\"Ft_{time}\", \"event\"], axis=0)\n",
    "    \n",
    "    cindex = 1-concordance_index(df[\"event_time\"], df[f\"Ft_{time}\"], df[\"event\"])\n",
    "    return {\"endpoint\":endpoint, \"module\": module, \"features\": feature, \"partition\":partition, \"iteration\":iteration, \"n\": len(df), \"time\":time, \"cindex\":cindex}\n",
    "\n",
    "def calculate_per_iteration(data_bm, endpoint, iteration, eids_bs, time):  \n",
    "    results = []\n",
    "    for module in modules: \n",
    "        temp_module = data_bm.query(\"module==@module\").set_index(\"eid\").loc[eids_bs].reset_index()\n",
    "        for feature in features:\n",
    "            temp_features = temp_module.query(\"features==@feature\")\n",
    "            if len(temp_features)>0:\n",
    "                for partition in partitions:\n",
    "                    temp_partition = temp_features.query(\"partition==@partition\")[[\"eid\", \"Ft_10\", \"MACE_event\", \"MACE_event_time\"]]\n",
    "                    results.append(calculate_per_endpoint(temp_partition, partition, iteration, endpoint, module, feature, time=10))\n",
    "    return results\n",
    "\n",
    "data_bm = data_test.query(\"endpoint==@endpoint\").merge(endpoint_data[[\"eid\", f\"{endpoint}_event\", f\"{endpoint}_event_time\"]], on=\"eid\", how=\"left\")\n",
    "eids = data_bm.eid.unique()\n",
    "with ProgressBar():\n",
    "    rows = []\n",
    "    for iteration in tqdm(iterations):\n",
    "        eids_bs = np.random.choice(eids, size=len(eids))\n",
    "        data_future = client.scatter(data_bm)\n",
    "        rows.append(client.submit(calculate_per_iteration, data_future, endpoint, iteration, eids_bs, time=10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask.distributed import progress\n",
    "progress(rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = client.gather(rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = [item for sublist in rows for item in sublist]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_endpoints_pp = pd.DataFrame({\"endpoint\":[], \"module\": [], \"features\": [], \"partition\": [], \"iteration\": [], \"n\":[],\"time\": [], \"cindex\": []}).append(rows, ignore_index=True)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc, integrated_brier_score\n",
    "\n",
    "def to_struc_array(X): return np.core.records.fromarrays(X.values.transpose(), names='event, event_time', formats = \"?, f8\")\n",
    "def to_struc_array_estimate(X): return np.core.records.fromarrays(X.values.transpose(), names='estimate', formats = \"f8\")\n",
    "\n",
    "data_censoring = data_train.loc[:,[\"eid\", \"MACE_event\", \"MACE_event_time\"]].drop_duplicates().reset_index(drop=True)\n",
    "survival_train = to_struc_array(data_censoring.loc[:,[\"MACE_event\", \"MACE_event_time\"]].dropna())\n",
    "survival_test = data_test.loc[:,[\"MACE_event\", \"MACE_event_time\"]+[f\"score_{score}\" for score in scores]].dropna()\n",
    "\n",
    "def calculate_per_endpoint(df, partition, endpoint, survival_train, survival_test, score):  \n",
    "    event = [0 if (endpoint_event == 0) | (endpoint_event_time > 10) else 1 for endpoint_event, endpoint_event_time in zip(df[endpoint+\"_event\"], df[endpoint+\"_event_time\"])]\n",
    "    event_time = [10 if (endpoint_event == 0) | (endpoint_event_time > 10) else endpoint_event_time for endpoint_event, endpoint_event_time in zip(df[endpoint+\"_event\"], df[endpoint+\"_event_time\"])]\n",
    "    df = df.assign(event = event, event_time = event_time)\n",
    "    df = df.dropna(subset=[\"event_time\", \"score_\" + score, \"event\"], axis=0)\n",
    "    \n",
    "    try: cindex = 1-concordance_index(df[\"event_time\"], df[\"score_\" + score], df[\"event\"])\n",
    "    except: cindex=np.nan\n",
    "        \n",
    "    survival = to_struc_array(survival_test.loc[:,[\"MACE_event\", \"MACE_event_time\"]])\n",
    "    estimate = survival_test.loc[:,[f\"score_{score}\"]].T.values[0] \n",
    "    \n",
    "    cindex_ipcw = concordance_index_ipcw(survival_train, survival, estimate)[0]\n",
    "    brier_score_ipcw = brier_score(survival_train, survival, estimate, 10)[1][0]\n",
    "    cum_dyn_auc = cumulative_dynamic_auc(survival_train, survival, estimate, 10)[0][0]\n",
    "    \n",
    "    return {\"endpoint\":endpoint, \"score\":score, \"partition\":partition, \"cindex\":cindex, \"cindex_ipcw\":cindex_ipcw, \"brier_score_ipcw\": brier_score_ipcw, \"cum_dyn_auc\": cum_dyn_auc}\n",
    "        \n",
    "rows = Parallel(n_jobs=10)(delayed(calculate_per_endpoint)(\n",
    "    data_test[data_test.partition==int(partition)], int(partition), endpoint, survival_train, survival_test, score) for partition in tqdm(partitions) for score in tqdm(scores, leave=False) for endpoint in endpoints)\n",
    "benchmark_endpoints_pp = pd.DataFrame({\"endpoint\":[], \"score\":[], \"partition\":[], \"cindex\":[], \"cindex_ipcw\": [], \"brier_score_ipcw\": [], \"cum_dyn_auc\": []}).append(rows, ignore_index=True)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name = \"benchmark_cindex_MACE_210705_centres_FINAL\"\n",
    "benchmark_endpoints_pp.to_feather(f\"{data_results_path}/{name}.feather\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:miniconda3-pl1.x]",
   "language": "python",
   "name": "conda-env-miniconda3-pl1.x-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "toc-autonumbering": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}