[d6904d]: / datasets / tjh / preprocess.ipynb

Download this file

1088 lines (1087 with data), 48.9 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "# Import necessary packages\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read raw data\n",
    "df_train: pd.DataFrame = pd.read_excel('./raw_data/time_series_375_prerpocess_en.xlsx')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Steps:\n",
    "\n",
    "- fill `patient_id`\n",
    "- only reserve y-m-d for `RE_DATE` column\n",
    "- merge lab tests of the same (patient_id, date)\n",
    "- calculate and save features' statistics information (demographic and lab test data are calculated separately)\n",
    "- normalize data\n",
    "- feature selection\n",
    "- fill missing data (our filling strategy will be described below)\n",
    "- combine above data to time series data (one patient one record)\n",
    "- export to python pickle file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fill `patient_id` rows\n",
    "df_train['PATIENT_ID'].fillna(method='ffill', inplace=True)\n",
    "\n",
    "# gender transformation: 1--male, 0--female\n",
    "df_train['gender'].replace(2, 0, inplace=True)\n",
    "\n",
    "# only reserve y-m-d for `RE_DATE` and `Discharge time` columns\n",
    "df_train['RE_DATE'] = df_train['RE_DATE'].dt.strftime('%Y-%m-%d')\n",
    "df_train['Discharge time'] = df_train['Discharge time'].dt.strftime('%Y-%m-%d')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = df_train.dropna(subset = ['PATIENT_ID', 'RE_DATE', 'Discharge time'], how='any')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate raw data's los interval\n",
    "df_grouped = df_train.groupby('PATIENT_ID')\n",
    "\n",
    "los_interval_list = []\n",
    "los_interval_alive_list = []\n",
    "los_interval_dead_list = []\n",
    "\n",
    "for name, group in df_grouped:\n",
    "    sorted_group = group.sort_values(by=['RE_DATE'], ascending=True)\n",
    "    # print(sorted_group['outcome'])\n",
    "    # print('---')\n",
    "    # print(type(sorted_group))\n",
    "    intervals = sorted_group['RE_DATE'].tolist()\n",
    "    outcome = sorted_group['outcome'].tolist()[0]\n",
    "    cur_visits_len = len(intervals)\n",
    "    # print(cur_visits_len)\n",
    "    if cur_visits_len == 1:\n",
    "        continue\n",
    "    for i in range(1, len(intervals)):\n",
    "        los_interval_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
    "        if outcome == 0:\n",
    "            los_interval_alive_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
    "        else:\n",
    "            los_interval_dead_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
    "\n",
    "los_interval_list = np.array(los_interval_list)\n",
    "los_interval_alive_list = np.array(los_interval_alive_list)\n",
    "los_interval_dead_list = np.array(los_interval_dead_list)\n",
    "\n",
    "output = {\n",
    "    'overall': los_interval_list,\n",
    "    'alive': los_interval_alive_list,\n",
    "    'dead': los_interval_dead_list,\n",
    "}\n",
    "# pd.to_pickle(output, 'raw_tjh_los_interval_list.pkl')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we have 2 types of prediction tasks: 1) predict mortality outcome, 2) length of stay\n",
    "\n",
    "# below are all lab test features\n",
    "labtest_features_str = \"\"\"\n",
    "Hypersensitive cardiac troponinI\themoglobin\tSerum chloride\tProthrombin time\tprocalcitonin\teosinophils(%)\tInterleukin 2 receptor\tAlkaline phosphatase\talbumin\tbasophil(%)\tInterleukin 10\tTotal bilirubin\tPlatelet count\tmonocytes(%)\tantithrombin\tInterleukin 8\tindirect bilirubin\tRed blood cell distribution width \tneutrophils(%)\ttotal protein\tQuantification of Treponema pallidum antibodies\tProthrombin activity\tHBsAg\tmean corpuscular volume\thematocrit\tWhite blood cell count\tTumor necrosis factorα\tmean corpuscular hemoglobin concentration\tfibrinogen\tInterleukin 1β\tUrea\tlymphocyte count\tPH value\tRed blood cell count\tEosinophil count\tCorrected calcium\tSerum potassium\tglucose\tneutrophils count\tDirect bilirubin\tMean platelet volume\tferritin\tRBC distribution width SD\tThrombin time\t(%)lymphocyte\tHCV antibody quantification\tD-D dimer\tTotal cholesterol\taspartate aminotransferase\tUric acid\tHCO3-\tcalcium\tAmino-terminal brain natriuretic peptide precursor(NT-proBNP)\tLactate dehydrogenase\tplatelet large cell ratio \tInterleukin 6\tFibrin degradation products\tmonocytes count\tPLT distribution width\tglobulin\tγ-glutamyl transpeptidase\tInternational standard ratio\tbasophil count(#)\t2019-nCoV nucleic acid detection\tmean corpuscular hemoglobin \tActivation of partial thromboplastin time\tHypersensitive c-reactive protein\tHIV antibody quantification\tserum sodium\tthrombocytocrit\tESR\tglutamic-pyruvic transaminase\teGFR\tcreatinine\n",
    "\"\"\"\n",
    "\n",
    "# below are 2 demographic features\n",
    "demographic_features_str = \"\"\"\n",
    "age\tgender\n",
    "\"\"\"\n",
    "\n",
    "labtest_features = [f for f in labtest_features_str.strip().split('\\t')]\n",
    "demographic_features = [f for f in demographic_features_str.strip().split('\\t')]\n",
    "target_features = ['outcome', 'LOS']\n",
    "\n",
    "# from our observation, `2019-nCoV nucleic acid detection` feature (in lab test) are all -1 value\n",
    "# so we remove this feature here\n",
    "labtest_features.remove('2019-nCoV nucleic acid detection')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if some values are negative, set it as Null\n",
    "df_train[df_train[demographic_features + labtest_features]<0] = np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge lab tests of the same (patient_id, date)\n",
    "df_train = df_train.groupby(['PATIENT_ID', 'RE_DATE', 'Discharge time'], dropna=True, as_index = False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate length-of-stay lable\n",
    "df_train['LOS'] = (pd.to_datetime(df_train['Discharge time']) - pd.to_datetime(df_train['RE_DATE'])).dt.days"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if los values are negative, set it as 0\n",
    "df_train['LOS'] = df_train['LOS'].clip(lower=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save features' statistics information\n",
    "\n",
    "def calculate_statistic_info(df, features):\n",
    "    \"\"\"all values calculated\"\"\"\n",
    "    statistic_info = {}\n",
    "    len_df = len(df)\n",
    "    for _, e in enumerate(features):\n",
    "        h = {}\n",
    "        h['count'] = int(df[e].count())\n",
    "        h['missing'] = str(round(float((100-df[e].count()*100/len_df)),3))+\"%\"\n",
    "        h['mean'] = float(df[e].mean())\n",
    "        h['max'] = float(df[e].max())\n",
    "        h['min'] = float(df[e].min())\n",
    "        h['median'] = float(df[e].median())\n",
    "        h['std'] = float(df[e].std())\n",
    "        statistic_info[e] = h\n",
    "    return statistic_info\n",
    "\n",
    "def calculate_middle_part_statistic_info(df, features):\n",
    "    \"\"\"calculate 5% ~ 95% percentile data\"\"\"\n",
    "    statistic_info = {}\n",
    "    len_df = len(df)\n",
    "    # calculate 5% and 95% percentile of dataframe\n",
    "    middle_part_df_info = df.quantile([.05, .95])\n",
    "\n",
    "    for _, e in enumerate(features):\n",
    "        low_value = middle_part_df_info[e][.05]\n",
    "        high_value = middle_part_df_info[e][.95]\n",
    "        middle_part_df_element = df.loc[(df[e] >= low_value) & (df[e] <= high_value)][e]\n",
    "        h = {}\n",
    "        h['count'] = int(middle_part_df_element.count())\n",
    "        h['missing'] = str(round(float((100-middle_part_df_element.count()*100/len_df)),3))+\"%\"\n",
    "        h['mean'] = float(middle_part_df_element.mean())\n",
    "        h['max'] = float(middle_part_df_element.max())\n",
    "        h['min'] = float(middle_part_df_element.min())\n",
    "        h['median'] = float(middle_part_df_element.median())\n",
    "        h['std'] = float(middle_part_df_element.std())\n",
    "        statistic_info[e] = h\n",
    "    return statistic_info\n",
    "\n",
    "# labtest_statistic_info = calculate_statistic_info(df_train, labtest_features)\n",
    "\n",
    "\n",
    "# group by patient_id, then calculate lab test/demographic features' statistics information\n",
    "groupby_patientid_df = df_train.groupby(['PATIENT_ID'], dropna=True, as_index = False).mean()\n",
    "\n",
    "\n",
    "# calculate statistic info (all values calculated)\n",
    "labtest_patientwise_statistic_info = calculate_statistic_info(groupby_patientid_df, labtest_features)\n",
    "demographic_statistic_info = calculate_statistic_info(groupby_patientid_df, demographic_features) # it's also patient-wise\n",
    "\n",
    "# calculate statistic info (5% ~ 95% only)\n",
    "demographic_statistic_info_2 = calculate_middle_part_statistic_info(groupby_patientid_df, demographic_features) \n",
    "labtest_patientwise_statistic_info_2 = calculate_middle_part_statistic_info(groupby_patientid_df, labtest_features) \n",
    "\n",
    "# take 2 statistics information's union\n",
    "statistic_info = labtest_patientwise_statistic_info_2 | demographic_statistic_info_2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# observe features, export to csv file [optional]\n",
    "to_export_dict = {'name': [], 'missing_rate': [], 'count': [], 'mean': [], 'max': [], 'min': [], 'median': [], 'std': []}\n",
    "for key in statistic_info:\n",
    "    detail = statistic_info[key]\n",
    "    to_export_dict['name'].append(key)\n",
    "    to_export_dict['count'].append(detail['count'])\n",
    "    to_export_dict['missing_rate'].append(detail['missing'])\n",
    "    to_export_dict['mean'].append(detail['mean'])\n",
    "    to_export_dict['max'].append(detail['max'])\n",
    "    to_export_dict['min'].append(detail['min'])\n",
    "    to_export_dict['median'].append(detail['median'])\n",
    "    to_export_dict['std'].append(detail['std'])\n",
    "to_export_df = pd.DataFrame.from_dict(to_export_dict)\n",
    "# to_export_df.to_csv('statistic_info.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# normalize data\n",
    "def normalize_data(df, features, statistic_info):\n",
    "    \n",
    "    df_features = df[features]\n",
    "    df_features = df_features.apply(lambda x: (x - statistic_info[x.name]['mean']) / (statistic_info[x.name]['std']+1e-12))\n",
    "    df = pd.concat([df[['PATIENT_ID', 'gender', 'RE_DATE', 'outcome', 'LOS']], df_features], axis=1)\n",
    "    return df\n",
    "df_train = normalize_data(df_train, ['age'] + labtest_features, statistic_info) # gender don't need to be normalized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter outliers\n",
    "def filter_data(df, features, bar=3):\n",
    "    for f in features:\n",
    "        df[f] = df[f].mask(df[f].abs().gt(bar))\n",
    "    return df\n",
    "df_train = filter_data(df_train, demographic_features + labtest_features, bar=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop rows if all labtest_features are recorded nan\n",
    "df_train = df_train.dropna(subset = labtest_features, how='all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate data statistics after preprocessing steps (before imputation)\n",
    "\n",
    "# Step 1: reverse z-score normalization operation\n",
    "df_reverse = df_train\n",
    "# reverse normalize data\n",
    "def reverse_normalize_data(df, features, statistic_info):\n",
    "    df_features = df[features]\n",
    "    df_features = df_features.apply(lambda x: x * (statistic_info[x.name]['std']+1e-12) + statistic_info[x.name]['mean'])\n",
    "    df = pd.concat([df[['PATIENT_ID', 'gender', 'RE_DATE', 'outcome', 'LOS']], df_features], axis=1)\n",
    "    return df\n",
    "df_reverse = reverse_normalize_data(df_reverse, ['age'] + labtest_features, statistic_info) # gender don't need to be normalized\n",
    "\n",
    "statistics = {}\n",
    "\n",
    "for f in demographic_features+labtest_features:\n",
    "    statistics[f]={}\n",
    "\n",
    "def calculate_quantile_statistic_info(df, features, case):\n",
    "    \"\"\"all values calculated\"\"\"\n",
    "    for _, e in enumerate(features):\n",
    "        # print(e, lo, mi, hi)\n",
    "        if e == 'gender':\n",
    "            unique, count=np.unique(df[e],return_counts=True)\n",
    "            data_count=dict(zip(unique,count)) # key = 1 male, 0 female\n",
    "            print(data_count)\n",
    "            male_percentage = data_count[1.0]*100/(data_count[1.0]+data_count[0.0])\n",
    "            statistics[e][case] = f\"{male_percentage:.2f}% Male\"\n",
    "            print(statistics[e][case])\n",
    "        else:\n",
    "            lo = round(np.nanpercentile(df[e], 25), 2)\n",
    "            mi = round(np.nanpercentile(df[e], 50), 2)\n",
    "            hi = round(np.nanpercentile(df[e], 75), 2)\n",
    "            statistics[e][case] = f\"{mi:.2f} [{lo:.2f}, {hi:.2f}]\"\n",
    "\n",
    "def calculate_missing_rate(df, features, case='missing_rate'):\n",
    "    for _, e in enumerate(features):\n",
    "        missing_rate = round(float(df[e].isnull().sum()*100/df[e].shape[0]), 2)\n",
    "        statistics[e][case] = f\"{missing_rate:.2f}%\"\n",
    "\n",
    "tmp_groupby_pid = df_reverse.groupby(['PATIENT_ID'], dropna=True, as_index = False).mean()\n",
    "\n",
    "calculate_quantile_statistic_info(tmp_groupby_pid, demographic_features, 'overall')\n",
    "calculate_quantile_statistic_info(tmp_groupby_pid[tmp_groupby_pid['outcome']==0], demographic_features, 'alive')\n",
    "calculate_quantile_statistic_info(tmp_groupby_pid[tmp_groupby_pid['outcome']==1], demographic_features, 'dead')\n",
    "\n",
    "calculate_quantile_statistic_info(df_reverse, labtest_features, 'overall')\n",
    "calculate_quantile_statistic_info(df_reverse[df_reverse['outcome']==0], labtest_features, 'alive')\n",
    "calculate_quantile_statistic_info(df_reverse[df_reverse['outcome']==1], labtest_features, 'dead')\n",
    "\n",
    "calculate_missing_rate(df_reverse, demographic_features+labtest_features, 'missing_rate')\n",
    "\n",
    "export_quantile_statistics = {'Characteristics':[], 'Overall':[], 'Alive':[], 'Dead':[], 'Missing Rate':[]}\n",
    "for f in demographic_features+labtest_features:\n",
    "    export_quantile_statistics['Characteristics'].append(f)\n",
    "    export_quantile_statistics['Overall'].append(statistics[f]['overall'])\n",
    "    export_quantile_statistics['Alive'].append(statistics[f]['alive'])\n",
    "    export_quantile_statistics['Dead'].append(statistics[f]['dead'])\n",
    "    export_quantile_statistics['Missing Rate'].append(statistics[f]['missing_rate'])\n",
    "\n",
    "# pd.DataFrame.from_dict(export_quantile_statistics).to_csv('statistics.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_data_existing_length(data):\n",
    "    res = 0\n",
    "    for i in data:\n",
    "        if not pd.isna(i):\n",
    "            res += 1\n",
    "    return res\n",
    "# elements in data are sorted in time ascending order\n",
    "def fill_missing_value(data, to_fill_value=0):\n",
    "    data_len = len(data)\n",
    "    data_exist_len = calculate_data_existing_length(data)\n",
    "    if data_len == data_exist_len:\n",
    "        return data\n",
    "    elif data_exist_len == 0:\n",
    "        # data = [to_fill_value for _ in range(data_len)]\n",
    "        for i in range(data_len):\n",
    "            data[i] = to_fill_value\n",
    "        return data\n",
    "    if pd.isna(data[0]):\n",
    "        # find the first non-nan value's position\n",
    "        not_na_pos = 0\n",
    "        for i in range(data_len):\n",
    "            if not pd.isna(data[i]):\n",
    "                not_na_pos = i\n",
    "                break\n",
    "        # fill element before the first non-nan value with median\n",
    "        for i in range(not_na_pos):\n",
    "            data[i] = to_fill_value\n",
    "    # fill element after the first non-nan value\n",
    "    for i in range(1, data_len):\n",
    "        if pd.isna(data[i]):\n",
    "            data[i] = data[i-1]\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fill missing data using our strategy and convert to time series records\n",
    "grouped = df_train.groupby('PATIENT_ID')\n",
    "\n",
    "all_x_demographic = []\n",
    "all_x_labtest = []\n",
    "all_y = []\n",
    "all_missing_mask = []\n",
    "\n",
    "for name, group in grouped:\n",
    "    sorted_group = group.sort_values(by=['RE_DATE'], ascending=True)\n",
    "    patient_demographic = []\n",
    "    patient_labtest = []\n",
    "    patient_y = []\n",
    "    \n",
    "    for f in demographic_features+labtest_features:\n",
    "        to_fill_value = (statistic_info[f]['median'] - statistic_info[f]['mean'])/(statistic_info[f]['std']+1e-12)\n",
    "        # take median patient as the default to-fill missing value\n",
    "        # print(sorted_group[f].values)\n",
    "        fill_missing_value(sorted_group[f].values, to_fill_value)\n",
    "        # print(sorted_group[f].values)\n",
    "        # print('-----------')\n",
    "    all_missing_mask.append((np.isfinite(sorted_group[demographic_features+labtest_features].to_numpy())).astype(int))\n",
    "\n",
    "    for _, v in sorted_group.iterrows():\n",
    "        patient_y.append([v['outcome'], v['LOS']])\n",
    "        demo = []\n",
    "        lab = []\n",
    "        for f in demographic_features:\n",
    "            demo.append(v[f])\n",
    "        for f in labtest_features:\n",
    "            lab.append(v[f])\n",
    "        patient_labtest.append(lab)\n",
    "        patient_demographic.append(demo)\n",
    "    all_y.append(patient_y)\n",
    "    all_x_demographic.append(patient_demographic[-1])\n",
    "    all_x_labtest.append(patient_labtest)\n",
    "\n",
    "# all_x_demographic (2 dim, record each patients' demographic features)\n",
    "# all_x_labtest (3 dim, record each patients' lab test features)\n",
    "# all_y (3 dim, patients' outcome/los of all visits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x_labtest = np.array(all_x_labtest, dtype=object)\n",
    "x_lab_length = [len(_) for _ in all_x_labtest]\n",
    "x_lab_length = torch.tensor(x_lab_length, dtype=torch.int)\n",
    "max_length = int(x_lab_length.max())\n",
    "all_x_labtest = [torch.tensor(_) for _ in all_x_labtest]\n",
    "# pad lab test sequence to the same shape\n",
    "all_x_labtest = torch.nn.utils.rnn.pad_sequence((all_x_labtest), batch_first=True)\n",
    "\n",
    "all_x_demographic = torch.tensor(all_x_demographic)\n",
    "batch_size, demo_dim = all_x_demographic.shape\n",
    "# repeat demographic tensor\n",
    "all_x_demographic = torch.reshape(all_x_demographic.repeat(1, max_length), (batch_size, max_length, demo_dim))\n",
    "# demographic tensor concat with lab test tensor\n",
    "all_x = torch.cat((all_x_demographic, all_x_labtest), 2)\n",
    "\n",
    "all_y = np.array(all_y, dtype=object)\n",
    "all_y = [torch.Tensor(_) for _ in all_y]\n",
    "# pad [outcome/los] sequence as well\n",
    "all_y = torch.nn.utils.rnn.pad_sequence((all_y), batch_first=True)\n",
    "\n",
    "all_missing_mask = np.array(all_missing_mask, dtype=object)\n",
    "all_missing_mask = [torch.tensor(_) for _ in all_missing_mask]\n",
    "all_missing_mask = torch.nn.utils.rnn.pad_sequence((all_missing_mask), batch_first=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save pickle format dataset (export torch tensor)\n",
    "pd.to_pickle(all_x, f'./processed_data/x.pkl')\n",
    "pd.to_pickle(all_y, f'./processed_data/y.pkl')\n",
    "pd.to_pickle(x_lab_length, f'./processed_data/visits_length.pkl')\n",
    "pd.to_pickle(all_missing_mask, f'./processed_data/missing_mask.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate patients' outcome statistics (patients-wise)\n",
    "outcome_list = []\n",
    "y_outcome = all_y[:, :, 0]\n",
    "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
    "for i in indices:\n",
    "    outcome_list.append(y_outcome[i][0].item())\n",
    "outcome_list = np.array(outcome_list)\n",
    "print(len(outcome_list))\n",
    "unique, count=np.unique(outcome_list,return_counts=True)\n",
    "data_count=dict(zip(unique,count))\n",
    "print(data_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate patients' outcome statistics (records-wise)\n",
    "outcome_records_list = []\n",
    "y_outcome = all_y[:, :, 0]\n",
    "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
    "for i in indices:\n",
    "    outcome_records_list.extend(y_outcome[i][0:x_lab_length[i]].tolist())\n",
    "outcome_records_list = np.array(outcome_records_list)\n",
    "print(len(outcome_records_list))\n",
    "unique, count=np.unique(outcome_records_list,return_counts=True)\n",
    "data_count=dict(zip(unique,count))\n",
    "print(data_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate patients' mean los and 95% percentile los\n",
    "los_list = []\n",
    "y_los = all_y[:, :, 1]\n",
    "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
    "for i in indices:\n",
    "    # los_list.extend(y_los[i][: x_lab_length[i].long()].tolist())\n",
    "    los_list.append(y_los[i][0].item())\n",
    "los_list = np.array(los_list)\n",
    "print(los_list.mean() * 0.5)\n",
    "print(np.median(los_list) * 0.5)\n",
    "print(np.percentile(los_list, 95))\n",
    "\n",
    "print('median:', np.median(los_list))\n",
    "print('Q1:', np.percentile(los_list, 25))\n",
    "print('Q3:', np.percentile(los_list, 75))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "los_alive_list = np.array([los_list[i] for i in range(len(los_list)) if outcome_list[i] == 0])\n",
    "los_dead_list = np.array([los_list[i] for i in range(len(los_list)) if outcome_list[i] == 1])\n",
    "print(len(los_alive_list))\n",
    "print(len(los_dead_list))\n",
    "\n",
    "print('[Alive]')\n",
    "print('median:', np.median(los_alive_list))\n",
    "print('Q1:', np.percentile(los_alive_list, 25))\n",
    "print('Q3:', np.percentile(los_alive_list, 75))\n",
    "\n",
    "print('[Dead]')\n",
    "print('median:', np.median(los_dead_list))\n",
    "print('Q1:', np.percentile(los_dead_list, 25))\n",
    "print('Q3:', np.percentile(los_dead_list, 75))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tjh_los_statistics = {\n",
    "    'overall': los_list,\n",
    "    'alive': los_alive_list,\n",
    "    'dead': los_dead_list\n",
    "}\n",
    "# pd.to_pickle(tjh_los_statistics, 'tjh_los_statistics.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate visits length Median [Q1, Q3]\n",
    "visits_list = np.array(x_lab_length)\n",
    "visits_alive_list = np.array([x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 0])\n",
    "visits_dead_list = np.array([x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 1])\n",
    "print(len(visits_alive_list))\n",
    "print(len(visits_dead_list))\n",
    "\n",
    "print('[Total]')\n",
    "print('median:', np.median(visits_list))\n",
    "print('Q1:', np.percentile(visits_list, 25))\n",
    "print('Q3:', np.percentile(visits_list, 75))\n",
    "\n",
    "print('[Alive]')\n",
    "print('median:', np.median(visits_alive_list))\n",
    "print('Q1:', np.percentile(visits_alive_list, 25))\n",
    "print('Q3:', np.percentile(visits_alive_list, 75))\n",
    "\n",
    "print('[Dead]')\n",
    "print('median:', np.median(visits_dead_list))\n",
    "print('Q1:', np.percentile(visits_dead_list, 25))\n",
    "print('Q3:', np.percentile(visits_dead_list, 75))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Length-of-stay interval (overall/alive/dead)\n",
    "los_interval_list = []\n",
    "los_interval_alive_list = []\n",
    "los_interval_dead_list = []\n",
    "\n",
    "y_los = all_y[:, :, 1]\n",
    "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
    "for i in indices:\n",
    "    cur_visits_len = x_lab_length[i]\n",
    "    if cur_visits_len == 1:\n",
    "        continue\n",
    "    for j in range(1, cur_visits_len):\n",
    "        los_interval_list.append(y_los[i][j-1]-y_los[i][j])\n",
    "        if outcome_list[i] == 0:\n",
    "            los_interval_alive_list.append(y_los[i][j-1]-y_los[i][j])\n",
    "        else:\n",
    "            los_interval_dead_list.append(y_los[i][j-1]-y_los[i][j])\n",
    "\n",
    "los_interval_list = np.array(los_interval_list)\n",
    "los_interval_alive_list = np.array(los_interval_alive_list)\n",
    "los_interval_dead_list = np.array(los_interval_dead_list)\n",
    "\n",
    "output = {\n",
    "    'overall': los_interval_list,\n",
    "    'alive': los_interval_alive_list,\n",
    "    'dead': los_interval_dead_list,\n",
    "}\n",
    "# pd.to_pickle(output, 'tjh_los_interval_list.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(los_interval_list), len(los_interval_alive_list), len(los_interval_dead_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_nan(x):\n",
    "    if np.isnan(np.sum(x.cpu().numpy())):\n",
    "        print(\"some values from input are nan\")\n",
    "    else:\n",
    "        print(\"no nan\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Draw Charts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import PercentFormatter\n",
    "import matplotlib.font_manager as font_manager\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "color = 'cornflowerblue'\n",
    "ec = 'None'\n",
    "alpha=0.5\n",
    "alive_color = 'olivedrab'\n",
    "dead_color = 'orchid'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tj_overall = pd.read_csv('./tjh_data_raw.csv')\n",
    "tj_overall.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tj_alive = tj_overall.loc[tj_overall['outcome'] == 0]\n",
    "tj_dead = tj_overall.loc[tj_overall['outcome'] == 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tj_overall.describe().to_csv('tjh_describe.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "limit = 0.05\n",
    "\n",
    "from matplotlib.ticker import PercentFormatter\n",
    "import matplotlib.font_manager as font_manager\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "color = 'cornflowerblue'\n",
    "ec = 'None'\n",
    "alive_color = 'olivedrab'\n",
    "# dead_color = 'mediumslateblue'\n",
    "dead_color = 'orchid'\n",
    "alpha=0.5\n",
    "\n",
    "csfont = {'fontname':'Times New Roman', 'fontsize': 18}\n",
    "font = 'Times New Roman'\n",
    "fig=plt.figure(figsize=(16,12), dpi= 500, facecolor='w', edgecolor='k')\n",
    "\n",
    "idx = 1\n",
    "\n",
    "key = 'age'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "ax = plt.subplot(4, 4, idx)\n",
    "ax.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha, label='overall')\n",
    "plt.xlabel('Age',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "ax.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2, label='alive')\n",
    "ax.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2, label='dead')\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'White blood cell count'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel(key,**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'Red blood cell count'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel(key,**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'neutrophils(%)'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('neutrophils %',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = '(%)lymphocyte'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('lymphocyte %',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'monocytes(%)'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel(key,**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'Platelet count'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel(key,**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'lymphocyte count'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Lymphocyte count',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'hemoglobin'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Hemoglobin',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'calcium'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Calcium',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'hematocrit'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Hematocrit',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'albumin'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Albumin',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'neutrophils count'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Neutrophils count',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'monocytes count'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Monocytes count',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'basophil count(#)'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]),  color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Basophil count',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]),  color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead[key]),  color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "key = 'eosinophils(%)'\n",
    "low = tj_overall[key].quantile(limit)\n",
    "high = tj_overall[key].quantile(1 - limit)\n",
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
    "plt.xlabel('Eosinophils %',**csfont)\n",
    "plt.ylabel('Percentage',**csfont)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.subplot(4, 4, idx)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
    "plt.xticks(**csfont)\n",
    "plt.yticks(**csfont)\n",
    "idx += 1\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "print(handles, labels)\n",
    "plt.figlegend(handles, labels, loc='upper center', ncol=5, fontsize=18, bbox_to_anchor=(0.5, 1.05), prop=font_manager.FontProperties(family='Times New Roman',\n",
    "                                   style='normal', size=18))\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.11 ('python37')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "vscode": {
   "interpreter": {
    "hash": "a10b846bdc9fc41ee38835cbc29d70b69dd5fd54e1341ea2c410a7804a50447a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}