--- a
+++ b/ndl_train_100.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python
+
+import datetime
+import logging
+import sys
+
+import warnings
+warnings.filterwarnings("ignore")
+
+import pickle
+import lightgbm
+import pandas as pd
+import numpy as np
+from tqdm import tqdm
+from pathlib import Path
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn.linear_model import LogisticRegression
+from sklearn.ensemble import RandomForestClassifier
+
+from utils.splits import set_group_splits
+import lr.args
+import rf.args
+import gbm.args
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+sh = logging.StreamHandler()
+sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
+logger.addHandler(sh)
+
+def run_100(task, ori_df, clf_model, params, args, threshold):
+  preds = []
+  targs = []
+  probs = []
+
+  seeds = list(range(args.start_seed, args.start_seed + 100))
+  for seed in tqdm(seeds, desc=f'{task} Runs'):
+    df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=seed)
+    vectorizer = TfidfVectorizer(sublinear_tf=True, ngram_range=(1,2), binary=True, max_features=60_000)
+
+    x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note'])
+    x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note'])
+
+    y_train = df.loc[(df['split'] == 'train')][f'{task}_label'].to_numpy()
+    y_test = df.loc[(df['split'] == 'test')][f'{task}_label'].to_numpy()
+    targs.append(y_test)
+
+    clf = clf_model(**params)
+    clf.fit(x_train, y_train)
+    pickle.dump(clf, open(args.modeldir/f'{task}_seed_{seed}.pkl', 'wb'))
+
+    pos_prob = clf.predict_proba(x_test)[:, 1]
+    probs.append(pos_prob)
+
+    y_pred = (pos_prob > threshold).astype(np.int64)
+    preds.append(y_pred)
+
+  with open(args.workdir/f'{task}_preds.pkl', 'wb') as f:
+    pickle.dump(targs, f)
+    pickle.dump(preds, f)
+    pickle.dump(probs, f)
+
+if __name__=='__main__':
+  if len(sys.argv) != 3:
+    logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps) model_name (lr|rf|gbm)")
+    sys.exit(1)
+
+  task = sys.argv[1]
+  if task not in ['ia', 'ps']:
+    logger.error("Task values are either ia (imminent admission) or ps (prolonged stay)")
+    sys.exit(1)
+
+  clf_name = sys.argv[2]
+  if clf_name not in ['lr', 'rf', 'gbm']:
+    logger.error("Allowed models are lr (logistic regression), rf (random forest), or gbm (gradient boosting machines)")
+    sys.exit(1)
+
+  if clf_name == 'lr':
+    clf_model = LogisticRegression
+    args = lr.args.args
+    ia_params = lr.args.ia_params
+    ps_params = lr.args.ps_params
+  elif clf_name == 'rf':
+    clf_model = RandomForestClassifier
+    args = rf.args.args
+    ia_params = rf.args.ia_params
+    ps_params = rf.args.ps_params
+  else:
+    clf_model = lightgbm.LGBMClassifier
+    args = gbm.args.args
+    ia_params = gbm.args.ia_params
+    ps_params = gbm.args.ps_params
+
+  args.dataset_csv =  Path('./data/proc_dataset.csv')
+  args.workdir = Path(f'./data/workdir/{clf_name}')
+  args.modeldir = args.workdir/'models'
+
+  ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)
+  if task == 'ia':
+    task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)
+    prefix = 'imminent_adm'
+    params = ia_params
+    threshold = args.ia_thresh
+  if task == 'ps':
+    task_df = ps_df = ori_df.loc[(ori_df['chartinterval'] != 0)][args.prolonged_stay_cols].reset_index(drop=True)
+    prefix = 'prolonged_stay'
+    params = ps_params
+    threshold = args.ps_thresh
+
+  logger.info(args.workdir)
+  logger.info(args.modeldir)
+  logger.info(f"Running 100 seed test run for task {task} with model {clf_name}")
+  t1 = datetime.datetime.now()
+  run_100(prefix, task_df, clf_model, params, args, threshold)
+  dt = datetime.datetime.now() - t1
+  logger.info(f"100 seed test run completed. Took {dt.days} days, {dt.seconds//3600} hours, and {(dt.seconds//60)%60} minutes")