--- a +++ b/datasets/cdsl/preprocess.ipynb @@ -0,0 +1,2221 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# hm dataset pre-processing\n", + "\n", + "import packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pickle as pkl\n", + "import torch\n", + "import math\n", + "import datetime\n", + "from tqdm import tqdm\n", + "import datetime\n", + "import re\n", + "from functools import reduce" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demographic data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demographic = pd.read_csv('./raw_data/19_04_2021/COVID_DSL_01.CSV', encoding='ISO-8859-1', sep='|')\n", + "print(len(demographic))\n", + "demographic.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "med = pd.read_csv('./raw_data/19_04_2021/COVID_DSL_04.CSV', encoding='ISO-8859-1', sep='|')\n", + "print(len(med))\n", + "med.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(med['ID_ATC7'].unique())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "get rid of patient with missing label" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(len(demographic))\n", + "demographic = demographic.dropna(axis=0, how='any', subset=['IDINGRESO', 'F_INGRESO_ING', 'F_ALTA_ING', 'MOTIVO_ALTA_ING'])\n", + "print(len(demographic))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def outcome2num(x):\n", + " if x == 'Fallecimiento':\n", + " return 1\n", + " else:\n", + " return 0\n", + "\n", + "def to_one_hot(x, feature):\n", + " if x == feature:\n", + " return 1\n", + " else:\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# select necessary columns from demographic\n", + "demographic = demographic[\n", + " [\n", + " 'IDINGRESO', \n", + " 'EDAD',\n", + " 'SEX',\n", + " 'F_INGRESO_ING', \n", + " 'F_ALTA_ING', \n", + " 'MOTIVO_ALTA_ING', \n", + " 'ESPECIALIDAD_URGENCIA', \n", + " 'DIAG_URG'\n", + " ]\n", + " ]\n", + "\n", + "# rename column\n", + "demographic = demographic.rename(columns={\n", + " 'IDINGRESO': 'PATIENT_ID',\n", + " 'EDAD': 'AGE',\n", + " 'SEX': 'SEX',\n", + " 'F_INGRESO_ING': 'ADMISSION_DATE',\n", + " 'F_ALTA_ING': 'DEPARTURE_DATE',\n", + " 'MOTIVO_ALTA_ING': 'OUTCOME',\n", + " 'ESPECIALIDAD_URGENCIA': 'DEPARTMENT_OF_EMERGENCY',\n", + " 'DIAG_URG': 'DIAGNOSIS_AT_EMERGENCY_VISIT'\n", + "})\n", + "\n", + "# SEX: male: 1; female: 0\n", + "demographic['SEX'].replace('MALE', 1, inplace=True)\n", + "demographic['SEX'].replace('FEMALE', 0, inplace=True)\n", + "\n", + "# outcome: Fallecimiento(dead): 1; others: 0\n", + "demographic['OUTCOME'] = demographic['OUTCOME'].map(outcome2num)\n", + "\n", + "# diagnosis at emergency visit (loss rate < 10%)\n", + "# demographic['DIFFICULTY_BREATHING'] = demographic['DIAGNOSIS_AT_EMERGENCY_VISIT'].map(lambda x: to_one_hot(x, 'DIFICULTAD RESPIRATORIA')) # 1674\n", + "# demographic['SUSPECT_COVID'] = demographic['DIAGNOSIS_AT_EMERGENCY_VISIT'].map(lambda x: to_one_hot(x, 'SOSPECHA COVID-19')) # 960\n", + "# demographic['FEVER'] = demographic['DIAGNOSIS_AT_EMERGENCY_VISIT'].map(lambda x: to_one_hot(x, 'FIEBRE')) # 455\n", + "\n", + "# department of emergency (loss rate < 10%)\n", + "# demographic['EMERGENCY'] = demographic['DEPARTMENT_OF_EMERGENCY'].map(lambda x: to_one_hot(x, 'Medicina de Urgencias')) # 3914" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# del useless data\n", + "demographic = demographic[\n", + " [\n", + " 'PATIENT_ID',\n", + " 'AGE',\n", + " 'SEX',\n", + " 'ADMISSION_DATE',\n", + " 'DEPARTURE_DATE',\n", + " 'OUTCOME',\n", + " # 'DIFFICULTY_BREATHING',\n", + " # 'SUSPECT_COVID',\n", + " # 'FEVER',\n", + " # 'EMERGENCY'\n", + " ]\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demographic.describe().to_csv('demographic_overview.csv', mode='w', index=False)\n", + "demographic.describe()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Analyze data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(demographic['PATIENT_ID'], demographic['AGE'], s=1)\n", + "plt.xlabel('Patient Id')\n", + "plt.ylabel('Age')\n", + "plt.title('Patient-Age Scatter Plot')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(demographic['PATIENT_ID'], demographic['AGE'], s=1)\n", + "plt.xlabel('Patient Id')\n", + "plt.ylabel('Age')\n", + "plt.title('Patient-Age Scatter Plot')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demographic.to_csv('demographic.csv', mode='w', index=False)\n", + "demographic.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vital Signal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs = pd.read_csv('./raw_data/19_04_2021/COVID_DSL_02.CSV', encoding='ISO-8859-1', sep='|')\n", + "print(len(vital_signs))\n", + "vital_signs.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs = vital_signs.rename(columns={\n", + " 'IDINGRESO': 'PATIENT_ID',\n", + " 'CONSTANTS_ING_DATE': 'RECORD_DATE',\n", + " 'CONSTANTS_ING_TIME': 'RECORD_TIME',\n", + " 'FC_HR_ING': 'HEART_RATE',\n", + " 'GLU_GLY_ING': 'BLOOD_GLUCOSE',\n", + " 'SAT_02_ING': 'OXYGEN_SATURATION',\n", + " 'TA_MAX_ING': 'MAX_BLOOD_PRESSURE',\n", + " 'TA_MIN_ING': 'MIN_BLOOD_PRESSURE',\n", + " 'TEMP_ING': 'TEMPERATURE'\n", + "})\n", + "vital_signs['RECORD_TIME'] = vital_signs['RECORD_DATE'] + ' ' + vital_signs['RECORD_TIME']\n", + "vital_signs['RECORD_TIME'] = vital_signs['RECORD_TIME'].map(lambda x: str(datetime.datetime.strptime(x, '%Y-%m-%d %H:%M')))\n", + "vital_signs = vital_signs.drop(['RECORD_DATE', 'SAT_02_ING_OBS', 'BLOOD_GLUCOSE'], axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def format_temperature(x):\n", + " if type(x) == str:\n", + " return float(x.replace(',', '.'))\n", + " else:\n", + " return float(x)\n", + "\n", + "def format_oxygen(x):\n", + " x = float(x)\n", + " if x > 100:\n", + " return np.nan\n", + " else:\n", + " return x\n", + "\n", + "def format_heart_rate(x):\n", + " x = int(x)\n", + " if x > 220:\n", + " return np.nan\n", + " else:\n", + " return x\n", + "\n", + "vital_signs['TEMPERATURE'] = vital_signs['TEMPERATURE'].map(lambda x: format_temperature(x))\n", + "vital_signs['OXYGEN_SATURATION'] = vital_signs['OXYGEN_SATURATION'].map(lambda x: format_oxygen(x))\n", + "vital_signs['HEART_RATE'] = vital_signs['HEART_RATE'].map(lambda x: format_heart_rate(x))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs = vital_signs.replace(0, np.NAN)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs = vital_signs.groupby(['PATIENT_ID', 'RECORD_TIME'], dropna=True, as_index = False).mean()\n", + "vital_signs.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs.describe().to_csv('vital_signs_overview.csv', index=False, mode='w')\n", + "vital_signs.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#plt.rcParams['figure.figsize'] = [10, 5]\n", + "fig=plt.figure(figsize=(16,10), dpi= 100, facecolor='w', edgecolor='k')\n", + "\n", + "plt.subplot(2, 3, 1)\n", + "plt.scatter(vital_signs.index, vital_signs['MAX_BLOOD_PRESSURE'], s=1)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Max Blood Pressure')\n", + "plt.title('Visit-Max Blood Pressure Scatter Plot')\n", + "\n", + "plt.subplot(2, 3, 2)\n", + "plt.scatter(vital_signs.index, vital_signs['MIN_BLOOD_PRESSURE'], s=1)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Min Blood Pressure')\n", + "plt.title('Visit-Min Blood Pressure Scatter Plot')\n", + "\n", + "plt.subplot(2, 3, 3)\n", + "plt.scatter(vital_signs.index, vital_signs['TEMPERATURE'], s=1)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Temperature')\n", + "plt.title('Visit-Temperature Scatter Plot')\n", + "\n", + "plt.subplot(2, 3, 4)\n", + "plt.scatter(vital_signs.index, vital_signs['HEART_RATE'], s=1)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Heart Rate')\n", + "plt.title('Visit-Heart Rate Scatter Plot')\n", + "\n", + "plt.subplot(2, 3, 5)\n", + "plt.scatter(vital_signs.index, vital_signs['OXYGEN_SATURATION'], s=1)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Oxygen Saturation')\n", + "plt.title('Visit-Oxygen Saturation Scatter Plot')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#plt.rcParams['figure.figsize'] = [10, 5]\n", + "fig=plt.figure(figsize=(16,10), dpi= 100, facecolor='w', edgecolor='k')\n", + "\n", + "plt.subplot(2, 3, 1)\n", + "plt.hist(vital_signs['MAX_BLOOD_PRESSURE'], bins=30)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Max Blood Pressure')\n", + "plt.title('Visit-Max Blood Pressure Histogram')\n", + "\n", + "plt.subplot(2, 3, 2)\n", + "plt.hist(vital_signs['MIN_BLOOD_PRESSURE'], bins=30)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Min Blood Pressure')\n", + "plt.title('Visit-Min Blood Pressure Histogram')\n", + "\n", + "plt.subplot(2, 3, 3)\n", + "plt.hist(vital_signs['TEMPERATURE'], bins=30)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Temperature')\n", + "plt.title('Visit-Temperature Histogram')\n", + "\n", + "plt.subplot(2, 3, 4)\n", + "plt.hist(vital_signs['HEART_RATE'], bins=30)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Heart Rate')\n", + "plt.title('Visit-Heart Rate Histogram')\n", + "\n", + "plt.subplot(2, 3, 5)\n", + "plt.hist(vital_signs['OXYGEN_SATURATION'], bins=30)\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Oxygen Saturation')\n", + "plt.title('Visit-Oxygen Saturation Histogram')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Missing rate of each visit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sum(vital_signs.T.isnull().sum()) / ((len(vital_signs.T) - 2) * len(vital_signs))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Normalize data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "for key in vital_signs.keys()[2:]:\n", + " vital_signs[key] = (vital_signs[key] - vital_signs[key].mean()) / (vital_signs[key].std() + 1e-12)\n", + "\n", + "vital_signs.describe()\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vital_signs.to_csv('visual_signs.csv', mode='w', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(vital_signs) / len(vital_signs['PATIENT_ID'].unique())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Lab Tests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lab_tests = pd.read_csv('./raw_data/19_04_2021/COVID_DSL_06_v2.CSV', encoding='ISO-8859-1', sep=';')\n", + "lab_tests = lab_tests.rename(columns={'IDINGRESO': 'PATIENT_ID'})\n", + "print(len(lab_tests))\n", + "\n", + "# del useless data\n", + "lab_tests = lab_tests[\n", + " [\n", + " 'PATIENT_ID',\n", + " 'LAB_NUMBER',\n", + " 'LAB_DATE',\n", + " 'TIME_LAB',\n", + " 'ITEM_LAB',\n", + " 'VAL_RESULT'\n", + " # UD_RESULT: unit\n", + " # REF_VALUES: reference values\n", + " ]\n", + " ]\n", + "\n", + "lab_tests.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lab_tests = lab_tests.groupby(['PATIENT_ID', 'LAB_NUMBER', 'LAB_DATE', 'TIME_LAB', 'ITEM_LAB'], dropna=True, as_index = False).first()\n", + "lab_tests = lab_tests.set_index(['PATIENT_ID', 'LAB_NUMBER', 'LAB_DATE', 'TIME_LAB', 'ITEM_LAB'], drop = True).unstack('ITEM_LAB')['VAL_RESULT'].reset_index()\n", + "\n", + "lab_tests = lab_tests.drop([\n", + " 'CFLAG -- ALARMA HEMOGRAMA', \n", + " 'CORONA -- PCR CORONAVIRUS 2019nCoV', \n", + " 'CRIOGLO -- CRIOGLOBULINAS',\n", + " 'EGCOVID -- ESTUDIO GENETICO COVID-19',\n", + " 'FRO1 -- ',\n", + " 'FRO1 -- FROTIS EN SANGRE PERIFERICA',\n", + " 'FRO2 -- ',\n", + " 'FRO2 -- FROTIS EN SANGRE PERIFERICA',\n", + " 'FRO3 -- ',\n", + " 'FRO3 -- FROTIS EN SANGRE PERIFERICA',\n", + " 'FRO_COMEN -- ',\n", + " 'FRO_COMEN -- FROTIS EN SANGRE PERIFERICA',\n", + " 'G-CORONAV (RT-PCR) -- Tipo de muestra: ASPIRADO BRONCOALVEOLAR',\n", + " 'G-CORONAV (RT-PCR) -- Tipo de muestra: EXUDADO',\n", + " 'GRRH -- GRUPO SANGUÖNEO Y FACTOR Rh',\n", + " 'HEML -- RECUENTO CELULAR LIQUIDO',\n", + " 'HEML -- Recuento Hemat¡es',\n", + " 'IFSUERO -- INMUNOFIJACION EN SUERO',\n", + " 'OBS_BIOMOL -- OBSERVACIONES GENETICA MOLECULAR',\n", + " 'OBS_BIOO -- Observaciones Bioqu¡mica Orina',\n", + " 'OBS_CB -- Observaciones Coagulaci¢n',\n", + " 'OBS_GASES -- Observaciones Gasometr¡a Arterial',\n", + " 'OBS_GASV -- Observaciones Gasometr¡a Venosa',\n", + " 'OBS_GEN2 -- OBSERVACIONES GENETICA',\n", + " 'OBS_HOR -- Observaciones Hormonas',\n", + " 'OBS_MICRO -- Observaciones Microbiolog¡a',\n", + " 'OBS_NULA2 -- Observaciones Bioqu¡mica',\n", + " 'OBS_NULA3 -- Observaciones Hematolog¡a',\n", + " 'OBS_PESP -- Observaciones Pruebas especiales',\n", + " 'OBS_SERO -- Observaciones Serolog¡a',\n", + " 'OBS_SIS -- Observaciones Orina',\n", + " 'PCR VIRUS RESPIRATORIOS -- Tipo de muestra: ASPIRADO BRONCOALVEOLAR',\n", + " 'PCR VIRUS RESPIRATORIOS -- Tipo de muestra: BAS',\n", + " 'PCR VIRUS RESPIRATORIOS -- Tipo de muestra: ESPUTO',\n", + " 'PCR VIRUS RESPIRATORIOS -- Tipo de muestra: EXUDADO',\n", + " 'PCR VIRUS RESPIRATORIOS -- Tipo de muestra: LAVADO BRONCOALVEOLAR',\n", + " 'PCR VIRUS RESPIRATORIOS -- Tipo de muestra: LAVADO NASOFARÖNGEO',\n", + " 'PTGOR -- PROTEINOGRAMA ORINA',\n", + " 'RESUL_IFT -- ESTUDIO DE INMUNOFENOTIPO',\n", + " 'RESUL_IFT -- Resultado',\n", + " 'Resultado -- Resultado',\n", + " 'SED1 -- ',\n", + " 'SED1 -- SEDIMENTO',\n", + " 'SED2 -- ',\n", + " 'SED2 -- SEDIMENTO',\n", + " 'SED3 -- ',\n", + " 'SED3 -- SEDIMENTO',\n", + " 'TIPOL -- TIPO DE LIQUIDO',\n", + " 'Tecnica -- T\\x82cnica',\n", + " 'TpMues -- Tipo de muestra',\n", + " 'VHCBLOT -- INMUNOBLOT VIRUS HEPATITIS C',\n", + " 'VIR_TM -- VIRUS TIPO DE MUESTRA',\n", + " 'LEGIORI -- AG. LEGIONELA PNEUMOPHILA EN ORINA',\n", + " 'NEUMOORI -- AG NEUMOCOCO EN ORINA',\n", + " 'VIHAC -- VIH AC'\n", + " ], axis=1)\n", + "\n", + " \n", + "lab_tests.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lab_tests = lab_tests.replace('Sin resultado.', np.nan)\n", + "lab_tests = lab_tests.replace('Sin resultado', np.nan)\n", + "lab_tests = lab_tests.replace('----', np.nan).replace('---', np.nan)\n", + "lab_tests = lab_tests.replace('> ', '').replace('< ', '')\n", + "\n", + "def change_format(x):\n", + " if x is None:\n", + " return np.nan\n", + " elif type(x) == str:\n", + " if x.startswith('Negativo ('):\n", + " return x.replace('Negativo (', '-')[:-1]\n", + " elif x.startswith('Positivo ('):\n", + " return x.replace('Positivo (', '')[:-1]\n", + " elif x.startswith('Zona limite ('):\n", + " return x.replace('Zona limite (', '')[:-1]\n", + " elif x.startswith('>'):\n", + " return x.replace('> ', '').replace('>', '')\n", + " elif x.startswith('<'):\n", + " return x.replace('< ', '').replace('<', '')\n", + " elif x.endswith(' mg/dl'):\n", + " return x.replace(' mg/dl', '')\n", + " elif x.endswith('/æl'):\n", + " return x.replace('/æl', '')\n", + " elif x.endswith(' copias/mL'):\n", + " return x.replace(' copias/mL', '')\n", + " elif x == 'Numerosos':\n", + " return 1.5\n", + " elif x == 'Aislados':\n", + " return 0.5\n", + " elif x == 'Se detecta' or x == 'Se observan' or x == 'Normal' or x == 'Positivo':\n", + " return 1\n", + " elif x == 'No se detecta' or x == 'No se observan' or x == 'Negativo':\n", + " return 0\n", + " elif x == 'Indeterminado':\n", + " return np.nan\n", + " else:\n", + " num = re.findall(\"[-+]?\\d+\\.\\d+\", x)\n", + " if len(num) == 0:\n", + " return np.nan\n", + " else:\n", + " return num[0]\n", + " else:\n", + " return x\n", + "\n", + "feature_value_dict = dict()\n", + "\n", + "for k in tqdm(lab_tests.keys()[4:]):\n", + " lab_tests[k] = lab_tests[k].map(lambda x: change_format(change_format(x)))\n", + " feature_value_dict[k] = lab_tests[k].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def nan_and_not_nan(x):\n", + " if x == x:\n", + " return 1\n", + " else: # nan\n", + " return 0\n", + "\n", + "def is_float(num):\n", + " try:\n", + " float(num)\n", + " return True\n", + " except ValueError:\n", + " return False\n", + "\n", + "def is_all_float(x):\n", + " for i in x:\n", + " if i == i and (i != None):\n", + " if not is_float(i):\n", + " return False\n", + " return True\n", + "\n", + "def to_float(x):\n", + " if x != None:\n", + " return float(x)\n", + " else:\n", + " return np.nan\n", + "\n", + "other_feature_dict = dict()\n", + "\n", + "for feature in tqdm(feature_value_dict.keys()):\n", + " values = feature_value_dict[feature]\n", + " if is_all_float(values):\n", + " lab_tests[feature] = lab_tests[feature].map(lambda x: to_float(x))\n", + " elif len(values) == 2:\n", + " lab_tests[feature] = lab_tests[feature].map(lambda x: nan_and_not_nan(x))\n", + " else:\n", + " other_feature_dict[feature] = values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "other_feature_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def format_time(t):\n", + " if '/' in t:\n", + " return str(datetime.datetime.strptime(t, '%d/%m/%Y %H:%M'))\n", + " else:\n", + " return str(datetime.datetime.strptime(t, '%d-%m-%Y %H:%M'))\n", + "\n", + "lab_tests['RECORD_TIME'] = lab_tests['LAB_DATE'] + ' ' + lab_tests['TIME_LAB']\n", + "lab_tests['RECORD_TIME'] = lab_tests['RECORD_TIME'].map(lambda x: format_time(x))\n", + "lab_tests = lab_tests.drop(['LAB_NUMBER', 'LAB_DATE', 'TIME_LAB'], axis=1)\n", + "# lab_tests = lab_tests.drop(['LAB_NUMBER', 'TIME_LAB'], axis=1)\n", + "lab_tests.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lab_tests_patient = lab_tests.groupby(['PATIENT_ID'], dropna=True, as_index = False).mean()\n", + "print(len(lab_tests_patient))\n", + "count = [i for i in lab_tests_patient.count()[1:]]\n", + "plt.hist(count)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient_total = len(lab_tests_patient)\n", + "threshold = patient_total * 0.1\n", + "reserved_keys = []\n", + "\n", + "for key in lab_tests_patient.keys():\n", + " if lab_tests_patient[key].count() > threshold:\n", + " reserved_keys.append(key)\n", + "\n", + "print(len(reserved_keys))\n", + "reserved_keys" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reserved_keys.insert(1, 'RECORD_TIME')\n", + "\n", + "lab_tests = lab_tests.groupby(['PATIENT_ID', 'RECORD_TIME'], dropna=True, as_index = False).mean()\n", + "\n", + "lab_tests = lab_tests[reserved_keys]\n", + "lab_tests.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Missing rate of each visit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sum(lab_tests.T.isnull().sum()) / ((len(lab_tests.T) - 2) * len(lab_tests))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Scatter Plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig=plt.figure(figsize=(16,200), dpi= 100, facecolor='w', edgecolor='k')\n", + "\n", + "i = 1\n", + "for key in lab_tests.keys()[2:]:\n", + " plt.subplot(33, 3, i)\n", + " plt.scatter(lab_tests.index, lab_tests[key], s=1)\n", + " plt.ylabel(key)\n", + " i += 1\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig=plt.figure(figsize=(20,120), dpi= 100, facecolor='w', edgecolor='k')\n", + "\n", + "i = 1\n", + "for key in lab_tests.keys()[2:]:\n", + " plt.subplot(23, 4, i)\n", + " plt.hist(lab_tests[key], bins=30)\n", + " q3 = lab_tests[key].quantile(0.75)\n", + " q1 = lab_tests[key].quantile(0.25)\n", + " qh = q3 + 3 * (q3 - q1)\n", + " ql = q1 - 3 * (q3 - q1)\n", + " sigma = 5\n", + " plt.axline([sigma*lab_tests[key].std() + lab_tests[key].mean(), 0], [sigma*lab_tests[key].std() + lab_tests[key].mean(), 1], color = \"r\", linestyle=(0, (5, 5)))\n", + " plt.axline([-sigma*lab_tests[key].std() + lab_tests[key].mean(), 0], [-sigma*lab_tests[key].std() + lab_tests[key].mean(), 1], color = \"r\", linestyle=(0, (5, 5)))\n", + " #plt.axline([lab_tests[key].quantile(0.25), 0], [lab_tests[key].quantile(0.25), 1], color = \"k\", linestyle=(0, (5, 5)))\n", + " #plt.axline([lab_tests[key].quantile(0.75), 0], [lab_tests[key].quantile(0.75), 1], color = \"k\", linestyle=(0, (5, 5)))\n", + " plt.axline([qh, 0], [qh, 1], color='k', linestyle=(0, (5, 5)))\n", + " plt.axline([ql, 0], [ql, 1], color='k', linestyle=(0, (5, 5)))\n", + " plt.ylabel(key)\n", + " i += 1\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Normalize data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "for key in lab_tests.keys()[2:]:\n", + " lab_tests[key] = (lab_tests[key] - lab_tests[key].mean()) / (lab_tests[key].std() + 1e-12)\n", + "\n", + "lab_tests.describe()\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 【del normalization】\n", + "# for key in lab_tests.keys()[2:]:\n", + "# r = lab_tests[lab_tests[key].between(lab_tests[key].quantile(0.05), lab_tests[key].quantile(0.95))]\n", + "# lab_tests[key] = (lab_tests[key] - r[key].mean()) / (r[key].std() + 1e-12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lab_tests.to_csv('lab_test.csv', mode='w', index=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Concat data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demographic['PATIENT_ID'] = demographic['PATIENT_ID'].map(lambda x: str(int(x)))\n", + "vital_signs['PATIENT_ID'] = vital_signs['PATIENT_ID'].map(lambda x: str(int(x)))\n", + "lab_tests['PATIENT_ID'] = lab_tests['PATIENT_ID'].map(lambda x: str(int(x)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(demographic['PATIENT_ID'].unique()), len(vital_signs['PATIENT_ID'].unique()), len(lab_tests['PATIENT_ID'].unique())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df = pd.merge(vital_signs, lab_tests, on=['PATIENT_ID', 'RECORD_TIME'], how='outer')\n", + "\n", + "train_df = train_df.groupby(['PATIENT_ID', 'RECORD_TIME'], dropna=True, as_index = False).mean()\n", + "\n", + "train_df = pd.merge(demographic, train_df, on=['PATIENT_ID'], how='left')\n", + "\n", + "train_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# del rows without patient_id, admission_date, record_time, or outcome\n", + "train_df = train_df.dropna(axis=0, how='any', subset=['PATIENT_ID', 'ADMISSION_DATE', 'RECORD_TIME', 'OUTCOME'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df.to_csv('train.csv', mode='w', index=False)\n", + "train_df.describe()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Missing rate of each visit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sum(train_df.T.isnull().sum()) / ((len(train_df.T) - 2) * len(train_df))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Split and save data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* demo: demographic data\n", + "* x: lab test & vital signs\n", + "* y: outcome & length of stay" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient_ids = train_df['PATIENT_ID'].unique()\n", + "\n", + "demo_cols = ['AGE', 'SEX'] # , 'DIFFICULTY_BREATHING', 'FEVER', 'SUSPECT_COVID', 'EMERGENCY'\n", + "test_cols = []\n", + "\n", + "# get column names\n", + "for k in train_df.keys():\n", + " if not k in demographic.keys():\n", + " if not k == 'RECORD_TIME':\n", + " test_cols.append(k)\n", + "\n", + "test_median = train_df[test_cols].median()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_cols" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df['RECORD_TIME_DAY'] = train_df['RECORD_TIME'].map(lambda x: datetime.datetime.strptime(x, '%Y-%m-%d %H:%M:%S').strftime('%Y-%m-%d'))\n", + "train_df['RECORD_TIME_HOUR'] = train_df['RECORD_TIME'].map(lambda x: datetime.datetime.strptime(x, '%Y-%m-%d %H:%M:%S').strftime('%Y-%m-%d %H'))\n", + "train_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df_day = train_df.groupby(['PATIENT_ID', 'ADMISSION_DATE', 'DEPARTURE_DATE', 'RECORD_TIME_DAY'], dropna=True, as_index = False).mean()\n", + "train_df_hour = train_df.groupby(['PATIENT_ID', 'ADMISSION_DATE', 'DEPARTURE_DATE', 'RECORD_TIME_HOUR'], dropna=True, as_index = False).mean()\n", + "\n", + "len(train_df), len(train_df_day), len(train_df_hour)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "```\n", + "number of visits (total)\n", + "- Original data: 168777\n", + "- Merge by hour: 130141\n", + "- Merge by day: 42204\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(train_df['PATIENT_ID'].unique())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_visit_intervals(df):\n", + " ls = []\n", + " for pat in df['PATIENT_ID'].unique():\n", + " ls.append(len(df[df['PATIENT_ID'] == pat]))\n", + " return ls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ls_org = get_visit_intervals(train_df)\n", + "ls_hour = get_visit_intervals(train_df_hour)\n", + "ls_day = get_visit_intervals(train_df_day)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from matplotlib.ticker import PercentFormatter\n", + "import matplotlib.font_manager as font_manager\n", + "import pandas as pd\n", + "import numpy as np\n", + "csfont = {'fontname':'Times New Roman', 'fontsize': 18}\n", + "font = 'Times New Roman'\n", + "fig=plt.figure(figsize=(18,4), dpi= 100, facecolor='w', edgecolor='k')\n", + "plt.style.use('seaborn-whitegrid')\n", + "color = 'cornflowerblue'\n", + "ec = 'None'\n", + "alpha=0.5\n", + "\n", + "ax = plt.subplot(1, 3, 1)\n", + "ax.hist(ls_org, bins=20, weights=np.ones(len(ls_org)) / len(ls_org), color=color, ec=ec, alpha=alpha, label='overall')\n", + "plt.xlabel('Num of visits (org)',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "\n", + "ax = plt.subplot(1, 3, 2)\n", + "ax.hist(ls_hour, bins=20, weights=np.ones(len(ls_hour)) / len(ls_hour), color=color, ec=ec, alpha=alpha, label='overall')\n", + "plt.xlabel('Num of visits (hour)',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "\n", + "ax = plt.subplot(1, 3, 3)\n", + "ax.hist(ls_day, bins=20, weights=np.ones(len(ls_day)) / len(ls_day), color=color, ec=ec, alpha=alpha, label='overall')\n", + "plt.xlabel('Num of visits (day)',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_statistic(lst, name):\n", + " print(f'[{name}]\\tMax:\\t{max(lst)}, Min:\\t{min(lst)}, Median:\\t{np.median(lst)}, Mean:\\t{np.mean(lst)}, 80%:\\t{np.quantile(lst, 0.8)}, 90%:\\t{np.quantile(lst, 0.9)}, 95%:\\t{np.quantile(lst, 0.95)}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_statistic(ls_org, 'ls_org')\n", + "get_statistic(ls_hour, 'ls_hour')\n", + "get_statistic(ls_day, 'ls_day')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df_hour['LOS'] = train_df_hour['ADMISSION_DATE']\n", + "train_df_hour['LOS_HOUR'] = train_df_hour['ADMISSION_DATE']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df_hour = train_df_hour.reset_index()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for idx in tqdm(range(len(train_df_hour))):\n", + " info = train_df_hour.loc[idx]\n", + " admission = datetime.datetime.strptime(info['ADMISSION_DATE'], '%Y-%m-%d %H:%M:%S')\n", + " departure = datetime.datetime.strptime(info['DEPARTURE_DATE'], '%Y-%m-%d %H:%M:%S')\n", + " visit_hour = datetime.datetime.strptime(info['RECORD_TIME_HOUR'], '%Y-%m-%d %H')\n", + " hour = (departure - visit_hour).seconds / (24 * 60 * 60) + (departure - visit_hour).days\n", + " los = (departure - admission).seconds / (24 * 60 * 60) + (departure - admission).days\n", + " train_df_hour.at[idx, 'LOS'] = float(los)\n", + " train_df_hour.at[idx, 'LOS_HOUR'] = float(hour)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df_hour['LOS']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "los = []\n", + "for pat in tqdm(train_df_hour['PATIENT_ID'].unique()):\n", + " los.append(float(train_df_hour[train_df_hour['PATIENT_ID'] == pat]['LOS'].head(1)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_statistic(los, 'los')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from matplotlib.ticker import PercentFormatter\n", + "import matplotlib.font_manager as font_manager\n", + "import pandas as pd\n", + "import numpy as np\n", + "csfont = {'fontname':'Times New Roman', 'fontsize': 18}\n", + "font = 'Times New Roman'\n", + "fig=plt.figure(figsize=(6,6), dpi= 100, facecolor='w', edgecolor='k')\n", + "plt.style.use('seaborn-whitegrid')\n", + "color = 'cornflowerblue'\n", + "ec = 'None'\n", + "alpha=0.5\n", + "\n", + "ax = plt.subplot(1, 1, 1)\n", + "ax.hist(los, bins=20, weights=np.ones(len(los)) / len(los), color=color, ec=ec, alpha=alpha, label='overall')\n", + "plt.xlabel('Length of stay',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df_hour_idx = train_df_hour.reset_index()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df_hour_idx['LOS'] = train_df_hour_idx['ADMISSION_DATE']\n", + "\n", + "for idx in tqdm(range(len(train_df_hour_idx))):\n", + " info = train_df_hour_idx.loc[idx]\n", + " # admission = datetime.datetime.strptime(info['ADMISSION_DATE'], '%Y-%m-%d %H:%M:%S')\n", + " departure = datetime.datetime.strptime(info['DEPARTURE_DATE'], '%Y-%m-%d %H:%M:%S')\n", + " visit_hour = datetime.datetime.strptime(info['RECORD_TIME_HOUR'], '%Y-%m-%d %H')\n", + " hour = (departure - visit_hour).seconds / (24 * 60 * 60) + (departure - visit_hour).days\n", + " train_df_hour_idx.at[idx, 'LOS'] = float(hour)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df_hour['LOS'] = train_df_hour['LOS_HOUR']\n", + "train_df_hour.drop(columns=['LOS_HOUR'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# los_threshold = 13.0\n", + "\n", + "# visit_num_hour = []\n", + "\n", + "# for pat in tqdm(train_df_hour_idx['PATIENT_ID'].unique()):\n", + "# pat_records = train_df_hour_idx[train_df_hour_idx['PATIENT_ID'] == pat]\n", + "# hour = 0\n", + "# for vis in pat_records.index:\n", + "# pat_visit = pat_records.loc[vis]\n", + "# if pat_visit['LOS_HOUR'] <= los_threshold:\n", + "# hour += 1\n", + "# visit_num_hour.append(hour)\n", + "# if hour == 0:\n", + "# print(pat)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import matplotlib.pyplot as plt\n", + "# from matplotlib.ticker import PercentFormatter\n", + "# import matplotlib.font_manager as font_manager\n", + "# import pandas as pd\n", + "# import numpy as np\n", + "# csfont = {'fontname':'Times New Roman', 'fontsize': 18}\n", + "# font = 'Times New Roman'\n", + "# fig=plt.figure(figsize=(6,6), dpi= 100, facecolor='w', edgecolor='k')\n", + "# plt.style.use('seaborn-whitegrid')\n", + "# color = 'cornflowerblue'\n", + "# ec = 'None'\n", + "# alpha=0.5\n", + "\n", + "# ax = plt.subplot(1, 1, 1)\n", + "# ax.hist(visit_num_hour, bins=20, weights=np.ones(len(visit_num_hour)) / len(visit_num_hour), color=color, ec=ec, alpha=alpha, label='overall')\n", + "# plt.xlabel('Visit num (80% los)',**csfont)\n", + "# plt.ylabel('Percentage',**csfont)\n", + "# plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.xticks(**csfont)\n", + "# plt.yticks(**csfont)\n", + "\n", + "# plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df = train_df_hour\n", + "train_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_statistic(train_df['LOS'], 'los')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_df['LOS'] = train_df['LOS'].clip(lower=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_statistic(train_df['LOS'], 'los')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# the first visit of each person\n", + "def init_prev(prev):\n", + " miss = []\n", + " l = len(prev)\n", + " for idx in range(l):\n", + " #print(prev[idx])\n", + " #print(type(prev[idx]))\n", + " if np.isnan(prev[idx]): # there is no previous record\n", + " prev[idx] = test_median[idx] # replace nan to median\n", + " miss.append(1) # mark miss as 1\n", + " else: # there is a previous record\n", + " miss.append(0)\n", + " return miss\n", + "\n", + "# the rest of the visits\n", + "def fill_nan(cur, prev):\n", + " l = len(prev)\n", + " miss = []\n", + " for idx in range(l):\n", + " #print(cur[idx])\n", + " if np.isnan(cur[idx]): # there is no record in current timestep\n", + " cur[idx] = prev[idx] # cur <- prev\n", + " miss.append(1)\n", + " else: # there is a record in current timestep\n", + " miss.append(0)\n", + " return miss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "index = train_df.loc[0].index\n", + "\n", + "csv = dict()\n", + "for key in ['PatientID', 'RecordTime', 'AdmissionTime', 'DischargeTime', 'Outcome', 'LOS', 'Sex', 'Age']:\n", + " csv[key] = []\n", + "for key in index[8:-2]:\n", + " csv[key] = []\n", + " \n", + "for pat in tqdm(patient_ids): # for all patients\n", + " # get visits for pat.id == PATIENT_ID\n", + " info = train_df[train_df['PATIENT_ID'] == pat]\n", + " info = info[max(0, len(info) - 76):]\n", + " idxs = info.index\n", + " for i in idxs:\n", + " visit = info.loc[i]\n", + " for key in index[8:-2]:\n", + " csv[key].append(visit[key])\n", + " # ['PatientID', 'RecordTime', 'AdmissionTime', 'DischargeTime', 'Outcome', 'LOS', 'Sex', 'Age']\n", + " csv['PatientID'].append(visit['PATIENT_ID'])\n", + " t, h = visit['RECORD_TIME_HOUR'].split()\n", + " t = t.split('-')\n", + " csv['RecordTime'].append(t[1]+'/'+t[2]+'/'+t[0]+' '+h) # 2020-04-06 10 -> 04/06/2020 10\n", + " t = visit['ADMISSION_DATE'][:10].split('-')\n", + " csv['AdmissionTime'].append(t[1]+'/'+t[2]+'/'+t[0])\n", + " t = visit['DEPARTURE_DATE'][:10].split('-')\n", + " csv['DischargeTime'].append(t[1]+'/'+t[2]+'/'+t[0])\n", + " csv['Outcome'].append(visit['OUTCOME'])\n", + " csv['LOS'].append(visit['LOS_HOUR'])\n", + " csv['Sex'].append(visit['SEX'])\n", + " csv['Age'].append(visit['AGE'])\n", + " \n", + "pd.DataFrame(csv).to_csv('processed_data/CDSL.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x, y, demo, x_lab_len, missing_mask = [], [], [], [], []\n", + "\n", + "for pat in tqdm(patient_ids): # for all patients\n", + " # get visits for pat.id == PATIENT_ID\n", + " info = train_df[train_df['PATIENT_ID'] == pat]\n", + " info = info[max(0, len(info) - 76):]\n", + " indexes = info.index\n", + " visit = info.loc[indexes[0]] # get the first visit\n", + "\n", + " # demographic data\n", + " demo.append([visit[k] for k in demo_cols])\n", + " \n", + " # label\n", + " outcome = visit['OUTCOME']\n", + " los = []\n", + "\n", + " # lab test & vital signs\n", + " tests = []\n", + " prev = visit[test_cols]\n", + " miss = [] # missing matrix\n", + " miss.append(init_prev(prev)) # fill nan for the first visit for every patient and add missing status to missing matrix\n", + " # leave = datetime.datetime.strptime(visit['DEPARTURE_DATE'], '%Y-%m-%d %H:%M:%S')\n", + " \n", + " first = True\n", + " for i in indexes:\n", + " visit = info.loc[i]\n", + " # now = datetime.datetime.strptime(visit['RECORD_TIME'], '%Y-%m-%d %H')\n", + " cur = visit[test_cols]\n", + " tmp = fill_nan(cur, prev) # fill nan for the rest of the visits\n", + " if not first:\n", + " miss.append(tmp) # add missing status to missing matrix\n", + " tests.append(cur)\n", + " # los_visit = (leave - now).days\n", + " # if los_visit < 0:\n", + " # los_visit = 0\n", + " los.append(visit['LOS'])\n", + " prev = cur\n", + " first = False\n", + "\n", + " valid_visit = len(los)\n", + " # outcome = [outcome] * valid_visit\n", + " x_lab_len.append(valid_visit)\n", + " missing_mask.append(miss) # append the patient's missing matrix to the total missing matrix\n", + "\n", + " # tests = np.pad(tests, ((0, max_visit - valid_visit), (0, 0)))\n", + " # outcome = np.pad(outcome, (0, max_visit - valid_visit))\n", + " # los = np.pad(los, (0, max_visit - valid_visit))\n", + " \n", + " y.append([outcome, los])\n", + " x.append(tests)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_x = x\n", + "all_x_demo = demo\n", + "all_y = y\n", + "all_missing_mask = missing_mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_x_labtest = np.array(all_x, dtype=object)\n", + "x_lab_length = [len(_) for _ in all_x_labtest]\n", + "x_lab_length = torch.tensor(x_lab_length, dtype=torch.int)\n", + "max_length = int(x_lab_length.max())\n", + "all_x_labtest = [torch.tensor(_) for _ in all_x_labtest]\n", + "all_x_labtest = torch.nn.utils.rnn.pad_sequence((all_x_labtest), batch_first=True)\n", + "all_x_demographic = torch.tensor(all_x_demo)\n", + "batch_size, demo_dim = all_x_demographic.shape\n", + "all_x_demographic = torch.reshape(all_x_demographic.repeat(1, max_length), (batch_size, max_length, demo_dim))\n", + "all_x = torch.cat((all_x_demographic, all_x_labtest), 2)\n", + "\n", + "all_y = np.array(all_y, dtype=object)\n", + "patient_list = []\n", + "for pat in all_y:\n", + " visits = []\n", + " for i in pat[1]:\n", + " visits.append([pat[0], i])\n", + " patient_list.append(visits)\n", + "new_all_y = np.array(patient_list, dtype=object)\n", + "output_all_y = [torch.Tensor(_) for _ in new_all_y]\n", + "output_all_y = torch.nn.utils.rnn.pad_sequence((output_all_y), batch_first=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_missing_mask = np.array(all_missing_mask, dtype=object)\n", + "all_missing_mask = [torch.tensor(_) for _ in all_missing_mask]\n", + "all_missing_mask = torch.nn.utils.rnn.pad_sequence((all_missing_mask), batch_first=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_x.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_missing_mask.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save pickle format dataset (torch)\n", + "pd.to_pickle(all_x,f'./processed_data/x.pkl' )\n", + "pd.to_pickle(all_missing_mask,f'./processed_data/missing_mask.pkl' )\n", + "pd.to_pickle(output_all_y,f'./processed_data/y.pkl' )\n", + "pd.to_pickle(x_lab_length,f'./processed_data/visits_length.pkl' )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate patients' outcome statistics (patients-wise)\n", + "outcome_list = []\n", + "y_outcome = output_all_y[:, :, 0]\n", + "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n", + "for i in indices:\n", + " outcome_list.append(y_outcome[i][0].item())\n", + "outcome_list = np.array(outcome_list)\n", + "print(len(outcome_list))\n", + "unique, count=np.unique(outcome_list,return_counts=True)\n", + "data_count=dict(zip(unique,count))\n", + "print(data_count)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate patients' outcome statistics (records-wise)\n", + "outcome_records_list = []\n", + "y_outcome = output_all_y[:, :, 0]\n", + "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n", + "for i in indices:\n", + " outcome_records_list.extend(y_outcome[i][0:x_lab_length[i]].tolist())\n", + "outcome_records_list = np.array(outcome_records_list)\n", + "print(len(outcome_records_list))\n", + "unique, count=np.unique(outcome_records_list,return_counts=True)\n", + "data_count=dict(zip(unique,count))\n", + "print(data_count)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate patients' mean los and 95% percentile los\n", + "los_list = []\n", + "y_los = output_all_y[:, :, 1]\n", + "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n", + "for i in indices:\n", + " # los_list.extend(y_los[i][: x_lab_length[i].long()].tolist())\n", + " los_list.append(y_los[i][0].item())\n", + "los_list = np.array(los_list)\n", + "print(los_list.mean() * 0.5)\n", + "print(np.median(los_list) * 0.5)\n", + "print(np.percentile(los_list, 95))\n", + "\n", + "print('median:', np.median(los_list))\n", + "print('Q1:', np.percentile(los_list, 25))\n", + "print('Q3:', np.percentile(los_list, 75))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "los_alive_list = np.array([los_list[i] for i in range(len(los_list)) if outcome_list[i] == 0])\n", + "los_dead_list = np.array([los_list[i] for i in range(len(los_list)) if outcome_list[i] == 1])\n", + "print(len(los_alive_list))\n", + "print(len(los_dead_list))\n", + "\n", + "print('[Alive]')\n", + "print('median:', np.median(los_alive_list))\n", + "print('Q1:', np.percentile(los_alive_list, 25))\n", + "print('Q3:', np.percentile(los_alive_list, 75))\n", + "\n", + "print('[Dead]')\n", + "print('median:', np.median(los_dead_list))\n", + "print('Q1:', np.percentile(los_dead_list, 25))\n", + "print('Q3:', np.percentile(los_dead_list, 75))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cdsl_los_statistics = {\n", + " 'overall': los_list,\n", + " 'alive': los_alive_list,\n", + " 'dead': los_dead_list\n", + "}\n", + "pd.to_pickle(cdsl_los_statistics, 'cdsl_los_statistics.pkl')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# calculate visits length Median [Q1, Q3]\n", + "visits_list = np.array(x_lab_length)\n", + "visits_alive_list = np.array([x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 0])\n", + "visits_dead_list = np.array([x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 1])\n", + "print(len(visits_alive_list))\n", + "print(len(visits_dead_list))\n", + "\n", + "print('[Total]')\n", + "print('median:', np.median(visits_list))\n", + "print('Q1:', np.percentile(visits_list, 25))\n", + "print('Q3:', np.percentile(visits_list, 75))\n", + "\n", + "print('[Alive]')\n", + "print('median:', np.median(visits_alive_list))\n", + "print('Q1:', np.percentile(visits_alive_list, 25))\n", + "print('Q3:', np.percentile(visits_alive_list, 75))\n", + "\n", + "print('[Dead]')\n", + "print('median:', np.median(visits_dead_list))\n", + "print('Q1:', np.percentile(visits_dead_list, 25))\n", + "print('Q3:', np.percentile(visits_dead_list, 75))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def check_nan(x):\n", + " if np.isnan(np.sum(x.cpu().numpy())):\n", + " print(\"some values from input are nan\")\n", + " else:\n", + " print(\"no nan\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "check_nan(all_x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Draw Charts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from matplotlib.ticker import PercentFormatter\n", + "import matplotlib.font_manager as font_manager\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "plt.style.use('seaborn-whitegrid')\n", + "color = 'cornflowerblue'\n", + "ec = 'None'\n", + "alpha=0.5\n", + "alive_color = 'olivedrab'\n", + "dead_color = 'orchid'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demographic.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train = pd.read_csv('./train.csv')\n", + "train['PATIENT_ID']=train['PATIENT_ID'].astype(str)\n", + "demographic['PATIENT_ID']=demographic['PATIENT_ID'].astype(str)\n", + "pat = {\n", + " 'PATIENT_ID': train['PATIENT_ID'].unique()\n", + "}\n", + "pat = pd.DataFrame(pat)\n", + "demo = pd.merge(demographic, pat, on='PATIENT_ID', how='inner')\n", + "\n", + "demo_alive = demo.loc[demo['OUTCOME'] == 0]\n", + "demo_dead = demo.loc[demo['OUTCOME'] == 1]\n", + "demo_overall = demo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demo.to_csv('demo_overall.csv', index=False)\n", + "demo_alive.to_csv('demo_alive.csv', index=False)\n", + "demo_dead.to_csv('demo_dead.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient = pd.DataFrame({\"PATIENT_ID\": (demo_alive['PATIENT_ID'].unique())})\n", + "lab_tests_alive = pd.merge(lab_tests, patient, how='inner', on='PATIENT_ID')\n", + "print(len(lab_tests_alive['PATIENT_ID'].unique()))\n", + "\n", + "patient = pd.DataFrame({\"PATIENT_ID\": (demo_dead['PATIENT_ID'].unique())})\n", + "lab_tests_dead = pd.merge(lab_tests, patient, how='inner', on='PATIENT_ID')\n", + "print(len(lab_tests_dead['PATIENT_ID'].unique()))\n", + "\n", + "patient = pd.DataFrame({\"PATIENT_ID\": (demo_overall['PATIENT_ID'].unique())})\n", + "lab_tests_overall = pd.merge(lab_tests, patient, how='inner', on='PATIENT_ID')\n", + "print(len(lab_tests_overall['PATIENT_ID'].unique()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient = pd.DataFrame({\"PATIENT_ID\": (demo_alive['PATIENT_ID'].unique())})\n", + "vital_signs_alive = pd.merge(vital_signs, patient, how='inner', on='PATIENT_ID')\n", + "len(vital_signs_alive['PATIENT_ID'].unique())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient = pd.DataFrame({\"PATIENT_ID\": (demo_dead['PATIENT_ID'].unique())})\n", + "vital_signs_dead = pd.merge(vital_signs, patient, how='inner', on='PATIENT_ID')\n", + "len(vital_signs_dead['PATIENT_ID'].unique())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient = pd.DataFrame({\"PATIENT_ID\": (demo_overall['PATIENT_ID'].unique())})\n", + "vital_signs_overall = pd.merge(vital_signs, patient, how='inner', on='PATIENT_ID')\n", + "len(vital_signs_overall['PATIENT_ID'].unique())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "limit = 0.05\n", + "\n", + "csfont = {'fontname':'Times New Roman', 'fontsize': 18}\n", + "font = 'Times New Roman'\n", + "fig=plt.figure(figsize=(16,12), dpi= 100, facecolor='w', edgecolor='k')\n", + "\n", + "idx = 1\n", + "\n", + "key = 'AGE'\n", + "low = demo_overall[key].quantile(limit)\n", + "high = demo_overall[key].quantile(1 - limit)\n", + "demo_AGE_overall = demo_overall[demo_overall[key].between(low, high)]\n", + "demo_AGE_dead = demo_dead[demo_dead[key].between(low, high)]\n", + "demo_AGE_alive = demo_alive[demo_alive[key].between(low, high)]\n", + "ax = plt.subplot(4, 4, idx)\n", + "ax.hist(demo_AGE_overall[key], bins=20, weights=np.ones(len(demo_AGE_overall[key])) / len(demo_AGE_overall), color=color, ec=ec, alpha=alpha, label='overall')\n", + "plt.xlabel('Age',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# ax.title('Age Histogram', **csfont)\n", + "ax.hist(demo_AGE_alive[key], bins=20, weights=np.ones(len(demo_AGE_alive[key])) / len(demo_AGE_alive), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2, label='alive')\n", + "ax.hist(demo_AGE_dead[key], bins=20, weights=np.ones(len(demo_AGE_dead[key])) / len(demo_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2, label='dead')\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'TEMPERATURE'\n", + "low = vital_signs_overall[key].quantile(limit)\n", + "high = vital_signs_overall[key].quantile(1 - limit)\n", + "vs_TEMPERATURE_overall = vital_signs_overall[vital_signs_overall[key].between(low, high)]\n", + "vs_TEMPERATURE_dead = vital_signs_dead[vital_signs_dead[key].between(low, high)]\n", + "vs_TEMPERATURE_alive = vital_signs_alive[vital_signs_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(vs_TEMPERATURE_overall['TEMPERATURE'], bins=20, weights=np.ones(len(vs_TEMPERATURE_overall)) / len(vs_TEMPERATURE_overall), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('Temperature',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(vs_TEMPERATURE_alive['TEMPERATURE'], bins=20, weights=np.ones(len(vs_TEMPERATURE_alive)) / len(vs_TEMPERATURE_alive), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(vs_TEMPERATURE_dead['TEMPERATURE'], bins=20, weights=np.ones(len(vs_TEMPERATURE_dead)) / len(vs_TEMPERATURE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "# plt.subplot(4, 4, 3)\n", + "# plt.hist(lab_tests_overall['CREA -- CREATININA'], bins=20, density=True, color=color, ec=ec, alpha=alpha)\n", + "# plt.xlabel('CREA -- CREATININA',**csfont)\n", + "# plt.ylabel('Percentage',**csfont)\n", + "# plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# # plt.title('Temperature Histogram', **csfont)\n", + "# plt.hist(lab_tests_alive['CREA -- CREATININA'], bins=20, density=True, color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "# plt.hist(lab_tests_dead['CREA -- CREATININA'], bins=20, density=True, color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "# plt.xticks(**csfont)\n", + "# plt.yticks(**csfont)\n", + "\n", + "key = 'CREA -- CREATININA'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('CREA -- CREATININA',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'HEM -- Hemat¡es'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('HEM -- Hemat¡es',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'LEUC -- Leucocitos'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('LEUC -- Leucocitos',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'PLAQ -- Recuento de plaquetas'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('PLAQ -- Recuento de plaquetas',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'CHCM -- Conc. Hemoglobina Corpuscular Media'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('CHCM',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'HCTO -- Hematocrito'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('HCTO -- Hematocrito',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'VCM -- Volumen Corpuscular Medio'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('VCM -- Volumen Corpuscular Medio',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'HGB -- Hemoglobina'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('HGB -- Hemoglobina',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'HCM -- Hemoglobina Corpuscular Media'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('HCM -- Hemoglobina Corpuscular Media',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'NEU -- Neutr¢filos'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('NEU -- Neutr¢filos',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'NEU% -- Neutr¢filos %'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('NEU% -- Neutr¢filos%',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'LIN -- Linfocitos'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('LIN -- Linfocitos',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'LIN% -- Linfocitos %'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('LIN% -- Linfocitos%',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "key = 'ADW -- Coeficiente de anisocitosis'\n", + "low = lab_tests_overall[key].quantile(limit)\n", + "high = lab_tests_overall[key].quantile(1 - limit)\n", + "lt_key_overall = lab_tests_overall[lab_tests_overall[key].between(low, high)]\n", + "lt_key_dead = lab_tests_dead[lab_tests_dead[key].between(low, high)]\n", + "lt_key_alive = lab_tests_alive[lab_tests_alive[key].between(low, high)]\n", + "plt.subplot(4, 4, idx)\n", + "plt.hist(lt_key_overall[key], bins=20, weights=np.ones(len(lt_key_overall[key])) / len(lt_key_overall[key]), color=color, ec=ec, alpha=alpha)\n", + "plt.xlabel('ADW -- Coeficiente de anisocitosis',**csfont)\n", + "plt.ylabel('Percentage',**csfont)\n", + "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n", + "# plt.title('Temperature Histogram', **csfont)\n", + "plt.hist(lt_key_alive[key], bins=20, weights=np.ones(len(lt_key_alive[key])) / len(lt_key_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.hist(lt_key_dead[key], bins=20, weights=np.ones(len(lt_key_dead[key])) / len(lt_key_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n", + "plt.xticks(**csfont)\n", + "plt.yticks(**csfont)\n", + "idx += 1\n", + "\n", + "handles, labels = ax.get_legend_handles_labels()\n", + "print(handles, labels)\n", + "# fig.legend(handles, labels, loc='upper center')\n", + "plt.figlegend(handles, labels, loc='upper center', ncol=5, fontsize=18, bbox_to_anchor=(0.5, 1.05), prop=font_manager.FontProperties(family='Times New Roman',\n", + " style='normal', size=18))\n", + "# fig.legend(, [], loc='upper center')\n", + "\n", + "fig.tight_layout()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.11 ('python37')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + }, + "vscode": { + "interpreter": { + "hash": "a10b846bdc9fc41ee38835cbc29d70b69dd5fd54e1341ea2c410a7804a50447a" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}