Diff of /mlp/predict.py [000000] .. [3f1788]

Switch to side-by-side view

--- a
+++ b/mlp/predict.py
@@ -0,0 +1,122 @@
+#!/usr/bin/env python
+
+import datetime
+import logging
+import sys
+sys.path.append('../')
+
+import warnings
+warnings.filterwarnings("ignore")
+
+import pickle
+import lightgbm
+import pandas as pd
+import numpy as np
+from tqdm import tqdm
+from pathlib import Path
+from sklearn.feature_extraction.text import TfidfVectorizer
+
+from torch import optim
+
+from skorch import NeuralNetBinaryClassifier
+from skorch.toy import MLPModule
+from skorch.dataset import CVSplit
+from skorch.callbacks import *
+
+from utils.splits import set_group_splits
+from args import args
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+sh = logging.StreamHandler()
+sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
+logger.addHandler(sh)
+
+def run_100(task, task_df, args, threshold):
+  preds = []
+  targs = []
+  probs = []
+
+  reduce_lr = LRScheduler(
+    policy='ReduceLROnPlateau',
+    mode='min',
+    factor=0.5,
+    patience=1,
+  )
+
+  seeds = list(range(args.start_seed, args.start_seed + 100))
+  for seed in tqdm(seeds, desc=f'{task} Runs'):
+    # logger.info(f"Spliting with seed {seed}")
+    checkpoint = Checkpoint(
+      dirname=args.modeldir/f'{task}_seed_{seed}',
+    )
+    df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=seed)
+    vectorizer = TfidfVectorizer(sublinear_tf=True, ngram_range=(1,2), binary=True, max_features=60_000)
+
+    x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note']).astype(np.float32)
+    x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note']).astype(np.float32)
+
+    x_train = np.asarray(x_train.todense())
+    x_test = np.asarray(x_test.todense())
+    vocab_sz = len(vectorizer.vocabulary_)
+
+    y_test = df.loc[(df['split'] == 'test')][f'{task}_label'].to_numpy()
+    targs.append(y_test)
+
+    clf = MLPModule(input_units=vocab_sz, output_units=1, hidden_units=args.hidden_dim, num_hidden=1, dropout=args.dropout_p, squeeze_output=True)
+
+    net = NeuralNetBinaryClassifier(
+      clf,
+      max_epochs=args.max_epochs,
+      lr=args.lr,
+      device=args.device,
+      optimizer=optim.Adam,
+      optimizer__weight_decay=args.wd,
+      batch_size=args.batch_size,
+      verbose=1,
+      callbacks=[EarlyStopping, ProgressBar, checkpoint, reduce_lr],
+      train_split=CVSplit(cv=0.15, stratified=True),
+      iterator_train__shuffle=True,
+      threshold=threshold,
+    )
+    net.set_params(callbacks__valid_acc=None);
+    net.initialize()
+    net.load_params(checkpoint=checkpoint)
+
+    pos_prob = net.predict_proba(x_test)
+    probs.append(pos_prob)
+
+    y_pred = net.predict(x_test)
+    preds.append(y_pred)
+
+  with open(args.workdir/f'{task}_preds.pkl', 'wb') as f:
+    pickle.dump(targs, f)
+    pickle.dump(preds, f)
+    pickle.dump(probs, f)
+
+if __name__=='__main__':
+  if len(sys.argv) != 2:
+    logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps)")
+    sys.exit(1)
+
+  task = sys.argv[1]
+  if task not in ['ia', 'ps']:
+    logger.error("Task values are either ia (imminent admission) or ps (prolonged stay)")
+    sys.exit(1)
+
+  args.modeldir = args.workdir/'models'
+  ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)
+  if task == 'ia':
+    task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)
+    prefix = 'imminent_adm'
+    threshold = args.ia_thresh
+  if task == 'ps':
+    task_df = ps_df = ori_df.loc[(ori_df['chartinterval'] != 0)][args.prolonged_stay_cols].reset_index(drop=True)
+    prefix = 'prolonged_stay'
+    threshold = args.ps_thresh
+
+  logger.info(f"Running 100 seed test run for task {task}")
+  t1 = datetime.datetime.now()
+  run_100(prefix, task_df, args, threshold)
+  dt = datetime.datetime.now() - t1
+  logger.info(f"100 seed test run completed. Took {dt.days} days, {dt.seconds//3600} hours, and {(dt.seconds//60)%60} minutes")