491 lines (490 with data), 15.9 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cox model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import os\n",
"from tqdm.auto import tqdm\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import lifelines\n",
"from lifelines import CoxPHFitter\n",
"from sklearn.model_selection import StratifiedKFold\n",
"\n",
"from joblib import Parallel, delayed\n",
"from tqdm.notebook import tqdm\n",
"import neptune\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"import shutil\n",
"import anndata as ad\n",
"import pickle\n",
"import pathlib"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"project_name = \"210616_centres_dask\"\n",
"data_path = \"/data/analysis/ag-reils/steinfej\"\n",
"data_pre = f\"{data_path}/data/2_datasets_pre/{project_name}\"\n",
"data_post = f\"{data_path}/data/3_datasets_post/{project_name}\"\n",
"\n",
"project_label = \"21_PGS_Revision\"\n",
"project_path = f\"/data/analysis/ag-reils/ag-reils-shared/cardioRS/results/projects/{project_label}\"\n",
"figures_path = f\"{project_path}/figures\"\n",
"data_results_path = f\"{project_path}/data\"\n",
"pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)\n",
"pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"endpoints = ['MACE']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dask.distributed import Client, LocalCluster\n",
"cluster = LocalCluster(n_workers=20, threads_per_worker=100)\n",
"client = Client(cluster)\n",
"client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"partitions = [str(p) for p in range(22)]\n",
"splits = [\"train\", \"valid\", \"test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create COX and Predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_temp = pd.read_feather(f\"{data_post}/data_merged.feather\")\n",
"eids_dict = {}\n",
"for endpoint in tqdm(endpoints):\n",
" if endpoint == \"MACE\": eids_incl = data_temp.copy().query(f\"myocardial_infarction==False&stroke==False&statins==False\").eid.to_list()\n",
" print(endpoint, len(eids_incl))\n",
" eids_dict[endpoint] = eids_incl"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_description = pd.read_feather(f\"{data_post}/description.feather\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_data(dataset_path, partition, split, eids_incl):\n",
" return pd.read_feather(f\"{data_post}/partition_{partition}/{split}/data_imputed_normalized.feather\").set_index(\"eid\")\n",
"\n",
"data_all = {partition: {split: client.submit(load_data, data_post, partition, split, eids_incl) for split in splits} for partition in tqdm(partitions)}"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"def load_data(dataset_path, partition, split, eids_incl):\n",
" temp = pd.read_feather(f\"{dataset_path}/partition_{partition}/{split}/data_imputed_normalized.feather\").set_index(\"eid\")\n",
" pgs_cols = [col for col in data_temp.columns.to_list() if \"PGS\" in col]\n",
" for col in tqdm(pgs_cols): temp[f\"age*{col}\"] = data_temp[\"age_at_recruitment\"]*data_temp[col]\n",
"\n",
"data_all = {partition: {split: client.submit(load_data, dataset_path, partition, split, eids_incl) for split in splits} for partition in tqdm(partitions)}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_all = client.gather(data_all)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"basics = [\n",
"'age_at_recruitment',\n",
"'ethnic_background_0.0',\n",
"'ethnic_background_1.0',\n",
"'ethnic_background_2.0',#na 2 -> 5\n",
"'ethnic_background_3.0',\n",
"'ethnic_background_4.0',\n",
"'townsend_deprivation_index_at_recruitment',\n",
"'sex'\n",
"]\n",
"questionnaire = [\n",
"'overall_health_rating_0.0',\n",
"'overall_health_rating_1.0',\n",
"'overall_health_rating_2.0',\n",
"'overall_health_rating_3.0',\n",
"'smoking_status_0.0',\n",
"'smoking_status_1.0',\n",
"'smoking_status_2.0',\n",
"]\n",
"measurements = [\n",
"'body_mass_index_bmi',\n",
"'weight',\n",
"\"standing_height\",\n",
"'systolic_blood_pressure',\n",
"'diastolic_blood_pressure',\n",
"]\n",
"\n",
"labs = [\n",
"\"cholesterol\",\n",
"\"hdl_cholesterol\",\n",
"\"ldl_direct\",\n",
"\"triglycerides\"\n",
"]\n",
"\n",
"family_history = [\n",
"'fh_heart_disease',\n",
"]\n",
"\n",
"diagnoses = [\n",
"'diabetes1',\n",
"'diabetes2',\n",
"'chronic_kidney_disease',\n",
"'atrial_fibrillation',\n",
"'migraine',\n",
"'rheumatoid_arthritis',\n",
"'systemic_lupus_erythematosus',\n",
"'severe_mental_illness',\n",
"'erectile_dysfunction',\n",
"]\n",
"\n",
"medications = [\n",
"\"antihypertensives\",\n",
"\"ass\",\n",
"\"atypical_antipsychotics\",\n",
"\"glucocorticoids\"\n",
"]\n",
"\n",
"pgs_all = [\n",
" 'PGS000011',\n",
" 'PGS000018',\n",
" 'PGS000039',\n",
" 'PGS000057',\n",
" 'PGS000058',\n",
" 'PGS000059'\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"feature_dict = {\n",
"\"basics\": basics,\n",
"\"questionnaire\": questionnaire,\n",
"\"measurements\": measurements,\n",
"\"labs\": labs,\n",
"\"family_history\": family_history,\n",
"\"medications\": medications,\n",
"\"diagnoses\": diagnoses,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"features = {}\n",
"features[\"clinical\"] = feature_dict[\"basics\"]+feature_dict[\"questionnaire\"]+feature_dict[\"measurements\"] + feature_dict[\"labs\"]+feature_dict[\"family_history\"]+feature_dict[\"medications\"]+feature_dict[\"diagnoses\"]\n",
"features[\"clinical_pgs_all\"] = features[\"clinical\"] + pgs_all\n",
"features[\"clinical_pgs_all*age\"] = features[\"clinical_pgs_all\"] \n",
"features[\"sun_pgs\"] = [\"age_at_recruitment\", \"sex\", 'smoking_status_0.0', \"diabetes2\", \"systolic_blood_pressure\", \"diastolic_blood_pressure\", \"cholesterol\", \"hdl_cholesterol\", \"PGS000018\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"formulas = {}\n",
"formulas[\"clinical\"] = \"+\".join(features[\"clinical\"])\n",
"formulas[\"clinical_pgs_all\"] = \"+\".join(features[\"clinical_pgs_all\"])\n",
"formulas[\"clinical_pgs_all*age\"] = \"+\".join([col for col in features[\"clinical\"] if col!=\"age_at_recruitment\"])+\"+\"+\"+\".join([f\"age_at_recruitment*{col}\" for col in pgs_all])\n",
"formulas[\"sun_pgs\"] = [\"age_at_recruitment\", \"sex\", 'smoking_status_0.0', \"diabetes2\", \"systolic_blood_pressure\", \"diastolic_blood_pressure\", \"cholesterol\", \"hdl_cholesterol\", \"PGS000018\", \"PGS000039\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#endpoint = \"M_MACE\"; \n",
"events=[endpoint+'_event' for endpoint in endpoints] \n",
"times=[endpoint+'_event_time' for endpoint in endpoints]\n",
"groups = list(features)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = {}\n",
"for group in tqdm(groups): \n",
" data[group] = {\"features\":features[group]+events+times}\n",
" for partition in partitions: \n",
" data[group][partition] = {}\n",
" for split in splits: data[group][partition][split] = data_all[partition][split].loc[:, data[group][\"features\"]].copy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from lifelines.utils import concordance_index\n",
"import pathlib\n",
"\n",
"def fit_predict_coxph(data_h5ad, endpoint, group, partition, time, event, eids_incl, dump_path):\n",
" pathlib.Path(dump_path).mkdir(parents=True, exist_ok=True) \n",
"\n",
" cph = CoxPHFitter()\n",
" train_data = data_h5ad[\"train\"].reset_index().query(\"eid==@eids_incl\").set_index(\"eid\")\n",
" val_data = data_h5ad[\"valid\"].reset_index().query(\"eid==@eids_incl\").set_index(\"eid\")\n",
" test_data = data_h5ad[\"test\"].reset_index().query(\"eid==@eids_incl\").set_index(\"eid\")\n",
"\n",
"\n",
" covariates_with_tte = [col for col in data[group][\"features\"] if \"MACE\" not in col]+[time, event]\n",
" for col in covariates_with_tte:\n",
" if train_data[col].nunique()==1: covariates_with_tte.remove(col)\n",
"\n",
" cph.fit(train_data[covariates_with_tte], duration_col=time, event_col=event, show_progress=True, step_size=0.5, formula=formulas[group])\n",
" pickle.dump(cph, open(f\"{dump_path}/{endpoint}_{group}_{partition}.p\", \"wb\" ) )\n",
" print(concordance_index(val_data[time], -cph.predict_partial_hazard(val_data[covariates_with_tte]), val_data[event]))\n",
"\n",
" surv_train = 1-cph.predict_survival_function(train_data[covariates_with_tte], times=[t for t in range(1,27)])\n",
" surv_val = 1-cph.predict_survival_function(val_data[covariates_with_tte], times=[t for t in range(1,27)]) # as years + 1 \n",
" surv_test = 1-cph.predict_survival_function(test_data[covariates_with_tte], times=[t for t in range(1,27)]) # as years + 1 \n",
"\n",
" pred = {\"train\":train_data.reset_index()[[\"eid\"]],\n",
" \"val\":val_data.reset_index()[[\"eid\"]],\n",
" \"test\":test_data.reset_index()[[\"eid\"]],}\n",
"\n",
" pred[\"train\"][f\"score_COX_{group}\"] = surv_train.iloc[0].to_list()\n",
" pred[\"val\"][f\"score_COX_{group}\"] = surv_val.iloc[0].to_list()\n",
" pred[\"test\"][f\"score_COX_{group}\"] = surv_test.iloc[0].to_list()\n",
"\n",
"\n",
" time_cols = {t: f\"0_{t}_Ft\" for t in range(1, 27)}\n",
" for t, col in time_cols.items():\n",
" pred[\"train\"][col] = surv_train.T[t].to_list()\n",
" pred[\"val\"][col] = surv_val.T[t].to_list()\n",
" pred[\"test\"][col] = surv_test.T[t].to_list()\n",
"\n",
" preds = pd.concat([pred[\"train\"].assign(split=\"train\"), pred[\"val\"].assign(split=\"valid\"), pred[\"test\"].assign(split=\"test\")], axis=0)\\\n",
" .assign(endpoint=endpoint, features=group, partition=partition, module=\"COXPH\", datamodule=\"UKBBSurvivalDatamodule\", net=\"\", calibrated=\"False\")\n",
" preds = preds[[\"eid\", 'endpoint', 'features', 'split', 'partition', 'module', 'datamodule', 'net', 'calibrated'] + list(time_cols.values())].reset_index(drop=True)\n",
" preds.to_feather(f\"{dump_path}/{endpoint}_{group}_{partition}.feather\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dump_path = f\"{data_post}/COXPH/210631_PGS_REVISION\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for endpoint in tqdm(endpoints):\n",
" time = f\"{endpoint}_event_time\"\n",
" event = f\"{endpoint}_event\"\n",
" eids_incl = eids_dict[endpoint]\n",
" for group in tqdm(groups):\n",
" print(group)\n",
" for partition in partitions:\n",
" fit_predict_coxph(data[group][partition], endpoint, group, partition, time, event, eids_incl, dump_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Read and Process Predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"files = sorted(glob.glob(f\"{dump_path}/*.feather\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import joblib\n",
"import pandas as pd\n",
"from joblib import Parallel, delayed\n",
"from tqdm.auto import tqdm\n",
"def get_df(path): return pd.read_feather(path)#return pd.read_csv(f\"{path[:-8]}.csv\", index_col=0)\n",
"with joblib.parallel_backend('dask'):\n",
" dfs = Parallel(n_jobs=80)(delayed(get_df)(path) for path in tqdm(files) if path is not None if not pd.isna(path))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictions = pd.concat(dfs, axis=0).reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def convert_to_float32(df):\n",
" for col in tqdm(df.columns.to_list()):\n",
" if df[col].dtype == \"float64\": \n",
" print(col, \"convert\")\n",
" df[col]= df[col].astype(\"float32\")\n",
" return df\n",
"\n",
"for col in tqdm(predictions.columns.to_list()):\n",
" if predictions[col].dtype == \"object\": predictions[col]= predictions[col].astype(\"category\")\n",
" \n",
"predictions[\"partition\"] = predictions[\"partition\"].astype(int)\n",
"predictions = convert_to_float32(predictions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def fix_column_names(df):\n",
" # rename and fix time bugs!!! -> 0_11_Ft -> Ft at t=10 -> fix earlier\n",
" time_fix_map = dict(zip([col for col in df.columns if \"Ft\" in col], [f\"Ft_{col}\" for col in range(len([col for col in df.columns if \"Ft\" in col]))]))\n",
" df = df.rename(time_fix_map, axis=\"columns\")\n",
" return df\n",
"predictions = fix_column_names(predictions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictions.to_feather(f\"{data_results_path}/predictions_cox_210631_REVISION.feather\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:miniconda3-pl1.x]",
"language": "python",
"name": "conda-env-miniconda3-pl1.x-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}