1088 lines (1087 with data), 48.9 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"# Import necessary packages\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Read raw data\n",
"df_train: pd.DataFrame = pd.read_excel('./raw_data/time_series_375_prerpocess_en.xlsx')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Steps:\n",
"\n",
"- fill `patient_id`\n",
"- only reserve y-m-d for `RE_DATE` column\n",
"- merge lab tests of the same (patient_id, date)\n",
"- calculate and save features' statistics information (demographic and lab test data are calculated separately)\n",
"- normalize data\n",
"- feature selection\n",
"- fill missing data (our filling strategy will be described below)\n",
"- combine above data to time series data (one patient one record)\n",
"- export to python pickle file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fill `patient_id` rows\n",
"df_train['PATIENT_ID'].fillna(method='ffill', inplace=True)\n",
"\n",
"# gender transformation: 1--male, 0--female\n",
"df_train['gender'].replace(2, 0, inplace=True)\n",
"\n",
"# only reserve y-m-d for `RE_DATE` and `Discharge time` columns\n",
"df_train['RE_DATE'] = df_train['RE_DATE'].dt.strftime('%Y-%m-%d')\n",
"df_train['Discharge time'] = df_train['Discharge time'].dt.strftime('%Y-%m-%d')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_train = df_train.dropna(subset = ['PATIENT_ID', 'RE_DATE', 'Discharge time'], how='any')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# calculate raw data's los interval\n",
"df_grouped = df_train.groupby('PATIENT_ID')\n",
"\n",
"los_interval_list = []\n",
"los_interval_alive_list = []\n",
"los_interval_dead_list = []\n",
"\n",
"for name, group in df_grouped:\n",
" sorted_group = group.sort_values(by=['RE_DATE'], ascending=True)\n",
" # print(sorted_group['outcome'])\n",
" # print('---')\n",
" # print(type(sorted_group))\n",
" intervals = sorted_group['RE_DATE'].tolist()\n",
" outcome = sorted_group['outcome'].tolist()[0]\n",
" cur_visits_len = len(intervals)\n",
" # print(cur_visits_len)\n",
" if cur_visits_len == 1:\n",
" continue\n",
" for i in range(1, len(intervals)):\n",
" los_interval_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
" if outcome == 0:\n",
" los_interval_alive_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
" else:\n",
" los_interval_dead_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
"\n",
"los_interval_list = np.array(los_interval_list)\n",
"los_interval_alive_list = np.array(los_interval_alive_list)\n",
"los_interval_dead_list = np.array(los_interval_dead_list)\n",
"\n",
"output = {\n",
" 'overall': los_interval_list,\n",
" 'alive': los_interval_alive_list,\n",
" 'dead': los_interval_dead_list,\n",
"}\n",
"# pd.to_pickle(output, 'raw_tjh_los_interval_list.pkl')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we have 2 types of prediction tasks: 1) predict mortality outcome, 2) length of stay\n",
"\n",
"# below are all lab test features\n",
"labtest_features_str = \"\"\"\n",
"Hypersensitive cardiac troponinI\themoglobin\tSerum chloride\tProthrombin time\tprocalcitonin\teosinophils(%)\tInterleukin 2 receptor\tAlkaline phosphatase\talbumin\tbasophil(%)\tInterleukin 10\tTotal bilirubin\tPlatelet count\tmonocytes(%)\tantithrombin\tInterleukin 8\tindirect bilirubin\tRed blood cell distribution width \tneutrophils(%)\ttotal protein\tQuantification of Treponema pallidum antibodies\tProthrombin activity\tHBsAg\tmean corpuscular volume\thematocrit\tWhite blood cell count\tTumor necrosis factorα\tmean corpuscular hemoglobin concentration\tfibrinogen\tInterleukin 1β\tUrea\tlymphocyte count\tPH value\tRed blood cell count\tEosinophil count\tCorrected calcium\tSerum potassium\tglucose\tneutrophils count\tDirect bilirubin\tMean platelet volume\tferritin\tRBC distribution width SD\tThrombin time\t(%)lymphocyte\tHCV antibody quantification\tD-D dimer\tTotal cholesterol\taspartate aminotransferase\tUric acid\tHCO3-\tcalcium\tAmino-terminal brain natriuretic peptide precursor(NT-proBNP)\tLactate dehydrogenase\tplatelet large cell ratio \tInterleukin 6\tFibrin degradation products\tmonocytes count\tPLT distribution width\tglobulin\tγ-glutamyl transpeptidase\tInternational standard ratio\tbasophil count(#)\t2019-nCoV nucleic acid detection\tmean corpuscular hemoglobin \tActivation of partial thromboplastin time\tHypersensitive c-reactive protein\tHIV antibody quantification\tserum sodium\tthrombocytocrit\tESR\tglutamic-pyruvic transaminase\teGFR\tcreatinine\n",
"\"\"\"\n",
"\n",
"# below are 2 demographic features\n",
"demographic_features_str = \"\"\"\n",
"age\tgender\n",
"\"\"\"\n",
"\n",
"labtest_features = [f for f in labtest_features_str.strip().split('\\t')]\n",
"demographic_features = [f for f in demographic_features_str.strip().split('\\t')]\n",
"target_features = ['outcome', 'LOS']\n",
"\n",
"# from our observation, `2019-nCoV nucleic acid detection` feature (in lab test) are all -1 value\n",
"# so we remove this feature here\n",
"labtest_features.remove('2019-nCoV nucleic acid detection')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# if some values are negative, set it as Null\n",
"df_train[df_train[demographic_features + labtest_features]<0] = np.nan"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# merge lab tests of the same (patient_id, date)\n",
"df_train = df_train.groupby(['PATIENT_ID', 'RE_DATE', 'Discharge time'], dropna=True, as_index = False).mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# calculate length-of-stay lable\n",
"df_train['LOS'] = (pd.to_datetime(df_train['Discharge time']) - pd.to_datetime(df_train['RE_DATE'])).dt.days"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# if los values are negative, set it as 0\n",
"df_train['LOS'] = df_train['LOS'].clip(lower=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save features' statistics information\n",
"\n",
"def calculate_statistic_info(df, features):\n",
" \"\"\"all values calculated\"\"\"\n",
" statistic_info = {}\n",
" len_df = len(df)\n",
" for _, e in enumerate(features):\n",
" h = {}\n",
" h['count'] = int(df[e].count())\n",
" h['missing'] = str(round(float((100-df[e].count()*100/len_df)),3))+\"%\"\n",
" h['mean'] = float(df[e].mean())\n",
" h['max'] = float(df[e].max())\n",
" h['min'] = float(df[e].min())\n",
" h['median'] = float(df[e].median())\n",
" h['std'] = float(df[e].std())\n",
" statistic_info[e] = h\n",
" return statistic_info\n",
"\n",
"def calculate_middle_part_statistic_info(df, features):\n",
" \"\"\"calculate 5% ~ 95% percentile data\"\"\"\n",
" statistic_info = {}\n",
" len_df = len(df)\n",
" # calculate 5% and 95% percentile of dataframe\n",
" middle_part_df_info = df.quantile([.05, .95])\n",
"\n",
" for _, e in enumerate(features):\n",
" low_value = middle_part_df_info[e][.05]\n",
" high_value = middle_part_df_info[e][.95]\n",
" middle_part_df_element = df.loc[(df[e] >= low_value) & (df[e] <= high_value)][e]\n",
" h = {}\n",
" h['count'] = int(middle_part_df_element.count())\n",
" h['missing'] = str(round(float((100-middle_part_df_element.count()*100/len_df)),3))+\"%\"\n",
" h['mean'] = float(middle_part_df_element.mean())\n",
" h['max'] = float(middle_part_df_element.max())\n",
" h['min'] = float(middle_part_df_element.min())\n",
" h['median'] = float(middle_part_df_element.median())\n",
" h['std'] = float(middle_part_df_element.std())\n",
" statistic_info[e] = h\n",
" return statistic_info\n",
"\n",
"# labtest_statistic_info = calculate_statistic_info(df_train, labtest_features)\n",
"\n",
"\n",
"# group by patient_id, then calculate lab test/demographic features' statistics information\n",
"groupby_patientid_df = df_train.groupby(['PATIENT_ID'], dropna=True, as_index = False).mean()\n",
"\n",
"\n",
"# calculate statistic info (all values calculated)\n",
"labtest_patientwise_statistic_info = calculate_statistic_info(groupby_patientid_df, labtest_features)\n",
"demographic_statistic_info = calculate_statistic_info(groupby_patientid_df, demographic_features) # it's also patient-wise\n",
"\n",
"# calculate statistic info (5% ~ 95% only)\n",
"demographic_statistic_info_2 = calculate_middle_part_statistic_info(groupby_patientid_df, demographic_features) \n",
"labtest_patientwise_statistic_info_2 = calculate_middle_part_statistic_info(groupby_patientid_df, labtest_features) \n",
"\n",
"# take 2 statistics information's union\n",
"statistic_info = labtest_patientwise_statistic_info_2 | demographic_statistic_info_2\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# observe features, export to csv file [optional]\n",
"to_export_dict = {'name': [], 'missing_rate': [], 'count': [], 'mean': [], 'max': [], 'min': [], 'median': [], 'std': []}\n",
"for key in statistic_info:\n",
" detail = statistic_info[key]\n",
" to_export_dict['name'].append(key)\n",
" to_export_dict['count'].append(detail['count'])\n",
" to_export_dict['missing_rate'].append(detail['missing'])\n",
" to_export_dict['mean'].append(detail['mean'])\n",
" to_export_dict['max'].append(detail['max'])\n",
" to_export_dict['min'].append(detail['min'])\n",
" to_export_dict['median'].append(detail['median'])\n",
" to_export_dict['std'].append(detail['std'])\n",
"to_export_df = pd.DataFrame.from_dict(to_export_dict)\n",
"# to_export_df.to_csv('statistic_info.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# normalize data\n",
"def normalize_data(df, features, statistic_info):\n",
" \n",
" df_features = df[features]\n",
" df_features = df_features.apply(lambda x: (x - statistic_info[x.name]['mean']) / (statistic_info[x.name]['std']+1e-12))\n",
" df = pd.concat([df[['PATIENT_ID', 'gender', 'RE_DATE', 'outcome', 'LOS']], df_features], axis=1)\n",
" return df\n",
"df_train = normalize_data(df_train, ['age'] + labtest_features, statistic_info) # gender don't need to be normalized"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# filter outliers\n",
"def filter_data(df, features, bar=3):\n",
" for f in features:\n",
" df[f] = df[f].mask(df[f].abs().gt(bar))\n",
" return df\n",
"df_train = filter_data(df_train, demographic_features + labtest_features, bar=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# drop rows if all labtest_features are recorded nan\n",
"df_train = df_train.dropna(subset = labtest_features, how='all')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Calculate data statistics after preprocessing steps (before imputation)\n",
"\n",
"# Step 1: reverse z-score normalization operation\n",
"df_reverse = df_train\n",
"# reverse normalize data\n",
"def reverse_normalize_data(df, features, statistic_info):\n",
" df_features = df[features]\n",
" df_features = df_features.apply(lambda x: x * (statistic_info[x.name]['std']+1e-12) + statistic_info[x.name]['mean'])\n",
" df = pd.concat([df[['PATIENT_ID', 'gender', 'RE_DATE', 'outcome', 'LOS']], df_features], axis=1)\n",
" return df\n",
"df_reverse = reverse_normalize_data(df_reverse, ['age'] + labtest_features, statistic_info) # gender don't need to be normalized\n",
"\n",
"statistics = {}\n",
"\n",
"for f in demographic_features+labtest_features:\n",
" statistics[f]={}\n",
"\n",
"def calculate_quantile_statistic_info(df, features, case):\n",
" \"\"\"all values calculated\"\"\"\n",
" for _, e in enumerate(features):\n",
" # print(e, lo, mi, hi)\n",
" if e == 'gender':\n",
" unique, count=np.unique(df[e],return_counts=True)\n",
" data_count=dict(zip(unique,count)) # key = 1 male, 0 female\n",
" print(data_count)\n",
" male_percentage = data_count[1.0]*100/(data_count[1.0]+data_count[0.0])\n",
" statistics[e][case] = f\"{male_percentage:.2f}% Male\"\n",
" print(statistics[e][case])\n",
" else:\n",
" lo = round(np.nanpercentile(df[e], 25), 2)\n",
" mi = round(np.nanpercentile(df[e], 50), 2)\n",
" hi = round(np.nanpercentile(df[e], 75), 2)\n",
" statistics[e][case] = f\"{mi:.2f} [{lo:.2f}, {hi:.2f}]\"\n",
"\n",
"def calculate_missing_rate(df, features, case='missing_rate'):\n",
" for _, e in enumerate(features):\n",
" missing_rate = round(float(df[e].isnull().sum()*100/df[e].shape[0]), 2)\n",
" statistics[e][case] = f\"{missing_rate:.2f}%\"\n",
"\n",
"tmp_groupby_pid = df_reverse.groupby(['PATIENT_ID'], dropna=True, as_index = False).mean()\n",
"\n",
"calculate_quantile_statistic_info(tmp_groupby_pid, demographic_features, 'overall')\n",
"calculate_quantile_statistic_info(tmp_groupby_pid[tmp_groupby_pid['outcome']==0], demographic_features, 'alive')\n",
"calculate_quantile_statistic_info(tmp_groupby_pid[tmp_groupby_pid['outcome']==1], demographic_features, 'dead')\n",
"\n",
"calculate_quantile_statistic_info(df_reverse, labtest_features, 'overall')\n",
"calculate_quantile_statistic_info(df_reverse[df_reverse['outcome']==0], labtest_features, 'alive')\n",
"calculate_quantile_statistic_info(df_reverse[df_reverse['outcome']==1], labtest_features, 'dead')\n",
"\n",
"calculate_missing_rate(df_reverse, demographic_features+labtest_features, 'missing_rate')\n",
"\n",
"export_quantile_statistics = {'Characteristics':[], 'Overall':[], 'Alive':[], 'Dead':[], 'Missing Rate':[]}\n",
"for f in demographic_features+labtest_features:\n",
" export_quantile_statistics['Characteristics'].append(f)\n",
" export_quantile_statistics['Overall'].append(statistics[f]['overall'])\n",
" export_quantile_statistics['Alive'].append(statistics[f]['alive'])\n",
" export_quantile_statistics['Dead'].append(statistics[f]['dead'])\n",
" export_quantile_statistics['Missing Rate'].append(statistics[f]['missing_rate'])\n",
"\n",
"# pd.DataFrame.from_dict(export_quantile_statistics).to_csv('statistics.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def calculate_data_existing_length(data):\n",
" res = 0\n",
" for i in data:\n",
" if not pd.isna(i):\n",
" res += 1\n",
" return res\n",
"# elements in data are sorted in time ascending order\n",
"def fill_missing_value(data, to_fill_value=0):\n",
" data_len = len(data)\n",
" data_exist_len = calculate_data_existing_length(data)\n",
" if data_len == data_exist_len:\n",
" return data\n",
" elif data_exist_len == 0:\n",
" # data = [to_fill_value for _ in range(data_len)]\n",
" for i in range(data_len):\n",
" data[i] = to_fill_value\n",
" return data\n",
" if pd.isna(data[0]):\n",
" # find the first non-nan value's position\n",
" not_na_pos = 0\n",
" for i in range(data_len):\n",
" if not pd.isna(data[i]):\n",
" not_na_pos = i\n",
" break\n",
" # fill element before the first non-nan value with median\n",
" for i in range(not_na_pos):\n",
" data[i] = to_fill_value\n",
" # fill element after the first non-nan value\n",
" for i in range(1, data_len):\n",
" if pd.isna(data[i]):\n",
" data[i] = data[i-1]\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fill missing data using our strategy and convert to time series records\n",
"grouped = df_train.groupby('PATIENT_ID')\n",
"\n",
"all_x_demographic = []\n",
"all_x_labtest = []\n",
"all_y = []\n",
"all_missing_mask = []\n",
"\n",
"for name, group in grouped:\n",
" sorted_group = group.sort_values(by=['RE_DATE'], ascending=True)\n",
" patient_demographic = []\n",
" patient_labtest = []\n",
" patient_y = []\n",
" \n",
" for f in demographic_features+labtest_features:\n",
" to_fill_value = (statistic_info[f]['median'] - statistic_info[f]['mean'])/(statistic_info[f]['std']+1e-12)\n",
" # take median patient as the default to-fill missing value\n",
" # print(sorted_group[f].values)\n",
" fill_missing_value(sorted_group[f].values, to_fill_value)\n",
" # print(sorted_group[f].values)\n",
" # print('-----------')\n",
" all_missing_mask.append((np.isfinite(sorted_group[demographic_features+labtest_features].to_numpy())).astype(int))\n",
"\n",
" for _, v in sorted_group.iterrows():\n",
" patient_y.append([v['outcome'], v['LOS']])\n",
" demo = []\n",
" lab = []\n",
" for f in demographic_features:\n",
" demo.append(v[f])\n",
" for f in labtest_features:\n",
" lab.append(v[f])\n",
" patient_labtest.append(lab)\n",
" patient_demographic.append(demo)\n",
" all_y.append(patient_y)\n",
" all_x_demographic.append(patient_demographic[-1])\n",
" all_x_labtest.append(patient_labtest)\n",
"\n",
"# all_x_demographic (2 dim, record each patients' demographic features)\n",
"# all_x_labtest (3 dim, record each patients' lab test features)\n",
"# all_y (3 dim, patients' outcome/los of all visits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"all_x_labtest = np.array(all_x_labtest, dtype=object)\n",
"x_lab_length = [len(_) for _ in all_x_labtest]\n",
"x_lab_length = torch.tensor(x_lab_length, dtype=torch.int)\n",
"max_length = int(x_lab_length.max())\n",
"all_x_labtest = [torch.tensor(_) for _ in all_x_labtest]\n",
"# pad lab test sequence to the same shape\n",
"all_x_labtest = torch.nn.utils.rnn.pad_sequence((all_x_labtest), batch_first=True)\n",
"\n",
"all_x_demographic = torch.tensor(all_x_demographic)\n",
"batch_size, demo_dim = all_x_demographic.shape\n",
"# repeat demographic tensor\n",
"all_x_demographic = torch.reshape(all_x_demographic.repeat(1, max_length), (batch_size, max_length, demo_dim))\n",
"# demographic tensor concat with lab test tensor\n",
"all_x = torch.cat((all_x_demographic, all_x_labtest), 2)\n",
"\n",
"all_y = np.array(all_y, dtype=object)\n",
"all_y = [torch.Tensor(_) for _ in all_y]\n",
"# pad [outcome/los] sequence as well\n",
"all_y = torch.nn.utils.rnn.pad_sequence((all_y), batch_first=True)\n",
"\n",
"all_missing_mask = np.array(all_missing_mask, dtype=object)\n",
"all_missing_mask = [torch.tensor(_) for _ in all_missing_mask]\n",
"all_missing_mask = torch.nn.utils.rnn.pad_sequence((all_missing_mask), batch_first=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save pickle format dataset (export torch tensor)\n",
"pd.to_pickle(all_x, f'./processed_data/x.pkl')\n",
"pd.to_pickle(all_y, f'./processed_data/y.pkl')\n",
"pd.to_pickle(x_lab_length, f'./processed_data/visits_length.pkl')\n",
"pd.to_pickle(all_missing_mask, f'./processed_data/missing_mask.pkl')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Calculate patients' outcome statistics (patients-wise)\n",
"outcome_list = []\n",
"y_outcome = all_y[:, :, 0]\n",
"indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
"for i in indices:\n",
" outcome_list.append(y_outcome[i][0].item())\n",
"outcome_list = np.array(outcome_list)\n",
"print(len(outcome_list))\n",
"unique, count=np.unique(outcome_list,return_counts=True)\n",
"data_count=dict(zip(unique,count))\n",
"print(data_count)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Calculate patients' outcome statistics (records-wise)\n",
"outcome_records_list = []\n",
"y_outcome = all_y[:, :, 0]\n",
"indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
"for i in indices:\n",
" outcome_records_list.extend(y_outcome[i][0:x_lab_length[i]].tolist())\n",
"outcome_records_list = np.array(outcome_records_list)\n",
"print(len(outcome_records_list))\n",
"unique, count=np.unique(outcome_records_list,return_counts=True)\n",
"data_count=dict(zip(unique,count))\n",
"print(data_count)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Calculate patients' mean los and 95% percentile los\n",
"los_list = []\n",
"y_los = all_y[:, :, 1]\n",
"indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
"for i in indices:\n",
" # los_list.extend(y_los[i][: x_lab_length[i].long()].tolist())\n",
" los_list.append(y_los[i][0].item())\n",
"los_list = np.array(los_list)\n",
"print(los_list.mean() * 0.5)\n",
"print(np.median(los_list) * 0.5)\n",
"print(np.percentile(los_list, 95))\n",
"\n",
"print('median:', np.median(los_list))\n",
"print('Q1:', np.percentile(los_list, 25))\n",
"print('Q3:', np.percentile(los_list, 75))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"los_alive_list = np.array([los_list[i] for i in range(len(los_list)) if outcome_list[i] == 0])\n",
"los_dead_list = np.array([los_list[i] for i in range(len(los_list)) if outcome_list[i] == 1])\n",
"print(len(los_alive_list))\n",
"print(len(los_dead_list))\n",
"\n",
"print('[Alive]')\n",
"print('median:', np.median(los_alive_list))\n",
"print('Q1:', np.percentile(los_alive_list, 25))\n",
"print('Q3:', np.percentile(los_alive_list, 75))\n",
"\n",
"print('[Dead]')\n",
"print('median:', np.median(los_dead_list))\n",
"print('Q1:', np.percentile(los_dead_list, 25))\n",
"print('Q3:', np.percentile(los_dead_list, 75))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tjh_los_statistics = {\n",
" 'overall': los_list,\n",
" 'alive': los_alive_list,\n",
" 'dead': los_dead_list\n",
"}\n",
"# pd.to_pickle(tjh_los_statistics, 'tjh_los_statistics.pkl')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# calculate visits length Median [Q1, Q3]\n",
"visits_list = np.array(x_lab_length)\n",
"visits_alive_list = np.array([x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 0])\n",
"visits_dead_list = np.array([x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 1])\n",
"print(len(visits_alive_list))\n",
"print(len(visits_dead_list))\n",
"\n",
"print('[Total]')\n",
"print('median:', np.median(visits_list))\n",
"print('Q1:', np.percentile(visits_list, 25))\n",
"print('Q3:', np.percentile(visits_list, 75))\n",
"\n",
"print('[Alive]')\n",
"print('median:', np.median(visits_alive_list))\n",
"print('Q1:', np.percentile(visits_alive_list, 25))\n",
"print('Q3:', np.percentile(visits_alive_list, 75))\n",
"\n",
"print('[Dead]')\n",
"print('median:', np.median(visits_dead_list))\n",
"print('Q1:', np.percentile(visits_dead_list, 25))\n",
"print('Q3:', np.percentile(visits_dead_list, 75))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Length-of-stay interval (overall/alive/dead)\n",
"los_interval_list = []\n",
"los_interval_alive_list = []\n",
"los_interval_dead_list = []\n",
"\n",
"y_los = all_y[:, :, 1]\n",
"indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
"for i in indices:\n",
" cur_visits_len = x_lab_length[i]\n",
" if cur_visits_len == 1:\n",
" continue\n",
" for j in range(1, cur_visits_len):\n",
" los_interval_list.append(y_los[i][j-1]-y_los[i][j])\n",
" if outcome_list[i] == 0:\n",
" los_interval_alive_list.append(y_los[i][j-1]-y_los[i][j])\n",
" else:\n",
" los_interval_dead_list.append(y_los[i][j-1]-y_los[i][j])\n",
"\n",
"los_interval_list = np.array(los_interval_list)\n",
"los_interval_alive_list = np.array(los_interval_alive_list)\n",
"los_interval_dead_list = np.array(los_interval_dead_list)\n",
"\n",
"output = {\n",
" 'overall': los_interval_list,\n",
" 'alive': los_interval_alive_list,\n",
" 'dead': los_interval_dead_list,\n",
"}\n",
"# pd.to_pickle(output, 'tjh_los_interval_list.pkl')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(los_interval_list), len(los_interval_alive_list), len(los_interval_dead_list)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def check_nan(x):\n",
" if np.isnan(np.sum(x.cpu().numpy())):\n",
" print(\"some values from input are nan\")\n",
" else:\n",
" print(\"no nan\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Draw Charts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import PercentFormatter\n",
"import matplotlib.font_manager as font_manager\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"plt.style.use('seaborn-whitegrid')\n",
"color = 'cornflowerblue'\n",
"ec = 'None'\n",
"alpha=0.5\n",
"alive_color = 'olivedrab'\n",
"dead_color = 'orchid'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tj_overall = pd.read_csv('./tjh_data_raw.csv')\n",
"tj_overall.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tj_alive = tj_overall.loc[tj_overall['outcome'] == 0]\n",
"tj_dead = tj_overall.loc[tj_overall['outcome'] == 1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tj_overall.describe().to_csv('tjh_describe.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"limit = 0.05\n",
"\n",
"from matplotlib.ticker import PercentFormatter\n",
"import matplotlib.font_manager as font_manager\n",
"plt.style.use('seaborn-whitegrid')\n",
"color = 'cornflowerblue'\n",
"ec = 'None'\n",
"alive_color = 'olivedrab'\n",
"# dead_color = 'mediumslateblue'\n",
"dead_color = 'orchid'\n",
"alpha=0.5\n",
"\n",
"csfont = {'fontname':'Times New Roman', 'fontsize': 18}\n",
"font = 'Times New Roman'\n",
"fig=plt.figure(figsize=(16,12), dpi= 500, facecolor='w', edgecolor='k')\n",
"\n",
"idx = 1\n",
"\n",
"key = 'age'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"ax = plt.subplot(4, 4, idx)\n",
"ax.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha, label='overall')\n",
"plt.xlabel('Age',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"ax.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2, label='alive')\n",
"ax.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2, label='dead')\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'White blood cell count'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel(key,**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'Red blood cell count'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel(key,**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'neutrophils(%)'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('neutrophils %',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = '(%)lymphocyte'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('lymphocyte %',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'monocytes(%)'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel(key,**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'Platelet count'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel(key,**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'lymphocyte count'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Lymphocyte count',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'hemoglobin'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Hemoglobin',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'calcium'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Calcium',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'hematocrit'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Hematocrit',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'albumin'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Albumin',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'neutrophils count'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Neutrophils count',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'monocytes count'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Monocytes count',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'basophil count(#)'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Basophil count',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"key = 'eosinophils(%)'\n",
"low = tj_overall[key].quantile(limit)\n",
"high = tj_overall[key].quantile(1 - limit)\n",
"tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
"tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
"tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
"plt.xlabel('Eosinophils %',**csfont)\n",
"plt.ylabel('Percentage',**csfont)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.subplot(4, 4, idx)\n",
"plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.subplot(4, 4, idx)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
"plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
"plt.xticks(**csfont)\n",
"plt.yticks(**csfont)\n",
"idx += 1\n",
"\n",
"handles, labels = ax.get_legend_handles_labels()\n",
"print(handles, labels)\n",
"plt.figlegend(handles, labels, loc='upper center', ncol=5, fontsize=18, bbox_to_anchor=(0.5, 1.05), prop=font_manager.FontProperties(family='Times New Roman',\n",
" style='normal', size=18))\n",
"\n",
"fig.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.11 ('python37')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
},
"vscode": {
"interpreter": {
"hash": "a10b846bdc9fc41ee38835cbc29d70b69dd5fd54e1341ea2c410a7804a50447a"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}