[aca2dc]: / neuralcvd / postprocessing / 1_collect_results_cox.ipynb

Download this file

491 lines (490 with data), 15.9 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cox model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import lifelines\n",
    "from lifelines import CoxPHFitter\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm.notebook import tqdm\n",
    "import neptune\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import shutil\n",
    "import anndata as ad\n",
    "import pickle\n",
    "import pathlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project_name = \"210616_centres_dask\"\n",
    "data_path = \"/data/analysis/ag-reils/steinfej\"\n",
    "data_pre = f\"{data_path}/data/2_datasets_pre/{project_name}\"\n",
    "data_post = f\"{data_path}/data/3_datasets_post/{project_name}\"\n",
    "\n",
    "project_label = \"21_PGS_Revision\"\n",
    "project_path = f\"/data/analysis/ag-reils/ag-reils-shared/cardioRS/results/projects/{project_label}\"\n",
    "figures_path = f\"{project_path}/figures\"\n",
    "data_results_path = f\"{project_path}/data\"\n",
    "pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)\n",
    "pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoints = ['MACE']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask.distributed import Client, LocalCluster\n",
    "cluster = LocalCluster(n_workers=20, threads_per_worker=100)\n",
    "client = Client(cluster)\n",
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "partitions = [str(p) for p in range(22)]\n",
    "splits = [\"train\", \"valid\", \"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create COX and Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_temp = pd.read_feather(f\"{data_post}/data_merged.feather\")\n",
    "eids_dict = {}\n",
    "for endpoint in tqdm(endpoints):\n",
    "    if endpoint == \"MACE\": eids_incl = data_temp.copy().query(f\"myocardial_infarction==False&stroke==False&statins==False\").eid.to_list()\n",
    "    print(endpoint, len(eids_incl))\n",
    "    eids_dict[endpoint] = eids_incl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_description = pd.read_feather(f\"{data_post}/description.feather\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(dataset_path, partition, split, eids_incl):\n",
    "    return pd.read_feather(f\"{data_post}/partition_{partition}/{split}/data_imputed_normalized.feather\").set_index(\"eid\")\n",
    "\n",
    "data_all = {partition: {split: client.submit(load_data, data_post, partition, split, eids_incl) for split in splits} for partition in tqdm(partitions)}"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "def load_data(dataset_path, partition, split, eids_incl):\n",
    "    temp = pd.read_feather(f\"{dataset_path}/partition_{partition}/{split}/data_imputed_normalized.feather\").set_index(\"eid\")\n",
    "    pgs_cols = [col for col in data_temp.columns.to_list() if \"PGS\" in col]\n",
    "    for col in tqdm(pgs_cols): temp[f\"age*{col}\"] = data_temp[\"age_at_recruitment\"]*data_temp[col]\n",
    "\n",
    "data_all = {partition: {split: client.submit(load_data, dataset_path, partition, split, eids_incl) for split in splits} for partition in tqdm(partitions)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_all = client.gather(data_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "basics = [\n",
    "'age_at_recruitment',\n",
    "'ethnic_background_0.0',\n",
    "'ethnic_background_1.0',\n",
    "'ethnic_background_2.0',#na 2 -> 5\n",
    "'ethnic_background_3.0',\n",
    "'ethnic_background_4.0',\n",
    "'townsend_deprivation_index_at_recruitment',\n",
    "'sex'\n",
    "]\n",
    "questionnaire = [\n",
    "'overall_health_rating_0.0',\n",
    "'overall_health_rating_1.0',\n",
    "'overall_health_rating_2.0',\n",
    "'overall_health_rating_3.0',\n",
    "'smoking_status_0.0',\n",
    "'smoking_status_1.0',\n",
    "'smoking_status_2.0',\n",
    "]\n",
    "measurements = [\n",
    "'body_mass_index_bmi',\n",
    "'weight',\n",
    "\"standing_height\",\n",
    "'systolic_blood_pressure',\n",
    "'diastolic_blood_pressure',\n",
    "]\n",
    "\n",
    "labs = [\n",
    "\"cholesterol\",\n",
    "\"hdl_cholesterol\",\n",
    "\"ldl_direct\",\n",
    "\"triglycerides\"\n",
    "]\n",
    "\n",
    "family_history = [\n",
    "'fh_heart_disease',\n",
    "]\n",
    "\n",
    "diagnoses = [\n",
    "'diabetes1',\n",
    "'diabetes2',\n",
    "'chronic_kidney_disease',\n",
    "'atrial_fibrillation',\n",
    "'migraine',\n",
    "'rheumatoid_arthritis',\n",
    "'systemic_lupus_erythematosus',\n",
    "'severe_mental_illness',\n",
    "'erectile_dysfunction',\n",
    "]\n",
    "\n",
    "medications = [\n",
    "\"antihypertensives\",\n",
    "\"ass\",\n",
    "\"atypical_antipsychotics\",\n",
    "\"glucocorticoids\"\n",
    "]\n",
    "\n",
    "pgs_all = [\n",
    "    'PGS000011',\n",
    "    'PGS000018',\n",
    "    'PGS000039',\n",
    "    'PGS000057',\n",
    "    'PGS000058',\n",
    "    'PGS000059'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_dict = {\n",
    "\"basics\": basics,\n",
    "\"questionnaire\": questionnaire,\n",
    "\"measurements\": measurements,\n",
    "\"labs\": labs,\n",
    "\"family_history\": family_history,\n",
    "\"medications\": medications,\n",
    "\"diagnoses\": diagnoses,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = {}\n",
    "features[\"clinical\"] = feature_dict[\"basics\"]+feature_dict[\"questionnaire\"]+feature_dict[\"measurements\"] + feature_dict[\"labs\"]+feature_dict[\"family_history\"]+feature_dict[\"medications\"]+feature_dict[\"diagnoses\"]\n",
    "features[\"clinical_pgs_all\"] = features[\"clinical\"] + pgs_all\n",
    "features[\"clinical_pgs_all*age\"] = features[\"clinical_pgs_all\"] \n",
    "features[\"sun_pgs\"] = [\"age_at_recruitment\", \"sex\", 'smoking_status_0.0', \"diabetes2\", \"systolic_blood_pressure\", \"diastolic_blood_pressure\", \"cholesterol\", \"hdl_cholesterol\", \"PGS000018\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "formulas = {}\n",
    "formulas[\"clinical\"] = \"+\".join(features[\"clinical\"])\n",
    "formulas[\"clinical_pgs_all\"] = \"+\".join(features[\"clinical_pgs_all\"])\n",
    "formulas[\"clinical_pgs_all*age\"] = \"+\".join([col for col in features[\"clinical\"] if col!=\"age_at_recruitment\"])+\"+\"+\"+\".join([f\"age_at_recruitment*{col}\" for col in pgs_all])\n",
    "formulas[\"sun_pgs\"] = [\"age_at_recruitment\", \"sex\", 'smoking_status_0.0', \"diabetes2\", \"systolic_blood_pressure\", \"diastolic_blood_pressure\", \"cholesterol\", \"hdl_cholesterol\", \"PGS000018\", \"PGS000039\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#endpoint = \"M_MACE\"; \n",
    "events=[endpoint+'_event' for endpoint in endpoints] \n",
    "times=[endpoint+'_event_time' for endpoint in endpoints]\n",
    "groups = list(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for group in tqdm(groups): \n",
    "    data[group] = {\"features\":features[group]+events+times}\n",
    "    for partition in partitions: \n",
    "        data[group][partition] = {}\n",
    "        for split in splits: data[group][partition][split] = data_all[partition][split].loc[:, data[group][\"features\"]].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lifelines.utils import concordance_index\n",
    "import pathlib\n",
    "\n",
    "def fit_predict_coxph(data_h5ad, endpoint, group, partition, time, event, eids_incl, dump_path):\n",
    "    pathlib.Path(dump_path).mkdir(parents=True, exist_ok=True)      \n",
    "\n",
    "    cph = CoxPHFitter()\n",
    "    train_data = data_h5ad[\"train\"].reset_index().query(\"eid==@eids_incl\").set_index(\"eid\")\n",
    "    val_data = data_h5ad[\"valid\"].reset_index().query(\"eid==@eids_incl\").set_index(\"eid\")\n",
    "    test_data = data_h5ad[\"test\"].reset_index().query(\"eid==@eids_incl\").set_index(\"eid\")\n",
    "\n",
    "\n",
    "    covariates_with_tte = [col for col in data[group][\"features\"] if \"MACE\" not in col]+[time, event]\n",
    "    for col in covariates_with_tte:\n",
    "        if train_data[col].nunique()==1: covariates_with_tte.remove(col)\n",
    "\n",
    "    cph.fit(train_data[covariates_with_tte], duration_col=time, event_col=event, show_progress=True, step_size=0.5, formula=formulas[group])\n",
    "    pickle.dump(cph, open(f\"{dump_path}/{endpoint}_{group}_{partition}.p\", \"wb\" ) )\n",
    "    print(concordance_index(val_data[time], -cph.predict_partial_hazard(val_data[covariates_with_tte]), val_data[event]))\n",
    "\n",
    "    surv_train = 1-cph.predict_survival_function(train_data[covariates_with_tte], times=[t for t in range(1,27)])\n",
    "    surv_val = 1-cph.predict_survival_function(val_data[covariates_with_tte], times=[t for t in range(1,27)]) # as years + 1  \n",
    "    surv_test = 1-cph.predict_survival_function(test_data[covariates_with_tte], times=[t for t in range(1,27)]) # as years + 1 \n",
    "\n",
    "    pred = {\"train\":train_data.reset_index()[[\"eid\"]],\n",
    "            \"val\":val_data.reset_index()[[\"eid\"]],\n",
    "           \"test\":test_data.reset_index()[[\"eid\"]],}\n",
    "\n",
    "    pred[\"train\"][f\"score_COX_{group}\"] = surv_train.iloc[0].to_list()\n",
    "    pred[\"val\"][f\"score_COX_{group}\"] = surv_val.iloc[0].to_list()\n",
    "    pred[\"test\"][f\"score_COX_{group}\"] = surv_test.iloc[0].to_list()\n",
    "\n",
    "\n",
    "    time_cols = {t: f\"0_{t}_Ft\" for t in range(1, 27)}\n",
    "    for t, col in time_cols.items():\n",
    "        pred[\"train\"][col] = surv_train.T[t].to_list()\n",
    "        pred[\"val\"][col] = surv_val.T[t].to_list()\n",
    "        pred[\"test\"][col] = surv_test.T[t].to_list()\n",
    "\n",
    "    preds = pd.concat([pred[\"train\"].assign(split=\"train\"), pred[\"val\"].assign(split=\"valid\"), pred[\"test\"].assign(split=\"test\")], axis=0)\\\n",
    "        .assign(endpoint=endpoint, features=group, partition=partition, module=\"COXPH\", datamodule=\"UKBBSurvivalDatamodule\", net=\"\", calibrated=\"False\")\n",
    "    preds = preds[[\"eid\", 'endpoint', 'features', 'split', 'partition', 'module', 'datamodule', 'net', 'calibrated'] + list(time_cols.values())].reset_index(drop=True)\n",
    "    preds.to_feather(f\"{dump_path}/{endpoint}_{group}_{partition}.feather\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dump_path = f\"{data_post}/COXPH/210631_PGS_REVISION\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for endpoint in tqdm(endpoints):\n",
    "    time = f\"{endpoint}_event_time\"\n",
    "    event = f\"{endpoint}_event\"\n",
    "    eids_incl = eids_dict[endpoint]\n",
    "    for group in tqdm(groups):\n",
    "        print(group)\n",
    "        for partition in partitions:\n",
    "            fit_predict_coxph(data[group][partition], endpoint, group, partition, time, event, eids_incl, dump_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read and Process Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "files = sorted(glob.glob(f\"{dump_path}/*.feather\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "import pandas as pd\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm.auto import tqdm\n",
    "def get_df(path): return pd.read_feather(path)#return pd.read_csv(f\"{path[:-8]}.csv\", index_col=0)\n",
    "with joblib.parallel_backend('dask'):\n",
    "    dfs = Parallel(n_jobs=80)(delayed(get_df)(path) for path in tqdm(files) if path is not None if not pd.isna(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = pd.concat(dfs, axis=0).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_float32(df):\n",
    "    for col in tqdm(df.columns.to_list()):\n",
    "        if df[col].dtype == \"float64\": \n",
    "            print(col, \"convert\")\n",
    "            df[col]= df[col].astype(\"float32\")\n",
    "    return df\n",
    "\n",
    "for col in tqdm(predictions.columns.to_list()):\n",
    "    if predictions[col].dtype == \"object\": predictions[col]= predictions[col].astype(\"category\")\n",
    "        \n",
    "predictions[\"partition\"] = predictions[\"partition\"].astype(int)\n",
    "predictions = convert_to_float32(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fix_column_names(df):\n",
    "    # rename and fix time bugs!!! -> 0_11_Ft -> Ft at t=10 -> fix earlier\n",
    "    time_fix_map = dict(zip([col for col in df.columns if \"Ft\" in col], [f\"Ft_{col}\" for col in range(len([col for col in df.columns if \"Ft\" in col]))]))\n",
    "    df = df.rename(time_fix_map, axis=\"columns\")\n",
    "    return df\n",
    "predictions = fix_column_names(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions.to_feather(f\"{data_results_path}/predictions_cox_210631_REVISION.feather\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:miniconda3-pl1.x]",
   "language": "python",
   "name": "conda-env-miniconda3-pl1.x-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}