--- a +++ b/mlp/mlp_classifier.ipynb @@ -0,0 +1,1263 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imminent ICU Admission and Prolonged Stay Prediction using Neural Networks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports & Inits" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:23:54.544191Z", + "start_time": "2019-08-10T14:23:54.531466Z" + } + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:23:56.643514Z", + "start_time": "2019-08-10T14:23:54.545371Z" + }, + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'workdir': PosixPath('../data/workdir/mlp'),\n", + " 'dataset_csv': PosixPath('../data/proc_dataset.csv'),\n", + " 'cols': ['hadm_id',\n", + " 'imminent_adm_label',\n", + " 'prolonged_stay_label',\n", + " 'processed_note',\n", + " 'charttime',\n", + " 'intime',\n", + " 'chartinterval'],\n", + " 'imminent_adm_cols': ['hadm_id', 'processed_note', 'imminent_adm_label'],\n", + " 'prolonged_stay_cols': ['hadm_id', 'processed_note', 'prolonged_stay_label'],\n", + " 'dates': ['charttime', 'intime'],\n", + " 'device': 'cuda:2',\n", + " 'start_seed': 127,\n", + " 'min_freq': 3,\n", + " 'batch_size': 128,\n", + " 'hidden_dim': 100,\n", + " 'dropout_p': 0.1,\n", + " 'lr': 0.001,\n", + " 'wd': 0.001,\n", + " 'max_lr': 0.1,\n", + " 'max_epochs': 100,\n", + " 'ia_thresh': 0.2,\n", + " 'ps_thresh': 0.25}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import pdb\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "sns.set_style(\"darkgrid\")\n", + "%matplotlib inline\n", + "\n", + "import numpy as np\n", + "np.set_printoptions(precision=5)\n", + "\n", + "import pandas as pd\n", + "import pickle\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch import optim\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "\n", + "from sklearn.feature_extraction.text import TfidfVectorizer\n", + "\n", + "from skorch import NeuralNetBinaryClassifier\n", + "from skorch.toy import MLPModule\n", + "from skorch.dataset import CVSplit\n", + "from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint\n", + "\n", + "from mlp_model import MLPModel\n", + "from utils.splits import set_group_splits\n", + "from utils.metrics import BinaryAvgMetrics, get_best_model\n", + "from utils.plots import *\n", + "\n", + "from args import args\n", + "vars(args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NN Dev" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:23:58.527239Z", + "start_time": "2019-08-10T14:23:56.645116Z" + } + }, + "outputs": [], + "source": [ + "seed = 643\n", + "full_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)\n", + "ia_df = full_df.loc[(full_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)\n", + "ps_df = full_df.loc[(full_df['chartinterval'] != 0)][args.prolonged_stay_cols].reset_index(drop=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imminent ICU Admission" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:23:58.587967Z", + "start_time": "2019-08-10T14:23:58.528632Z" + } + }, + "outputs": [], + "source": [ + "ori_df = set_group_splits(ia_df.copy(), group_col='hadm_id', seed=seed)\n", + "df = ori_df\n", + "# df = ori_df.sample(1000).reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:24:24.952218Z", + "start_time": "2019-08-10T14:23:58.589218Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((50809, 4), (42683, 60000), (8126, 60000), (42683,), (8126,))" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vectorizer = TfidfVectorizer(sublinear_tf=True, ngram_range=(1,2), binary=True, max_features=60_000)\n", + "\n", + "x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note']).astype(np.float32)\n", + "x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note']).astype(np.float32)\n", + "\n", + "x_train = np.asarray(x_train.todense())\n", + "x_test = np.asarray(x_test.todense())\n", + "\n", + "vocab_sz = len(vectorizer.vocabulary_)\n", + "\n", + "y_train = df.loc[(df['split'] == 'train')]['imminent_adm_label'].to_numpy()\n", + "y_test = df.loc[(df['split'] == 'test')]['imminent_adm_label'].to_numpy()\n", + "df.shape, x_train.shape, x_test.shape, y_train.shape, y_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:25:10.219798Z", + "start_time": "2019-08-10T14:25:09.574881Z" + } + }, + "outputs": [], + "source": [ + "train_ds = TensorDataset(torch.tensor(x_train), torch.tensor(y_train.astype(np.float32)))\n", + "train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True)\n", + "itr = iter(train_dl)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:25:10.376434Z", + "start_time": "2019-08-10T14:25:10.222543Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.6790, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "clf = MLPModule(input_units=vocab_sz, output_units=1, hidden_units=100, num_hidden=1, dropout=args.dropout_p, squeeze_output=True)\n", + "\n", + "loss_fn = nn.BCEWithLogitsLoss()\n", + "x, y = next(itr)\n", + "y_pred = clf(x)\n", + "\n", + "loss_fn(y_pred, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:25:11.452893Z", + "start_time": "2019-08-10T14:25:11.426771Z" + } + }, + "outputs": [], + "source": [ + "reduce_lr = LRScheduler(\n", + " policy='ReduceLROnPlateau',\n", + " mode='min',\n", + " factor=0.5,\n", + " patience=1,\n", + ")\n", + "\n", + "checkpoint = Checkpoint(\n", + " dirname=args.workdir/'models/dev3',\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:25:11.480936Z", + "start_time": "2019-08-10T14:25:11.454266Z" + } + }, + "outputs": [], + "source": [ + "net = NeuralNetBinaryClassifier(\n", + " clf,\n", + " max_epochs=args.max_epochs,\n", + " lr=args.lr,\n", + " device=args.device,\n", + " optimizer=optim.Adam,\n", + " optimizer__weight_decay=args.wd,\n", + " batch_size=args.batch_size,\n", + " verbose=1,\n", + " callbacks=[EarlyStopping, checkpoint, reduce_lr],\n", + " train_split=CVSplit(cv=0.15, stratified=True),\n", + " iterator_train__shuffle=True, \n", + " threshold=args.ia_thresh,\n", + ")\n", + "\n", + "net.set_params(callbacks__valid_acc=None);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:23:12.121340Z", + "start_time": "2019-08-10T14:18:03.575272Z" + } + }, + "outputs": [], + "source": [ + "# net.fit(x_train, y_train.astype(np.float32))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:25:28.804887Z", + "start_time": "2019-08-10T14:25:22.921517Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "<class 'skorch.classifier.NeuralNetBinaryClassifier'>[initialized](\n", + " module_=MLPModule(\n", + " (nonlin): ReLU()\n", + " (sequential): Sequential(\n", + " (0): Linear(in_features=60000, out_features=100, bias=True)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.1)\n", + " (3): Linear(in_features=100, out_features=1, bias=True)\n", + " )\n", + " ),\n", + ")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.initialize()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:25:32.359374Z", + "start_time": "2019-08-10T14:25:32.263137Z" + } + }, + "outputs": [], + "source": [ + "net.load_params(checkpoint=checkpoint)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T14:25:34.483957Z", + "start_time": "2019-08-10T14:25:34.160864Z" + }, + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7f388e84a4e0>" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "<Figure size 720x576 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(10, 8))\n", + "n_epochs = len(net.history)\n", + "sns.lineplot(range(n_epochs), net.history[:, 'train_loss'])\n", + "sns.lineplot(range(n_epochs), net.history[:, 'valid_loss'])\n", + "ax.legend(['train_loss', 'valid_loss'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:07:54.326248Z", + "start_time": "2019-08-10T02:07:51.779447Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "prob = net.predict_proba(x_test)\n", + "y_pred = net.predict(x_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:07:55.088517Z", + "start_time": "2019-08-10T02:07:54.327862Z" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 10))\n", + "plot_youden(ax, y_test, prob, 0.1, 0.9, 40)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:09:14.799918Z", + "start_time": "2019-08-10T02:09:13.571047Z" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 10))\n", + "plot_thresh_range(ax, y_test, prob, 0.1, 0.9, 40)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:09:22.235184Z", + "start_time": "2019-08-10T02:09:21.973392Z" + }, + "scrolled": false + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10,8))\n", + "plot_roc(ax, y_test, prob)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:09:25.384728Z", + "start_time": "2019-08-10T02:09:25.127586Z" + } + }, + "outputs": [], + "source": [ + "threshold = 0.2\n", + "y_pred = (prob > threshold).astype(np.int64)\n", + "cm = confusion_matrix(y_test, y_pred)\n", + "tn,fp,fn,tp = cm[0][0],cm[0][1],cm[1][0],cm[1][1]\n", + "sensitivity = tp/(tp+fn)\n", + "specificity = tn/(tn+fp)\n", + "ppv = tp/(tp+fp)\n", + "npv = tn/(tn+fn)\n", + "f1 = (2*ppv*sensitivity)/(ppv+sensitivity)\n", + "auroc = roc_auc_score(y_test, prob)\n", + "\n", + "d = {\n", + " 'sensitivity': np.round(sensitivity, 3),\n", + " 'specificity': np.round(specificity, 3),\n", + " 'ppv': np.round(ppv, 3),\n", + " 'npv': np.round(npv, 3),\n", + " 'f1': np.round(f1, 3),\n", + " 'auroc': np.round(auroc, 3),\n", + "}\n", + "\n", + "metrics = pd.DataFrame(d.values(), index=d.keys(), columns=['Value'])\n", + "metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:10:59.034655Z", + "start_time": "2019-08-10T02:10:58.573846Z" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10, 8))\n", + "plot_confusion_matrix(ax, cm, classes=['Delayed', 'Imminent'], normalize=False, title='Confusion matrix')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prolonged ICU Stay" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:11:48.636198Z", + "start_time": "2019-08-10T02:11:48.514559Z" + } + }, + "outputs": [], + "source": [ + "ori_df = set_group_splits(ps_df.copy(), group_col='hadm_id', seed=seed)\n", + "df = ori_df\n", + "# df = ori_df.sample(1000).reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:12:51.529544Z", + "start_time": "2019-08-10T02:12:18.693490Z" + } + }, + "outputs": [], + "source": [ + "vectorizer = TfidfVectorizer(sublinear_tf=True, ngram_range=(1,2), binary=True, max_features=60_000)\n", + "\n", + "x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note']).astype(np.float32)\n", + "x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note']).astype(np.float32)\n", + "\n", + "x_train = np.asarray(x_train.todense())\n", + "x_test = np.asarray(x_test.todense())\n", + "\n", + "vocab_sz = len(vectorizer.vocabulary_)\n", + "\n", + "y_train = df.loc[(df['split'] == 'train')]['prolonged_stay_label'].to_numpy()\n", + "y_test = df.loc[(df['split'] == 'test')]['prolonged_stay_label'].to_numpy()\n", + "df.shape, x_train.shape, x_test.shape, y_train.shape, y_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:13:13.127114Z", + "start_time": "2019-08-10T02:13:12.459632Z" + } + }, + "outputs": [], + "source": [ + "train_ds = TensorDataset(torch.tensor(x_train), torch.tensor(y_train.astype(np.float32)))\n", + "train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True)\n", + "itr = iter(train_dl)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:13:13.727112Z", + "start_time": "2019-08-10T02:13:13.663448Z" + } + }, + "outputs": [], + "source": [ + "clf = MLPModule(input_units=vocab_sz, output_units=1, hidden_units=args.hidden_dim, num_hidden=1, dropout=args.dropout_p, squeeze_output=True)\n", + "loss_fn = nn.BCEWithLogitsLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:13:14.414141Z", + "start_time": "2019-08-10T02:13:14.370681Z" + } + }, + "outputs": [], + "source": [ + "x, y = next(itr)\n", + "y_pred = clf(x)\n", + "loss_fn(y_pred, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:13:15.663594Z", + "start_time": "2019-08-10T02:13:15.631754Z" + } + }, + "outputs": [], + "source": [ + "reduce_lr = LRScheduler(\n", + " policy='ReduceLROnPlateau',\n", + " mode='min',\n", + " factor=0.5,\n", + " patience=1,\n", + ")\n", + "\n", + "checkpoint = Checkpoint(\n", + " dirname=args.workdir/'models/ps_dev_run1',\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:13:16.561491Z", + "start_time": "2019-08-10T02:13:16.528937Z" + } + }, + "outputs": [], + "source": [ + "net = NeuralNetBinaryClassifier(\n", + " clf,\n", + " max_epochs=args.max_epochs,\n", + " lr=args.lr,\n", + " device=args.device,\n", + " optimizer=optim.Adam,\n", + " optimizer__weight_decay=args.wd,\n", + " batch_size=args.batch_size,\n", + " verbose=1,\n", + " callbacks=[EarlyStopping, checkpoint, reduce_lr],\n", + " train_split=CVSplit(cv=0.15, stratified=True),\n", + " iterator_train__shuffle=True, \n", + ")\n", + "\n", + "net.set_params(callbacks__valid_acc=None);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:19:50.994267Z", + "start_time": "2019-08-10T02:13:18.983380Z" + }, + "scrolled": false + }, + "outputs": [], + "source": [ + "net.fit(x_train, y_train.astype(np.float32))\n", + "# net.initialize()\n", + "# net.load_params(checkpoint=checkpoint)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:19:51.519879Z", + "start_time": "2019-08-10T02:19:50.996110Z" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10, 8))\n", + "n_epochs = len(net.history)\n", + "sns.lineplot(range(n_epochs), net.history[:, 'train_loss'])\n", + "sns.lineplot(range(n_epochs), net.history[:, 'valid_loss'])\n", + "ax.legend(['train_loss', 'valid_loss'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:19:54.687943Z", + "start_time": "2019-08-10T02:19:51.521325Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "prob = net.predict_proba(x_test)\n", + "y_pred = net.predict(x_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:19:55.501191Z", + "start_time": "2019-08-10T02:19:54.689503Z" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 10))\n", + "plot_youden(ax, y_test, prob, 0.1, 0.9, 40)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:29:33.089367Z", + "start_time": "2019-08-10T02:29:31.801492Z" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 10))\n", + "plot_thresh_range(ax, y_test, prob, 0.1, 0.9, 40)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:29:40.317871Z", + "start_time": "2019-08-10T02:29:40.049324Z" + }, + "scrolled": false + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10,8))\n", + "plot_roc(ax, y_test, prob)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-08-10T02:30:01.696081Z", + "start_time": "2019-08-10T02:30:01.657269Z" + } + }, + "outputs": [], + "source": [ + "threshold = 0.25\n", + "y_pred = (prob > threshold).astype(np.int64)\n", + "cm = confusion_matrix(y_test, y_pred)\n", + "tn,fp,fn,tp = cm[0][0],cm[0][1],cm[1][0],cm[1][1]\n", + "sensitivity = tp/(tp+fn)\n", + "specificity = tn/(tn+fp)\n", + "ppv = tp/(tp+fp)\n", + "npv = tn/(tn+fn)\n", + "f1 = (2*ppv*sensitivity)/(ppv+sensitivity)\n", + "auroc = roc_auc_score(y_test, prob)\n", + "\n", + "d = {\n", + " 'sensitivity': np.round(sensitivity, 3),\n", + " 'specificity': np.round(specificity, 3),\n", + " 'ppv': np.round(ppv, 3),\n", + " 'npv': np.round(npv, 3),\n", + " 'f1': np.round(f1, 3),\n", + " 'auroc': np.round(auroc, 3),\n", + "}\n", + "\n", + "metrics = pd.DataFrame(d.values(), index=d.keys(), columns=['Value'])\n", + "metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-16T17:27:21.639571Z", + "start_time": "2019-07-16T17:27:21.397075Z" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10, 8))\n", + "plot_confusion_matrix(ax, cm, classes=['Delayed', 'Imminent'], normalize=False, title='Confusion matrix')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "## Metrics" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "hidden": true + }, + "source": [ + "### Imminent ICU Admission" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:55:29.364845Z", + "start_time": "2019-07-22T11:55:28.722095Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "with open(args.workdir/f'imminent_adm_preds.pkl', 'rb') as f:\n", + " targs = pickle.load(f)\n", + " preds = pickle.load(f)\n", + " probs = pickle.load(f)\n", + "\n", + "fnames = [f'imminent_adm_seed_{seed}.pkl' for seed in range(args.start_seed, args.start_seed + 100)]\n", + "bam = BinaryAvgMetrics(targs, preds, probs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:55:29.586114Z", + "start_time": "2019-07-22T11:55:29.366363Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "bam.get_avg_metrics(conf=0.95)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:55:29.818141Z", + "start_time": "2019-07-22T11:55:29.587526Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "get_best_model(bam, fnames)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:55:30.206128Z", + "start_time": "2019-07-22T11:55:29.819466Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10, 8))\n", + "plot_mean_roc(ax, bam.targs, bam.probs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:55:30.515657Z", + "start_time": "2019-07-22T11:55:30.207420Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 2, figsize=(15, 6))\n", + "\n", + "plot_confusion_matrix(ax[0], bam.cm_avg, classes=['not imminent', 'imminent'], normalize=False,\\\n", + " title='Confusion Matrix Over Runs')\n", + "plot_confusion_matrix(ax[1], bam.cm_avg, classes=['not imminent', 'imminent'], normalize=True,\\\n", + " title='Normalized Confusion Matrix Over Runs')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "hidden": true + }, + "source": [ + "### Prolonged ICU Stay" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:55:55.670109Z", + "start_time": "2019-07-22T11:55:54.957696Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "with open(args.workdir/f'prolonged_stay_preds.pkl', 'rb') as f:\n", + " targs = pickle.load(f)\n", + " preds = pickle.load(f)\n", + " probs = pickle.load(f)\n", + "\n", + "fnames = [f'prolonged_stay_seed_{seed}.pkl' for seed in range(args.start_seed, args.start_seed + 100)]\n", + "bam = BinaryAvgMetrics(targs, preds, probs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:55:59.488758Z", + "start_time": "2019-07-22T11:55:59.229856Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "bam.get_avg_metrics(conf=0.95)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:56:04.618880Z", + "start_time": "2019-07-22T11:56:04.366173Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "get_best_model(bam, fnames)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:56:08.900406Z", + "start_time": "2019-07-22T11:56:08.487768Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10, 8))\n", + "plot_mean_roc(ax, bam.targs, bam.probs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T11:56:16.786590Z", + "start_time": "2019-07-22T11:56:16.425934Z" + }, + "hidden": true + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 2, figsize=(15, 6))\n", + "\n", + "plot_confusion_matrix(ax[0], bam.cm_avg, classes=['Discharge within 5 days', 'Discharge after 5 days'], normalize=False, title='Confusion Matrix Over Runs')\n", + "plot_confusion_matrix(ax[1], bam.cm_avg, classes=['Discharge within 5 days', 'Discharge after 5 days'], normalize=True, title='Normalized Confusion Matrix Over Runs')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Full Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:06:04.466706Z", + "start_time": "2019-07-22T12:06:02.766145Z" + } + }, + "outputs": [], + "source": [ + "seed = 643\n", + "ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)\n", + "ori_df['relative_charttime'] = (ori_df['charttime'] - ori_df['intime'])\n", + "\n", + "ia_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols + ['relative_charttime']].reset_index(drop=True)\n", + "\n", + "ps_df = ori_df.loc[(ori_df['chartinterval'] != 0)][args.prolonged_stay_cols + ['relative_charttime']].reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:06:04.491622Z", + "start_time": "2019-07-22T12:06:04.468225Z" + } + }, + "outputs": [], + "source": [ + "interval_hours = 12\n", + "starting_day = -20\n", + "ending_day = -1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imminent ICU Admission" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:06:04.550287Z", + "start_time": "2019-07-22T12:06:04.492938Z" + } + }, + "outputs": [], + "source": [ + "df = set_group_splits(ia_df.copy(), pct=0.25, group_col='hadm_id', seed=seed)\n", + "df['prob'] = -1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:06:13.284651Z", + "start_time": "2019-07-22T12:06:04.551535Z" + } + }, + "outputs": [], + "source": [ + "vectorizer = TfidfVectorizer(min_df=args.min_freq, binary=True, analyzer=str.split, sublinear_tf=True)\n", + "\n", + "x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note']).astype(np.float32)\n", + "x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note']).astype(np.float32)\n", + "\n", + "x_train = np.asarray(x_train.todense())\n", + "x_test = np.asarray(x_test.todense())\n", + "\n", + "vocab_sz = len(vectorizer.vocabulary_)\n", + "\n", + "y_train = df.loc[(df['split'] == 'train')]['imminent_adm_label'].to_numpy()\n", + "y_test = df.loc[(df['split'] == 'test')]['imminent_adm_label'].to_numpy()\n", + "\n", + "df.shape, x_train.shape, x_test.shape, y_train.shape, y_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:04:38.168048Z", + "start_time": "2019-07-22T12:04:37.823108Z" + } + }, + "outputs": [], + "source": [ + "train_ds = TensorDataset(torch.tensor(x_train), torch.tensor(y_train.astype(np.float32)))\n", + "train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True)\n", + "itr = iter(train_dl)\n", + "\n", + "clf = MLPModule(input_units=vocab_sz, output_units=1, hidden_units=args.hidden_dim, num_hidden=1, dropout=args.dropout_p, squeeze_output=True)\n", + "loss_fn = nn.BCEWithLogitsLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:04:38.239707Z", + "start_time": "2019-07-22T12:04:38.170461Z" + } + }, + "outputs": [], + "source": [ + "x, y = next(itr)\n", + "y_pred = clf(x)\n", + "\n", + "loss_fn(y_pred, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:06:26.700111Z", + "start_time": "2019-07-22T12:06:26.457513Z" + } + }, + "outputs": [], + "source": [ + "reduce_lr = LRScheduler(\n", + " policy='ReduceLROnPlateau',\n", + " mode='min',\n", + " factor=0.5,\n", + " patience=1,\n", + ")\n", + "\n", + "checkpoint = Checkpoint(\n", + " dirname=args.workdir/'models/ia_full_run_01',\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:06:51.913709Z", + "start_time": "2019-07-22T12:06:51.852150Z" + } + }, + "outputs": [], + "source": [ + "clf = MLPModule(input_units=vocab_sz, output_units=1, hidden_units=args.hidden_dim, num_hidden=1, dropout=args.dropout_p, squeeze_output=True)\n", + "\n", + "args.batch_size=64\n", + "args.device='cuda:2'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:06:51.940021Z", + "start_time": "2019-07-22T12:06:51.915151Z" + } + }, + "outputs": [], + "source": [ + "net = NeuralNetBinaryClassifier(\n", + " clf,\n", + " max_epochs=args.max_epochs,\n", + " lr=args.lr,\n", + " device=args.device,\n", + " optimizer=optim.Adam,\n", + " optimizer__weight_decay=args.wd,\n", + " batch_size=args.batch_size,\n", + " verbose=1,\n", + " callbacks=[EarlyStopping, checkpoint, reduce_lr],\n", + " train_split=CVSplit(cv=0.15, stratified=True),\n", + " iterator_train__shuffle=True, \n", + "# iterator_train__num_workers=4,\n", + "# iterator_train__pin_memory=True,\n", + "# iterator_train__drop_last=True,\n", + "# iterator_valid__num_workers=4,\n", + "# iterator_valid__pin_memory=True,\n", + " threshold=args.ia_thresh,\n", + ")\n", + "\n", + "net.set_params(callbacks__valid_acc=None);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-22T12:07:07.392335Z", + "start_time": "2019-07-22T12:06:53.182073Z" + } + }, + "outputs": [], + "source": [ + "net.fit(x_train, y_train.astype(np.float32))\n", + "# net.initialize()\n", + "# net.load_params(checkpoint=checkpoint)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": true, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}