221 lines (220 with data), 6.9 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Benchmarks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import os\n",
"from tqdm.auto import tqdm\n",
"import pathlib\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import lifelines"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%env MKL_NUM_THREADS=1\n",
"%env NUMEXPR_NUM_THREADS=1\n",
"%env OMP_NUM_THREADS=1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dask.distributed import Client, LocalCluster\n",
"cluster = LocalCluster(n_workers=40, threads_per_worker=1)\n",
"client = Client(cluster)\n",
"cluster.scheduler"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"project_name = \"210616_centres_dask\"\n",
"data_path = \"/data/analysis/ag-reils/steinfej\"\n",
"data_pre = f\"{data_path}/data/2_datasets_pre/{project_name}\"\n",
"data_post = f\"{data_path}/data/3_datasets_post/{project_name}\"\n",
"\n",
"project_label = \"21_PGS_Revision\"\n",
"project_path = f\"/data/analysis/ag-reils/ag-reils-shared/cardioRS/results/projects/{project_label}\"\n",
"figures_path = f\"{project_path}/figures\"\n",
"data_results_path = f\"{project_path}/data\"\n",
"pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)\n",
"pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_feather(f\"{data_post}/data_merged.feather\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"endpoints = ['MACE']\n",
"endpoint_labels = sorted([f\"{e}_event\" for e in endpoints]+[f\"{e}_event_time\" for e in endpoints])\n",
"endpoint_data = pd.read_feather(f\"{data_post}/data_merged.feather\", columns=[\"eid\"] + endpoint_labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preds = pd.read_feather(f\"{data_results_path}/predictions_210703_centres_FINAL.feather\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_test = preds[['eid','endpoint', 'module','features','split','partition','Ft_10']].query(\"split=='test'\")\n",
"data_test\n",
"\n",
"modules = data_test.module.unique().tolist()\n",
"features = data_test.features.unique().tolist()\n",
"partitions = data_test.partition.unique().tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"iterations=[i for i in range(1000)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc, integrated_brier_score\n",
"from lifelines.utils import concordance_index\n",
"from dask.diagnostics import ProgressBar\n",
"\n",
"def calculate_per_endpoint(df, endpoint, module, feature, iteration, time): \n",
" event = [0 if (endpoint_event == 0) | (endpoint_event_time > time) else 1 for endpoint_event, endpoint_event_time in zip(df[endpoint+\"_event\"], df[endpoint+\"_event_time\"])]\n",
" event_time = [time if (endpoint_event == 0) | (endpoint_event_time > time) else endpoint_event_time for endpoint_event, endpoint_event_time in zip(df[endpoint+\"_event\"], df[endpoint+\"_event_time\"])]\n",
" df = df.assign(event = event, event_time = event_time)\n",
" df = df.dropna(subset=[\"event_time\", f\"Ft_{time}\", \"event\"], axis=0)\n",
" \n",
" cindex = 1-concordance_index(df[\"event_time\"], df[f\"Ft_{time}\"], df[\"event\"])\n",
" #except: cindex=np.nan\n",
" return {\"endpoint\":endpoint, \"module\": module, \"features\": feature, \"iteration\": iteration, \"time\":time, \"cindex\":cindex}\n",
"\n",
"data_bm = data_test.query(\"endpoint==@endpoint\").merge(endpoint_data[[\"eid\", f\"{endpoint}_event\", f\"{endpoint}_event_time\"]], on=\"eid\", how=\"left\")\n",
"eids = data_test.eid.unique()\n",
"endpoint = \"MACE\"\n",
"with ProgressBar():\n",
" rows = []\n",
" for iteration in tqdm(iterations): \n",
" eids_bs = np.random.choice(eids, size=len(eids))\n",
" for module in modules: \n",
" temp_module = data_bm.query(\"module==@module\")\n",
" for feature in features:\n",
" temp_features = temp_module.query(\"features==@feature\")\n",
" if len(temp_features)>0:\n",
" data_future = client.scatter(temp_features[[\"eid\", \"Ft_10\", \"MACE_event\", \"MACE_event_time\"]].set_index(\"eid\").loc[eids_bs].reset_index())\n",
" rows.append(client.submit(calculate_per_endpoint, data_future, endpoint, module, feature, iteration, time=10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dask.distributed import progress\n",
"progress(rows)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rows = client.gather(rows)\n",
"benchmark_endpoints_pp = pd.DataFrame({\"endpoint\":[], \"module\": [], \"features\": [], \"partition\": [], \"time\": [], \"cindex\": []}).append(rows, ignore_index=True)\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"name = \"benchmark_cindex_MACE_210703_all_FINAL\"\n",
"benchmark_endpoints_pp.to_feather(f\"{data_results_path}/{name}.feather\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:miniconda3-pl1.x]",
"language": "python",
"name": "conda-env-miniconda3-pl1.x-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
},
"toc-autonumbering": false
},
"nbformat": 4,
"nbformat_minor": 4
}